解决网络训练验证过程中显存增加的原因
作者:互联网
最近在训练网络时发现网络训练了几个epoch之后就会出现OOM
一开始以为是内存不够,后来才发现是在网络训练过程中,显存会不断的增加。
针对以上的问题,查找资料总结了三种有用的方式
- 训练过程过程中,保存参数加.item()
原代码:
def train_one_epoch(
model, criterion, train_dataloader, optimizer, epoch, clip_max_norm
):
model.train()
device = next(model.parameters()).device
train_loss = 0
for i, d in enumerate(train_dataloader):
d = d.to(device)
optimizer.zero_grad()
out_net = model(d)
loss = criterion(out_net, d, epoch)
train_loss += loss
loss.backward()
if clip_max_norm > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), clip_max_norm)
optimizer.step()
更改后:
def train_one_epoch(
model, criterion, train_dataloader, optimizer, epoch, clip_max_norm
):
model.train()
device = next(model.parameters()).device
train_loss = 0
for i, d in enumerate(train_dataloader):
d = d.to(device)
optimizer.zero_grad()
out_net = model(d)
loss = criterion(out_net, d, epoch)
train_loss += loss.item()
loss.backward()
if clip_max_norm > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), clip_max_norm)
optimizer.step()
原因可以参考:https://zhuanlan.zhihu.com/p/85838274
- model.eval()放在with torch.no_grad()后:
原代码:
def test_epoch(epoch, test_dataloader, model, criterion):
model.eval()
device = next(model.parameters()).device
valid_loss = 0
with torch.no_grad():
for d in test_dataloader:
d = d.to(device)
out_net = model(d)
loss = criterion(out_net, d, epoch)
valid_loss += loss.item()
修改后:
def test_epoch(epoch, test_dataloader, model, criterion):
valid_loss = 0
with torch.no_grad():
model.eval()
device = next(model.parameters()).device
for d in test_dataloader:
d = d.to(device)
out_net = model(d)
loss = criterion(out_net, d, epoch)
valid_loss += loss.item()
3.使用torch.cuda.empty_cache()清空不用显存:
def test_epoch(epoch, test_dataloader, model, criterion):
valid_loss = 0
with torch.no_grad():
model.eval()
device = next(model.parameters()).device
for d in test_dataloader:
d = d.to(device)
out_net = model(d)
loss = criterion(out_net, d, epoch)
valid_loss += loss.item()
torch.cuda.empty_cache()
使用参考:https://www.i4k.xyz/article/zxyhhjs2017/92795831
标签:显存,loss,训练,验证,dataloader,epoch,train,device,model 来源: https://blog.csdn.net/ChandelerGause/article/details/121303409