[tf]tf.gather_nd的用法

2019-02-17  本文已影响0人  VanJordan

函数原型,nd的意思是可以收集n dimensiontensor

tf.gather_nd(
    params,
    indices,
    name=None
)
    indices = [[0, 0], [1, 1]]
    params = [['a', 'b'], ['c', 'd']]
    output = ['a', 'd']
    indices = [[1], [0]]
    params = [['a', 'b'], ['c', 'd']]
    output = [['c', 'd'], ['a', 'b']]
    indices = [[1]]
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]
    output = [[['a1', 'b1'], ['c1', 'd1']]]
max_indicies = tf.argmax(T, 1)
import tensorflow as tf

sess = tf.InteractiveSession()

values = tf.constant([[0, 0, 0, 1],
                      [0, 1, 0, 0],
                      [0, 0, 1, 0]])

T = tf.constant([[0, 1, 2 ,  3],
                 [4, 5, 6 ,  7],
                 [8, 9, 10, 11]])

max_indices = tf.argmax(values, axis=1)
# If T.get_shape()[0] is None, you can replace it with tf.shape(T)[0].
result = tf.gather_nd(T, tf.stack((tf.range(T.get_shape()[0], 
                                            dtype=max_indices.dtype),
                                   max_indices),
                                  axis=1))
print(result.eval())
上一篇 下一篇

猜你喜欢

热点阅读