如何在PyTorch中修复GPU内存不足问题

huangapple go评论68阅读模式
英文:

how to fix GPU out of memory in PyTorch

问题

I want to train wav2vec2 model for persian language and I have 2h (7k record's) and I use this code for training:

from transformers import Wav2Vec2ForCTC

model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-base", 
    ctc_loss_reduction="mean", 
    pad_token_id=processor.tokenizer.pad_token_id,
)

model.freeze_feature_extractor()

training_args = TrainingArguments(
    output_dir="/content/drive/MyDrive/model-output",
    group_by_length=True,
    per_device_train_batch_size=4,
    evaluation_strategy="steps",
    num_train_epochs=30,
    fp16=True,
    gradient_checkpointing=True, 
    save_steps=500,
    eval_steps=500,
    logging_steps=500,
    learning_rate=1e-4,
    weight_decay=0.005,
    warmup_steps=1000,
    save_total_limit=2,
)

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=common_voice_train,
    eval_dataset=common_voice_test,
    tokenizer=processor.feature_extractor,
)

trainer.train()

When I run this, I got this ERROR:

CUDA out of memory. Tried to allocate 1.22 GiB (GPU 0; 14.75 GiB total capacity; 12.59 GiB already allocated; 296.81 MiB free; 13.45 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

This is the ERROR 如何在PyTorch中修复GPU内存不足问题

/usr/local/lib/python3.10/dist-packages/transformers/optimization.py:407: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
warnings.warn(
/usr/local/lib/python3.10/dist-packages/transformers/models/wav2vec2/processing_wav2vec2.py:155: UserWarning: `as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your labels by using the argument `text` of the regular `__call__` method (either in the same call as your audio inputs, or in a separate call.
warnings.warn(
[ 501/43430 01:23 < 1:59:11, 6.00 it/s, Epoch 0.12/10]
Step	Training Loss	Validation Loss
[50/61 00:24 < 00:05, 2.04 it/s]
英文:

I want to train wav2vec2 model for persian language and I have 2h (7k record's) and I use this code for training

from transformers import Wav2Vec2ForCTC

model = Wav2Vec2ForCTC.from_pretrained(
    &quot;facebook/wav2vec2-base&quot;, 
    ctc_loss_reduction=&quot;mean&quot;, 
    pad_token_id=processor.tokenizer.pad_token_id,
)

model.freeze_feature_extractor()

training_args = TrainingArguments(
    output_dir=&quot;/content/drive/MyDrive/model-output&quot;,
    group_by_length=True,
    per_device_train_batch_size=4,
    evaluation_strategy=&quot;steps&quot;,
    num_train_epochs=30,
    fp16=True,
    gradient_checkpointing=True, 
    save_steps=500,
    eval_steps=500,
    logging_steps=500,
    learning_rate=1e-4,
    weight_decay=0.005,
    warmup_steps=1000,
    save_total_limit=2,
)

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=common_voice_train,
    eval_dataset=common_voice_test,
    tokenizer=processor.feature_extractor,
)

trainer.train()

when I run this, I got this ERROR

> CUDA out of memory. Tried to allocate 1.22 GiB (GPU 0; 14.75 GiB total capacity; 12.59 GiB already allocated; 296.81 MiB free; 13.45 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

this is the ERROR 如何在PyTorch中修复GPU内存不足问题

    /usr/local/lib/python3.10/dist-packages/transformers/optimization.py:407: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
warnings.warn(
/usr/local/lib/python3.10/dist-packages/transformers/models/wav2vec2/processing_wav2vec2.py:155: UserWarning: `as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your labels by using the argument `text` of the regular `__call__` method (either in the same call as your audio inputs, or in a separate call.
warnings.warn(
[ 501/43430 01:23 &lt; 1:59:11, 6.00 it/s, Epoch 0.12/10]
Step	Training Loss	Validation Loss
[50/61 00:24 &lt; 00:05, 2.04 it/s]
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ in &lt;cell line: 1&gt;:1                                                                              │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:1664 in train                    │
│                                                                                                  │
│   1661 │   │   inner_training_loop = find_executable_batch_size(                                 │
│   1662 │   │   │   self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size  │
│   1663 │   │   )                                                                                 │
│ ❱ 1664 │   │   return inner_training_loop(                                                       │
│   1665 │   │   │   args=args,                                                                    │
│   1666 │   │   │   resume_from_checkpoint=resume_from_checkpoint,                                │
│   1667 │   │   │   trial=trial,                                                                  │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:2019 in _inner_training_loop     │
│                                                                                                  │
│   2016 │   │   │   │   │   self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epo  │
│   2017 │   │   │   │   │   self.control = self.callback_handler.on_step_end(args, self.state, s  │
│   2018 │   │   │   │   │                                                                         │
│ ❱ 2019 │   │   │   │   │   self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_k  │
│   2020 │   │   │   │   else:                                                                     │
│   2021 │   │   │   │   │   self.control = self.callback_handler.on_substep_end(args, self.state  │
│   2022                                                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:2300 in _maybe_log_save_evaluate │
│                                                                                                  │
│   2297 │   │   │   │   │   )                                                                     │
│   2298 │   │   │   │   │   metrics.update(dataset_metrics)                                       │
│   2299 │   │   │   else:                                                                         │
│ ❱ 2300 │   │   │   │   metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)                 │
│   2301 │   │   │   self._report_to_hp_search(trial, self.state.global_step, metrics)             │
│   2302 │   │   │                                                                                 │
│   2303 │   │   │   # Run delayed LR scheduler now that metrics are populated                     │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:3029 in evaluate                 │
│                                                                                                  │
│   3026 │   │   start_time = time.time()                                                          │
│   3027 │   │                                                                                     │
│   3028 │   │   eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else se  │
│ ❱ 3029 │   │   output = eval_loop(                                                               │
│   3030 │   │   │   eval_dataloader,                                                              │
│   3031 │   │   │   description=&quot;Evaluation&quot;,                                                     │
│   3032 │   │   │   # No point gathering the predictions if there are no metrics, otherwise we d  │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:3210 in evaluation_loop          │
│                                                                                                  │
│   3207 │   │   │   │   │   batch_size = observed_batch_size                                      │
│   3208 │   │   │                                                                                 │
│   3209 │   │   │   # Prediction step                                                             │
│ ❱ 3210 │   │   │   loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_o  │
│   3211 │   │   │   inputs_decode = self._prepare_input(inputs[&quot;input_ids&quot;]) if args.include_inp  │
│   3212 │   │   │                                                                                 │
│   3213 │   │   │   if is_torch_tpu_available():                                                  │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:3466 in prediction_step          │
│                                                                                                  │
│   3463 │   │   │   else:                                                                         │
│   3464 │   │   │   │   if has_labels or loss_without_labels:                                     │
│   3465 │   │   │   │   │   with self.compute_loss_context_manager():                             │
│ ❱ 3466 │   │   │   │   │   │   loss, outputs = self.compute_loss(model, inputs, return_outputs=  │
│   3467 │   │   │   │   │   loss = loss.mean().detach()                                           │
│   3468 │   │   │   │   │                                                                         │
│   3469 │   │   │   │   │   if isinstance(outputs, dict):                                         │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:2767 in compute_loss             │
│                                                                                                  │
│   2764 │   │   │   labels = inputs.pop(&quot;labels&quot;)                                                 │
│   2765 │   │   else:                                                                             │
│   2766 │   │   │   labels = None                                                                 │
│ ❱ 2767 │   │   outputs = model(**inputs)                                                         │
│   2768 │   │   # Save past state if it exists                                                    │
│   2769 │   │   # TODO: this needs to be fixed and made cleaner later.                            │
│   2770 │   │   if self.args.past_index &gt;= 0:                                                     │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501 in _call_impl            │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:1684   │
│ in forward                                                                                       │
│                                                                                                  │
│   1681 │   │                                                                                     │
│   1682 │   │   return_dict = return_dict if return_dict is not None else self.config.use_return  │
│   1683 │   │                                                                                     │
│ ❱ 1684 │   │   outputs = self.wav2vec2(                                                          │
│   1685 │   │   │   input_values,                                                                 │
│   1686 │   │   │   attention_mask=attention_mask,                                                │
│   1687 │   │   │   output_attentions=output_attentions,                                          │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501 in _call_impl            │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:1320   │
│ in forward                                                                                       │
│                                                                                                  │
│   1317 │   │   │   hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention  │
│   1318 │   │   )                                                                                 │
│   1319 │   │                                                                                     │
│ ❱ 1320 │   │   encoder_outputs = self.encoder(                                                   │
│   1321 │   │   │   hidden_states,                                                                │
│   1322 │   │   │   attention_mask=attention_mask,                                                │
│   1323 │   │   │   output_attentions=output_attentions,                                          │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501 in _call_impl            │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:798 in │
│ forward                                                                                          │
│                                                                                                  │
│    795 │   │   │   │   │   │   attention_mask,                                                   │
│    796 │   │   │   │   │   )                                                                     │
│    797 │   │   │   │   else:                                                                     │
│ ❱  798 │   │   │   │   │   layer_outputs = layer(                                                │
│    799 │   │   │   │   │   │   hidden_states, attention_mask=attention_mask, output_attentions=  │
│    800 │   │   │   │   │   )                                                                     │
│    801 │   │   │   │   hidden_states = layer_outputs[0]                                          │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501 in _call_impl            │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:679 in │
│ forward                                                                                          │
│                                                                                                  │
│    676 │                                                                                         │
│    677 │   def forward(self, hidden_states, attention_mask=None, output_attentions=False):       │
│    678 │   │   attn_residual = hidden_states                                                     │
│ ❱  679 │   │   hidden_states, attn_weights, _ = self.attention(                                  │
│    680 │   │   │   hidden_states, attention_mask=attention_mask, output_attentions=output_atten  │
│    681 │   │   )                                                                                 │
│    682 │   │   hidden_states = self.dropout(hidden_states)                                       │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501 in _call_impl            │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:596 in │
│ forward                                                                                          │
│                                                                                                  │
│    593 │   │   │   attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + at  │
│    594 │   │   │   attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)      │
│    595 │   │                                                                                     │
│ ❱  596 │   │   attn_weights = nn.functional.softmax(attn_weights, dim=-1)                        │
│    597 │   │                                                                                     │
│    598 │   │   if layer_head_mask is not None:                                                   │
│    599 │   │   │   if layer_head_mask.size() != (self.num_heads,):                               │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/functional.py:1843 in softmax                   │
│                                                                                                  │
│   1840 │   if dim is None:                                                                       │
│   1841 │   │   dim = _get_softmax_dim(&quot;softmax&quot;, input.dim(), _stacklevel)                       │
│   1842 │   if dtype is None:                                                                     │
│ ❱ 1843 │   │   ret = input.softmax(dim)                                                          │
│   1844 │   else:                                                                                 │
│   1845 │   │   ret = input.softmax(dim, dtype=dtype)                                             │
│   1846 │   return ret                                                                            │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
OutOfMemoryError: CUDA out of memory. Tried to allocate 6.86 GiB (GPU 0; 14.75 GiB total capacity; 5.06 GiB already
allocated; 5.60 GiB free; 8.09 GiB reserved in total by PyTorch) If reserved memory is &gt;&gt; allocated memory try 
setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and 
PYTORCH_CUDA_ALLOC_CONF

答案1

得分: 1

这个人说他们成功解决了wav2vec2中的内存问题,方法是防止将较长的音频片段作为输入,建议尝试排除超过一定长度的片段。1

英文:

So I found some info that might be relevant, this person said that they were able to solve their memory problems in the wav2vec2 by preventing longer sound clips as input, try to experiment with excluding clips that exceed a certain length.

huangapple
  • 本文由 发表于 2023年5月22日 22:14:48
  • 转载请务必保留本文链接:https://go.coder-hub.com/76307106.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定