StackLLaMA: A hands-on guide to train LLaMA with RLHF

阿里云国内75折 回扣 微信号:monov8
阿里云国际,腾讯云国际,低至75折。AWS 93折 免费开户实名账号 代冲值 优惠多多 微信号:monov8 飞机:@monov6

Paper name

StackLLaMA: A hands-on guide to train LLaMA with RLHF

Paper Reading Note

Project URL: https://huggingface.co/blog/stackllama
Code URL: https://huggingface.co/docs/trl/index

TL;DR

  • Huggingface 公司开发的 RLHF 训练代码已集成到 huggingface 的 trl 库中在 Stack Exchange 数据集对 LLaMA 模型进行了微调。博客详细介绍了 SFT有监督微调、RM奖励/偏好建模和 RLHF人类反馈的强化学习的训练细节并介绍了一些训练中可能遇到的问题及解决思路

Introduction

背景

  • ChatGPT、GPT-4 和 Claude 等模型是功能强大的语言模型它们人类反馈强化学习 (RLHF) 的方法进行了微调以使得它们的行为方式更好地符合我们的期望

本文方案

  • 在这篇博客文章中我们展示了使用 SFT有监督微调、RM奖励/偏好建模和 RLHF人类反馈的强化学习相结合的方法训练 LlaMa 模型回答 Stack Exchange (一个问答网站每个答案有对应的用户点赞数目标注) 上的问题的所有步骤。
    在这里插入图片描述
  • 经过以上微调训练本文训练了一个 StackLLaMA 模型开源到了 Hub 上整个训练流程也开源到了 trl

Dataset/Algorithm/Model/Experiment Detail

实现方式

LLaMA 模型

  • 在进行RLHF时从一个有能力的模型开始非常重要RLHF 步骤只是微调模型以使其与本文想要与其交互和期望其响应的方式相一致。因此本文选择使用最近推出的性能出色的 LLaMA 模型。LLaMA 模型是由 Meta AI 开发的最新大型语言模型大小从 7B 到 65B 参数不等并在 1T 到 1.4T 个 token 数据集之间进行了训练使其性能很强。本文使用 7B 模型作为所有后续步骤的基础

Stack Exchange 数据集

  • 收集人类反馈是一项复杂而昂贵的工作。为了引导这个例子的过程同时仍然建立一个有用的模型使用 Stack Exchange 数据集数据集包括来自 StackExchange 平台的问题及其相应的答案包括用于代码和许多其他主题的 StackOverflow。这个数据集信息量很大回复的答案与赞成票的数量和已接受答案的标签都有

  • 本文使用 A General Language Assistant as a Laboratory for Alignment 中提到的方法来给每个答案进行打分

    • score = round(log2 (1 + upvotes)) 注这里用 log 的原因是人们一般优先看高赞回答导致强者恒强这里希望用 log 稍微拉低高赞回答的分数
    • 被提问者接受的答案分数再加上 1
    • upvotes 为负的分数设置为 -1
  • 对于 reward model每个问题需要两个回答用于对比。一些问题有几十个回复导致有很多个匹配答案对本文对每个问题最多采样 10 个答案对以限制每个问题的数据数量。最后通过将 HTML 转换为 markdown 得到格式干净的数据数据示例和处理脚本在stack-exchange-paired

高效训练策略

  • 即便训练最小的 LLaMA 模型也需要大量的显存消耗简单计算
    • 基于 bf16 进行参数存储每个参数占用 2 bytesAdam 优化器暂用 8 bytes所以一个 7B 参数模型会消耗 (2+8)*7B=70GB 左右显存计算注意力分数等中间值时可能需要更多显存
  • 本文使用 Parameter-Efficient Fine-Tuning (PEFT) 技巧比如在 8 bit load 的模型上使用 LoRA
    • 以 8 bit 加载模型可显著减少显存占用因为每个参数只需要一个 byte 例如 7B LlaMa 在显存中占用 7GB
    • 在这种配置下一般 1B 的参数需要 1.2~1.4Gb 的显存 (取决于批量大小和序列长度)80GB A100 一般可以训练 50-60B 的模型
  • 同时使用 dp 进行加速
    在这里插入图片描述

Supervised fine-tuning

  • 开始训练奖励模型和通过强化学习调整模型之前如果模型在我们感兴趣的领域中表现良好那么这会有所帮助。在本文的情况下希望它能够回答问题而对于其他用例可能希望它能够遵循指令这种情况下需要进行指令调整。实现这一点最简单的方法是使用来自该领域或任务的文本继续使用语言建模目标对语言模型进行训练。StackExchange 数据集非常庞大超过 1000 万条指令因此可以轻松地在其中的一个子集上训练语言模型。
  • 利用与预训练阶段一样的 causal language modeling objective 损失来仅模型微调。为了有效地使用数据本文使用了一种叫做“packing”的技术不是在批次中每个样本都有一个文本然后填充到模型的最长文本或最大上下文而是将许多文本连接在一起用 EOS token 分隔并切割上下文大小的块来填充批次无需任何填充。
    在这里插入图片描述
    采用这种方法训练效率要高得多因为每个通过模型的 token 都会被训练而传统的数据读取方法会在损失计算中将填充的 token 排除掉。如果没有太多的数据并且不希望有偶尔截断一些溢出上下文的 token 这种问题也可以使用传统的数据加载器。上面描述的数据预处理方法在代码中是 ConstantLengthDataset 实现的
  • 模型使用 LoRA 方式进行训练因为之后还需要使用不同的 loss 对模型进行训练这里训练完成之后需要将 LoRA 的模型参数合入到原始模型中

