Deep-Learning-with-PyTorch-3.5.3
2020-10-05 本文已影响0人
追求科技的足球
3.5.3 管理张量的dtype属性
为了分配正确的数字类型的张量,我们可以指定适当的dtype作为构造函数的参数。 例如:
# In[47]:
double_points = torch.ones(10, 2, dtype=torch.double)
short_points = torch.tensor([[1, 2], [3, 4]], dtype=torch.short
)
我们可以通过访问相应的属性来找到张量的dtype:
# In[48]:
short_points.dtype
# Out[48]:
torch.int16
我们还可以使用相应的转换方法将张量创建函数的输出转换为正确的类型,例如
# In[49]:
double_points = torch.zeros(10, 2).double()
short_points = torch.ones(10, 2).short()
或更方便的to方法:
# In[50]:
double_points = torch.zeros(10, 2).to(torch.double)
short_points = torch.ones(10, 2).to(dtype=torch.short)
在后台,检查是否有必要进行转换,如果有必要,则进行转换。 以dtype命名的强制转换方法(例如float)是to的简写,但是to方法可以采用其他参数,我们将在3.9节中讨论。
在操作中混合输入类型时,输入将自动转换为较大的类型。 因此,如果我们要进行32位计算,则需要确保所有输入(最多)是32位:
# In[51]:
points_64 = torch.rand(5, dtype=torch.double) # rand将张量元素初始化为0到1之间的随机数。
points_short = points_64.to(torch.short)
points_64 * points_short # works from PyTorch 1.3 onwards
# Out[51]:
tensor([0., 0., 0., 0., 0.], dtype=torch.float64)