
huangapple go评论47阅读模式

How to release CPU memory in pytorch? (for large-scale inference)


I understand that you want the code part to remain untranslated. Here's the translated text:




import gc, time, torch, pytorch_lightning as pl
from transformers import BertTokenizer, BertModel
from torch.utils.data import DataLoader

class EncoderModelPL(pl.LightningModule):
    def __init__(
        model: BertModel,
        super(EncoderModelPL, self).__init__()
        self.model: BertModel = model

    def forward(self, x):
        return self.model(x, output_hidden_states=True) # yes, I know this uses a ton of memory but my downstream application requires intermediate hidden states

MODEL_ID = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(MODEL_ID)
model = EncoderModelPL(BertModel.from_pretrained(MODEL_ID)).to("cuda")

dataset = torch.randint(low=0, high=30000, size=(800, 200), device="cuda")
tokens_dataloader = DataLoader(dataset, batch_size=32, shuffle=False) 

trainer = pl.Trainer(accelerator="gpu")
bert_outputs_per_batch: list = trainer.predict(
    model=model, dataloaders=tokens_dataloader
) # CPU memory steadily increases here, this outputs a num-batches-length list containing the Bert output for each batch, stored in CPU
del bert_outputs_per_batch

...<more downstream stuff>...


  • 如代码中所示,我尝试删除输出列表,然后进行垃圾回收。
  • 我还尝试在此后休眠几秒钟,以便帮助垃圾回收(但没有用)。
  • 我不认为这是由于跟踪计算图而引起的问题,我已确认如果在dataset上添加.detach()也会出现相同的行为。
  • 我尝试过在Dataloader中使用pin_memory=False,似乎没有任何区别。


  • 我如何强制释放CPU内存中的张量?
  • 更广泛地说,是否有比分批次更好(例如更节省内存)的大规模推断方法?



I'm trying to do large-scale inference of a pretrained BERT model on a single machine and I'm running into CPU out-of-memory errors. Since the dataset is too big to score the model on the whole dataset at once, I'm trying to run it in batches, store the results in a list, and then concatenate those tensors together at the end. I understand that storing tensors in lists can quickly use up large amounts of CPU memory.

However, I am unable to figure out how to release this memory after the tensors are concatenated and therefore I'm running into OOM errors downstream.

Minimal Reproducible Example:
<!-- language: python -->

import gc, time, torch, pytorch_lightning as pl
from transformers import BertTokenizer, BertModel
from torch.utils.data import DataLoader

class EncoderModelPL(pl.LightningModule):
    def __init__(
        model: BertModel,
        super(EncoderModelPL, self).__init__()
        self.model: BertModel = model

    def forward(self, x):
        return self.model(x, output_hidden_states=True) # yes, I know this uses a ton of memory but my downstream application requires intermediate hidden states
MODEL_ID = &quot;bert-base-uncased&quot;
tokenizer = BertTokenizer.from_pretrained(MODEL_ID)
model = EncoderModelPL(BertModel.from_pretrained(MODEL_ID)).to(&quot;cuda&quot;)

dataset = torch.randint(low=0, high=30000, size=(800, 200), device=&quot;cuda&quot;)
tokens_dataloader = DataLoader(dataset, batch_size=32, shuffle=False) 

trainer = pl.Trainer(accelerator=&quot;gpu&quot;)
bert_outputs_per_batch: list = trainer.predict(
    model=model, dataloaders=tokens_dataloader
) # CPU memory steadily increases here, this outputs a num-batches-length list containing the Bert output for each batch, stored in CPU
del bert_outputs_per_batch

...&lt;more downstream stuff&gt;...

A few additional notes on what I've tried:

  • As shown in the code, I've tried deleting the list of outputs and then garbage collecting.
  • I also tried sleeping for a few seconds after that in case that helps the garbage collection (it doesn't).
  • I don't believe this is an issue caused by tracking the computation graph, I've confirmed the same behavior if I add .detach() to dataset.
  • I've tried playing with pin_memory=False in the Dataloader, it appears to make no difference.

In summary, my questions are:

  • How can I force the release of tensors from CPU memory?
  • More broadly, is there a better (e.g. more memory-efficient) approach to large-scale inference than batching?

Thanks for the help!


得分: 1

memory = a + (b * batch_size)

首先尝试 batch_size=1,如果有效,尝试 2、4、8。


The memory used is a constant plus a something proportional to the batch size.
memory = a + (b * batch_size)

First try with batch_size=1, if it works try with 2, 4, 8.

  • 本文由 发表于 2023年4月11日 16:32:54
  • 转载请务必保留本文链接:https://go.coder-hub.com/75983889.html



:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:
