to numpy() and to torch()
作者:互联网
def to_numpy(self) -> 'Batch':
"""Change all torch.Tensor to numpy.ndarray in-place."""
for k, v in self.items():
if isinstance(v, torch.Tensor):
self[k] = v.detach().cpu().numpy()
return self
def to_torch(self, dtype : torch.dtype = torch.float32, device: str = "cpu") -> 'Batch':
"""Change all numpy.ndarray to torch.Tensor in-place."""
for k, v in self.items():
self[k] = torch.as_tensor(v, dtype=dtype, device=device)
return self
below is how to use them:
batch = batch.to_torch(dtype=torch.float32, device=self.args["device"])
from :
offlinerl /neorl
标签:Tensor,dtype,self,torch,device,numpy 来源: https://www.cnblogs.com/leifzhang/p/16198435.html