tf2 随机种子设置

2021-08-23  本文已影响0人  见喉

全局种子

tf.random.set_seed(116)

针对程序重新运行或者tf.function(类似于re-run of a program),保证随机操作顺序相同

例如1

重新运行程序

tf.random.set_seed(1234)

print(tf.random.uniform([1]))  # generates 'A1'

print(tf.random.uniform([1]))  # generates 'A2'

(now close the program and run it again)

tf.random.set_seed(1234)

print(tf.random.uniform([1]))  # generates 'A1'

print(tf.random.uniform([1]))  # generates 'A2'

例如2

定义函数tf.function

tf.random.set_seed(1234)

@tf.function

def f():

  a = tf.random.uniform([1])

  b = tf.random.uniform([1])

  return a, b

@tf.function

def g():

  a = tf.random.uniform([1])

  b = tf.random.uniform([1])

  return a, b

print(f())  # prints '(A1, A2)'

print(g())  # prints '(A1, A2)'

output

(<tf.Tensor: id=20, shape=(1,), dtype=float32, numpy=array([0.96046877], dtype=float32)>, <tf.Tensor: id=21, shape=(1,), dtype=float32, numpy=array([0.85591054], dtype=float32)>)

(<tf.Tensor: id=41, shape=(1,), dtype=float32, numpy=array([0.96046877], dtype=float32)>, <tf.Tensor: id=42, shape=(1,), dtype=float32, numpy=array([0.85591054], dtype=float32)>)




操作种子

tf.random.truncated_normal([4,3], stddev=0.1, seed=1)

例如1

内部计数器,每次执行时会增加,产生不同的结果

print(tf.random.uniform([1], seed=1))  # generates 'A1'

print(tf.random.uniform([1], seed=1))  # generates 'A2'

(now close the program and run it again)

print(tf.random.uniform([1], seed=1))  # generates 'A1'

print(tf.random.uniform([1], seed=1))  # generates 'A2'

例如2

多个相同操作种子包含在tf.funtion中,因操作时间不长,共享相同的计数器

@tf.function

def foo():

  a = tf.random.uniform([1], seed=1)

  b = tf.random.uniform([1], seed=1)

  return a, b

print(foo())  # prints '(A1, A1)'

print(foo())  # prints '(A2, A2)'

output

(<tf.Tensor: id=20, shape=(1,), dtype=float32, numpy=array([0.2390374], dtype=float32)>, <tf.Tensor: id=21, shape=(1,), dtype=float32, numpy=array([0.2390374], dtype=float32)>)

(<tf.Tensor: id=22, shape=(1,), dtype=float32, numpy=array([0.22267115], dtype=float32)>, <tf.Tensor: id=23, shape=(1,), dtype=float32, numpy=array([0.22267115], dtype=float32)>)

@tf.function

def bar():

  a = tf.random.uniform([1])#不设置操作种子

  b = tf.random.uniform([1])

  return a, b

print(bar())  # prints '(A1, A2)'

print(bar())  # prints '(A3, A4)'


全局种子+操作种子

全局种子会重置计数器tf.random.set_seed()

tf.random.set_seed(1234)

print(tf.random.uniform([1], seed=1))  # generates 'A1'

print(tf.random.uniform([1], seed=1))  # generates 'A2'

tf.random.set_seed(1234)

print(tf.random.uniform([1], seed=1))  # generates 'A1'

print(tf.random.uniform([1], seed=1))  # generates 'A2'

相当于关闭了程序re-run


附注

以下三种随机操作顺序不同:

1全局+操作

tf.random.set_seed(1234)

print(tf.random.uniform([1], seed=1)) 

output

tf.Tensor([0.1689806], shape=(1,), dtype=float32)

2全局

tf.random.set_seed(1234)

print(tf.random.uniform([1)) 

output

tf.Tensor([0.5380393], shape=(1,), dtype=float32)

3操作

print(tf.random.uniform([1], seed=1)) 

output

tf.Tensor([0.2390374], shape=(1,), dtype=float32)

上一篇下一篇

猜你喜欢

热点阅读