@IT·互联网程序员iOS Developer

iOS实现类Prisma软件(二)

2017-05-26  本文已影响842人  Jiao123

前言


前面写了一篇利用TF在iOS实现类Prisma软件的文章后,收到很多网上朋友交流实现思想与求Demo回复,可见大家对于这个功能实现有浓厚的兴趣。

上一篇文章并未深入详解Google的实现的原理,仅仅是简单将参数和计算图在iOS设备上跑起来,且由于TF本身编译搭建工程就很复杂,所以未托管源码供下载。

本次主要是通过剖析Google论文中实现逻辑后,利用iOS新的Metal框架以及设备的GPU部分加速运算后实现了图片的类Prisma渲染(当然网络参数还是Google训练出来的)。期间跨过了很多坑,这里分享出来,希望对大家研究有所帮助😅。

2017草莓音乐节(阴阳师舞台)

结构


看过《A LEARNED REPRESENTATION FOR ARTISTIC STYLE》这篇论文的应该都还是比较熟悉Google提出的整体网络结构,如下图:

网络结构

后面虚线部分是VGG网络,这部分训练方法和2015年《A Neural Algorithm of Artistic Style》这篇论文中一样,并非google此次优化的重点。
重点在前面Style transfer network这部分,我们的到的训练参数也都是这一层的参数。这是一个前向生成图片的网络,有了这么一个前向网络层,我们只需要训练好参数,生成图片就只需前向运算一遍就可以的到,这样相对于直接利用VGG网络来回学习生成图片节省了很大一部分时间。并且由于合成图片时间短,所以在移动设备上也是可以本地运行。

下面是Style transfer network的网络结构:

network

网络一共3个卷积+5个Residual Block+2个upsampling+1个卷积,其实Residual Block就是两次卷积,然后将输入与输出相加,upsampling是先用Nearest-Neighbor放大图片,然后卷积。所以一共有16层卷积操作,并且每次卷积操作后先Batch-normalization,再接激活函数(这里开始因为苹果MPSCNN库直接就可以卷积后带激活函数,所以在实现的时候我把BN放到了激活之后,生成图片就一直错误😭)

Padding Mode 论文这里写的是Reflect,苹果Metal不支持这个padding方式,我自己写了一个😢,但最后发现其实用Zero Padding反而才是对的。不知道是否是苹果的卷积实现有不同,还是说这里Padding模式仅仅是针对训练的时候?这里后面有时间会再研究一下。

以上就是整个实现的核心网络结构,理论上我们有了参数,知道了网络实现,不用TF计算图,自己实现也是可以的。这样可以免去繁琐的TF集成,编译,并且自己的网络调试、控制内存等等都要方便很多。
但是,并非如此简单,苹果Metal框架很多深度神经网络的kernel都还没有,仅仅对卷积操作有部分封装。下面就分享实现过程中几个比较重要算法的实现。

Batch-Normalization


BN(Batch-Normalization)其实是这个网络的核心部分,每种不同的style图片就是在这里进行区分的,当你选择不同的style的时候,每层卷积操作是相同的,但是BN不同,就是改变最后生成图片的样式。

开始我期望Metal有BN的实现,但找了一圈没发现,考虑过写kernel这样和卷积操作都能在GPU上运算,但是最后发现kernel编码从头学太复杂,于是就在CPU上实现一个,每次卷积完过后,图片Copy出来在CPU上运算BN,然后再接激活函数(还是期待苹果后面能提供支持BN的kernel😊)。
实现代码如下:

