scipy的稀疏矩阵转换成torch的sparse tensor
2021-01-10 本文已影响0人
heishanlaoniu
def scipy_sparse_mat_to_torch_sparse_tensor(sparse_mx):
"""
将scipy的sparse matrix转换成torch的sparse tensor.
"""
sparse_mx = sparse_mx.tocoo().astype(np.float32)
indices = torch.from_numpy(
np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
values = torch.from_numpy(sparse_mx.data)
shape = torch.Size(sparse_mx.shape)
return torch.sparse.FloatTensor(indices, values, shape)