Tensorflow添加正则项的三种方式

2018-10-27  本文已影响0人  cheerss
  1. 手写计算出正则项的大小,并通过add_to_collection()方法把它加入到collection中,在需要的时候再通过get_collection()方法取出来。
    • 优势:正则项的计算表达式可以随意定义
    • 劣势:需要手写比较麻烦,毕竟常用的正则项Tensorflow中都是有的
import os
import tensorflow as tf
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

def _variable_with_weight_decay(name, wd):
    a = tf.get_variable(name, [1], initializer=tf.constant_initializer(1.0))
    loss = tf.multiply(tf.nn.l2_loss(a), wd, name="weight_loss")
    tf.add_to_collection("loss", loss)
    tf.add_to_collection(tf.GraphKeys.WEIGHTS, a)
    return a

def main():
    with tf.Graph().as_default():
        a = _variable_with_weight_decay("a", 0.1)
        b = _variable_with_weight_decay("b", 0.1)
        init_op = tf.global_variables_initializer()
        all_weight_decay = tf.get_collection("loss")
        all_weights = tf.get_collection(tf.GraphKeys.WEIGHTS)
        print(all_weights)
        # 输出 [<tf.Variable 'a:0' shape=(1,) dtype=float32_ref>, <tf.Variable 'b:0' shape=(1,) dtype=float32_ref>]
        loss_all = tf.add_n(tf.get_collection('loss'))
        
        with tf.Session() as sess:
            sess.run(init_op)
            res = sess.run(all_weight_decay)
            print(res)
            # 输出 [0.05, 0.05]
        
if __name__ == "__main__":
    main()
  1. 在variable_scope中指定,或者在get_variable中指定,当然,tf.layers.conv2d等API中也都可以指定正则项。其实variable_scope就像一个默认值。所有的正则项会被加入到tf.GraphKeys.REGULARIZATION_LOSSES中,这个东西也可以通过get_collection()获取。
    • 优势:方便,可以直接调用API中自带的正则项
    • 劣势:自带的regularizer有限
import os
import tensorflow as tf
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

def main():
    with tf.Graph().as_default():
        with tf.variable_scope("first", regularizer=tf.contrib.layers.l2_regularizer(0.1)) as scope:
            a = tf.get_variable("a", [1], initializer=tf.constant_initializer(1.0))
        init_op = tf.global_variables_initializer()
        print(tf.get_collection(tf.GraphKeys.WEIGHTS))
        print(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
        loss = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
        
        with tf.Session() as sess:
            sess.run(init_op)
            res = sess.run(loss)
            print(res)
        
if __name__ == "__main__":
    main()
  1. 手动创建一个regularizer后作用于变量。这种方式可以不仅限于tf.variable_scope(),而是所有变量。其中tf.contrib.layers.apply_regularization()这个函数如果不指定weights_list参数则默认作用于tf.GraphKeys.WEIGHTS
import os
import tensorflow as tf
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

def main():
    with tf.Graph().as_default():
        with tf.variable_scope("first") as scope:
            regularizer = tf.contrib.layers.l2_regularizer(0.1)
            a = tf.get_variable("a", [1], initializer=tf.constant_initializer(1.0))
            b = tf.get_variable("b", [1], initializer=tf.constant_initializer(1.0))
            
        tf.contrib.layers.apply_regularization(regularizer, weights_list=[a, b])
        init_op = tf.global_variables_initializer()
        loss = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
        
        with tf.Session() as sess:
            sess.run(init_op)
            res = sess.run(loss)
            print(res)
        
if __name__ == "__main__":
    main()
上一篇下一篇

猜你喜欢

热点阅读