- (void)batch_norm:(MPSImage *)image styles:(float *)styles shift:(float *)shift
{
    NSUInteger w = image.texture.width;
    NSUInteger h = image.texture.height;
    NSUInteger featureNum = image.featureChannels;
    float *gamma = calloc(featureNum, sizeof(float));
    float *beta = calloc(featureNum, sizeof(float));
//    float gamma[featureNum], beta[featureNum];
    vDSP_mmul(styles, 1, shift, 1, beta, 1, 1, featureNum, styleNum);
    vDSP_mmul(styles, 1, shift+featureNum*styleNum, 1, gamma, 1, 1, featureNum, styleNum);
//    for (int i = 0; i < featureNum; i++) {
//        printf("%f,%f ",gamma[i],beta[i]);
//    }
//    NSLog(@"%@",image);
//    
    NSUInteger numSlices = (featureNum + 3) / 4;
    NSUInteger numComponents = featureNum < 3 ? featureNum : 4;
    NSUInteger channels = featureNum < 3 ? featureNum : numSlices * 4;
    float16_t *htemp = calloc(w*h*channels, sizeof(float16_t));
    for (int i = 0; i < numSlices; i++) {
        [image.texture getBytes:htemp+w*h*numComponents*i bytesPerRow:w*numComponents*2 bytesPerImage:0 fromRegion:MTLRegionMake3D(0, 0, 0, w, h, 1) mipmapLevel:0 slice:i];
    }
    
    float *temp = calloc(w*h*channels, sizeof(float));
    [self halfTofloat:htemp floatp:temp width:w height:h channel:channels];
    float mean, var;
    for (int i = 0; i < featureNum; i++) {
        int slice = i / 4;
        int stride = i % 4;
        vDSP_normalize(temp+slice*w*h*numComponents+stride, numComponents, temp+slice*w*h*numComponents+stride, numComponents, &mean, &var, w*h);
        if (var == 0) {
            vDSP_vfill(&var, temp+slice*w*h*numComponents+stride, numComponents, w*h);
        }
        vDSP_vsmul(temp+slice*w*h*numComponents+stride, numComponents, &gamma[i], temp+slice*w*h*numComponents+stride, numComponents, w*h);
        vDSP_vsadd(temp+slice*w*h*numComponents+stride, numComponents, &beta[i], temp+slice*w*h*numComponents+stride, numComponents, w*h);
    }
    [self floatToHalf:temp halfp:htemp width:w height:h channel:channels];
    for (int i = 0; i < numSlices; i++) {
        [image.texture replaceRegion:MTLRegionMake3D(0, 0, 0, w, h, 1) mipmapLevel:0 slice:i withBytes:htemp+w*h*numComponents*i bytesPerRow:w*numComponents*2 bytesPerImage:0];
    }
    free(temp);
    free(htemp);
    free(gamma);
    free(beta);
}

Nearest-Neighbor

这个填充算法,苹果也没有直接提供,BlitCommandEncoder里面有相关的方法,但是我感觉使用有点麻烦,本来是个很简单的填充算法,再加上前面BN都已经在CPU上实现了,这个也就调用2次,于是我也直接在CPU上去实现运算了。

原理很简单,就是放大图片像素点周围用这一个色值去填充,


Nearest-Neighbor

实现代码:

- (void)ResizeNearestNeighbor:(MPSImage *)source destinationImage:(MPSImage *)destinationImage
{
    NSUInteger w = source.texture.width;
    NSUInteger h = source.texture.height;
    NSUInteger w2 = destinationImage.texture.width;
    NSUInteger h2 = destinationImage.texture.height;
    NSUInteger featureNum = source.featureChannels;
    NSUInteger numSlices = (featureNum + 3) / 4;
    NSUInteger numComponents = featureNum < 3 ? featureNum : 4;
    NSUInteger channels = featureNum < 3 ? featureNum : numSlices * 4;
    float16_t *htemp1 = calloc(w*h*channels, sizeof(float16_t));
    float16_t *htemp2 = calloc(w2*h2*channels, sizeof(float16_t));
    for (int i = 0; i < numSlices; i++) {
        [source.texture getBytes:htemp1+w*h*numComponents*i bytesPerRow:w*numComponents*2 bytesPerImage:0 fromRegion:MTLRegionMake3D(0, 0, 0, w, h, 1) mipmapLevel:0 slice:i];
    }
    
    int x_ratio = (int)((w<<16)/w2) +1;
    int y_ratio = (int)((h<<16)/h2) +1;
    int x2, y2 ;
    
    for (int k = 0; k < featureNum; k++) {
        int slice = k / 4;
        int stride = k % 4;
        for (int i=0;i<h2;i++) {
            for (int j=0;j<w2;j++) {
                x2 = ((j*x_ratio)>>16) ;
                y2 = ((i*y_ratio)>>16) ;
                htemp2[slice*w2*h2*numComponents+(i*w2+j)*numComponents+stride] = htemp1[slice*w*h*numComponents+((y2*w)+x2)*numComponents+stride];
            }
        }
    }
    
    for (int i = 0; i < numSlices; i++) {
        [destinationImage.texture replaceRegion:MTLRegionMake3D(0, 0, 0, w2, h2, 1) mipmapLevel:0 slice:i withBytes:htemp2+w2*h2*numComponents*i bytesPerRow:w2*numComponents*2 bytesPerImage:0];
    }
    free(htemp1);
    free(htemp2);
}

