Tensorflow添加正则项的三种方式
2018-10-27 本文已影响0人
cheerss
- 手写计算出正则项的大小,并通过
add_to_collection()
方法把它加入到collection中,在需要的时候再通过get_collection()
方法取出来。- 优势:正则项的计算表达式可以随意定义
- 劣势:需要手写比较麻烦,毕竟常用的正则项Tensorflow中都是有的
import os
import tensorflow as tf
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
def _variable_with_weight_decay(name, wd):
a = tf.get_variable(name, [1], initializer=tf.constant_initializer(1.0))
loss = tf.multiply(tf.nn.l2_loss(a), wd, name="weight_loss")
tf.add_to_collection("loss", loss)
tf.add_to_collection(tf.GraphKeys.WEIGHTS, a)
return a
def main():
with tf.Graph().as_default():
a = _variable_with_weight_decay("a", 0.1)
b = _variable_with_weight_decay("b", 0.1)
init_op = tf.global_variables_initializer()
all_weight_decay = tf.get_collection("loss")
all_weights = tf.get_collection(tf.GraphKeys.WEIGHTS)
print(all_weights)
# 输出 [<tf.Variable 'a:0' shape=(1,) dtype=float32_ref>, <tf.Variable 'b:0' shape=(1,) dtype=float32_ref>]
loss_all = tf.add_n(tf.get_collection('loss'))
with tf.Session() as sess:
sess.run(init_op)
res = sess.run(all_weight_decay)
print(res)
# 输出 [0.05, 0.05]
if __name__ == "__main__":
main()
- 在variable_scope中指定,或者在get_variable中指定,当然,
tf.layers.conv2d
等API中也都可以指定正则项。其实variable_scope就像一个默认值。所有的正则项会被加入到tf.GraphKeys.REGULARIZATION_LOSSES
中,这个东西也可以通过get_collection()
获取。- 优势:方便,可以直接调用API中自带的正则项
- 劣势:自带的regularizer有限
import os
import tensorflow as tf
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
def main():
with tf.Graph().as_default():
with tf.variable_scope("first", regularizer=tf.contrib.layers.l2_regularizer(0.1)) as scope:
a = tf.get_variable("a", [1], initializer=tf.constant_initializer(1.0))
init_op = tf.global_variables_initializer()
print(tf.get_collection(tf.GraphKeys.WEIGHTS))
print(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
loss = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
with tf.Session() as sess:
sess.run(init_op)
res = sess.run(loss)
print(res)
if __name__ == "__main__":
main()
- 手动创建一个regularizer后作用于变量。这种方式可以不仅限于
tf.variable_scope()
,而是所有变量。其中tf.contrib.layers.apply_regularization()
这个函数如果不指定weights_list
参数则默认作用于tf.GraphKeys.WEIGHTS
import os
import tensorflow as tf
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
def main():
with tf.Graph().as_default():
with tf.variable_scope("first") as scope:
regularizer = tf.contrib.layers.l2_regularizer(0.1)
a = tf.get_variable("a", [1], initializer=tf.constant_initializer(1.0))
b = tf.get_variable("b", [1], initializer=tf.constant_initializer(1.0))
tf.contrib.layers.apply_regularization(regularizer, weights_list=[a, b])
init_op = tf.global_variables_initializer()
loss = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
with tf.Session() as sess:
sess.run(init_op)
res = sess.run(loss)
print(res)
if __name__ == "__main__":
main()