其他分享
首页 > 其他分享> > DDP训练

DDP训练

作者:互联网

from os import stat
import os
import time
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader, sampler
from torch.optim.lr_scheduler import StepLR
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data.distributed import DistributedSampler

import shutil
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.log_softmax(x)

        return x


def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for idx, (images, targets) in enumerate(train_loader):
        images, targets = images.to(device), targets.to(device)
        pred = model(images)
        loss = F.cross_entropy(pred, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # print("===>local_rank:{}".format(args.local_rank))
        if idx % args.log_interval == 0 and args.local_rank == 0:
            print("Train Time:{}, epoch: {}, step: {}, loss: {}".format(time.strftime("%Y-%m-%d%H:%M:%S"), epoch + 1, idx, loss.item()))


def test(args, model, device, test_loader):
    model.eval()
    test_loss = 0
    test_acc = 0

    with torch.no_grad():
        for (images, targets) in test_loader:
            images, targets = images.to(device), targets.to(device)
            pred = model(images)
            loss = F.cross_entropy(pred, targets, reduction="sum")
            test_loss += loss.item()
            pred_label = torch.argmax(pred, dim=1, keepdims=True)
            test_acc += pred_label.eq(targets.view_as(pred_label)).sum().item()

    test_loss /= len(test_loader.dataset)
    test_acc /= len(test_loader.dataset)

    print("Test Time:{}, loss: {}, acc: {}".format(time.strftime("%Y-%m-%d%H:%M:%S"), test_loss, test_acc))

    return test_acc

def save_checkpoint(state, is_best, filename = 'checkpoint.pth.tar'):
    torch.save(state, filename)
    print("===> save state to {}\n".format(filename))
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

def main():
    parser = argparse.ArgumentParser(description="MNIST TRAINING")
    parser.add_argument('--device_ids', type=str, default='0', help="Training Devices")
    parser.add_argument('--epochs', type=int, default=10, help="Training Epoch")
    parser.add_argument('--log_interval', type=int, default=100, help="Log Interval")
    parser.add_argument('--resume', type=str, default="/home/ubuntu/suyunzheng_ws/mnist/code/checkpoint.pth.tar", help="checkpoint resume path")
    parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameterm do not modify')          # 注: 这里如果使用了argparse, 一定要在参数里面加上--local_rank, 否则运行还是会出错的
    args = parser.parse_args()

    device_ids = list(map(int, args.device_ids.split(',')))
    dist.init_process_group(backend='nccl')
    device = torch.device('cuda:{}'.format(device_ids[args.local_rank]))            # 不同的进程(GPU)这个是不同的
    print("===> devcice:{}\n".format(device))
    torch.cuda.set_device(device)           # 设置当前设备
    model = Net().to(device)
    model = DistributedDataParallel(module=model, device_ids=[device_ids[args.local_rank]], output_device=device_ids[args.local_rank], find_unused_parameters=True)         # 忽略有,但是没有使用的参数


    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307), (0.3081))])

    dataset_train = datasets.MNIST('/home/ubuntu/suyunzheng_ws/mnist/data', train=True, transform=transform, download=True)
    dataset_test = datasets.MNIST('/home/ubuntu/suyunzheng_ws/mnist/data', train=False, transform=transform, download=True)

    sampler_train = DistributedSampler(dataset=dataset_train, shuffle=True)

    train_loader = DataLoader(dataset_train, batch_size=8, num_workers=8, sampler=sampler_train)         # DDP batch_size means per GPU's batch_size, while DP batch_size means all GPU's batch_size
    test_loader = DataLoader(dataset_test, batch_size=8, shuffle=False, num_workers=8)

    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    scheduler = StepLR(optimizer, step_size=1)

    start_epoch = 0
    best_acc = 0
    if args.resume and args.local_rank == 0:
        print("===> resume, ..., device:{}".format(device))
        if os.path.isfile(args.resume):
            print("===> loading checkpoint :{}".format(args.resume))
            # loc = 'cuda:{}'.format(0)
            loc = 'cuda:{}'.format(device_ids[args.local_rank])
            checkpoint = torch.load(args.resume, map_location=loc)
            start_epoch = checkpoint['epoch']
            best_acc = checkpoint['best_acc']
            # best_acc =best_acc.to(loc)

            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("===> loaded checkpoint {} (epoch {}, best_acc {})".format(args.resume, checkpoint['epoch'], checkpoint['best_acc']))


    for start_epoch in range(args.epochs):
        sampler_train.set_epoch(epoch=start_epoch)            # 设置不同GPU的数据是变换的,保证更加shuffle
        train(args, model, device, train_loader, optimizer, start_epoch)
        
        scheduler.step()

        if args.local_rank == 0:                        # local_rank有点像不同的进程
            acc = test(args, model, device, test_loader)
            is_best = acc>best_acc
            best_acc = max(acc, best_acc)
            save_checkpoint(
                {
                    'epoch': start_epoch+1,
                    'state_dict': model.state_dict(),
                    'best_acc': best_acc,
                    'optimizer': optimizer.state_dict(),
                },
                is_best=is_best
            )

        # if args.local_rank == 0:
        #     torch.save(model.state_dict(), 'train.pt')

if __name__ == '__main__':
    main()

    # python -m torch.distributed.launch --nproc_per_node 2 --master_port 1234 train_ddp.py --device_ids=0,1 --epoch=2

标签:acc,DDP,torch,训练,args,test,device,best
来源: https://blog.csdn.net/suyunzzz/article/details/122177945