Tensorflow

TensorFlow学习笔记(13)tf.argmax浅析

2018-11-19  本文已影响0人  谢昆明

求这组数据的最大值?
[[0.06251886 0.2645436 0.04882399 0.09480914 0.04890436 0.15327263
0.0369646 0.22686356 0.0089916 0.05430767]]

这时候就是用tf.argmax的最好时候,测试代码

from __future__ import print_function
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data


sess = tf.Session()
a = tf.constant([1.,2.,3.,0.,9.,])
b = tf.constant([[1,2,3],
                 [3,2,1],
                 [4,5,6],
                 [6,5,4]])

col_max0 = sess.run(tf.argmax(a, 0))
print (col_max0)
#  4

col_max = sess.run(tf.argmax(b, 0) )  #当axis=0时返回每一列的最大值的位置索引
print (col_max)
#  [3 2 2]

row_max = sess.run(tf.argmax(b, 1) )  #当axis=1时返回每一行中的最大值的位置索引
print (row_max)
#  [2 0 2 0]
上一篇下一篇

猜你喜欢

热点阅读