TF2.0:tf.cast()张量数据类型转换函数

2019-12-07  本文已影响0人  胜负55开

TF中:张量元素的数据类型:

所以:tf.cast()函数可以实现“元素数据类型”的互转。

函数:tf.cast(a, dtype = 新数据类型)
参数:a是原始数据,dtype是想要转换的新的数据类型 —— 上面类型均可互转!
返回值:返回一个和a的shape一样的,元素类型变为新数据类型的张量。

例子1:int与float互转

import tensorflow as tf

a = tf.random.normal( [3, 4] ) + 1.2

# 转为int:自动4舍5入
a_int = tf.cast(a, dtype = tf.int32)
print(a_int)

# 转为更高精度double:
a_double = tf.cast(a, dtype = tf.float64)
print(a_double)

# 结果:
<tf.Tensor: id=15, shape=(3, 4), dtype=int32, numpy=
array([[0, 2, 1, 1],
       [2, 0, 0, 0],
       [0, 1, 1, 0]])>

<tf.Tensor: id=16, shape=(3, 4), dtype=float64, numpy=
array([[-0.7194407 ,  2.5145092 ,  1.83174539,  1.16168237],
       [ 2.46861219,  0.52549803, -0.0511241 ,  0.56459582],
       [-0.27080107,  1.45568728,  1.62033534,  0.73094273]])>

例子2:bool与int互转 —— 统计预测正确的个数!

import tensorflow as tf

a = tf.constant( [True, False, True, True, False] )

# 转为int:
a_int = tf.cast(a, dtype = tf.int32)
print(a_int)

# 结果:
<tf.Tensor: id=20, shape=(5,), dtype=int32, numpy=array([1, 0, 1, 1, 0])>
上一篇 下一篇

猜你喜欢

热点阅读