利用 DeepSpeed-Chat 的 RLHF API 自定义你自己的 RLHF 训练流程

DeepSpeed-Chat 允许用户使用我们灵活的 API(如下所示)构建自己的 RLHF 训练流程,用户可以使用这些 API 重建自己的 RLHF 训练策略。我们希望这些功能可以为研究探索中创建各种 RLHF 算法提供通用接口和后端。

engine = DeepSpeedRLHFEngine(
  actor_model_name_or_path=args.actor_model_name_or_path,
  critic_model_name_or_path=args.critic_model_name_or_path,
  tokenizer=tokenizer,
  num_total_iters=num_total_iters,
  args=args)

trainer = DeepSpeedPPOTrainer(engine=engine, args=args)

for prompt_batch in prompt_train_dataloader:
  out = trainer.generate_experience(prompt_batch)
  actor_loss, critic_loss = trainer.train_rlhf(out)

原创文章,作者:校长,如若转载,请注明出处:https://www.yundongfang.com/Yun246069.html

(0)
打赏 微信扫一扫不于多少! 微信扫一扫不于多少! 支付宝扫一扫礼轻情意重 支付宝扫一扫礼轻情意重
上一篇 2023年4月12日
下一篇 2023年4月12日

相关推荐