PyTorch: torch.cuda.OutOfMemoryError: 在设备 0 上第 0 个副本中捕获到 OutOfMemoryError

huangapple go评论56阅读模式
英文:

PyTorch: torch.cuda.OutOfMemoryError: Caught OutOfMemoryError in replica 0 on device 0

问题

在同时训练多个模型并使用多个GPU(如mobilenetmobilenetv2)时,您可能会遇到内存不足的问题。以下是可能的解决方案:

  1. 在每个模型训练之间进行内存清理:

    import gc
    gc.collect()
    torch.cuda.empty_cache()  # PyTorch的内存清理方法
    

这些代码片段可以在每个模型训练之间插入,以释放之前模型使用的GPU内存,从而减少内存占用。

请确保在每个模型训练之前和之后都执行这些操作,以便及时释放内存并避免torch.cuda.OutOfMemoryError错误。

英文:

I am training multiple models on multiple GPUs like mobilenet, mobilenetv2 at the same time. After training and evaluating first model, I am getting an error torch.cuda.OutOfMemoryError: Caught OutOfMemoryError in replica 0 on device 0.. I have tried various solutions like below

Code

import time
import pathlib
from os.path import isfile
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
import models
from utils import *
from config import config
from data import DataLoader
# for ignore imagenet PIL EXIF UserWarning
import warnings
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
best_acc1 = 0
def main():
global opt, start_epoch, best_acc1
opt = config()
if opt.cuda and not torch.cuda.is_available():
raise Exception('No GPU found, please run without --cuda')
print('\n=> creating model \'{}\''.format(opt.arch))
if opt.arch == 'shufflenet':
model = models.__dict__[opt.arch](opt.dataset, opt.width_mult, opt.groups)
else:
model = models.__dict__[opt.arch](opt.dataset, opt.width_mult)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=opt.lr,
momentum=opt.momentum, weight_decay=opt.weight_decay,
nesterov=True)
start_epoch = 0
n_retrain = 0
if opt.cuda:
torch.cuda.set_device(opt.gpuids[0])
with torch.cuda.device(opt.gpuids[0]):
model = model.cuda()
criterion = criterion.cuda()
model = nn.DataParallel(model, device_ids=opt.gpuids,
output_device=opt.gpuids[0])
cudnn.benchmark = True
# checkpoint file
ckpt_dir = pathlib.Path('checkpoint')
ckpt_file = ckpt_dir / opt.arch / opt.dataset / opt.ckpt
# for resuming training
if opt.resume:
if isfile(ckpt_file):
print('==> Loading Checkpoint \'{}\''.format(opt.ckpt))
checkpoint = load_model(model, ckpt_file, opt)
start_epoch = checkpoint['epoch']
optimizer.load_state_dict(checkpoint['optimizer'])
print('==> Loaded Checkpoint \'{}\' (epoch {})'.format(
opt.ckpt, start_epoch))
else:
print('==> no checkpoint found \'{}\''.format(
opt.ckpt))
return
# Data loading
print('==> Load data..')
train_loader, val_loader = DataLoader(opt.batch_size, opt.workers,
opt.dataset, opt.datapath,
opt.cuda)
# for evaluation
if opt.evaluate:
if isfile(ckpt_file):
print('==> Loading Checkpoint \'{}\''.format(opt.ckpt))
checkpoint = load_model(model, ckpt_file, opt)
start_epoch = checkpoint['epoch']
optimizer.load_state_dict(checkpoint['optimizer'])
print('==> Loaded Checkpoint \'{}\' (epoch {})'.format(
opt.ckpt, start_epoch))
# evaluate on validation set
print('\n===> [ Evaluation ]')
start_time = time.time()
acc1, acc5 = validate(val_loader, model, criterion)
save_eval(['{}-{}-{}'.format(opt.arch, opt.dataset, opt.ckpt[:-4]),
str(acc1)[7:-18], str(acc5)[7:-18]], opt)
elapsed_time = time.time() - start_time
print('====> {:.2f} seconds to evaluate this model\n'.format(
elapsed_time))
return
else:
print('==> no checkpoint found \'{}\''.format(
opt.ckpt))
return
# train...
train_time = 0.0
validate_time = 0.0
for epoch in range(start_epoch, opt.epochs):
adjust_learning_rate(optimizer, epoch, opt.lr)
print('\n==> {}/{} training'.format(opt.arch, opt.dataset))
print('==> Epoch: {}, lr = {}'.format(
epoch, optimizer.param_groups[0]["lr"]))
# train for one epoch
print('===> [ Training ]')
start_time = time.time()
acc1_train, acc5_train = train(train_loader,
epoch=epoch, model=model,
criterion=criterion, optimizer=optimizer)
elapsed_time = time.time() - start_time
train_time += elapsed_time
print('====> {:.2f} seconds to train this epoch\n'.format(
elapsed_time))
# evaluate on validation set
print('===> [ Validation ]')
start_time = time.time()
acc1_valid, acc5_valid = validate(val_loader, model, criterion)
elapsed_time = time.time() - start_time
validate_time += elapsed_time
print('====> {:.2f} seconds to validate this epoch\n'.format(
elapsed_time))
# remember best Acc@1 and save checkpoint and summary csv file
is_best = acc1_valid > best_acc1
best_acc1 = max(acc1_valid, best_acc1)
state = {'epoch': epoch + 1,
'model': model.state_dict(),
'optimizer': optimizer.state_dict()}
summary = [epoch,
str(acc1_train)[7:-18], str(acc5_train)[7:-18],
str(acc1_valid)[7:-18], str(acc5_valid)[7:-18]]
save_model(state, epoch, is_best, opt)
save_summary(summary, opt)
avg_train_time = train_time / (opt.epochs-start_epoch)
avg_valid_time = validate_time / (opt.epochs-start_epoch)
total_train_time = train_time + validate_time
print('====> average training time per epoch: {:,}m {:.2f}s'.format(
int(avg_train_time//60), avg_train_time%60))
print('====> average validation time per epoch: {:,}m {:.2f}s'.format(
int(avg_valid_time//60), avg_valid_time%60))
print('====> training time: {}h {}m {:.2f}s'.format(
int(train_time//3600), int((train_time%3600)//60), train_time%60))
print('====> validation time: {}h {}m {:.2f}s'.format(
int(validate_time//3600), int((validate_time%3600)//60), validate_time%60))
print('====> total training time: {}h {}m {:.2f}s'.format(
int(total_train_time//3600), int((total_train_time%3600)//60), total_train_time%60))
def train(train_loader, **kwargs):
epoch = kwargs.get('epoch')
model = kwargs.get('model')
criterion = kwargs.get('criterion')
optimizer = kwargs.get('optimizer')
batch_time = AverageMeter('Time', ':6.3f')
data_time = AverageMeter('Data', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
progress = ProgressMeter(len(train_loader), batch_time, data_time,
losses, top1, top5, prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i, (input, target) in enumerate(train_loader):
# measure data loading time
data_time.update(time.time() - end)
if opt.cuda:
target = target.cuda(non_blocking=True)
# compute output
output = model(input)
loss = criterion(output, target)
# measure accuracy and record loss
acc1, acc5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(acc1[0], input.size(0))
top5.update(acc5[0], input.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
if i % opt.print_freq == 0:
progress.print(i)
end = time.time()
print('====> Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))
return top1.avg, top5.avg
def validate(val_loader, model, criterion):
batch_time = AverageMeter('Time', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
progress = ProgressMeter(len(val_loader), batch_time, losses, top1, top5,
prefix='Test: ')
# switch to evaluate mode
model.eval()
with torch.no_grad():
end = time.time()
for i, (input, target) in enumerate(val_loader):
if opt.cuda:
target = target.cuda(non_blocking=True)
# compute output
output = model(input)
loss = criterion(output, target)
# measure accuracy and record loss
acc1, acc5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(acc1[0], input.size(0))
top5.update(acc5[0], input.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
if i % opt.print_freq == 0:
progress.print(i)
end = time.time()
print('====> Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))
return top1.avg, top5.avg
if __name__ == '__main__':
start_time = time.time()
main()
elapsed_time = time.time() - start_time
print('====> total time: {}h {}m {:.2f}s'.format(
int(elapsed_time//3600), int((elapsed_time%3600)//60), elapsed_time%60))

Solutions

gc.collect()
torch.cuda.empty_cache() # PyTorch thing

Trace-back

==> mobilenet/cifar10 training
==> Epoch: 17, lr = 0.07093217661806457
===> [ Training ]
Epoch: [17][0/9]	Time  2.638 ( 2.638)	Data  2.527 ( 2.527)	Loss 1.1166e+00 (1.1166e+00)	Acc@1  59.76 ( 59.76)	Acc@5  95.52 ( 95.52)
====> Acc@1 61.468 Acc@5 95.854
====> 4.97 seconds to train this epoch
===> [ Validation ]
Test: [0/2]	Time  1.674 ( 1.674)	Loss 1.1883e+00 (1.1883e+00)	Acc@1  57.50 ( 57.50)	Acc@5  95.46 ( 95.46)
====> Acc@1 57.620 Acc@5 95.300
====> 1.84 seconds to validate this epoch
==> mobilenet/cifar10 training
==> Epoch: 18, lr = 0.06951353308570328
===> [ Training ]
Epoch: [18][0/9]	Time  2.582 ( 2.582)	Data  2.467 ( 2.467)	Loss 1.0763e+00 (1.0763e+00)	Acc@1  61.83 ( 61.83)	Acc@5  96.33 ( 96.33)
====> Acc@1 62.808 Acc@5 96.350
====> 4.92 seconds to train this epoch
===> [ Validation ]
Test: [0/2]	Time  1.721 ( 1.721)	Loss 1.1518e+00 (1.1518e+00)	Acc@1  58.51 ( 58.51)	Acc@5  95.67 ( 95.67)
====> Acc@1 58.540 Acc@5 95.560
====> 1.88 seconds to validate this epoch
==> mobilenet/cifar10 training
==> Epoch: 19, lr = 0.06812326242398921
===> [ Training ]
Epoch: [19][0/9]	Time  2.441 ( 2.441)	Data  2.314 ( 2.314)	Loss 1.0599e+00 (1.0599e+00)	Acc@1  62.20 ( 62.20)	Acc@5  96.34 ( 96.34)
====> Acc@1 63.502 Acc@5 96.530
====> 4.75 seconds to train this epoch
===> [ Validation ]
Test: [0/2]	Time  1.664 ( 1.664)	Loss 1.1191e+00 (1.1191e+00)	Acc@1  59.76 ( 59.76)	Acc@5  96.39 ( 96.39)
====> Acc@1 59.460 Acc@5 96.060
====> 1.83 seconds to validate this epoch
====> average training time per epoch: 0m 6.81s
====> average validation time per epoch: 0m 1.88s
====> training time: 0h 2m 16.22s
====> validation time: 0h 0m 37.55s
====> total training time: 0h 2m 53.77s
====> total time: 0h 3m 18.80s
=> creating model 'mobilenet'
==> Load data..
Files already downloaded and verified
Files already downloaded and verified
==> Loading Checkpoint '/home2/coremax/Documents/BoxMix/checkpoint/mobilenet/cifar10/ckpt_best.pth'
==> Loaded Checkpoint '/home2/coremax/Documents/BoxMix/checkpoint/mobilenet/cifar10/ckpt_best.pth' (epoch 20)
===> [ Evaluation ]
Test: [ 0/40]	Time  1.680 ( 1.680)	Loss 1.0908e+00 (1.0908e+00)	Acc@1  64.45 ( 64.45)	Acc@5  96.09 ( 96.09)
====> Acc@1 59.460 Acc@5 96.060
====> 2.21 seconds to evaluate this model
====> total time: 0h 0m 6.03s
=> creating model 'mobilenetv2'
==> Load data..
Files already downloaded and verified
Files already downloaded and verified
==> mobilenetv2/cifar10 training
==> Epoch: 0, lr = 0.1
===> [ Training ]
Traceback (most recent call last):
File "/home2/coremax/Documents/BoxMix/main.py", line 257, in <module>
main()
File "/home2/coremax/Documents/BoxMix/main.py", line 117, in main
acc1_train, acc5_train = train(train_loader,
File "/home2/coremax/Documents/BoxMix/main.py", line 187, in train
output = model(input)
File "/home2/coremax/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home2/coremax/anaconda3/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 171, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File "/home2/coremax/anaconda3/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 181, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File "/home2/coremax/anaconda3/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 89, in parallel_apply
output.reraise()
File "/home2/coremax/anaconda3/lib/python3.9/site-packages/torch/_utils.py", line 543, in reraise
raise exception
torch.cuda.OutOfMemoryError: Caught OutOfMemoryError in replica 0 on device 0.

答案1

得分: 0

我正在使用两个GPU(TESLA V100 16GB)训练mobilenet模型,批量大小为6096,虽然很大但仍然可以轻松训练模型。当我同时训练多个模型,如mobilenetmobilenetv2时,mobilenetv2出现了replica错误。我尝试了gc.collect()torch.cuda.empty_cache()的解决方案,但对我不起作用。

我通过将批量大小从6096显著减小到256来解决了上述问题。

英文:

I am training mobilenet on two GPUs (TESLA V100 16GB) with a batch size of 6096 which is very bigger but still, I can train my model easily. When I trained multiple models like mobilenet and mobilenetv2at the same time I am getting replica error in mobilenetv2. I tried gc.collect() and torch.cuda.empty_cache() solution it didn't work for me.

I solved the above problem by significantly decreasing the batch size from 6096 to 256

huangapple
  • 本文由 发表于 2023年1月9日 15:14:30
  • 转载请务必保留本文链接:https://go.coder-hub.com/75054128.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定