我的猫狗大战分类
作者:互联网
Cats vs Dogs based on VGG16
1.准备数据集
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/My Drive
!unzip cat_dog.zip
2.数据预处理
#先为数据集划分合理的文件存储结构
import shutil
import os
os.mkdir('train_/cat/')
os.mkdir('train_/dog/')
for f in os.listdir('cat_dog/train'):
if f.split('_')[0] == 'cat':
shutil.move('cat_dog/train/'+f,'train_/cat/'+f)
else:
shutil.move('cat_dog/train/'+f,'train_/dog/'+f)
train_dir = 'train_'
# 对训练集做数据扩充
train_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.RandomRotation(30),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
train_dataset = datasets.ImageFolder(train_dir,transform=train_transform)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size = 128,
shuffle = True
)
查看图片
import matplotlib.pyplot as plt
def imshow(image, ax=None, title=None, normalize=True):
"""Imshow for Tensor."""
if ax is None:
fig, ax = plt.subplots()
image = image.numpy().transpose((1, 2, 0))
if normalize:
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
image = std * image + mean
image = np.clip(image, 0, 1)
ax.imshow(image)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.tick_params(axis='both', length=0)
ax.set_xticklabels('')
ax.set_yticklabels('')
return ax
images, labels = next(iter(train_dataloader))
title = 'Dog' if labels[0].item() == 1 else 'Cat'
imshow(images[0])
3.训练模型
#先获取预训练好的VGG16
model = models.vgg16(pretrained=True)
model
冻结特征提取部分的参数
for param in model.parameters():
param.requires_grad = False
修改自己的分类器
from collections import OrderedDict
#修改2个层
classifier = nn.Sequential(OrderedDict([
# Layer 1
('dropout1',nn.Dropout(0.3)),
('fc1', nn.Linear(25088,500)),
('relu', nn.ReLU()),
# output layer
('fc2', nn.Linear(500,2)),
('output', nn.LogSoftmax(dim=1))
]))
model.classifier = classifier
# 损失函数选取NLLLoss()
criterion = nn.NLLLoss()
# 优化器采用Adam
optimizer = optim.Adam(model.classifier.parameters(),lr =0.001)
model = model.to(device)
迭代训练
from tqdm import tqdm
epochs = 5
for e in range(epochs):
running_loss, total, correct = 0, 0 , 0
model.train()
for images,labels in tqdm(train_dataloader):
# Moving input to GPU
images, labels = images.to(device), labels.to(device)
# Forward prop
outputs = model(images)
loss = criterion(outputs,labels)
# Backward prop
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Metrics
running_loss += loss.item()
total += labels.size(0)
_, predicted = torch.max(torch.exp(outputs).data,1)
correct += (predicted == labels).sum().item()
else:
# Logs
print(f'Epoch {e} Training: Loss={running_loss:.5f} Acc={correct/total * 100:.2f}')
4.保存模型
checkpoints = {
'pre-trained':'vgg16',
'classifier':nn.Sequential(OrderedDict([
# 修改的Layer 1
('dropout1',nn.Dropout(0.3)),
('fc1', nn.Linear(25088,500)),
('relu', nn.ReLU()),
# 修改的output layer
('fc2', nn.Linear(500,2)),
('output', nn.LogSoftmax(dim=1))
])),
'state_dict':model.state_dict()
}
torch.save(checkpoints,'vgg16_catsVdogs.pth')
def load_saved_model(path):
checkpoint = torch.load(path)
model = models.vgg16(pretrained=True)
for param in model.parameters():
param.requires_grad = False
model.classifier = checkpoint['classifier']
model.load_state_dict(checkpoints['state_dict'])
model.eval()
return model
loaded_model = load_saved_model('vgg16_catsVdogs.pth')
loaded_model.to(device)
5.测试模型
test_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
predictions = []
for i in tqdm(range(0,2000)):
path = 'catdog/test/'+str(i)+'.jpg'
X = Image.open(path).convert('RGB')
X = test_transform(X)[:3,:,:]
X = X.unsqueeze(0)
X = X.to(device)
outputs = loaded_model(X)
predictions.append(torch.argmax(outputs).item())
6.生成结果
# 对测试集的2000图片预测
data = {'id':list(range(0,2000)),'label':predictions}
df = pd.DataFrame(data)
df.to_csv('cats-dogs-submission.csv',index=False)
7.总结与结果分析
- 仅仅训练最后一层时,得到的准确率为96.5
- 多训练一层参数时(Layer1),得到的准确率为97.45
- 加入数据扩增方法(对原图片进行水平竖直平移、翻转、缩放)、换用损失函数$NLLLoss$、增加迭代epoch为5后得到的准确率为98.45
标签:labels,nn,分类,大战,train,transforms,ax,model 来源: https://www.cnblogs.com/mihara/p/15450662.html