DeepSeek用的GRPO占用大量内存?有人给出了些方法

深入研究 GRPO,发现了意外收获。

DeepSeek用的GRPO占用大量内存?有人给出了些破解方法

原标题:DeepSeek用的GRPO占用大量内存?有人给出了些方法
文章来源:机器之心
内容字数:8253字

RTX 3080 移动版可训练的大模型及GRPO训练技巧

本文总结了使用RTX 3080移动版显卡(16GB显存)进行大型语言模型强化学习训练的经验,重点介绍了群组相对策略优化(GRPO)方法及其内存优化策略。

  1. 可训练模型大小及方法选择

    作者使用GRPO方法,在RTX 3080移动版上进行训练,发现模型大小和训练方式对显存需求影响很大。实验在参数量从5亿到140亿不等的模型上进行,比较了全参数微调和参数高效微调(PEFT,使用LoRA)。全参数微调比PEFT需要更多内存。在H100上进行的实验显示,全参数微调所需的VRAM超过80GB。

  2. GRPO的高内存需求原因

    GRPO的高内存需求源于其内部涉及多个模型(策略模型、参考模型和奖励模型),每个查询都会产生多个输出,导致内存占用迅速增加。即使奖励模型非参数化,内存需求依然很高。

  3. 内存优化策略

    为了降低内存占用,作者使用了两种技术:8位优化器(例如8-bit AdamW)和梯度检查点。8位优化器能更高效地存储优化器跟踪数据,而梯度检查点则通过在训练过程中拍摄快照来减少内存使用,虽然会降低训练速度(约20-30%),但能显著减少内存占用。

  4. 代码示例及参数设置

    作者提供了使用Hugging Face的trl库进行GRPO训练的代码示例,该代码简洁易懂,适合小型模型(如meta-llama/Llama-3.2-1B-Instruct)和数据集(如openai/GSM8K)。文中详细说明了各个参数(如`num_generations`、`batch_size`、`gradient_accumulation_steps`、`num_completions`、`max_prompt_length`、`max_completion_length`)对VRAM使用量的影响,并建议在内存瓶颈修复前使用`num_generations=4`。

  5. VRAM使用量估算

    作者给出了VRAM使用量的粗略估算方法,考虑了模型参数、梯度、优化器状态等因素,并指出PEFT可以减少梯度的显存占用。

  6. 实验结果及结论

    作者使用10亿参数的Llama 3.2模型进行了完整训练,结果显示GRPO显著提升了模型准确率(从19%提升到40.5%),展示了其强大潜力。

总而言之,本文为GPU资源有限的开发者提供了宝贵的GRPO训练经验,并通过内存优化策略和参数调整,帮助开发者在有限的硬件条件下训练更大的模型。


联系作者

文章来源:机器之心
作者微信:
作者简介:专业的人工智能媒体和产业服务平台

阅读原文
© 版权声明
问小白满血版DeepSeek免费不限次数使用

相关文章

问小白满血版DeepSeek免费不限次数使用

暂无评论

暂无评论...