Tensorflow——tf.argmax()解析
2018-10-29 本文已影响47人
SpareNoEfforts
简介
tf.argmax就是返回最大的那个数值所在的下标。
定义如下:
def argmax(self, axis=None, fill_value=None, out=None):
tf.argmax()的参数:比如,tf.argmax(array, 1)和tf.argmax(array, 0)有啥区别呢?
例子
test = np.array([[1, 2, 3], [2, 3, 4], [5, 4, 3], [8, 7, 2]])
np.argmax(test, 0) #输出:array([3, 3, 1]
np.argmax(test, 1) #输出:array([2, 2, 0, 0]
tf.argmax(array, 1)
等于1的时候,比较范围缩小了,只会比较每个数组内的数的大小,结果也会根据有几个数组,产生几个结果。
test[0] = array([1, 2, 3]) #2
test[1] = array([2, 3, 4]) #2
test[2] = array([5, 4, 3]) #0
test[3] = array([8, 7, 2]) #0
tf.argmax(array, 0)
你就这么想,0是最大的范围,所有的数组都要进行比较,只是比较的是这些数组相同位置上的数:
test[0] = array([1, 2, 3])
test[1] = array([2, 3, 4])
test[2] = array([5, 4, 3])
test[3] = array([8, 7, 2])
# output : [3, 3, 1]