Tensorflow模型中的trainable和training

2021-10-03  本文已影响0人  LabVIEW_Python

在模型或任何层上设置 trainable = False,则模型或所有子层也将变为不可训练,该操作叫冻结层。层被冻结后,可训练参数在训练的过程中,将不会被更新。

模型的_ call _()中有一个参数,training=None, 其指示网络的运行的过程中处于training模式还是inference模式

training参数
有些数据增强层,在inference模式下,直接恒等输出 区分training状态的网络层
数据增强层,在保存成模型文件后,存在于模型中的,例如: 数据增强层
所以建议将数据增强剥离出模型外,仅仅作用于数据集,data augmentation should only be applied to the training set.例如:
batch_size = 32
AUTOTUNE = tf.data.AUTOTUNE

def prepare(ds, shuffle=False, augment=False):
  # Resize and rescale all datasets
  ds = ds.map(lambda x, y: (resize_and_rescale(x), y), 
              num_parallel_calls=AUTOTUNE)

  if shuffle:
    ds = ds.shuffle(1000)

  # Batch all datasets
  ds = ds.batch(batch_size)

  # Use data augmentation only on the training set
  if augment:
    ds = ds.map(lambda x, y: (data_augmentation(x, training=True), y), 
                num_parallel_calls=AUTOTUNE)

  # Use buffered prefecting on all datasets
  return ds.prefetch(buffer_size=AUTOTUNE)
上一篇下一篇

猜你喜欢

热点阅读