其他分享
首页 > 其他分享> > to numpy() and to torch()

to numpy() and to torch()

作者:互联网

def to_numpy(self) -> 'Batch':
"""Change all torch.Tensor to numpy.ndarray in-place."""
for k, v in self.items():
if isinstance(v, torch.Tensor):
self[k] = v.detach().cpu().numpy()
return self

def to_torch(self, dtype : torch.dtype = torch.float32, device: str = "cpu") -> 'Batch':
"""Change all numpy.ndarray to torch.Tensor in-place."""
for k, v in self.items():
self[k] = torch.as_tensor(v, dtype=dtype, device=device)
return self


below is how to use them:
batch = batch.to_torch(dtype=torch.float32, device=self.args["device"])


from :
offlinerl /neorl

标签:Tensor,dtype,self,torch,device,numpy
来源: https://www.cnblogs.com/leifzhang/p/16198435.html