pytorch网络转libtorch常见问题
作者:互联网
目录
一、RuntimeError: all inputs of range must be ints, found Tensor in argument 0:
一、RuntimeError: all inputs of range must be ints, found Tensor in argument 0:
问题
参数类型不正确,函数的默认参数是tensor
解决措施
函数传入参数不是tensor需要注明类型
我的问题是传入参数npoint
是一个int
类型,没有注明会报错,更改如下:
由
def test(npoint):
...
更改为
def test(npoint: int):
...
二、RuntimeError: Sliced expression not yet supported for subscripted assignment. File a bug if you want this:
问题
不支持赋值给切片表达式
解决措施
根据自己需求,进行修改,可利用循环替代
我将view_shape[1:] = [1] * (len(view_shape) - 1)
更改为
for i in range(1, len(view_shape)):
view_shape[i] = 1
三、Tried to access nonexistent attribute or method 'len' of type 'torch.torch.nn.modules.container.ModuleList'. Did you forget to initialize an attribute in init()?
问题
forward
函数中好像不支持len(nn.ModuleList())
和下标访问
解决措施
如果是一个ModuleList()
可以用enumerate
函数,多个同维度的可以用zip
函数
我这里有两个ModuleList()
,所以采用zip
函数,更改如下:
由
for i, conv in enumerate(self.mlp_convs):
bn = self.mlp_bns[i]
new_points = F.relu(bn(conv(new_points)))
更改为
for conv, bn in zip(self.mlp_convs, self.mlp_bns):
new_points = F.relu(bn(conv(new_points)))
ref: https://github.com/pytorch/pytorch/issues/16123
四、Expected integer literal for index
问题和解决方法类似第三个
标签:常见问题,torch,pytorch,len,libtorch,shape,new,ModuleList,view 来源: https://www.cnblogs.com/xiaxuexiaoab/p/15555066.html