[tf]tf.gather_nd的用法
2019-02-17 本文已影响0人
VanJordan
函数原型,nd
的意思是可以收集n dimension
的tensor
tf.gather_nd(
params,
indices,
name=None
)
- 意思是要收集
[params[0][0],params[1][1]]
indices = [[0, 0], [1, 1]]
params = [['a', 'b'], ['c', 'd']]
output = ['a', 'd']
- 意思是要收集
[params[1],params[0]]
indices = [[1], [0]]
params = [['a', 'b'], ['c', 'd']]
output = [['c', 'd'], ['a', 'b']]
- 意思是要收集
[params[1]]
indices = [[1]]
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]
output = [[['a1', 'b1'], ['c1', 'd1']]]
- 我们使用这个函数的一般是想完成这样一个功能:T是一个二维
tensor
,我们想要根据另外一个二维tensor
value的最后一维最大元素的下标选出tensor
T 中最后一维最大的元素,组成一个新的一维的tensor
,那么就可以首先选出最后一维度的下标[1,2,3]
,然后将将其扩展成[[0,1],[1,2],[2,3]]
,然后使用这个函数选择即可。
max_indicies = tf.argmax(T, 1)
import tensorflow as tf
sess = tf.InteractiveSession()
values = tf.constant([[0, 0, 0, 1],
[0, 1, 0, 0],
[0, 0, 1, 0]])
T = tf.constant([[0, 1, 2 , 3],
[4, 5, 6 , 7],
[8, 9, 10, 11]])
max_indices = tf.argmax(values, axis=1)
# If T.get_shape()[0] is None, you can replace it with tf.shape(T)[0].
result = tf.gather_nd(T, tf.stack((tf.range(T.get_shape()[0],
dtype=max_indices.dtype),
max_indices),
axis=1))
print(result.eval())