初识pytorch
2019-07-08 本文已影响0人
菜田的守望者w
初始pytorch
定义tensor
- torch.FloatTensor()
- torch.Tensor()
上面两种方法一般给定一个维度直接生成随机数,也可以给定现成数据,不过现成数据一般使用下面的方法。
- torch.tensor()
torch一般使用数据类型为Double使用时需要定义torch.set_default_tensor_type(torch.DoubleTensor) - 生成从min到max的数使用randint(1,10)
- torch.full([ ],7)生成7的标量
- torch.full([1],7)生成一维的标量
变量的shape
- a.dim直接表示几维的变量
- a.size表示具体形状
- a.shape表示其形状比如[2,3]二行三列
类似numpy的torch变量
- torch.arange(0,10)
- torch.linspace(0,10,steps=4)等分切成4份
- torch.logspace(0,-1,stept=10) 从10的0此方到10的-1次方,生成10个数
- torch.ones()
torch.zeros()
torch.eye()
注意:eye只能接受一个或者两个参数,不能接受三个参数,就是只能适合于一个矩阵。
索引
a=torch.rand(4,3,28,28)
a[0].shape------->torch.size([3,28,28])
a[0,0].shape------->torch.size([28,28])
a[0,0,2,4]---------->tensor(0.8282)第一个图片第一个通道,二行四列的数据
切片
- :all
- :n-> n:<-
-
[start,end]
取前两张图片啊a[1:3,:,:,:]或者a[1:3] QQ截图20190708132914.jpg
QQ截图20190708132929.jpg
QQ截图20190708143116.jpg
QQ截图20190708144222.jpg
维度转换
- view的使用
b=a.view(4,28*28)
print(b.shape)
a=torch.tensor([1.2,2.3])
c=a.unsqueeze(-1)
f=torch.rand(4,32,14,14)
b=b.unsqueeze(1).unsqueeze(2).unsqueeze(0)
b.shape
torch.Size([1,32,1,1])
相反squeeze方法作用是将某一维度合并,但这一维度必须为1
expand操作方法
b.shape
torch.Size([1,32,1,1])
b.expand(4,32,14,14).shape
torch.Size([4,32,14,14])
b.expand(-1,32,-1,-1).shape
torch.Size([1,32,1,1])
repeate重复次数
b.shape
torch.Size([1,32,1,1])
b.repeat(4,32,1,1).shape
torch.Size([4,1024,1,1])
b.repeat(4,1,1,1).shape
torch.Size([4,32,1,1])
b.repeat(4,1,32,32)
torch.Size([4,32,32,32])
转制操作.t
a=torch.randn(3,4)
a.t()
生成4X3的矩阵只能使用2D矩阵
- 使用索引直接转置
b=torch.rand(4,3,28,32)
b.permute(0,2,3,1).shape
torch.Size([4,28,32,3]) - 将两个向量相加
a=torch.rand(3,32,8)
b=torch.rand(4,32,8)
print(torch.cat([a,b],dim=).shape)
(7,32,8) - 使用stack时两个张量必须维度相同
a=torch.tensor(3.14)
a.floor(),a.ceil(),a.trunc(),a.frac()
tensor(3.),tensor(4.),tensor(3.),tensor(0.1400)
四舍五入
a=torch.tensor(3.499)
a.round()
tensor(3.0)
a=torch.tensor(3.5)
a.round()
tensor(4.)
- 数据操作
a=torch.arange(8).view(2,4).float()
tensor([[0,1,2,3],
[4,5,6,7]])
a.min(), a.max(), a.mean(), a.prod()
最小值:0,最大值:7,平均值:3.5,累乘0
a.sum()
累加28
a.argmax(), a.argmin()
不指定dim的话转化为一维索引为7,0