其他分享
首页 > 其他分享> > 深度学习比赛入门——街景字符识别(三)

深度学习比赛入门——街景字符识别(三)

作者:互联网

本文是街景字符识别比赛第三阶段,关于模型搭建的相关问题

在baseline中,使用了比较简单的神经网络的搭建,并且为了加快模型的收敛,使用了预训练分类模型

class SVHN_Model1(nn.Module):
    def __init__(self):
        super(SVHN_Model1, self).__init__()

        model_conv = models.resnet18(pretrained=True)
        model_conv.avgpool = nn.AdaptiveAvgPool2d(1)
        model_conv = nn.Sequential(*list(model_conv.children())[:-1])
        self.cnn = model_conv

        self.fc1 = nn.Linear(512, 11)
        

        self.fc2 = nn.Linear(512, 11)
        
        self.fc3 = nn.Linear(512, 11)
        
        self.fc4 = nn.Linear(512, 11)
        

        self.fc5 = nn.Linear(512, 11)
        

    def forward(self, img):
        feat = self.cnn(img)
        # print(feat.shape)
        feat = feat.view(feat.shape[0], -1)
        c1 = self.fc1(feat)
        

        c2 = self.fc2(feat)
        

        c3 = self.fc3(feat)
        

        c4 = self.fc4(feat)
        

        c5 = self.fc5(feat)
        
        return c1, c2, c3, c4, c5

在原有的模型上,我们可以做一些更改,比如说加上正则化参数,增加网络的复杂性并扩充数据集,或者采用更复杂的高效的网络例如CRNN,多做一些尝试,目前正在尝试新的网络模型,待试验成功之后,再这里分享我的心得

标签:11,字符识别,入门,nn,街景,self,512,feat,Linear
来源: https://www.cnblogs.com/wushupei/p/12969374.html