tensor的slice赋值

2018-10-12  本文已影响0人  CodePlayHu

我们经常在numpy中会用到类似于label[:, :, :, :, 0] = 0这样的切片赋值操作,那么在TensorFlow中应该如何实现呢?

a = tf.Variable([[[1, 1, 1], [2, 2, 2]],[[3, 3, 3], [4, 4, 4]],[[5, 5, 5], [6, 6, 6]]])
with tf.Session() as sess:
    sess.run(tf.global_variable_initializer())
    sess.run(a[:2,:1,:].assign(-1*tf.ones_like(a[:2,:1,:]))) # 将a[:2,:1,:] 中的数值赋值为-1,注意assign函数中的参数不能直接赋-1,会报错说不支持broadcast
    sess.run(a)

输出

**第一个输出
array([[[-1, -1, -1],
        [ 2,  2,  2]],

       [[-1, -1, -1],
        [ 4,  4,  4]],

       [[ 5,  5,  5],
        [ 6,  6,  6]]], dtype=int32)
**第二个输出
array([[[-1, -1, -1],
        [ 2,  2,  2]],

       [[-1, -1, -1],
        [ 4,  4,  4]],

       [[ 5,  5,  5],
        [ 6,  6,  6]]], dtype=int32)

说明tensor-a中的数值已经被成功修改了。大功告成。

补充一个conditional slice assignment
例如我们需要让a tensor中所有等于-1的地方都变为0,怎么操作呢

contition = tf.equal(a,-1)
sess.run(tf.where(condition, tf.zeros_like(condition, dtype=tf.int32), a))
# 这里的tf.where中的三个参数分别代表判断条件,满足条件的位置赋值矩阵以及不满足条件的位置的赋值矩阵,要注意前两个矩阵要同大小。
# 这里的意思就是,满足a == -1的位置赋值为0,其他位置还是a原来的值
上一篇 下一篇

猜你喜欢

热点阅读