pytorch各种操作

2020-03-24  本文已影响0人  啊啊啊啊啊1231

pytorch为了加快数据提取,定义建立train_data,train_loader

train_data = TensorDataset(torch.tensor(train_data, dtype=torch.int))

train_loader = DataLoader(train_data, batch_size= batch_size,shuffle=True, drop_last=True)

上一篇 下一篇

猜你喜欢

热点阅读