英文:
How to release CPU memory in pytorch? (for large-scale inference)
问题
I understand that you want the code part to remain untranslated. Here's the translated text:
我正在尝试在单台计算机上对预训练的BERT模型进行大规模推断,并且遇到了CPU内存不足的错误。由于数据集太大,无法一次性对整个数据集进行模型评分,因此我尝试分批次运行它,将结果存储在列表中,然后在最后将这些张量连接在一起。我明白将张量存储在列表中会迅速使用大量的CPU内存。
然而,我无法弄清楚在张量连接后如何释放这些内存,因此在下游出现了内存不足的错误。
最小可复现示例:
import gc, time, torch, pytorch_lightning as pl
from transformers import BertTokenizer, BertModel
from torch.utils.data import DataLoader
class EncoderModelPL(pl.LightningModule):
def __init__(
self,
model: BertModel,
):
super(EncoderModelPL, self).__init__()
self.model: BertModel = model
def forward(self, x):
return self.model(x, output_hidden_states=True) # yes, I know this uses a ton of memory but my downstream application requires intermediate hidden states
MODEL_ID = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(MODEL_ID)
model = EncoderModelPL(BertModel.from_pretrained(MODEL_ID)).to("cuda")
dataset = torch.randint(low=0, high=30000, size=(800, 200), device="cuda")
tokens_dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
trainer = pl.Trainer(accelerator="gpu")
bert_outputs_per_batch: list = trainer.predict(
model=model, dataloaders=tokens_dataloader
) # CPU memory steadily increases here, this outputs a num-batches-length list containing the Bert output for each batch, stored in CPU
del bert_outputs_per_batch
gc.collect()
...<more downstream stuff>...
我尝试过以下几种方法:
- 如代码中所示,我尝试删除输出列表,然后进行垃圾回收。
- 我还尝试在此后休眠几秒钟,以便帮助垃圾回收(但没有用)。
- 我不认为这是由于跟踪计算图而引起的问题,我已确认如果在
dataset
上添加.detach()
也会出现相同的行为。 - 我尝试过在Dataloader中使用
pin_memory=False
,似乎没有任何区别。
总之,我的问题是:
- 我如何强制释放CPU内存中的张量?
- 更广泛地说,是否有比分批次更好(例如更节省内存)的大规模推断方法?
感谢您的帮助!
英文:
I'm trying to do large-scale inference of a pretrained BERT model on a single machine and I'm running into CPU out-of-memory errors. Since the dataset is too big to score the model on the whole dataset at once, I'm trying to run it in batches, store the results in a list, and then concatenate those tensors together at the end. I understand that storing tensors in lists can quickly use up large amounts of CPU memory.
However, I am unable to figure out how to release this memory after the tensors are concatenated and therefore I'm running into OOM errors downstream.
Minimal Reproducible Example:
<!-- language: python -->
import gc, time, torch, pytorch_lightning as pl
from transformers import BertTokenizer, BertModel
from torch.utils.data import DataLoader
class EncoderModelPL(pl.LightningModule):
def __init__(
self,
model: BertModel,
):
super(EncoderModelPL, self).__init__()
self.model: BertModel = model
def forward(self, x):
return self.model(x, output_hidden_states=True) # yes, I know this uses a ton of memory but my downstream application requires intermediate hidden states
MODEL_ID = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(MODEL_ID)
model = EncoderModelPL(BertModel.from_pretrained(MODEL_ID)).to("cuda")
dataset = torch.randint(low=0, high=30000, size=(800, 200), device="cuda")
tokens_dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
trainer = pl.Trainer(accelerator="gpu")
bert_outputs_per_batch: list = trainer.predict(
model=model, dataloaders=tokens_dataloader
) # CPU memory steadily increases here, this outputs a num-batches-length list containing the Bert output for each batch, stored in CPU
del bert_outputs_per_batch
gc.collect()
...<more downstream stuff>...
A few additional notes on what I've tried:
- As shown in the code, I've tried deleting the list of outputs and then garbage collecting.
- I also tried sleeping for a few seconds after that in case that helps the garbage collection (it doesn't).
- I don't believe this is an issue caused by tracking the computation graph, I've confirmed the same behavior if I add
.detach()
todataset
. - I've tried playing with
pin_memory=False
in the Dataloader, it appears to make no difference.
In summary, my questions are:
- How can I force the release of tensors from CPU memory?
- More broadly, is there a better (e.g. more memory-efficient) approach to large-scale inference than batching?
Thanks for the help!
答案1
得分: 1
内存使用是一个常数加上与批量大小成比例的一些东西。
memory = a + (b * batch_size)
首先尝试 batch_size=1
,如果有效,尝试 2、4、8。
英文:
The memory used is a constant plus a something proportional to the batch size.
memory = a + (b * batch_size)
First try with batch_size=1
, if it works try with 2, 4, 8.
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论