Tensorflow2.0(二) tf.Tensor张量操作
2020-06-13 本文已影响0人
侠之大者_7d3f
前言
Tensor 在PyTorch和Tensorflow中都用于存储多维数据,两个框架的Tensor有许多相同之处.
Requirement
- OS: Ubuntu 18.04
- Tensorflow >= 2.0.0
Tensorflow
A Tensor is a multi-dimensional array. Similar to NumPy
ndarray
objects,tf.Tensor
objects have a data type and a shape. Additionally,tf.Tensor
s can reside in accelerator memory (like a GPU). TensorFlow offers a rich library of operations (tf.add, tf.matmul, tf.linalg.inv etc.) that consume and producetf.Tensor
s. These operations automatically convert native Python types.
tf.Tensor 与numpy中的ndarray和相似,但是也存在不同:
- numpy只能在CPU上计算,numpy没有实现GPU加速计算
- numpy.ndarray可以修改(元素), 但是tf.Tensor的数据不可改
- Tensors can be backed by accelerator memory (like GPU, TPU).
- Tensors are immutable.
# Tensor 创建
var1 = tf.ones(shape=[3,4], dtype=tf.float32)
print(var1)
print(var1.shape)
print(var1.dtype)
print(var1[0,:]) # return a tensor
# var1[0, :] = 10 # 错误, tf.Tensor不可改(Tensors are immutable.), TypeError: 'tensorflow.python.framework.ops.EagerTensor' object does not support item assignment
# tensorflow
print('Tensorflow')
x = tf.ones([3,3])
print(id(x))
x += 10 # 这里新建了一个Tensor, 并不是in-place add
print(id(x))
# PyTorch
print('torch')
x = torch.ones([3,3])
print(id(x))
x += 10
print(id(x))
# numpy
print('numpy')
x = np.ones([3,3])
print(id(x))
x += 10
print(id(x))
图片.png
tf.Tensor与np.ndarray相互转换
# tf.Tensor 与numpy转换
# tf.Tensor 与numpy 可以自动转换
var2 = tf.zeros([3,3])
var3 = var2.numpy() # tensor--->numpy
print(type(var3))
# numpy ---> tensor
var4 = tf.convert_to_tensor(np.arange(12).reshape([3,4]))
print(var4)
tf.Tensor指定计算设备(CPU/GPU)
# CPU/GPU
# 查询CPU/GPU
print(tf.config.experimental.list_physical_devices('CPU'))
print(tf.config.experimental.list_physical_devices('GPU'))
# 判读Tensor是否在GPU/CPU
var5 = tf.random.uniform(shape=[3,3])
print(var5.device) # 判断Tensor在CPU/GPU
print(var5.device.endswith('GPU:0')) # 判断是否GPU Tensor
# 在CPU上计算
with tf.device('CPU:0'):
x1 = tf.random.normal(shape=[128, 64])
x2 = tf.random.normal(shape=[64, 128])
y = tf.matmul(a=x1,b=x2)
assert y.device.endswith('CPU:0')
print(y.shape)