Reward modeling and human preferences

  • 原则上可以直接使用人类标注来进行 RLHF 微调模型。然而这将需要在每次优化迭代之后向人类发送一些样本进行评分。由于收敛所需的训练样本数量较大以及人类阅读和标注速度的固有延迟这是昂贵而缓慢的。一般是训练一个奖励模型 (reward model) 来代替人类标注。奖励模型的目标是模仿人类来评价一段文本。有几种可能的策略来构建奖励模型最直接的方法是预测人类标注结果例如评分分数或“好/坏”的二进制值。在实践中更好的方法是预测两个答案的排名其中奖励模型输入为一个给定的 prompt x以及两个基于 x 输入的回复 (yk, yj)奖励模型来预测哪一个会被人类注释者评价更高。奖励函数的 loss 设计为
    在这里插入图片描述
    其中 r 是模型的输出分时yj 是两个回复中更好的回复也即期望奖励模型对于更好的回复的打分需要尽量高更差的回复的打分需要尽量低。loss 的代码实现如下
class RewardTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        rewards_j = model(input_ids=inputs["input_ids_j"],  attention_mask=inputs["attention_mask_j"])[0]
        rewards_k = model(input_ids=inputs["input_ids_k"], attention_mask=inputs["attention_mask_k"])[0]
        loss = -nn.functional.logsigmoid(rewards_j - rewards_k).mean()
        if return_outputs:
            return loss, {"rewards_j": rewards_j, "rewards_k": rewards_k}
        return loss
  • 实验配置
    • 训练数据使用了 100000 个候选对评测使用了 50000 数据
    • batchsize 41 epoch
    • AdamBF16
    • Lora rank 8alpha 32
    • 8xA100 训练需要几个小时
  • 实验结果67% 的准确率

Reinforcement Learning from Human Feedback

  • 基于前述的微调后的模型以及奖励模型进行强化学习训练包含以下步骤
    • 基于 prompt 输入生成回复
    • 使用奖励模型对回复进行评级
    • 使用评级进行 reinforcement learning policy-optimization 更新

在这里插入图片描述

  • 查询和响应提示在被 token 化并传递给模型之前按如下方式模板化该模板在 SFTRM 和 RLHF 三个步骤中保持一致

    Question: <Query>
    Answer: <Response>
    
  • 使用 RL 训练语言模型的一个常见问题是该模型可以通过生成完整的乱码来学习利用奖励模型这会导致奖励模型分配高奖励。为了平衡这一点在奖励中增加了一个惩罚保留了一个没有训练的模型 (即 SFT 后的模型) 作为参考并通过计算 KL-divergence 来对新模型的生成与参考模型的生成的相似性进行约束
    在这里插入图片描述

  • 整个 RLHF 的代码示例如下

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    question_tensors = batch["input_ids"]
        
    # sample from the policy and generate responses
    response_tensors = ppo_trainer.generate(
        question_tensors,
        return_prompt=False,
        length_sampler=output_length_sampler,
        **generation_kwargs,
    )
    batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True)

    # Compute sentiment score
    texts = [q + r for q, r in zip(batch["query"], batch["response"])]
    pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
    rewards = [torch.tensor(output[0]["score"] - script_args.reward_baseline) for output in pipe_outputs]

    # Run PPO step
    stats = ppo_trainer.step(question_tensors, response_tensors, rewards)
    # Log stats to WandB
    ppo_trainer.log_stats(stats, batch, rewards
  • 实验配置
    • 3x8 A100-80GB 需要 20 h 的训练时间

实验结果

奖励模型训练

  • 准确率为 67%作者的解释是任务比较难人也不一定能做好

RL 模型训练

  • 训练过程中每个 batch 的 reward

在这里插入图片描述

  • 训练后的模型可以模仿人的回复虽然不应该相信它关于 LLaMA 问题的建议但答案看起来连贯甚至提供了一个谷歌链接这个直接在官网测试发现回复的字数会多很多
    在这里插入图片描述

训练过程中的挑战

  • 高的 reward 不一定代表更好的性能
    在这里插入图片描述
    一般来说在 RL 中希望获得最高的奖励。在 RLHF 中因为使用了一个不完美的奖励模型如果有机会PPO 算法将利用这些不完美。这可能表现为奖励的突然增加但是当查看策略生成的文本时它们主要包含字符串 ( ```) 的重复因为奖励模型发现 stack exchange 上包含代码块的答案通常比没有代码块的排名更高。这个可以通过 KL 惩罚来一定程度缓解

  • KL 在这里的实现不一定是正的值因为本文采用了 KL 的估计值
    在这里插入图片描述
    可以看出来当 policy 模型采样的 token 比 SFT 模型的概率低时估计的 KL 值为负。但平均而言它将是正的否则将无法从 policy 中正确抽样。然而一些生成策略会强制生成一些 token 或则强行抑制一些 token。例如当批量生成时完成的序列会被 pad这时设置小的长度会导致 EOS token 被抑制。模型可以为那些导致负 KL 的 token 分配非常高或低的概率。由于 PPO 算法针对奖励进行优化它会追逐这些负惩罚导致不稳定
    在这里插入图片描述
    生成响应时需要小心建议在求助于更复杂的生成方法之前始终先使用简单的采样策略

  • ppo 的 loss 有不稳定的现象暂时还没有解决
    在这里插入图片描述

Thoughts

  • 作者认为后续一些可以研究的点
    • 有了训练好的模型后可以与其他模型进行对比评测
    • 有了评测基建后可以尝试在数据集上做修改比如过滤一些数据或增加一些数据
    • 不同模型架构和尺寸的对比
阿里云国内75折 回扣 微信号:monov8
阿里云国际,腾讯云国际,低至75折。AWS 93折 免费开户实名账号 代冲值 优惠多多 微信号:monov8 飞机:@monov6