Contextual Attention论文与源码解析与结果
2020-04-13 本文已影响0人
抬头挺胸才算活着
参考资料:
[1] Generative Image Inpainting with Contextual Attention
[2] question about the inpaint_ops.py
[3] about the contextual attention
[4] inpaint_ops / contextual attention
- Contextual Attention
在GAN生成图片中,很多都不能从自己图片中复制到需要的内容,传统方法PatchMatch这点做的很好,但是却不能产生图中没有的内容,所以[1]构造的模型提出了Contextual Attention,可以很好地复制图片中的内容。
下面是我看contextual_attention函数做的注释,大致弄懂了。
跟上图不一样的是f和b实际上大小是一样的,mask指定b中被污染的地方,f是前一阶段产生出来的大致的预测,经过图中的卷积,可以得出f中大概跟背景那块比较相似,后面再接一个deconv层可以把b中相似的地方给"借"过来,做实验可以看到输出的y跟输入的b是完全一样的,说明经过这个操作,f可以借到b中任何想要的内容。
def contextual_attention(f, b, mask=None, ksize=3, stride=1, rate=1,
fuse_k=3, softmax_scale=10., training=True, fuse=True):
""" Contextual attention layer implementation.
Contextual attention is first introduced in publication:
Generative Image Inpainting with Contextual Attention, Yu et al.
Args:
x: Input feature to match (foreground).
t: Input feature for match (background).
mask: Input mask for t, indicating patches not available.
跟b一样大小全零代表一张完整的好图,有缺陷的地方为1
ksize: Kernel size for contextual attention.
stride: Stride for extracting patches from t.
rate: Dilation for matching.
softmax_scale: Scaled softmax for attention.
training: Indicating if current graph is training or inference.
Returns:
tf.Tensor: output
"""
# sess = tf.InteractiveSession()
# get shapes
raw_fs = tf.shape(f)
raw_int_fs = f.get_shape().as_list()
raw_int_bs = b.get_shape().as_list()
# extract patches from background with stride and rate
kernel = 2*rate
# 这里跟下面的不同是,卷积核是像空洞卷积一样还是连在一起,这个模型选择了连在一起
# raw_w = tf.extract_image_patches(
# b, [1,kernel,kernel,1], [1,stride,stride,1], [1,rate,rate,1], padding='SAME')
raw_w = tf.extract_image_patches(
b, [1,kernel,kernel,1], [1,rate*stride,rate*stride,1], [1,1,1,1], padding='SAME')
# 这里两步不能直接用reshape搞定,看test1
raw_w = tf.reshape(raw_w, [raw_int_bs[0], -1, kernel, kernel, raw_int_bs[3]])
raw_w = tf.transpose(raw_w, [0, 2, 3, 4, 1]) # transpose to b*k*k*c*hw
# downscaling foreground option: downscaling both foreground and
# background for matching and use original background for reconstruction.
f = resize(f, scale=1./rate, func=tf.image.resize_nearest_neighbor)
# 为什么这里改为to_shape了??
b = resize(b, to_shape=[int(raw_int_bs[1]/rate), int(raw_int_bs[2]/rate)], func=tf.image.resize_nearest_neighbor) # https://github.com/tensorflow/tensorflow/issues/11651
if mask is not None:
mask = resize(mask, scale=1./rate, func=tf.image.resize_nearest_neighbor)
# 缩放后重新获取shape
fs = tf.shape(f)
int_fs = f.get_shape().as_list()
# 每张图片切成一份tensor
f_groups = tf.split(f, int_fs[0], axis=0)
# from t(H*W*C) to w(b*k*k*c*h*w)
bs = tf.shape(b)
int_bs = b.get_shape().as_list()
w = tf.extract_image_patches(
b, [1,ksize,ksize,1], [1,stride,stride,1], [1,1,1,1], padding='SAME')
w = tf.reshape(w, [int_fs[0], -1, ksize, ksize, int_fs[3]])
w = tf.transpose(w, [0, 2, 3, 4, 1]) # transpose to b*k*k*c*hw
# process mask
# mask跟b的区别:只有一个batch,只有一个通道,图片大小一样
if mask is None:
mask = tf.zeros([1, bs[1], bs[2], 1])
m = tf.extract_image_patches(
mask, [1,ksize,ksize,1], [1,stride,stride,1], [1,1,1,1], padding='SAME')
m = tf.reshape(m, [1, -1, ksize, ksize, 1])
m = tf.transpose(m, [0, 2, 3, 4, 1]) # transpose to b*k*k*c*hw
m = m[0]
# 每个采样出来的patch对应一个mean,mm的大小跟b的像素点个数一样多
mm = tf.cast(tf.equal(tf.reduce_mean(m, axis=[0,1,2], keep_dims=True), 0.), tf.float32)
# 每张图片对应一组w,k*k*c*hw
w_groups = tf.split(w, int_bs[0], axis=0)
raw_w_groups = tf.split(raw_w, int_bs[0], axis=0)
y = []
offsets = []
k = fuse_k
scale = softmax_scale
fuse_weight = tf.reshape(tf.eye(k), [k, k, 1, 1])
# 这里跟一般的卷积还不一样,一般的卷积都是每个batch都是一样的系数
# 这里每幅图片都是各自的卷积核系数
for xi, wi, raw_wi in zip(f_groups, w_groups, raw_w_groups):
# conv for compare
wi = wi[0]
# 对权重进行归一化
wi_normed = wi / tf.maximum(tf.sqrt(tf.reduce_sum(tf.square(wi), axis=[0,1,2])), 1e-4)
# 最关键的一步卷积
yi = tf.nn.conv2d(xi, wi_normed, strides=[1,1,1,1], padding="SAME")
# conv implementation for fuse scores to encourage large patches
if fuse:
yi = tf.reshape(yi, [1, fs[1]*fs[2], bs[1]*bs[2], 1])
yi = tf.nn.conv2d(yi, fuse_weight, strides=[1,1,1,1], padding='SAME')
yi = tf.reshape(yi, [1, fs[1], fs[2], bs[1], bs[2]])
yi = tf.transpose(yi, [0, 2, 1, 4, 3])
yi = tf.reshape(yi, [1, fs[1]*fs[2], bs[1]*bs[2], 1])
yi = tf.nn.conv2d(yi, fuse_weight, strides=[1,1,1,1], padding='SAME')
yi = tf.reshape(yi, [1, fs[2], fs[1], bs[2], bs[1]])
yi = tf.transpose(yi, [0, 2, 1, 4, 3])
# 保留f的长宽
yi = tf.reshape(yi, [1, fs[1], fs[2], bs[1]*bs[2]])
# lyc1 = yi.eval()[0,0,0,:]
# 如果mask所在领域有任意一个为1,代表有缺陷的值,那么这个邻域代表的filter卷积出来的跟f一样大小的feature map
# 会乘完之后变成0
# softmax to match
yi *= mm # mask
# 对应f的每个像素点所在邻域,所有filter一起算,看那个出来的值比较大。
yi = tf.nn.softmax(yi*scale, 3)
yi *= mm # mask
# lyc1 = yi.eval()[0,:,:,0]
# lyc2 = yi.eval()[0,:,:,1]
# 在b中的偏移,//和%运算符是因为offset这个轴的长度是fs[1]*fs[2]
offset = tf.argmax(yi, axis=3, output_type=tf.int32)
offset = tf.stack([offset // fs[2], offset % fs[2]], axis=-1)
# deconv for patch pasting
# 3.1 paste center
wi_center = raw_wi[0]
# lyc3 = wi_center.eval()[:,:,0,:]
# lyc4 = wi_center.eval()[:,:,0,0]
# print(repr(lyc3))
# print(repr(lyc4))
# 为什么除以4??
yi = tf.nn.conv2d_transpose(yi, wi_center, tf.concat([[1], raw_fs[1:]], axis=0), strides=[1,rate,rate,1]) / 4.
# lyc2 = yi.eval()[0,:,:,0]
y.append(yi)
offsets.append(offset)
y = tf.concat(y, axis=0)
y.set_shape(raw_int_fs)
offsets = tf.concat(offsets, axis=0)
offsets.set_shape(int_bs[:3] + [2])
# case1: visualize optical flow: minus current position
h_add = tf.tile(tf.reshape(tf.range(bs[1]), [1, bs[1], 1, 1]), [bs[0], 1, bs[2], 1])
w_add = tf.tile(tf.reshape(tf.range(bs[2]), [1, 1, bs[2], 1]), [bs[0], bs[1], 1, 1])
# 由绝对偏移转向相对偏移
offsets = offsets - tf.concat([h_add, w_add], axis=3)
# to flow image
flow = flow_to_image_tf(offsets)
# # case2: visualize which pixels are attended
# flow = highlight_flow_tf(offsets * tf.cast(mask, tf.int32))
if rate != 1:
flow = resize(flow, scale=rate, func=tf.image.resize_nearest_neighbor)
return y, flow
-
wgan loss
wgan loss只有两行,是根据下面公式计算出来的
def gan_wgan_loss(pos, neg, name='gan_wgan_loss'):
"""
wgan loss function for GANs.
- Wasserstein GAN: https://arxiv.org/abs/1701.07875
"""
with tf.variable_scope(name):
d_loss = tf.reduce_mean(neg-pos)
g_loss = -tf.reduce_mean(neg)
scalar_summary('d_loss', d_loss)
scalar_summary('g_loss', g_loss)
scalar_summary('pos_value_avg', tf.reduce_mean(pos))
scalar_summary('neg_value_avg', tf.reduce_mean(neg))
return g_loss, d_loss