清华提出4比特优化器 显著减少LLaMA微调的显存需求
要点:
1、针对优化器状态的量化方法,将优化器状态的数值精度降低至4比特
2、提出了针对一阶矩和二阶矩的量化策略,处理了零点问题等难点
3、在多个微调任务中达到全精度优化器的性能,同时可将LLaMA微调的显存需求减少超过50%
站长之家(ChinaZ.com)9月8日 消息:随着大模型规模的不断增大,显存需求成为模型训练的主要瓶颈之一。优化器状态中的一阶矩和二阶矩是占用大量显存的重要因素。为降低显存使用,清华大学朱军、陈键飞团队在ICLR2022的工作基础上,进一步将优化器状态的比特数降低到4比特,同时针对一阶矩和二阶矩的不同特点,提出了相应的量化策略。
项目地址:https://github.com/thu-ml/low-bit-optimizers
对于一阶矩,由于存在按行或列分布的异常值,提出采用更小的128大小的分块进行归一化。对二阶矩,确定零点问题是主要难点,去除零点的线性映射取得了很好效果,同时提出rank-1归一化更好地处理异常值。最后,提出了4比特AdamW和Factor两种低精度优化器。
在多个经典的微调任务中进行评估,结果表明4比特优化器能够匹配甚至超过32比特AdamW的性能。同时显著减少了优化器状态的显存需求,在LLaMA-7B的微调中最高可节省57.7%的显存。提供了开箱即用的PyTorch接口,只需要一行代码即可使用。
本研究工作展示了通过压缩的思路显著减少大模型微调中的显存瓶颈的可能性。同时优化器状态的低比特设计也为进一步探索内存高效的训练算法提供了有价值的经验。这些成果将促进大模型在有限硬件条件下的高效训练与应用。