其他分享
首页 > 其他分享> > vgg16复现

vgg16复现

作者:互联网

主要是练了一下数据读取

这次用的cifa10,整个是一个字典,取了前100个去训练了一下

要先把每一行reshape成32 * 32 * 3

self.data = self.data.reshape(-1, 32, 32, 3)

 __getitem__ 里放到tranforms之前先Image.fromarray()

 

VGG_dataset:

from torch.utils import data
from PIL import Image
import random
import torchvision.transforms as T
import matplotlib.pyplot as plt

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

# imgs = unpickle('H:/DataSet/cifar-10-python/cifar-10-batches-py/data_batch_1')
# print(imgs[b'data'].reshape(-1, 3, 32, 32))



class Dataset(data.Dataset):
    def __init__(self, root, train = True, test = False):
        self.test = test
        self.train = train
        imgs = unpickle(root)
        self.data = imgs[b'data'][: 100, :]
        self.data = self.data.reshape(-1, 32, 32, 3)
        self.label = imgs[b'labels'][: 100]

        if self.train:
            self.transforms = T.Compose([
                T.Scale(random.randint(256, 384)),
                T.RandomCrop(224),
                T.ToTensor(),
                T.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
            ])
        elif self.test:
            self.transforms = T.Compose([
                T.Scale(224),
                T.ToTensor(),
                T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])

    def __getitem__(self, index):
        data = Image.fromarray(self.data[index])
        data = self.transforms(data)
        return data, self.label[index]
    def __len__(self):
        return len(self.label)

 

config:

class configuration:
    train_root = 'H:/DataSet/cifar-10-python/cifar-10-batches-py/data_batch_1'
    test_root = 'H:/DataSet/cifar-10-python/cifar-10-batches-py/test_batch'
    label_nums = 10
    batch_size = 4
    epochs = 10
    lr = 0.01

VGG:

import torch
import torch.nn as nn
import torch.utils.data.dataloader as Dataloader
import numpy as np
import torch.nn.functional as F
from config import configuration
from VGG_dataset import Dataset
from tqdm import tqdm
import matplotlib.pyplot as plt
from PIL import Image

device = 'cuda' if torch.cuda.is_available() else 'cpu'

con = configuration()

class vgg(nn.Module):
    def __init__(self):
        super(vgg, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size = 3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(64, 64,kernel_size = 3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size = 3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(128, 128, kernel_size = 3, stride=1, padding=1)
        self.conv5 = nn.Conv2d(128, 256, kernel_size = 3, stride=1, padding=1)
        self.conv6 = nn.Conv2d(256, 256, kernel_size = 3, stride=1, padding=1)
        self.conv7 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.conv8 = nn.Conv2d(256, 512, kernel_size = 3, stride=1, padding=1)
        self.conv9 = nn.Conv2d(512, 512, kernel_size = 3, stride=1, padding=1)
        self.conv10 = nn.Conv2d(512, 512,  kernel_size=3, stride=1, padding=1)
        self.conv11 = nn.Conv2d(512, 512, kernel_size = 3, stride=1, padding=1)
        self.conv12 = nn.Conv2d(512, 512, kernel_size = 3, stride=1, padding=1)
        self.conv13 = nn.Conv2d(512, 512,  kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(512 * 7 * 7, 4096)
        self.fc2 = nn.Linear(4096, 4096)
        self.fc3 = nn.Linear(4096, con.label_nums)
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = F.relu(self.conv3(x))
        x = F.max_pool2d(F.relu(self.conv4(x)), 2)
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))
        x = F.max_pool2d(F.relu(self.conv7(x)), 2)
        x = F.relu(self.conv8(x))
        x = F.relu(self.conv9(x))
        x = F.max_pool2d(F.relu(self.conv10(x)), 2)
        x = F.relu(self.conv11(x))
        x = F.relu(self.conv12(x))
        x = F.max_pool2d(F.relu(self.conv13(x)), 2)
        x = x.view(x.size()[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# img = Image.open('H:/C5AM385_Intensity.jpg')
# print(np.array(img).shape)


if __name__ == '__main__':
    model = vgg()
    model.to(device)
    train_dataset = Dataset(con.train_root)
    test_dataset = Dataset(con.test_root, False, True)
    train_dataloader = Dataloader.DataLoader(train_dataset, batch_size = con.batch_size, shuffle = True, num_workers = 4)
    loss = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr = con.lr)

    for epoch in range(con.epochs):
        total_loss = 0
        cnt = 0
        true_label = 0
        for data, label in tqdm(train_dataloader):
            # print(np.array(data[0]).shape)
            # plt.imshow(data[0])
            # plt.show()

            optimizer.zero_grad()
            data.to(device)
            label.to(device)
            output = model(data)
            loss_value = loss(output, label)
            loss_value.backward()
            optimizer.step()
            output = torch.max(output, 1)[1]
            total_loss += loss_value
            true_label += torch.sum(output == label)
            cnt += 1
        loss_mean = total_loss / float(cnt)
        accuracy = true_label / float(len(train_dataset))
        print('Loss:{:.4f}, Accuracy:{:.2f}'.format(loss_mean, accuracy))
    print('Train Accepted!')

 

标签:__,nn,vgg16,self,复现,import,data,size
来源: https://www.cnblogs.com/WTSRUVF/p/15364206.html