其他分享
首页 > 其他分享> > pytorch flatten()

pytorch flatten()

作者:互联网

 

 torch.flatten(input, start_dim, end_dim).

举例:一个tensor 3*2* 2

start_dim=1  output 3*4

start_dim=0 end_dim=1.    6*2

如果没有后面两个参数直接变为一维的

 

标签:dim,end,tensor,torch,start,pytorch,flatten
来源: https://www.cnblogs.com/h694879357/p/15855063.html