TF2.0:获取网络中间卷积层中所有卷积核内的参数

2020-07-26  本文已影响0人  胜负55开

说明:以Unet网络为例,想要查看训练好的网络中某个卷积层内所有已训练好的卷积核内的数值,是可以实现的!

关键函数:

实例:

# 需要用到的各种包:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

# 导入训练并保存好的完整原Unet网络:
model = tf.keras.models.load_model( 'model_1.h5' )

# 获取已训练好的网络中的某中间卷积层的权重参数:
# conv2d_1_1是我在原网络搭建时为某卷积层起的名字
tmp = model.get_layer('conv2d_1_1').get_weights()

# tmp是个list,查看其长度:
len(tmp)
2   # 已分类:第一个元素存所有卷积核,第二个元素存所有卷积核对应的偏置

 # 获取所有的卷积核:共64个
tmp1 = tmp[0]
tmp1.shape, type(tmp1)
# 打印结果:
((3, 3, 3, 64), numpy.ndarray)   # 每个卷积核是(3,3,3),共64个 —— 正确!

# 获取每一个卷积核的偏置:共64个
tmp2 = tmp[1]
tmp2.shape, type(tmp2)
# 打印结果:
((64,), numpy.ndarray)  # 没毛病!

# 查看第一个卷积核的值:
tmp1_1 = tmp1[:, :, :, 0]
tmp1_1
# 打印结果:
array([[[ 0.07278805, -0.02773886,  0.00090645],
        [ 0.00616201,  0.0493045 ,  0.09584951],
        [ 0.03114345, -0.15490085, -0.00721605]],

       [[ 0.15069817,  0.01782583,  0.0167292 ],
        [ 0.04765529,  0.0310829 ,  0.20920704],
        [ 0.00184333, -0.09529579,  0.04455388]],

       [[ 0.09400825, -0.03355229,  0.04949141],
        [ 0.039369  ,  0.00072069,  0.10736698],
        [ 0.02848012, -0.16586785,  0.0236699 ]]], dtype=float32)

总结:不论是获取中间层的输出,还是获取中间层中权重参数,都需要先从原模型中拿到那个层!此时用到的函数就是:model.get_layer('层名')。之后想进一步查看该层中的各种信息,就直接使用各层的各种“属性”即可,例如:output属性、get_weights属性、等等有很多!

上一篇下一篇

猜你喜欢

热点阅读