tf.nn.in_top_k/tf.nn.top_k
2018-07-16 本文已影响0人
yalesaleng
tf.nn.in_top_k
correct = tf.nn.in_top_k(logits, labels, k)
其中:
logits: a tensor of shape [batch_size, NUM_CLASSES]
labels: a tensor of shape [batch_size]
理解:
- 1.对于logits的某行logits[i],
找到其前k个最大的预测值的index_0, .., index_k-1,
如果发现对应的labels[i]在{index_0, …, index_k-1},
则返回True.(大致这个意思)- 2.当k=1时,等价于tf.equal(logits, labels)。
但是,equal()函数中的logits和labels的shape必须一样。
因此,通过read_data_sets(…, one_hot=False,…)读取数据时,
必须使得one_hot=Ture(默认为False).
下面为源码中对此的解释:
This outputs a batch_size bool array, an entry out[i] is true if the prediction for the target class is among the top k predictions among all predictions for example i.
tf.nn.top_k
tf.nn.top_k(input, k=1, sorted=sorted, name)
在top_k()函数中,返回的是两个值:
- values: input中最后一维部分的前k个最大的值。
- indices:与values中各个值对应的索引。
(sorted默认为1, 即根据输出values进行降序输出)。