整个网络实现

最后,整个网络的实现,参照论文里的结构和链接顺序,其中所有的卷积都是继承的MPSCNNConvolution对象,代码有点长如下:

- (MPSImage *)forward:(CGImageRef)srcImage width:(int)width height:(int)height styles:(float *)styles
{
    id<MTLCommandBuffer> commandbuffer = [commandQueue commandBuffer];
    int w = width;
    int h = height;
    MTKTextureLoader *loader = [[MTKTextureLoader alloc] initWithDevice:mtDevice];
    id<MTLTexture> srcTexture = [loader newTextureWithCGImage:srcImage options:nil error:nil];
    MPSImage *cc1Image = [[MPSImage alloc] initWithTexture:srcTexture featureChannels:3];
    
//    MPSImage *tImage = [[MPSImage alloc] initWithTexture:srcTexture featureChannels:3];
//    MPSImageDescriptor *cc1Des = [MPSImageDescriptor imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat16 width:w height:h featureChannels:3];
//    MPSImage *cc1Image = [[MPSImage alloc] initWithDevice:mtDevice imageDescriptor:cc1Des];
//    [cc1Image.texture replaceRegion:MTLRegionMake3D(0, 0, 0, w, h, 1) mipmapLevel:0 withBytes:srcImage bytesPerRow:w*4*2];
    // contract
    MPSImageDescriptor *cc2Des = [MPSImageDescriptor imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat16 width:w height:h featureChannels:32];
    MPSImage *cc2Image = [[MPSImage alloc] initWithDevice:mtDevice imageDescriptor:cc2Des];
    [contractConv1 encodeToCommandBuffer:commandbuffer sourceImage:cc1Image destinationImage:cc2Image device:mtDevice];
    [commandbuffer commit];
    [commandbuffer waitUntilCompleted];
    [self batch_norm:cc2Image styles:styles shift:cc1Shift];
    
    commandbuffer = [commandQueue commandBuffer];
    [relu encodeToCommandBuffer:commandbuffer sourceImage:cc2Image destinationImage:cc2Image];
    w /= 2;
    h /= 2;
    MPSImageDescriptor *cc3Des = [MPSImageDescriptor imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat16 width:w height:h featureChannels:64];
    MPSImage *cc3Image = [[MPSImage alloc] initWithDevice:mtDevice imageDescriptor:cc3Des];
    [contractConv2 encodeToCommandBuffer:commandbuffer sourceImage:cc2Image destinationImage:cc3Image device:mtDevice];
    [commandbuffer commit];
    [commandbuffer waitUntilCompleted];
    [self batch_norm:cc3Image styles:styles shift:cc2Shift];

    commandbuffer = [commandQueue commandBuffer];
    [relu encodeToCommandBuffer:commandbuffer sourceImage:cc3Image destinationImage:cc3Image];
    w /= 2;
    h /= 2;
    MPSImageDescriptor *rcDes = [MPSImageDescriptor imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat16 width:w height:h featureChannels:128];
    MPSImage *rc11Image = [[MPSImage alloc] initWithDevice:mtDevice imageDescriptor:rcDes];
    [contractConv3 encodeToCommandBuffer:commandbuffer sourceImage:cc3Image destinationImage:rc11Image device:mtDevice];
    [commandbuffer commit];
    [commandbuffer waitUntilCompleted];
    [self batch_norm:rc11Image styles:styles shift:cc3Shift];
    
    // residual
    commandbuffer = [commandQueue commandBuffer];
    [relu encodeToCommandBuffer:commandbuffer sourceImage:rc11Image destinationImage:rc11Image];
    MPSImage *rc12Image = [[MPSImage alloc] initWithDevice:mtDevice imageDescriptor:rcDes];
    [residual1Conv1 encodeToCommandBuffer:commandbuffer sourceImage:rc11Image destinationImage:rc12Image device:mtDevice];
    [commandbuffer commit];
    [commandbuffer waitUntilCompleted];
    [self batch_norm:rc12Image styles:styles shift:rc11Shift];
    
    commandbuffer = [commandQueue commandBuffer];
    [relu encodeToCommandBuffer:commandbuffer sourceImage:rc12Image destinationImage:rc12Image];
    MPSImage *rc21Image = [[MPSImage alloc] initWithDevice:mtDevice imageDescriptor:rcDes];
    [residual1Conv2 encodeToCommandBuffer:commandbuffer sourceImage:rc12Image destinationImage:rc21Image device:mtDevice];
    [commandbuffer commit];
    [commandbuffer waitUntilCompleted];
    [self batch_norm:rc21Image styles:styles shift:rc12Shift];
    [self addImage:rc11Image B:rc21Image C:rc21Image];
    
    commandbuffer = [commandQueue commandBuffer];
    MPSImage *rc22Image = [[MPSImage alloc] initWithDevice:mtDevice imageDescriptor:rcDes];
    [residual2Conv1 encodeToCommandBuffer:commandbuffer sourceImage:rc21Image destinationImage:rc22Image device:mtDevice];
    [commandbuffer commit];
    [commandbuffer waitUntilCompleted];
    [self batch_norm:rc22Image styles:styles shift:rc21Shift];
    
    commandbuffer = [commandQueue commandBuffer];
    [relu encodeToCommandBuffer:commandbuffer sourceImage:rc22Image destinationImage:rc22Image];
    MPSImage *rc31Image = [[MPSImage alloc] initWithDevice:mtDevice imageDescriptor:rcDes];
    [residual2Conv2 encodeToCommandBuffer:commandbuffer sourceImage:rc22Image destinationImage:rc31Image device:mtDevice];
    [commandbuffer commit];
    [commandbuffer waitUntilCompleted];
    [self batch_norm:rc31Image styles:styles shift:rc22Shift];
    [self addImage:rc21Image B:rc31Image C:rc31Image];
    
    commandbuffer = [commandQueue commandBuffer];
    MPSImage *rc32Image = [[MPSImage alloc] initWithDevice:mtDevice imageDescriptor:rcDes];
    [residual3Conv1 encodeToCommandBuffer:commandbuffer sourceImage:rc31Image destinationImage:rc32Image device:mtDevice];
    [commandbuffer commit];
    [commandbuffer waitUntilCompleted];
    [self batch_norm:rc32Image styles:styles shift:rc31Shift];
    
    commandbuffer = [commandQueue commandBuffer];
    [relu encodeToCommandBuffer:commandbuffer sourceImage:rc32Image destinationImage:rc32Image];
    MPSImage *rc41Image = [[MPSImage alloc] initWithDevice:mtDevice imageDescriptor:rcDes];
    [residual3Conv2 encodeToCommandBuffer:commandbuffer sourceImage:rc32Image destinationImage:rc41Image device:mtDevice];
    [commandbuffer commit];
    [commandbuffer waitUntilCompleted];
    [self batch_norm:rc41Image styles:styles shift:rc32Shift];
    [self addImage:rc31Image B:rc41Image C:rc41Image];
    
    commandbuffer = [commandQueue commandBuffer];
    MPSImage *rc42Image = [[MPSImage alloc] initWithDevice:mtDevice imageDescriptor:rcDes];
    [residual4Conv1 encodeToCommandBuffer:commandbuffer sourceImage:rc41Image destinationImage:rc42Image device:mtDevice];
    [commandbuffer commit];
    [commandbuffer waitUntilCompleted];
    [self batch_norm:rc42Image styles:styles shift:rc41Shift];
    
    commandbuffer = [commandQueue commandBuffer];
    [relu encodeToCommandBuffer:commandbuffer sourceImage:rc42Image destinationImage:rc42Image];
    MPSImage *rc51Image = [[MPSImage alloc] initWithDevice:mtDevice imageDescriptor:rcDes];
    [residual4Conv2 encodeToCommandBuffer:commandbuffer sourceImage:rc42Image destinationImage:rc51Image device:mtDevice];
    [commandbuffer commit];
    [commandbuffer waitUntilCompleted];
    [self batch_norm:rc51Image styles:styles shift:rc42Shift];
    [self addImage:rc41Image B:rc51Image C:rc51Image];
    
    commandbuffer = [commandQueue commandBuffer];
    MPSImage *rc52Image = [[MPSImage alloc] initWithDevice:mtDevice imageDescriptor:rcDes];
    [residual5Conv1 encodeToCommandBuffer:commandbuffer sourceImage:rc51Image destinationImage:rc52Image device:mtDevice];
    [commandbuffer commit];
    [commandbuffer waitUntilCompleted];
    [self batch_norm:rc52Image styles:styles shift:rc51Shift];
    
    commandbuffer = [commandQueue commandBuffer];
    [relu encodeToCommandBuffer:commandbuffer sourceImage:rc52Image destinationImage:rc52Image];
    MPSImage *temp = [[MPSImage alloc] initWithDevice:mtDevice imageDescriptor:rcDes];
    [residual5Conv2 encodeToCommandBuffer:commandbuffer sourceImage:rc52Image destinationImage:temp device:mtDevice];
    [commandbuffer commit];
    [commandbuffer waitUntilCompleted];
    [self batch_norm:temp styles:styles shift:rc52Shift];
    [self addImage:rc51Image B:temp C:temp];
    
    // unsampling
    commandbuffer = [commandQueue commandBuffer];
    w *= 2;
    h *= 2;
    MPSImageDescriptor *ec1Des = [MPSImageDescriptor imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat16 width:w height:h featureChannels:128];
    MPSImage *ec1Image = [[MPSImage alloc] initWithDevice:mtDevice imageDescriptor:ec1Des];
    [self ResizeNearestNeighbor:temp destinationImage:ec1Image];
    
    MPSImageDescriptor *temp2Des = [MPSImageDescriptor imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat16 width:w height:h featureChannels:64];
    MPSImage *temp2 = [[MPSImage alloc] initWithDevice:mtDevice imageDescriptor:temp2Des];
    [expandConv1 encodeToCommandBuffer:commandbuffer sourceImage:ec1Image destinationImage:temp2 device:mtDevice];
    [commandbuffer commit];
    [commandbuffer waitUntilCompleted];
    [self batch_norm:temp2 styles:styles shift:ec1Shift];
    
    commandbuffer = [commandQueue commandBuffer];
    [relu encodeToCommandBuffer:commandbuffer sourceImage:temp2 destinationImage:temp2];
    w *= 2;
    h *= 2;
    MPSImageDescriptor *ec2Des = [MPSImageDescriptor imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat16 width:w height:h featureChannels:64];
    MPSImage *ec2Image = [[MPSImage alloc] initWithDevice:mtDevice imageDescriptor:ec2Des];
    [self ResizeNearestNeighbor:temp2 destinationImage:ec2Image];
    
    MPSImageDescriptor *ec3Des = [MPSImageDescriptor imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat16 width:w height:h featureChannels:32];
    MPSImage *ec3Image = [[MPSImage alloc] initWithDevice:mtDevice imageDescriptor:ec3Des];
    [expandConv2 encodeToCommandBuffer:commandbuffer sourceImage:ec2Image destinationImage:ec3Image device:mtDevice];
    [commandbuffer commit];
    [commandbuffer waitUntilCompleted];
    [self batch_norm:ec3Image styles:styles shift:ec2Shift];
    
    commandbuffer = [commandQueue commandBuffer];
    [relu encodeToCommandBuffer:commandbuffer sourceImage:ec3Image destinationImage:ec3Image];
    MPSImageDescriptor *destDes = [MPSImageDescriptor imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat16 width:w height:h featureChannels:3];
    MPSImage *destImage = [[MPSImage alloc] initWithDevice:mtDevice imageDescriptor:destDes];
    [expandConv3 encodeToCommandBuffer:commandbuffer sourceImage:ec3Image destinationImage:destImage device:mtDevice];
    [commandbuffer commit];
    [commandbuffer waitUntilCompleted];
    [self batch_norm:destImage styles:styles shift:ec3Shift];
    
    commandbuffer = [commandQueue commandBuffer];
    [sigmoid encodeToCommandBuffer:commandbuffer sourceImage:destImage destinationImage:destImage];
    [commandbuffer commit];
    [commandbuffer waitUntilCompleted];
    return destImage;
}

结语


什么话都不想留下了☠️,放几张程序运行图吧😊。

运行时内存、GPU、CPU状态
上一篇下一篇

猜你喜欢

热点阅读