- LLM 大模型训练之路
- 结束
LLM 大模型训练之路
训练组件
LLM 训练模式
【2024-3-8】LLM-SFT-trick
微调是指在已经预训练好的大型语言模型基础上,使用特定数据集进行进一步的训练,使模型适应特定任务或领域。
- 微调主要目的是,完成 知识注入、指令对齐
大模型应用中,指令微调已成为预训练大模型在实际业务应用最重要方式。许多垂直领域模型,都在预训练模型的基础上,通过针对性的指令微调,可以更好地适应最终任务和对齐用户偏好。
指令微调时,会将 Instruction(指令) 及对应的answer拼接成文本
- 拼接过程中一般会加入【USER】、【BOT】等角色
- 同时会加入开始、结束的special token
这样可以转换成一个chat式任务
如翻译任务
# instruction:
【USER】:将下列内容翻译成英语:{待翻译文本}
# answer:
【BOT】:{翻译结果}
# 拼接后的文本:
<bos_token>【USER】:将下列内容翻译成英语:{待翻译文本}<special token>【BOT】:{翻译结果} <eos_token>
将拼接文本采用预训练任务方式进行自回归预测
- 与预训练的区别:loss的计算,同样使用Cross-Entropy作为loss,指令微调时只会计算answer部分,Instruction部分通过设置ignore_index隐掉。
- 上面的案例中,只会计算 【BOT】: 之后的loss。
特定任务改造
- 分类任务: 模型最后添加softmax层。典型案: reward模型。
通过生成式模式解决判别式任务
- 如多目标文本分类问题,采用指令微调方式去解决,效果非常好。
- 甚至在7B、3B的base模型上,去生成一个复杂json结构(包含多层结构的标签)依然有效。
微调方法
- 微调方法分为全参数微调(Full Fine-tuning)、部分参数微调(Repurposing)
- 全微调方法:SFT
- 部分微调方法:LoRA、Adapter、Prefix-tuning、P-tuning、Prompt-tuning 、Freeze-tuning 等。
受GPT论文影响,大模型通用训练模式是三阶段训练模式:第一阶段 pre-train
,第二阶段 SFT
,第三阶段 RLHF
。
- 三阶段训练分别得到 base模型 以及 chat模型
- chat模型在base模型基础进行通用任务的
SFT
以及RLHF
,使模型具备了对话能力、推理能力、用户偏好对齐、以及其他的NLU的能力。
SFT 训练模式
- 模式一:基于 base模型 + 领域任务的SFT;
- 模式二:基于 base模型 + 领域数据 continue pre-train + 领域任务SFT;
- 模式三:基于 base模型 + 领域数据 continue pre-train + 通用任务SFT + 领域任务SFT;
- 模式四:基于 base模型 + 领域数据 continue pre-train + 通用任务与领域任务混合SFT;
- 模式五:基于 base模型 + 领域数据 continue pre-train(混入SFT数据) + 通用任务与领域任务混合SFT;
- 模式六:基于 chat模型 + 领域任务SFT;
- 模式六:基于 chat模型 + 领域数据 continue pre-train + 领域任务SFT
根据领域任务、领域样本、业务需求选择合适的训练模式。
- a. 是否需要 continue pre-train
- 大模型的知识来自 pre-train 阶段
- 如果领域任务数据集与 pre-train 数据集差异较大(如领域任务数据来自公司内部),pre-train 训练样本基本不可能覆盖到,那一定要进行 continue pre-train。
- 如果领域任务数据量较大(token在1B以上),并只追求领域任务效果,不考虑通用能力,建议进行continue pre-train。
- b. 选择 chat模型 还是 base模型
- 如果有好的base模型,在base模型基础进行领域数据的SFT, 与在chat模型上进行SFT,效果上差异不大。
- 基于chat模型进行领域SFT,很容导致灾难性遗忘,进行领域任务SFT之后,模型通用能力会降低,如只追求领域任务的效果,则不用考虑。
- 如果领域任务与通用任务有很大相关性,那这种二阶段SFT会提升领域任务效果。
- 如果既追求领域任务的效果,并且希望通用能力不下降,建议选择 base模型 作为基座模型。在base模型上进行多任务混合训练,混合训练的时候需要关注各任务间的数据配比。
- c. 其他
- 资源运行的情况下,如只考虑领域任务效果,选择模式二;
- 资源运行的情况下,如考虑模型综合能力,选择模式五;
- 资源不允许的情况下,考虑模式六;
SFT-训练参数
学习率
- 学习率非常重要,如果设置不当,很容易让SFT模型烂掉。
- SFT数据集不大时,建议设置较小学习率,一般为pre-train阶段学习率的0.1左右,如在pre-train阶段的学习率为9e-5,则SFT学习率设置为9e-6。
- 在10万SFT样本上,采用与pre-train一样的学习率,发现loss一直不收敛,在调低学习率至原来0.1之后,loss在两个epoch之后就收敛。
warmup_ratio
- 通常 pre-train 训练的
warmup_ratio
0.01~0.015之间,warmup-steps
在2000左右。 - SFT 时,建议用更小的ratio,因为相较于pre-train,SFT样本非常小,较小
warmup_ratio
可以使模型收敛更平滑。 - 但如果学习率设置较大,那可增大 warmup_ratio,两者呈正相关。
- 通常 pre-train 训练的
Epoch
- Epoch 可根据loss收敛情况设置
- 如果SFT样本较少,可设置较大epoch,在较小的epoch上loss会不收敛,指令都很难遵循。较大epoch会容易导致过拟合,但过拟合要优于欠拟合。
- 如果SFT样本数量较多,如在十万以上,一般2个epoch即可收敛。
其它
- 如果SFT任务类型较多,添加 system_prompt,不同任务使用不同 system_prompt;
- 好的基座模型非常重要
- SFT 时,loss依然是最重要的指标,一般在SFT过程中,loss会先升后降;
- 尝试多种模式训练方案,如 continue pre-train 中添加SFT数据,在SFT数据添加高质量的pre-train数据;
- 模型参数量非常重要
二次开发
- 1、
领域知识注入
:Continue PreTraining(增量预训练
): 一般垂直大模型是基于通用大模型进行二次开发,用领域内的语料进行继续预训练。 - 2、
知识召回
(激发):SFT( Supervised Finetuning,有监督微调
): 通过SFT激发大模型理解领域内的各种问题, 并进行回答的能力。 - 3、基础
偏好对齐
:奖励模型(RM)、强化学习(RL),让大模型的回答对齐人们的偏好,比如行文风格。 - 4、高阶
偏好对齐
:RLHF
(人类反馈强化学习训练)、DPO
(直接偏好优化)。
3个阶段:
- (1)、第一阶段:
CPT
(Continue PreTraining)增量预训练,在海量领域文档数据上二次预训练GPT模型,以注入领域知识。 - (2)、第二阶段:
SFT
(Supervised Fine-tuning)有监督微调,构造指令微调数据集,在预训练模型基础上做指令精调,以对齐指令意图。 - (3)、第三阶段 :
RLHF
和DPO
二选一。
Post-pretraining
Post-pretraining(后期预训练)是一种在模型的初始预训练和最终微调之间进行的训练方法。这种方法通常用于进一步适应模型以处理特定类型的数据或任务。
- 在通用预训练模型的基础上,对模型进行额外训练,使模型更好地适应特定的领域或任务
- 数据集: 某个领域,但比微调阶段使用的数据集更大、更广泛。
- 训练方法: 监督学习,自监督学习,取决于数据类型和训练目标, 如语言建模、文本分类、实体识别等
Post-pretraining 允许模型在保持通用性的同时,增强对特定领域的理解,有助于模型在后续的微调阶段更快速地适应特定任务。
- 与 SFT 相比,Post-pretraining 在微调之前提供了一个中间步骤,有助于模型更平滑地过渡到特定任务上。
- 与 RLHF 相比,Post-pretraining 不依赖于复杂的奖励机制或人类反馈,而是通过大量的领域特定数据来提升模型性能。
总结
- Post-pretraining 是一个介于
预训练
和微调
之间的训练阶段 - 使用大量的领域特定数据来进一步调整模型,使其更好地理解特定领域的语言和任务。
- 这个阶段不需要复杂的奖励机制,而是通过传统的监督或自监督学习方法来实现模型性能的提升。
增量预训练
增量预训练
是属于后期预训练
(Post-pretraining)
增量预训练
也叫领域自适应预训练
(domain-adapter pretraining),即在所属领域数据上继续预训练。
自适应预训练
(domain-adapter pretraining)的方法可以分为三类:Prompt-based方法、representation-based方法和model mixed-based方法。
- Prompt-based 方法
- representation-based 方法
- model mixed-based 方法
1. Prompt-based 方法
使用模型全局tuning的方式适应下游任务时,预训练模型的泛化性能会被严重削弱
因此, Prompt-based方法在保持预训练模型参数权重不变的条件下, 增加额外可学习的 Prompt tuning 模块来实现对下游任务的泛化,这样就能较好地保持原模型的泛化性能。
VPT
虽然可以较好地保留模型的泛化性,但是面对新的任务时,以往的Prompt模块的知识同样被覆盖,依旧遭遇了灾难性遗忘
问题。
为此,有学者提出了Prompt Pool
概念,设计了Prompt模块的集合,即P={P1,P2,…,Pm}
(m表示该Pool的最大尺寸)。
Prompt Pool 有效避免了单一Prompt的问题,但是Pool的设计使得其需要进行Prompt Selection操作,也就是需要将特定任务与其对应的Prompt模块进行索引匹配。
L2P
算法是一种较为常用的 Prompt selection算法,该算法设计了一种Key-Query的Prompt匹配方法,为每一个Prompt提供一个可学习的索引键k,即 P={(k1,P1),(k2,P2),…,(km,Pm)}
。
L2P利用预训练模型将输入特征编码到Key对用的嵌入空间中,然后利用余弦距离损失函数在已有的Pool中搜索最近似的Key。接着,利用如交叉熵损失等方法对搜索到的Key对应的Prompt进行进行优化。
类似的Prompt Selection 算法很多,如DualPrompt算法,该算法将Prompt进行解耦,分化为General Prompt和Expert Prompt。General Prompt面向所有任务,为所有任务中共享信息,而Expert Prompt针对独立任务,数量与任务量一致。其采用了和L2P相同的key-query匹配策略。
Prompt Selection虽然可行,但仍是硬匹配,选项有限。基于注意力信息加权的Prompt Combination方法则有效缓解了该问题。如CODA-Prompt通过对Prompt Pool进行注意力机制嵌入,为每个注意力赋予自适应权重,进而求算全局Key-Query的加权和,实现可学习式Prompt组合。我觉得稀疏式注意力Prompt combination应该也是很有趣的研究。
从根本上来说Prompt Combination仍受制于Prompt Pool的范围。为此, 许多学者则开展Prompt Generation有关的研究,如DAP,其利用MLP进行特定任务提示信息的编码生成。
优点:
- Prompt 有助于弥合domain gap,并可有效地对特定任务的知识进行编码。
- Prompt Design 属于lightweight模块,与input feature具有相同的维度,因此保存Prompt是parameter-efficient,适用于边缘场景。
- Prompt Pool作为预训练模型的外部存储器,其支持自适应知识的检索和特定实例的预测。
缺点:
- 一些研究发现L2P中的prompt selection过程收敛到一个单点,使得prompt selection只集中在特定子集上。
- 由于key和query在整个学习过程中不断变化,这些参数的更新将会消除先前任务的参数,导致matchimg-level和prompt-level的遗忘,使prompt selection成为CL的瓶颈。
- 固定大小的Prompt Pool会使得模型的表示能力受限。但是,若Prompt Pool随着数据的发展而增长,可能会为旧任务检索新的提示,导致训练和测试之间的不匹配。
- 最后,一些研究发现prompt-based CL的性能低于简单的representation-based的baseline性能。并且批量提示有损比较的公平性。
2. Representation-based 方法
representation-based 方法直接利用预训练模型强大的泛化性和通用性来实现持续学习。
- 比如Simple-CIL方法,是ADAM算法原文中提出的Baseline,Simple-CIL冻结预训练模型参数,并通过求算类别中心的方式来构建Classifier。在面对很多类别时,计算同类的embedding或features的平均值,并将该平均值作为该类别的标准(prototype),最后结合类别标准与余弦比较的方法替换模型的原始Classifier。
虽然基于prototype的方法存在一定的作用,但是并未很好地适应下游任务。为此,一些研究在基于prototype方法的基础上结合了外置参数高效调节模块或者外置适配器来使得预训练模型更加适应下游任务,如ADAM等。
ADAM等算法在进行类别标准设定时,类别标准之间的仍存在联系,导致任务效果降低。为此,RanPAC算法则采用online LDA classifier来去除原始方法prototype计算结果之间的相关性,加大类别间的分布差异。此外,RanPAC算法利用Random Projection layer将features映射到高维空间中,并在高维空间中进行prototype的计算,以使得特征分布符合高斯拟合。
相较于前面将预训练模型的通用语和适应性分离处理的方式,SLCA算法采用了差异学习率调整和特征经验重播的方式进行持续学习研究。该算法使用较小的learn rate调整模型主体部分,而使用较大的learn rate 调节模型的classifier,以实现模型的逐步微调和classifier的快速适应。为了避免忘记以前的分类器,SLCA还对分类特征分布进行建模,并重播它们以校准classifier。
优点:
- 由于class prototype代表了对应类别最常见的标准格式,因此利用其构建模型具有直观和可解释性。
- Representation-based 方法主要是冻结backbone和更新classifier权重。lightweight的更新成本增加了其现实应用的可行性。
缺点:
- 将不同模型的特征连接起来形成class prototype,容易造成模型信息冗余。例如,不同的backbone中存在重复提取共享特征。
- 当下游任务涉及多个领域时,在第一阶段调整模型不足以弥合数据集之间的领域差距。在这种情况下,不断调整backbone可能更适合提取特定于任务的特征。
3. Model Mixture-based 方法
Model Mixture-based 方法在持续学习工程中构建了一组模型,然后再推理阶段通过Model Ensemble和Model Merge来进行信息综合决策。
Model Ensemble中,ESN算法凭借预训练模型强大的通用性,构建多个classifier,在面对新任务重新初始化和训练一个新的classifier。在推理时,采用投票策略来整合多个模型的结果进行最终决策。
由于Model Ensemble的核心因素取决于模型的方差,一些研究通过增强模型之间的多样性来替代使用相同的预训练模型构建不同的classifier。如PromptFusion利用预训练的ViT和CLIP,并在推理过程中动态地对logit进行组合,即f(x) = λ fvit (x) +(1−λ)fclip(x)。
与多个backbone的集成不同,PROOF采用了仅使用单个CLIP的更全面的推理方法。由于CLIP支持视觉和文本特征的跨模态匹配,因此PROOF设计了一个三层集成,考虑image-to-text、image-to-image prototype、image-to-adjusted text的跨模态融合。
Model Merge将多个不同的模型合并为一个统一的模型,无需要额外的训练。LAE定义了online和offline学习协议,online模型通过交叉熵损失进行更新,目的是在新的任务中获取新的知识。离线模型则通过Model Merge进行更新,例如指数移动平均(EMA): θ offline←α·θ offline +(1−α)·θ Online,其中α为权衡参数。LAE仅将EMA应用于参数高效调谐模块(如prompt),其利用online和offline模型的最大logit进行推断。
与LAE一样,ZSCL将合并技术应用于CLIP模型,目的是在持续学习过程中保持其zero-shot性能。然而,随着EMA中权衡参数的改变,CLIP性能不再具有鲁棒性。因此,ZSCL建议每隔几次迭代合并参数,从而在模型训练期间创建平滑的损失轨迹。
此外,CoFiMA注意到EMA在Merge过程中对每个参数的重要性是相等的,CoFiMA 在Merge过程中插入Fisher information(费雪信息)作为每个参数的估计重要性。
优点:
- 学习多个模型可以做出不同的决策。因此,使用Model Ensemble和Model Merge自然会产生更健壮的结果。
- 由于直接合并模型进行统一预测,因此可以调整前模型和后模型的权重,以突出不同阶段之间知识共享的重要性。
- 由于模型集将在推理过程中合并,因此最终的推理成本不会随着模型集中添加更多模型而增加。
缺点:
- Model Ensemble需要保存所有的历史模型,并消耗大量的内存缓冲区。虽然基于Model Merge不需要这么大的成本,但合并大型backbone的权重也需要大量的额外计算。
- 决定Merge哪些参数仍然是问题。
微调 (Fine-tuning)
这个阶段,预训练模型(可能经过了Post-pretraining)被进一步训练,以优化特定任务上的表现。
微调通常在一个相对较小的、特定任务的数据集上进行,这个数据集包含了明确的标签,模型通过监督学习来进行优化。
微调目的: 调整模型的参数,使其能够在特定任务上做出准确的预测。
SFT 监督微调
SFT (Supervised Fine-Tuning) 是微调的一种形式,强调在有监督的环境下进行。
SFT阶段,用特定领域数据或私有化数据, 对预训练模型进行改良。
这一阶段需要指令微调数据,数据集通常由输入(用户问题)和输出(标准答案)两个字段构成。标准答案通常由专家标注获得。
- 1、SFT是一种简单的微调方法,它使用带有正确答案的数据集来继续训练一个预训练的模型。
- 2、这种方法依赖于大量的标注数据,即每个输入都有一个预先定义的正确输出。
- 3、微调的目的是使模型更好地适应特定的任务或领域【垂直领域】,比如特定类型的语言理解或生成任务。
- 4、SFT通常不涉及复杂的策略或奖励函数,只是简单地最小化预测输出和真实输出之间的差异。
SFT VS Pretrain
【2024-10-22】细谈大模型监督微调SFT:实战经验技巧和debug分析思路
SFT 和 pretrain 在训练方式上没有任何区别,主要区别在于数据组成形式上:
- pretrain 每条数据都是满编 4K / 8K,SFT 每条数据原本多长就是多长;
- SFT 会引入 pretrain 阶段未见过的 special_token,来让它们学习全新的语义;
- SFT 会让模型见到最重要的 eos_token,pretrain 模型因为没见过该 token 而无法停止生成;
- 借助 special_token,SFT 会把语料切分成不同的角色,标配的有 system、user、assistant,根据业务需求也可以有“背景”、“旁白”、“事件”等等;
- SFT 的 prompt 不做 loss,但这并不是说它不能做 loss。主要原因是 prompt 的同质化比较严重,不做 loss_mask 的话,同样的一句话会被翻来覆去的学,但如果你能保证你的每条 prompt 都是独一无二的,就完全可以省去 prompt 的 loss_mask 环节。对了,session 数据一定要想清楚是每一个 answer 都算 loss,还是只对最后一轮的 answer 算 loss。
除此之外,训练目的也不一样。
- pretrain 是在背书,纯粹的学习知识;
- sft 则是在做题,学习的是指令 follow 能力。
切勿在 sft 阶段强行给模型做知识注入,比如训个 50W 条的 code 数据,所有的知识注入工作应该采用 continue-pretrain 的思路进行,否则都会使得模型的通用能力掉点明显(SFT 做知识注入基本上是 100% 某个知识,但 continue-pretrain 做知识注入会控制在 10% ~ 20% 左右的比例)。
RLHF 人类反馈强化学习
RLHF 利用人类反馈来训练强化学习模型。
在RLHF中,模型通过与人类交互获得反馈,这些反馈作为奖励信号来指导模型的行为。RLHF通常用于训练能够生成更自然、更符合人类偏好的文本或其他输出的模型。这种方法特别适用于需要模型理解和适应人类偏好的场景。
- 1、RLHF (Reinforcement Learning from Human Feedback) 是一种更复杂的训练方法,结合了监督学习和强化学习。
- 2、在RLHF中,模型首先通过
监督学习
进行预训练,然后通过人类提供的反馈来进行强化学习。 - 3、人类反馈可以直接对模型输出评分,或模型输出之间做出选择的偏好。
- 4、强化学习部分涉及到定义一个
奖励函数
,根据人类反馈来调整模型的行为,以优化长期的奖励。 - 5、RLHF目标: 训练出一个在没有明确标签的复杂任务中表现良好的模型,这些任务可能需要更细致的判断和调整。
思考
对齐
instruction following 是 alignment (对齐)的一个特殊形式,但它并不构成对齐的全部内容。
对齐问题原本称为价值对齐
(value alignment)指一个 AI 系统训练目标可能与其实际需要面对的核心价值并不一致。
- 训练目标与真正希望 AI 满足的目标之间存在不匹配,而如何解决这个不匹配的问题被称作 value alignment problem。
OpenAI 2024年初提出 “Super-Alignment”, 探讨了 AGI 的水平远远超越人类,人类将如何是好。
OpenAI 当时提出了一个概念,即 “Weak-to-Strong Generalization”,如果目前机器智能尚不及人类,人类尚能与之互动;但若其智能发展至极高水平,人类似乎难以与其沟通。那么也就产生了一个问题,人们应该如何训练 AI,是否应该采用特定的方式?Next Token Prediction 或是 instruction following 是不是一个好的对齐方法?
alignment 问题核心假设:
- 因为人类很多时候并不清楚自己到底想要什么,因此很难给出一个完全具体的价值观描述,且不同人的价值观都有区分。
- 如果人类给出的指令永远不是特别准确,那么 AI 系统在执行任务时需要保持一定的不确定性。
框架 Cooperative Inverse Reinforcement Learning,来源于师兄 Dylan Hadfield-Menell(目前在MIT任教)和导师做的一个研究。
- 假设每个人都有一个 hidden reward function。当人与 AI 交互时,人可能想的是 AI 帮我递个咖啡,但人给 AI 的具体指令可能并不是这样,比如人可能只是说了“给我个喝的”,AI 需要不断去推断人类的真正意图。
在这样的定义下,人类的真正意图可以被建模成一个隐藏的奖励函数,机器人需要不断地根据人给出的所有信息来主动推断人类的真正意图。如果不确定时,最优策略是 AI 去问人类。
post-training 让模型更聪明
【2024-8-23】RL 是 LLM 的新范式
曾在 OpenAI 负责 post-traning 的 John Schulman: (RL 拥趸和布道者)
- post-training 是模型变得越来越聪明的重要原因,而
RLHF
是最重要的技术 tricks。
John Schulman 对 RLHF 的信仰来自 OpenAI 的亲身实践:
- GPT-4 的 Elo 分数之所以能比第一代 GPT 高出 100 分也和 post-traning 的提升相关。
Scaling law 让 AI 更聪明,而 RL 让 AI 更有用
InstructGPT 核心思想
- 利用人类的判断来指导模型的训练,因为这些 instruction following 的任务本身就是人类给出的指令。
- InstructGPT 能够处理复杂的指令,包括写代码等任务,很多在 zero-shot 设定上 GPT-3 做不了的任务都可以被完成。
InstructGPT 目标: 微调 GPT 模型,使其能够产生满足人类指令的输出。
为了使 GPT 完成指令遵从,技术挑战集中在:如何收集数据?
为了实现这一目标,需要完成两件事情:
- 指令,fine-tuning 首先需要收集指令,即人类的 prompts 或 instructions。
- 反馈,需要收集好的反馈来满足 human instructions。
从训练语言模型的角度来看,收集大量的人类指令(human instructions),以及对应的人类反馈。这些对应好的数据将被作为 Next Token Prediction 的训练数据,通过传统语言模型训练方法,即 SFT (Supervised Fine-Tuning),来进行训练。
于是, InstructGPT 训练过程:
- • 第一步,通过 SFT 收集 human demostration data 进行 SFT。
- • 第二步,收集人类偏好数据,利用数据学习一个奖励模型。
- • 第三步,使用 reward model 进行强化学习的 RLHF 训练。
最终就可以得到优化后的 InstructGPT 模型。
之后的 ChatGPT 总体训练流程概括为两个主要部分。
Pre-training
:涉及使用大量数据,通过语言模型的训练方法来训练一个基础模型。Post-training
:InstructGPT
和ChatGPT
所执行的步骤,即利用人类的标注数据或高质量的人类反馈数据进行后训练。
Post-training
通常包括至少两个步骤:
- 1)SFT 步骤,通过 human demonstration 的方法进行
监督学习
; - 2)RLHF 步骤,通过 human preference data 的方法进行
奖励学习
。
预训练与后训练之间也存在区别:
- • 数据方面:预训练和后训练在数据的质量和数量上存在差异。
- 预训练阶段需要处理海量数据,这可能需要大量的计算资源和较长的时间。
- 而在后训练部分,大量的数据是人类标注或通过某种方式构造出来的数据,数据质量通常较高,但与预训练阶段相比,数量会少很多。
- • 训练目标方面:
- 预训练的目标是压缩和
Next Token Prediction
; - 后训练的目标是
instruction following
。通过训练激发大模型的能力与智能,使模型 usable,能够尊从人类指令。
- 预训练的目标是压缩和
- • 训练过程方面 (dynamics):
- 预训练通常是固定的,需要收集一个庞大的数据集进行训练,这些数据通常是静态的。
- 对应 post-training,尤其是 RLHF ,其反馈是在线的,需要不断收集人的反馈,不断迭代,逐渐进化模型,这是一个动态的在线过程。
最后, post-training phase 也被称为对齐
(alignment phase), 将 LLM 的能力和人类的偏好保持一致,希望大模型的输出能够满足人类的价值取向和意图,确保模型的输出与人类的偏好一致。
SFT < RLHF ?
【2024-8-23】RL 是 LLM 的新范式
为什么 RLHF
效果优于 SFT
?
PPO 算法提出者 John Schulman
,曾经在 OpenAI 工作,Berkeley 的PhD, 2024年4月, 到 Berkeley 做过一场讲座,仔细讨论了 RLHF PPO 的重要性,两个观点:
- 第一, SFT 会导致幻觉 hallucination :
- 第二, RLHF helps uncertainty awareness,让大模型“知道”自己“确实不知道”。
进一步完善, RLHF 过程三点好处:
- 使用 负向反馈 进行
对比学习
,通过对比过程帮助模型降低幻觉 halluciation。 - 强化学习不是一个固定的过程。允许模型随着能力的不断提升,通过不断地问问题、不断地给出答案、不断地评判,从而让模型不停地从当前能力的边界进行主动探索,并不断拓宽自己的能力边界。
- 这两个因素共同作用能够形成 反事实推理 counter-factual reasoning 的作用,有可能解锁
因果学习
(casual learning)的巨大潜力,让模型具备更强的 reasoning 能力。
SFT 会导致幻觉
John Schulman 认为,大型模型之所以会产生幻觉,是因为 SFT 阶段学到了一些不正确的认知。
举例
- 当 GPT-3 被要求 “ write a bio of AI researcher John Schulman”时,GPT 错误地输出:John 从 2009 年开始在 CMU 任职 associate professor,从 2012 年开始任职 professor。但是真实情况是,John 在完成 PHD 学位后就在 OpenAI 工作,并未在其他地方工作(注:最近John刚加入了Anthropic)。GPT-3 输出的内容与实际明显不符。
为何大型模型会生成这样的错误信息?
- 思维实验,假设在预训练阶段,就存在一个 知识截断(knowledge cut off)。比如,假设 ChatGPT 的所有的知识和数据都截止于 2023 年。到 2024 年,希望通过 SFT 的方式 fine-tune ChatGPT,让它来描述 2024 年欧洲杯的情况。但因为 GPT 在预训练过程中没有任何关于 2024 年欧洲杯的信息,它自然也不知道西班牙是否夺冠,也不知道是否有进球等具体情况。
如果用现有的数据进行简单的 SFT,实际上 GPT 并不知道 2024 年发生了什么,但由于 SFT 的数据中包含了其他欧洲杯相关的问答数据,这些回答都是精准的,因此大模型可能会觉得,对于2024年欧洲杯的问题也应该给出一个准确答案才可以,但它本身可能在预训练阶段并没有掌握正确的信息,于是就鹦鹉学舌地说一些错误的内容。这种情况下,SFT 过强的监督信号导致人类实际上在引导 ChatGPT 说它不知道的东西。
另外还存在一种可能性,即 GPT 实际上知道答案,但提供标注的人员不知道。
- 例如,如果问到 2022 年某场足球联赛的问题,标注人员可能不了解答案,而 GPT 反而可能知道。在这种情况下,标注人员可能会给出 “I don’t know ” 的人类反馈。这反倒可能导致 GPT 产生混淆,因为它明明知道答案却被要求说不知道。这两种原因综合来看就可能导致模型在经过 SFT 阶段后非常容易出现 hallucination 现象。
他人观点
- SFT 确实容易导致幻觉,但不一定完全是预训练阶段数据的知识截断导致的,SFT也能学习新知识
问题:大模型在是否学会新知识?
存在一个非常微妙的边界。
- 如果不提供数据,大模型就不能够提供答案;
- 如果提供数据不完整,可能导致模型出现
幻觉
; - 如果数据提供足够多,模型就可能会学会新知识。
因此,到底给多少数据,很难判断,SFT 高质量数据集也是非常难构建的,这里就有一个非常不容易的数据挑战( a non-trivial data challenge for building a good SFT dataset)。
RLHF让大模型“知道”自己“确实不知道”
RLHF helps uncertainty awareness,让大模型“知道”自己“确实不知道”。
欧洲杯的例子
- 如果大模型不知道 2024 年欧洲杯的情况,用户却让大模型去描述欧洲杯的情况(在2024年欧洲杯上哪位运动员有进球),那大模型就可能会产生幻觉,这是因为模型实际上并不了解 2024 年欧洲杯的具体事件但被 SFT 引导说一个貌似正确的回复。
RLHF 如何防止 hallucination 的出现?
- 如果存在一个设计良好的
奖励函数
,情况就会不同。 - 如果模型给出正确答案,就给予正向的奖励分数 1分;
- 如果模型表示“我不知道”,就给予 0分;
- 如果模型给出错误答案,则扣除分数 4分。
在这种情况下,如果模型不知道 2024 年发生了什么,在强化学习过程中无法提供正确的回答,选择“不知道”成为更合理的策略。
这种机制鼓励模型在不知道答案时能够提供“不知道”的回答。这种方式能帮助模型保留了一定的不确定性,使模型能够产生正确的自我认知,来判断是否真的知道一个问题的答案。
他人观点
- 基本正确,尽管 John 解释可能不完全准确
- RLHF 所带来的不仅仅是处理知识边界的不确定性的能力(not only handle the knowledge cut off problem)
RLHF 提高了模型推理能力
RLHF 过程不仅帮助模型意识到不确定性,更重要的事情是 RLHF 帮助模型提高了 reasoning 能力。
相关性
不代表因果性
。大家会希望大模型掌握因果性
,而不希望仅仅看到相关性
。
因果性指什么?
- 传统统计学习里面有一个判断因果性的过程,叫 反事实推理 counter-factual reasoning。
是否可以舍弃 online attempt
问题:
- 模型训练上利用 negative signal 和 online exploration 两件事上,是否可以舍弃 online attempt ?即只通过正反馈和负反馈是否足够,而不需要模型持续在线尝试。只通过 contrasted learning,在 SFT 上加上负向案例,能否达到预期效果?
可以, DPO
( Direct Policy Optimization)
- 它与
PPO
算法的主要区别:DPO
去除了在线尝试的部分。DPO
算法其实很简单,基本遵从了SFT训练流程,但是在收集正例之外还会收集负例,对于每一个 prompt 都要求标注员提供好的和坏的两个答案。对于好的答案提升概率,对于坏的答案则是让模型“不说”。
DPO 算法是否能达到与 PPO 效果?
- 今年的 ICML2024 大会上的论文,Is DPO Superior to PPO for LLM Alignment?A Comprehensive Study 讨论了这个问题。这篇论文也是今年被选中的 4 篇有关 alignment 的 oral papers 的其中之一。
如果仅仅通过静态数据 覆盖 LLM 所有可能的输出, 非常困难。因此,在线探索和及时奖励反馈是一种更加高效让 LLM 学会说正确答案的方法。
结论
- 如果能够实现
PPO
算法,PPO 效果将会远远超过DPO
。因为, 正例反例和在线探索两件事都非常重要。 - 用 PPO 和 Code Llama 在 Coding Contest 上做了测试,发现使用开源模型加上 PPO 可以比 AlphaCode 这样的闭源模型在很难的 CodeForce 竞赛题上通过率提高 6%。这是一个纯开源模型加 RLHF 的尝试,并未添加任何新的数据。在这种很难的、需要强调 reasoning 能力的任务上,DPO 完全没有效果。
PPO RLHF 框架有哪些挑战?
PPO 包含四个模型:actor、critic、value network 和 reference network。
- 不同模型还有不同依赖,也就是前后依赖关系;
- 不同模型也有不同吞吐量,比如,actor 是一个传统的大模型,需要输出所有 response,而 critic 则只需要做评分。评分的吞吐量会远小于需要输出 response 的模型。
因此,不同模块的计算量存在显著差异。将这四个模块 scale up,并且做好算力平衡是具有挑战的。
挑战
- 算法: PPO RLHF 算法流程相对复杂
- 算法、流程都相对麻烦,多了很多流程。不仅需要正反馈、负反馈、需要奖励模型,并且涉及在线探索过程。
- 建议: 要 advantage normalization、需要一个大的 training batch;reference model 需要 moving average 等。
- 系统: 强化学习训练系统与传统的 SFT 有不太一样
- SFT 或 DPO 模型通常只包含一个 policy 模型,只需将数据输入语言模型即可,其训练逻辑相对简单。然而,对于强化学习,或者对于 PPO RLHF,情况则更为复杂。
- 数据: 数据非常重要
- RLHF 数据包括两部分:一是 prompt,即人写的 instruction。二是指模型的 responses。这两部分都相当复杂
PPO RLHF 面临的挑战主要分为算法、系统和数据三个方面:
- 算法层面:关键在于如何稳定训练过程,并调整算法的细节以提高性能。
- 系统设计:由于强化学习 PPO,RLHF 的计算流程非常复杂,系统设计需要提高整体的训练效率。
- 数据:数据分为两部分,一部分是 prompt,一部分是 response。两部分都很关键,只有将它们结合起来,才能形成一个完整的,比较成功的 PPO RLHF 的 training process。
【2024-8-23】RL 是 LLM 的新范式
训练数据
【2024-9-11】大模型数据基础:预训练阶段数据详解
- 预训练数据集组成
- 1 通用预训练数据集
- 1.1 网页
- 1.2 语言文本
- 1.3 书籍
- 1.4 学术材料
- 1.5 代码
- 1.6 平行语料库
- 1.7 社交媒体数据
- 1.8 百科全书
- 1.9 多类别数据
- 2 特定领域预训练数据集
- 预训练数据处理步骤
- 1 数据收集
- 2 数据过滤
- 2.1 基于模型的方法
- 2.2 基于启发式的方法
- 3 数据去重
- 4 数据标准化
- 5 数据审核
- 预训练数据整体分布现状及分析
预处理通常包括五个步骤:
【2024-5-23】再聊多轮对话微调训练格式与长序列训练
3个阶段的数据集格式: 增量预训练、单轮对话、多轮对话
- 增量预训练数据集:提升模型在特定领域或任务的能力。
- 单轮对话和多轮对话数据集:用于指令微调(instruction tuning)阶段,以提升模型回复特定指令的能力。
指令微调阶段目标:训练语言模型根据人类指令给出回答。一般只有回答部分(Output)的 loss 会用于梯度回传,而指令部分(System、Input)部分的 loss 则不会用于权重更新。
数据集进行预处理时引入 “system”、”input” 和 “output” 三个字段
- “system”、”input” 字段用于保存不需要计算 loss 的文本,如 系统或用户指令
- 而 “output” 字段则用于保存 需要计算 loss 的文本,如 输入指令对应的 GroundTruth 回答。
数据量
资源受限时,模型训练应该用多少数据?
- 预训练: 参考 缩放定律 ( scaling law)
- 微调: 如下文
【2024-7-29】大型语言模型高效微调策略,通过实验发现少量数据即可显著提升特定任务性能,并提出一种基于早期模型表现的贝叶斯超参数优化方法,有效预测最终模型效果,为资源节约型的LLM微调提供新途径。
数据效率研究
模型性能与数据量之间的最佳平衡点,从而优化资源利用。
- 虽然小型数据集显著改进效果,但是必须仔细考虑训练数据中属性分布,确保模型在所有目标变量上的全面表现。
- 另外可探索数据增强技术或不同的采样策略,增强模型性能,特别是针对那些出现频率较低的属性。
数据量对模型效果影响
200
(显著提升18pp) ->1000
(放缓) ->6500
(平衡点过后,收益减少)
详情
- (1)快速初始改进:
- 约
200
个样本(相当于大约100个网页),模型准确率从70%显著提升至88%。—— 即使是相对较小的数据集也能带来显著的性能提升。
- 约
- (2)收益递减:
- 达到
1,000
个样本后,准确率提升速度放缓,大部分性能增益在这个数据量水平就已经实现。
- 达到
- (3)属性特定趋势:
- 后期准确率提升主要由一个特定属性类型(如产品评分)所驱动。这一属性在数据集中出现的频率较低,只在大约25%的产品详情页面中出现。
- (4)性能瓶颈:
- 大约
6,500
个样本时,模型达到最大性能,这表明存在一个“最佳点”,在此之后,更多数据带来的收益逐渐减少。
- 大约
- (5)战略数据采样重要性:
- 即使小数据集也能显著提升模型性能,但要确保所有目标变量在训练数据中的分布均衡,以实现全面的模型表现。
超参数优化
通过采用贝叶斯
(Bayesian)优化并结合早期模型性能评估,可显著提高大型语言模型微调的效率和效果,减少计算成本,同时确保高最终准确率。
- 首先,使用一系列超参数进行
LoRA
微调。 - 然后,训练过程早期阶段,使用模型评估验证集上的准确率。
- 接着,将超参数配置及准确率添加到结果池中。
- 最后,运用Bayesian优化算法,基于结果池生成下一组超参数。
(1)超参数优化目标
- 寻找最优超参数集:找到一组能最大化模型在验证集上性能指标(如准确率)的超参数集合。
- 预测最终性能:最大化早期训练阶段与最终训练阶段之间模型性能的相关性,以便通过早期表现预测最终模型的质量。
(2)方法论
- Bayesian优化:采用Bayesian优化算法智能地探索超参数空间,平衡
探索
(exploration)和利用
(exploitation),通过构建代理模型(surrogate model)预测不同超参数设置下的模型性能。 - LoRA微调:首先使用一组超参数进行LoRA(Low-Rank Adaptation)微调,然后在训练过程的早期阶段评估模型性能。
- 迭代优化:保存超参数配置及其对应的性能值,然后使用Bayesian优化算法更新代理模型,建议下一步要评估的超参数配置。
训练早期阶段的模型性能与最终阶段的性能具有强烈的正相关性: 早期评估可有效地预测模型质量。
数据配比
引入大量行业数据,模型怎么反而变弱了? 参考
- 对一个回答问题能力不错的模型,用大量数据做
指令微调
后,模型不会回答问题了。
原因:
- 数据配比
- 数据差异过大
大模型可能在训练过程中过度专注于垂类数据,导致 loss 收敛不再依赖全局而是从部分数据进行考虑。
贝壳论文中,比较好的结果:
- 开源数据集:垂域数据集 = 4:1, 即开源占比总体训练数据的80%,而垂类数据仅占20%。
- 《垂域大模型训练》
对 continue pretraining, 如果要让模型不丢失通用能力,比如 summarization,qa 等
(1) 领域数据 continue pretraining 时,一定更要混合大量通用数据。
- 「领域数据比例要在
15%
以下」- 一旦超过这个阈值,模型通用能力会下降很明显。
- 这个阈值和不同的预训练模型相关,有些模型比如llama需要控制的阈值更低。
阈值其实是经验主义结论,范围都在 10%-15% 左右。
- 而且阈值和预训练模型的大小,预训练时原始数据的比例等条件都息息相关,需要在实践中反复修正。
(2) sft 比例可提高不少
领域数据
:通用数据
=1:1
- 如果sft数据量少,混不混数据差别就不太大了。
统一格式
统一增量预训练
、单轮对话
和多轮对话
三种数据集格式
[{
"conversation":[
{
"system": "xxx",
"input": "xxx",
"output": "xxx"
}
]
},
{
"conversation":[
{
"system": "xxx",
"input": "xxx",
"output": "xxx"
},
{
"input": "xxx",
"output": "xxx"
}
]
}]
训练过程中,将一条数据中 多组 “system”、”input” 和 “output” 进行拼接,之后输入模型,并行计算每个位置的 loss ,但只有 “output” 部分对应的 loss 参与梯度回传
<BOS>
和<EOS>
表示句子或文本的开始和结束
图解
增量预训练
增量预训练旨在帮助模型学习针对特定下游任务的语言知识和表达能力,因此数据集的全部内容对应的 loss 都应该用于梯度回传。
因此,数据集的 “system”、”input” 为空,而 “output” 为一整条语料数据。
[{
"conversation":[
{
"system": "",
"input": "",
"output": "I am an artificial intelligence (AI) assistant named Puyu. I was created by the Shanghai AI Laboratory and my purpose is to assist users with various tasks through natural language processing technology."
}
]
},
{
"conversation":[
{
"system": "",
"input": "",
"output": "I am an artificial intelligence programmed to assist with various types of tasks, including answering questions, providing information, and performing automated processes."
}
]
}]
单轮数据
单轮对话数据集由1条指令(或问题)及其对应 GroundTruth 回答组成。
由于只有回答部分需要对 loss 进行回传,因此数据集的 “system”、”input” 字段为输入指令,”output” 字段为对应回答
[{
"conversation":[
{
"system": "You are an AI asssistant."
"input": "Give three tips for staying healthy.",
"output": "1.Eat a balanced diet. 2. Exercise regularly. 3. Get enough sleep."
}
]
},
{
"conversation":[
{
"system": "You are an AI asssistant."
"input": "How to study English?",
"output": "1. Set clear goals. 2. Create a study plan. 3. Build vocabulary. 4. Practice speaking."
}
]
}]
多轮数据
多轮对话数据集往往由多轮指令(或问题)+ 对应 GroundTruth 回答组成。
假设有一条多轮对话数据,内容如下。
对于第 n 轮对话,将 User 和 Assistant 对应的输出设为 UserN 和 AssistantN。
System: You are an AI asssistant.
User1: Hello?
Assistant1: Hello! How can I help you?
User2: What\'s the date today?
Assistant2: Today is Monday, August 14, 2023.
User3: Thank you!
Assistant3: You are welcome.
如何使用上述这条多轮对话数据训练大模型?目前有两个主流方法。
- 方法 1
- System、User1、Assistant1、User2、Assistant2、User3 文本都视为模型的输入部分,将 Assistant3 的文本视为模型的预测部分,只有 Assistant3 部分的 loss 参与权重更新。
- 弊端在于没有充分利用多轮对话的训练数据,因为 Assistant1 和 Assistant2 的内容没有参与模型训练,导致训练数据利用率较低。
- 方法 2
- 将1条多轮对话数据拆分成多条数据。如将以上示例拆分成如下三条数据。
- 相比于方法1,方法2可以充分利用每一轮对话的数据,但需要将一条包含 n 轮对话的数据拆分为 n 条数据,训练效率降低 1/n。
- 方法 3
- XTuner 训练多轮对话模型时,采取了一种更加充分高效的方法。
- 将多轮对话进行拼接,之后输入模型,并行计算每个位置的 loss,而只有 Output 部分的 loss 参与回传。
[{
"conversation":[
{
"system": "You are an AI asssistant."
"input": "Hello?",
"output": "Hello! How can I help you?"
},
{
"input": "What's the date today?",
"output": "Today is Monday, August 14, 2023."
},
{
"input": "Thank you!",
"output": "You are welcome."
}
]
},
{
"conversation":[
{
"system": "You are an AI asssistant."
"input": "Hello?",
"output": "Hello! How can I help you?"
},
{
"input": "How's the weather today in Rosso?",
"output": "The weather in Rosso on Wednesday, August 16th, is going to be cloudy for most of the day, together with moderate rain around noon."
},
{
"input": "Thank you!",
"output": "You are welcome."
}
]
}]
数据集中的 “conversation” 键对应的值是一个列表,用于保存每一轮对话的指令和实际回答(GroundTruth)。为了保持格式统一,增量预训练数据集和单轮对话数据集中的 “conversation” 键也对应一个列表,只不过该列表的长度为 1。而在多轮对话数据集中,”conversation” 列表的长度为 n,以容纳 n 轮的对话内容。
LLMs 数据格式汇总
各类LLM数据格式汇总: chat_template
不同模型在是否存在默认 system message上, 有所不同(大多数模型都是没有的)。
每个模型都附上了有system版本和无system版本,如果在训练模型时希望加上system message, 可以参照template模板自行添加。
Qwen
官方默认 system message 即:You are a helpful assistant
<|im_start|>system
You are a helpful assistant<|im_end|>
<|im_start|>user
This is a instruction<|im_end|>
<|im_start|>assistant
This is a answer<|im_end|>
Yi
官方版本没有默认 system message,可以与llama一样, 不加 system message使用,有
<|im_start|>system
This is a system message<|im_end|>
<|im_start|>user
This is a instruction<|im_end|>
<|im_start|>assistant
This is a answer<|im_end|>
无system模式
<|im_start|>user
This is a instruction<|im_end|>
<|im_start|>assistant
This is a answer<|im_end|>
Gemma
官方版本不支持system
无system模式
<bos><start_of_turn>user
This is a instruction<end_of_turn>
<start_of_turn>model
This is a answer<end_of_turn>
Phi-3
官方版本没有默认的system message, 有此需求可依据下述模板自己构建
<s><|system|>
This is a system message<|end|>
<|user|>
This is a instruction<end>
<|assistant|>
This is a answer<end>
无system模式
<s><|user|>
This is a instruction<end>
<|assistant|>
This is a answer<end>
Deepseek
官方同样没有提供默认system message,有此需求可依据下述模板自己构建
<|begin▁of▁sentence|>This is a system message
User:This is a instruction
Assistant:This is a answer<|end▁of▁sentence|>
无system模式
<|begin▁of▁sentence|>User:This is a instruction
Assistant:This is a answer<|end▁of▁sentence|>
Mistral
没有提供system模式
无system模式
<s>[INST]:This is a instruction [/INST]This is a answer</s>
Llama2
<s>[INST] <<SYS>>
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
<</SYS>>
There's a llama in my garden 😱 What should I do? [/INST] This is a answer</s>
Llama3&3.1
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
This is a system prompt.<|eot_id|><|start_header_id|>user<|end_header_id|>
This is the first user input.<|eot_id|><|start_header_id|>assistant<|end_header_id|>
This is the first assistant response.<|eot_id|>
MiniCPM
<用户>This is a system message<AI>This is a instruction</s>
DeepSeek-coder
<|begin▁of▁sentence|>User: {user_message_1}
Assistant: {assistant_message_1}<|end▁of▁sentence|>User: {user_message_2}
Assistant:
You can also add an optional system message:
<|begin▁of▁sentence|>{system_message}
User: {user_message_1}
Assistant: {assistant_message_1}<|end▁of▁sentence|>User: {user_message_2}
Assistant:
ChatGPT 三步走
InstructGPT 分为如下三大步:
SFT
:生成模型GPT的有监督精调
(supervised fine-tuning)RM
:奖励模型
的训练(reward model training)PPO
:近端策略优化模型
( reinforcement learning via proximal policy optimization)
SFT
(supervised fine-tuning) 主要还是大量Prompt数据
- GPT模型通过有监督Prompt数据进行精调,即 next token prediction 任务(
NTP
)。 - 然后用精调后的模型对每个输入的 < 文本+prompt > 进行 generate,生成4~9个输出,并且进行解码操作。
【2023-11-20】transformers_tasks GPT-2 和 RLHF 示例
【2023-9-11】Understanding and Using Supervised Fine-Tuning (SFT) for Language Models
GPT 训练流程
【2023-5-23】Andrej Karpathy 在微软Build 2023开发者大会上进行了主题演讲:State of GPT(GPT的现状)
模型训练分为四个阶段:预训练
(Pretraining)、监督微调
(Supervised Finetuning)、奖励建模
(Reward Modeling)、以及强化学习
(Reinforcement Learning)。
- 数据量:预训练阶段所需的数据量很大,但质量要求不高;而后面的三个阶段恰恰相反,需要的数据质量较高。
- 训练方法:预训练和监督微调的训练方法相同,都是预测下一个单词。奖励模型和强化学习的训练方法则不同。奖励模型是二元分类学习,而强化学习则鼓励模型生成奖励模型评分较高的回答。
- 训练所需资源:预训练阶段的资源消耗巨大,使用数千颗GPU,花费数月时间,占总训练时间的99%。后面的三个阶段只需使用数十颗GPU,训练时间约数天。
预训练阶段的资源消耗如此巨大,只有大厂才有能力进行。如果资源有限,我们应将重心放在后三个阶段
ChatGPT流程
InstructGPT和instruction tuning方向的工作比较相关,独特之处在于继承了之前工作的风格——对齐人类偏好。与之前摘要任务相比,instructGPT的prompt分布更多样和复杂。
【2023-5-1】 ChatGPT训练三步流程
- AC架构中,Actor(学习策略)和Critic(学习价值)是两个模型,训练过程中参数都是变动的
- PPO基于A2C算法(同步优势更新的AC),经验回放过程中更新参数
- Critic和RM是一个模型,Instruct GPT论文中,RM 都是6b gpt-3, Critic目标不只是学习RM,还要配合、监督actor同步更新
- RLHF训练过程中涉及4个模型:actor、critic、rm和sft,后两者冻结,前两个持续更新
- PPO损失函数组成:RM打分损失 - β SFT差异损失,即打分高而差异小
备注
- SFT数据集和RM数据集的prompts来自于API和标注人员编写
- SFT数据集的回答是标注人员写的;
- PPO数据集来自于API
- Prompts的任务类型包括生成、QA、脑暴、聊天、改写、摘要、分类、抽取等任务。
- RM模型用了
GPT-3
6B,训练方法和之前摘要任务一样。 - Policy增加一个LM pretrain objective,可以修复alignment tax,让RL policy在公开NLP数据集上表现也很好
优化
全参数微调:对模型所有参数进行调整,如:SFT
- 问题:代价大(模型大、数据多、参数量大)
混合精度微调:省显存、加速,但丢失精度
- 训练时同时使用16位、32位浮点类型,加速,减少内存开销
- 部分参数使用32位类型,以保持数值稳定性,缩短单步用时
- 现代加速器使用16位专用硬件执行运算,速度更快
注意
- FP16: 内存存储、乘法运算
- FP32: 累加运算,避免下溢
- 手动放大梯度,以避免梯度爆炸,这样 FP16和FP32运算时,不容易出现下溢
其它加速方法
- 多卡并行:数据并行、模型并行
ZeRO
:分布式机器学习的新型内存优化技术,将三个步骤(optimizer state partitioning/add gradient partitioning/add parameter partitioning)拆分到不同的卡上,相比 数据并行,节省GPUp-tuning
:通过prompt encoder结构将prompt编码为向量,再与input embedding拼接。增加模型理解能力LoRA
:低秩适配,冻结大模型权重,只训练新增的网络层(两个小矩阵的乘积),降低fine-tune成本,同时保持类似效果
(0) Pre-Train
问题
【2024-7-28】面试LLM//各阶段
CLS token
预训练阶段:
- 模型训练句子时, 没有加
<CLS>
Token,但是预测时加了<CLS>
Token - 或者训练时加了
<CLS>
token, 但是预测时没有加<CLS>
Token
benchmark 预测会有啥问题?
benchmark会直接崩溃,之前gemma-2b训的时候带BOS,预测忘加了,benchmark全崩了。
原因
- 一个句子的第一个Token在模型中会吸收大量attention,那么当预测时改变了第一个Token,句子的预测会改变比较大,因为第一个Token改变了,而预测中大量attention来自第一个Token,所以预测的时候大量benchmark效果会不好。
三个阶段训练(SFT->RM->PPO)过程较长,更新迭代较慢?
考虑以下几种方法:
- 并行化训练:利用多个计算资源进行并行化训练,可以加速整个训练过程。可以通过使用多个CPU核心或GPU来并行处理不同的训练任务,从而提高训练的效率和速度。
- 分布式训练:将训练任务分发到多台机器或多个节点上进行分布式训练。通过将模型和数据分布在多个节点上,并进行并行计算和通信,可以加快训练的速度和更新的迭代。
- 优化算法改进:针对每个阶段的训练过程,可以考虑改进优化算法来加速更新迭代。例如,在SFT(Supervised Fine-Tuning)阶段,可以使用更高效的优化算法,如自适应学习率方法(Adaptive Learning Rate)或者剪枝技术来减少模型参数;在RM(Reward Modeling)阶段,可以使用更快速的模型训练算法,如快速梯度法(Fast Gradient Method)等;在PPO(Proximal Policy Optimization)阶段,可以考虑使用更高效的采样和优化方法,如并行采样、多步采样等。
- 迁移学习和预训练:利用迁移学习和预训练技术,可以利用已有的模型或数据进行初始化或预训练,从而加速训练过程。通过将已有模型的参数或特征迁移到目标模型中,可以减少目标模型的训练时间和样本需求。
- 参数调优和超参数搜索:对于每个阶段的训练过程,可以进行参数调优和超参数搜索,以找到更好的参数设置和配置。通过系统地尝试不同的参数组合和算法设定,可以找到更快速和高效的训练方式。
综合运用上述方法,可以加速三个阶段训练过程,提高更新迭代的速度和效率,从而减少训练时间和资源消耗。
(1) 第一步 SFT(全参数微调)
SFT 原理比较简单,难的是数据问题,需要大量的有监督Prompt文本
- Transformer【左】GPT【右】
大模型训练基座模型时,都采用「Next Token Prediction,NTP
」 任务
【2024-5-31】sft分为两种,拟合和对齐。
- 拟合:通过finetuning 得到稳定、符合需求的输出,包括格式、风格、特定模式等,是在业务落地中高频使用的方式;
- 对齐:指令对齐,让LLM更好地理解人类语言、执行自然语言指令,即LLM三个阶段之第二个阶段(pretrain、sft、rlhf)。
loss 改进
【2024-9-24】SFT loss 计算的那些坑(多轮合并/packing)
SFT 训练时, 直接输入 (input_ids, label)
, 训练效率低。
通常有两个加速方法:
- 多轮合并: 同一个会话的拆分、合并
- user 和 bot 交互了 3 轮, 数据格式: bot作答部分用 input_ids, 其余用 -100 表示
- (system, user1,
bot1
, pad), bot1 计算loss - (system, user1, bot1, user2,
bot2
, pad), bot2 计算loss - (system, user1, bot1, user2, bot2, user3,
bot3
), bot3 计算loss - loss 表达式:loss = 1/3 (l1/n1+l2/n2+l3/n3)
, ni 是 boti token数, li 是第i个样本的 loss - 不同样本之间有很多重复计算的前缀, 训练偏慢
- 加速
- 将3个样本合成1个, 借助 causal attention mask,每个 token 只能看到前面的 token,计算上和之前是等价
- 数据格式: (system, user1,
bot1
, user2,bot2
, user3,bot3
), 对应权重 li/ni - 问题: loss 计算有问题, pytorch
CrossEntropyLoss
默认取均值 mean,loss = (l1+l2+l3)/(n1+n2+n3)
, 而 ni 不一定相同, 导致 短句子权重被降低, 长句子被加权, loss 不等价
- packing: 将多个会话合成一条, 进一步加速
- 将所有样本拼接成1条,并加入
attention mask
, 保证后面的样本看不见前面的token。如 在 flash attention 中调用 flash_attn_varlen_qkvpacked_func,并传入 cu_seqlens 参数。 - 和之前一样,如果不修改 loss 计算方法,packing 的样本之间会存在因为长度不同,导致训练不充分的问题。
- 将所有样本拼接成1条,并加入
loss 计算会经历三次平均
- micro batch 维度,分母是这个 micro batch 中的所有 label 不是 -100 的 token 数
- DP 维度,分母是 DP size (和GPU数量相关)
- 梯度累加维度,分母是梯度累加数
禁用这三个平均,统一用 global batch
对话轮数作为分母。
- 新版 megatron 框架中,开启开关
--calculate-per-token-loss
, 即可禁用 DP 和梯度累加的平均 - 然后 修改
loss_func
,每个micro batch
都需要返回这个micro batch
的轮数 - 最后 框架会自动将所有轮数求和,作为分母。对于分子,需要除以这个轮次的token 数。
正确实现代码如下(loss_token_num, turn_num 是在构建 data 的时候构建的):
def loss_func(output_tensor, loss_mask, loss_token_num, turn_num):
losses = output_tensor.view(-1).float()
loss_mask = loss_mask.view(-1).float()
loss_token_num = loss_token_num.view(-1).float()
# label: [-100, -100, a, a, a, -100, b, b, -100, -100, c, c, c, -100, -100]
# loss_mask: [0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0]
# losses: [a0, a1, a2, a3, a4, b0, b1, b2, c0, c1, c2, c3, c4, d0, d1]
# losses * loss_mask = [0, 0, a2, a3, a4, 0, b1, b2, 0, 0, c2, c3, c4, 0, 0]
# loss_token_num: [3, 3, 3, 3, 3, 2, 2, 2, 3, 3, 3, 3, 3, 1, 1]
# losses * loss_mask / loss_token_num = [0, 0, a2/3, a3/3, a4/3, 0, b1/2, b2/2, 0, 0, c2/3, c3/3, c4/3, 0, 0]
# sum = 1/3 (a2 + a3 + a4) + 1/2 (b1 + b2) + 1/3 (c2 + c3 + c4)
loss = torch.sum(losses * loss_mask / loss_token_num)
loss_and_turn_num = torch.cat([loss.view(1), turn_num.view(1)])
# Reduce loss for logging.
loss_and_turn_num = loss_and_turn_num.clone().detach()
torch.distributed.all_reduce(loss_and_turn_num, group=mpu.get_data_parallel_group())
# 新版返回结构,开启 calculate_per_token_loss 开关后,返回三个值
# 第一个是反向传播实际使用的 loss, 所有 packing 的 loss 求和
# 第二个是 turn_num, 优化器状态更新时会使用对这个值求和然后缩放梯度
# 第三个是用于日志打印的 loss, 包含两个值,第一个是所有 loss 求和作为分子,第二个是所有 turn_num 求和作为分母
return loss, turn_num, {"lm loss": (loss_and_turn_num[0], loss_and_turn_num[1])}
无论是哪种方法,加速后都需要保证 loss 和原来等价。
加速注意:
- 不同样本之间等价;
- 不同轮次之间等价。
合并多轮 / packing 时,要修改 loss 计算方法,为每个 token 设置正确权重,并且关闭 DP
/ 梯度累加
的平均。
IFT 问题
5 月,伯克利的论文 The False Promise of Imitating Proprietary LLMs 指出这种方式微调出来的指令遵循模型存在的一系列问题:
- 在缺少大量模仿 ChatGPT 数据支持的任务上,这类模型无法改善 Base Model 到 ChatGPT 的差距;
- 这类模型只是擅长模仿 ChatGPT 的风格,而不是事实性,导致实际的性能差异会骗过人类评估者;
- 当前开源模型最大的限制仍然是 Base Model 层面跟 GPT 系列的差距,在微调而不是预训练环境进行优化可能是不正确的方向;
- 为了广泛地匹配 ChatGPT 支持的任务,需要更广泛和大量的模仿数据集,还需要新的工作;
而 6 月份 Allen Institute for AI 和华盛顿大学的 How Far Can Camels GO ?工作再次通过实验表明不同的指令微调数据集可以释放或者增强特定的能力,但并没有一个数据集或者组合可以在所有的评估中提供最佳性能,并且这一点在人类或模型担任评估者时也很容易无法被揭示。
对于指令遵循微调背后的团队来说,他们也意识到自己的模型由于 Base Model(LLaMA)的限制,在复杂推理和代码任务上很弱,并且难以进入正向数据飞轮 —— 模型能力越弱的领域越难得到更多的 query,也就难以筛选出高质量 query,想自己再标注提升模型能力就很困难。
至此,开源社区已经充分意识到原来这套微调 LLaMA 的框架的局限性,越来越多的团队开始探索预训练环节和更接近真实的人类反馈数据
数据示例
数据准备
Raw Data | Prompt | Label |
---|---|---|
我们去成都旅游,必须要去的地方是大熊猫繁殖基地。 | 大熊猫是 | 一种有黑白斑纹的动物。 |
我们去成都旅游,必须要去的地方是大熊猫繁殖基地。 | 大熊猫是 | 中国特有种,主要栖息地是中国四川、陕西和甘肃的山区。 |
我们去成都旅游,必须要去的地方是大熊猫繁殖基地。 | 大熊猫是 | 已在地球上生存了至少800万年,被誉为“活化石”和“中国国宝”即国兽,世界自然基金会的形象大使,是世界生物多样性保护的旗舰物种。 |
我们去成都旅游,必须要去的地方是大熊猫繁殖基地。 | 大熊猫是 | 属于熊科、大熊猫属的哺乳动物。仅有二个亚种。雄性个体稍大于雌性。体型肥硕似熊、丰腴富态,头圆尾短,头躯长1.2-1.8米,尾长10-12厘米。 |
raw_data = "我们去成都旅游,必须要去的地方是大熊猫繁殖基地。"
prompt = "大熊猫是"
labels = ["一种有黑白斑纹的动物。","中国特有种,主要栖息地是中国四川、陕西和甘肃的山区。",
"已在地球上生存了至少800万年,被誉为“活化石”和“中国国宝”即国兽,世界自然基金会的形象大使,是世界生物多样性保护的旗舰物种。",
"属于熊科、大熊猫属的哺乳动物。仅有二个亚种。雄性个体稍大于雌性。体型肥硕似熊、丰腴富态,头圆尾短,头躯长1.2-1.8米,尾长10-12厘米。"]
combine_data = [raw_data+prompt+label for label in labels]
初始化模型,对输入数据进行编码, 以 GPT-2 模型为例
from torch.utils.data import Dataset
from transformers import Trainer, TrainingArguments
from transformers import AutoTokenizer, AutoModelForCausalLM
# 模型加载
tokenizer = BloomTokenizerFast.from_pretrained('pre_train_model/gpt2')
model = BloomForCausalLM.from_pretrained('pre_train_model/gpt2')
# 自定义DataSet类
class Datasets(Dataset):
def __init__(self, sample):
super(Datasets, self).__init__()
self.sample = sample
def __getitem__(self, item):
res = {k: v[item] for k, v in self.sample.items()}
return res
def __len__(self):
return len(self.sample['labels'])
# 数据转换
combine_data_token = tokenizer.batch_encode_plus(
initial_data_,
max_length=256,
padding='max_length',
truncation=True,
return_tensors='pt'
)
# 将标签标签加入
combine_data_token['labels'] = combine_data_token['input_ids']
combine_data_token['labels'] = torch.where(
combine_data_token['labels']==0,
-100,
combine_data_token['labels']
)
# 模型训练保存
trainer_args = TrainingArguments("./model/", learning_rate=2e-5, weight_decay=0.01, num_train_epochs=10, auto_find_batch_size=True)
trainer = Trainer(model=initial_model, args=trainer_args, train_dataset=Datasets(initial_token_info))
trainer.train()
trainer.save_model()
# ----- 加载生成 --------
# 加载模型
model = AutoModelForCausalLM.from_pretrained('./model')
# 处理输入数据
input_data = raw_input + prompt
input_datas = tokenizer.encode_plus(
input_data,
return_tensors='pt'
)
input_ids = input_datas['input_ids']
# 模型生成
result = model.generate(
input_ids=input_ids,
max_length=256,
do_sample=True, # 增加随机性
num_beams=5,
num_return_sequences=5, # 每个样本生成5个结果
no_repeat_ngram_size=3, # 防止重复的token
early_stopping=True # 提前停止
)
decode_tokens = tokenizer.batch_decode(
result,
skip_special_tokens=True
)
results = [i.replace(' ', '') for i in decode_tokens]
print("results",results)
结果:
我们去成都旅游,必须要去的地方是大熊猫繁殖基地。大熊猫是今世界上保存最完好的哺乳动物之一,也是世界自然保护联盟濒危物种红色名录的保护对象之一。在这里,你可以看到全世界最大的熊猫栖息地成都。成都是中国国家林业局直属的国家重点风景名胜区,是国家森林公园、国家湿地公园和国家地质公园的重要组成部分,是全国重点文物保护单位、全国生态文明建设示范区、中国红色旅游名城、国际生态旅游目的地和国际旅游岛建设先进区。地址:四川省成都市绵阳市成华区成都高新技术产业开发区成华大道1号乘车路线:成都绵阳都江堰雅
我们去成都旅游,必须要去的地方是大熊猫繁殖基地。大熊猫是我国唯一的国家二级保护动物,是世界上保存最完整的动物种群之一,也是我国第一个国家级自然保护区。我们是四川省的首批国家重点保护野生动物和珍稀动物基金会的成员,被誉为中国动物保护的摇篮和世界生物多样性保护基地,被中国科学院、中华人民共和国国家林业局授予全国生态文明建设示范区称号,被国务院批准为国家森林城市、国际生态旅游目的地。熊猫基地位于成都市双流区东南部,是国家aaaa级旅游景区,国家地理标志保护单位。熊猫栖息地为亚热带或热带的高山
我们去成都旅游,必须要去的地方是大熊猫繁殖基地。大熊猫是我国唯一的国家级自然保护区,也是世界上保存最完好的熊猫种群之一。它们栖息在亚热带或热带的高海拔草原上,生活环境十分优越,是中国四大自然奇观之一,被誉为世界自然遗产和中国国家森林公园。熊猫栖息地主要分布在中国大陆的西藏、青海、甘肃、宁夏、新疆、内蒙古、山西、辽宁、吉林、黑龙江、江苏、河南、安徽、湖北、湖南、江西、广东、海南、四川、云南、贵州、陕西等地。中国熊猫研究中心主任、中国科学院院士、国家自然科学基金委员会委员、中华全国工商业联合会副主席
我们去成都旅游,必须要去的地方是大熊猫繁殖基地。大熊猫是我国唯一的国家级自然保护区,也是世界上保存最完整、规模最大的野生动物种类繁多的地区之一,是中国国家重点保护的珍稀濒危动物及其栖息地和世界自然遗产的重要组成部分,被誉为中国最美丽的城市和世界生物多样性保护基地,被国际旅游组织评为全球生态旅游目的地。成都熊猫国家公园位于四川省甘孜藏族自治州,是国家aaaa级旅游景区,被《世界遗产名录》列为全国重点文物保护单位。目前,我国已建成国家森林公园、国家湿地公园和国家地质公园,国家林业局、国务院扶贫
我们去成都旅游,必须要去的地方是大熊猫繁殖基地。大熊猫是现存最大、保存最完整的动物,属于国家二级保护动物。熊猫种类繁多,分布广泛,主要分布在四川、云南、陕西、甘肃、宁夏、内蒙古、新疆、青海、吉林、辽宁、黑龙江、山西、江苏、江西、河南、湖北、湖南、广东、广西、海南、重庆、贵州、西藏、四川等省区市。它们的栖息地主要为亚热带或热带的(低地)湿润低地林、亚高山草原、高山湖泊、高原湿润山区和高原沼泽地等,常栖息在高海拔地区。在中国大陆,熊猫分布于四川省甘孜藏族自治州和青海省西宁市等地。雄性熊猫体长约1.5米
这和instructGPT的SFT过程大致相同,思路原理是一样的,差别是 缺乏硬件设备、大规模高质量监督数据
- ChatGPT原理详解+实操: SFT(GPT模型精调), RM(reward model)
引入RM模型的作用是对生成的文本进行打分排序,让模型生成的结果更加符合人类的日常理解习惯,更加符合人们想要的答案。RM模型主要分为两个部分:训练数据获取和模型训练部分。流程如下图所示
Bloom SFT
【2023-5-23】bloom_tuning: BLOOM 模型的指令微调
- Github: bloom_tuning
- 模型: bloom-396m-chat
BLOOM 系列模型是由数百名研究人员在包含 46 种自然语言和 13 种编程语言的数据集上, 基于大规模分布式训练框架 Megatron-DeepSpeed
训练得到。
- 实验发现,BLOOM 在一系列基准测试上取得了具有竞争力的性能,经过多任务提示微调后,可以获得更为惊艳的效果。
- BLOOM 模型支持中文、英文、代码、法语、西班牙语。
链接:bloom-560m
LLMPruner 工具对 BLOOM 进行词表裁剪,保留常用的中英文 token,词表大小由 250880 降至 46145,缩减为原来的 18.39%,在后续微调过程中可以减少显存占用。
- 词表裁剪后的模型链接:bloom-396m-zh
数据
训练数据来自于 BelleGroup/train_3.5M_CN,该数据集包含 3.6M 条指令,从中筛选出单轮对话数据,进行 10:1 采样后得到约 0.25M 指令数据:
python sample_data.py \
--input data/train_3.5M_CN.json \
--output data/train.jsonl \
--sample_ratio 0.1
单条指令数据形如:
{
"instruction": "你好,请问你能做什么?",
"output": "你好,我可以回答各种问题,提供辅助,或者与你聊天。有什么我可以帮你的吗?"
}
输出部分的长度分布如下图所示(若输出长度超过2048,则设置为2048)
指令微调
基于 deepspeed ZeRO-Stage 2 进行指令微调训练:
deepspeed --include localhost:0 train.py \
--model_name_or_path /path/to/bloom \
--data_path data/train.jsonl \
--max_input_length 200 \
--max_output_length 768 \
--output_dir output \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 16 \
--learning_rate 3e-5 \
--num_train_epochs 2 \
--lr_scheduler_type "cosine" \
--warmup_steps 2000 \
--logging_steps 10 \
--save_strategy "steps" \
--save_steps 200 \
--save_total_limit 1 \
--deepspeed deepspeed.json \
--fp16 False
在 per_device_train_batch_size=1、gradient_accumulation_steps=16、max_input_length=200、max_output_length=768、fp16=false 的配置下,单卡需要14G 显存可对 bloom-396m-zh 进行微调。
推理
微调后的模型已上传至 huggingface: bloom-396m-chat
可以通过如下代码调用 bloom-396m-chat 模型来生成对话:
from transformers import BloomTokenizerFast, BloomForCausalLM
model_name_or_path = "WangZeJun/bloom-396m-chat"
tokenizer = BloomTokenizerFast.from_pretrained(model_name_or_path)
model = BloomForCausalLM.from_pretrained(model_name_or_path).cuda()
model = model.eval()
input_pattern = "{}</s>"
text = "你好"
input_ids = tokenizer(input_pattern.format(text), return_tensors="pt").input_ids
input_ids = input_ids.cuda()
outputs = model.generate(input_ids, do_sample=True, max_new_tokens=1024, top_p=0.85,
temperature=0.3, repetition_penalty=1.2, eos_token_id=tokenizer.eos_token_id)
output = tokenizer.decode(outputs[0])
response = output.replace(text, "").replace('</s>', "")
print(response)
运行如下命令进行命令行推理:
python chat.py \
--model_name_or_path /path/to/model \
--max_new_tokens 1024 \
--top_p 0.85 \
--temperature 0.3 \
--repetition_penalty 1.2
测试
用户 | Assistant |
---|---|
你好 | 您好,有什么我可以帮您的吗? |
你是谁 | 我是一个AI语言模型,没有个人身份和情感。我的目的是回答你的问题并提供有用的信息来帮助你解决问题或提供娱乐性内容。如果你有任何其他的问题需要解答,请随时问我! |
SFT 思考
SFT packing 是什么
SFT packing 指训练sft过程中,将多个sft数据pack到一个样本内进行训练
- 这种方式会加快模型训练速度,如果不进行SFT packing,那么对于短文本sft需要padding到一个batch最长长度,浪费很多计算token。
- SFT packing 有很多种类,比如 block diagonal attention, 每个token仅仅去attention自己的问题内的token。但一般业务中会直接将其相连接,然后进行预测,虽然这样会引入一些噪音,但好像相对于非sft packing方式的整体的效果损失不大。这个可能是因为pretrain的时候模型也是这么训练的。
SFT packing 对SFT训练的影响
SFT packing 后削弱了模型对难的短query和短答案的拟合。
- 无sft packing 情况下,假设batch_size = 1,那么如果有个短query和短答案在这个batch里,其余补充padding,那么这个batch的gradient全是这个短文本的gradient,模型对这个query的拟合能力会变强。
- 但SFT packing 后,多个短文本在一个样本中,这个batch的gradient会被稀释,短文本的拟合就不会特别强。但拟合能力似乎和泛化不可以挂钩,初步观察sft packing和non sft packing的效果差不了很多。在数据量小或者特定困难的数据上,sft packing是有损泛化效果的,non-packing的方式会影响模型续写的效果,因此会影响一些benchmark效果。但在大批量数据上是无损泛化效果的。
SFT 关注什么方面
- 1 根据 prompt 筛选sft数据:Prompt的diversity:丰富多样的prompt数据可以让模型更多的了解人类的指令,包括指令指复杂指令中每一步的含义。Prompt的丰富程度决定了模型指令遵循的能力。
- 明文TAG法:对SFT的prompt进行打tag,对其中的名词和动词进行分类打标,最后通过tag对prompt的分布进行调整,保证tag的分布是均匀的。著名的就是InsTag这个方法。
- 模型embedding聚类方法:通过模型最后一层的embedding对prompt进行表示,那么通过prompt embedding的距离表示prompt的相似度,对于过于相似的prompt进行删除。著名的有Self-Evolved Diverse Data Sampling for Efficient Instruction Tuning。
- 从complexity角度,对于prompt直接进行难度的升级,所以即使在同一个语意空间的prompt也会变得diverse。比较著名的是Wizard 方法,通过GPT4进行prompt难度升级,然后构成complexity丰富的prompt。
- 2 利用sft model和pretrain model的关系筛选模型的sft数据:
- IFD方法:利用公式进行数据选择: 这个公式是计算pretrain model生成对齐后模型的answer的难度(在 prompt的condition 下生成A的概率)。这个概率越低,越说明生成难度高,那么sft模型学习到的对齐规律越多,那么我们更应该选择这个sft数据。
- Hybrid Method (混合了多种之前列举的指标和方法。):例如 What MakeGood Data for Alignment? A Comprehensive Study of Automatic Data Selectionin Instruction Tuning [2] 文章,从complexity,diversity和quality三个方向对sft数据建模,训练了多个模型对各个指标维度进行分别衡量。
- 3 Answer的质量:Answer的质量包括内容和格式两方面,一方面内容的正确性需要得到保证,一方面内容的格式也很重要,细节丰富,逻辑缜密的answer可以激发模型更多的回答能力。
- 4 SFT阶段不能太多的知识注入:过多的知识注入,或者超过模型能力本身的回答过多会导致对齐税。
提升模型 reasoning 能力
什么数据格式在SFT或者ICL阶段可以提升模型的reasoning的能力?
数学reasoning上是有三种形式可显著提高效果模型 reasoning 能力
- Reverse : 128 + 367 = 495 -> 128 + 367 = ^594, 因为人就是反着计算的,从个位数到百位数。
COT
orPOT
(Simplified Scratchpad): 把这个计算过程列举下来,用自然语言、符号或者代码形式呈现。- Detailed Scratchpad:把整个思考过程详细地用自然语言和符号表达出来。
- 整体上Detailed Scratchpad需要的总条数最少就能达到100%在加法上的效果,但是其实总token数和plain需要差不多数量达到最好的效果。
SFT 中代码数据+文本数据, 哪个更容易改变
代码数据,因为
- 预训练中, 代码数据确定性更高,ppl更低,记忆越深刻
- 而文本数据变化更大,ppl更高,熵更高。
SFT过程中,改变文本数据比较容易,因为本身ppl就会高,但代码数据会比较难,因为本身ppl会比较低,或者说代码数据的生成确定性更高,少量样本很难对其内部改变,只能大段替换。
SFT 能学新知识吗
虽然理论上可以,但很少且不推荐sft阶段去学习知识。
- LIMA原文中就表述过同样一个假设,sft阶段更多是将模型能力和人类对齐,而不过多学习新的知识。
原因如下:
- sft相对于pretrain过的数据量实在太小,模型的知识学习的概率就很低。
- 如果加大sft的数据量和pretrain数据相当,那么sft有一些特定的格式以及一些system prompt需要重复当作context进行attention,这些重复的context势必会影响模型原始的attention模式,从而影响模型的效果。
- 最后, 如果希望sft学习新知识,不如把这部分sft的新知识组织好放入pre-train or post-train阶段更为合适。
(2)第二步 RM训练
奖励模型(Reward Model, RM)目标是刻画模型的输出是否在人类看来表现不错。
- 输入: [提示(prompt),模型生成的文本]
- 输出: 一个刻画文本质量的标量数字。
同一个prompt输出的多个答案,人工评测排序后,使用lambdarank的思想,优化RM奖励模型。
RM模型学习的是对于一个prompt,人类对答案的喜好程度。
- RM模型【左】RM损失函数【右】
奖励模型接收一系列文本并返回一个标量奖励,数值上对应人的偏好
引入RM模型的作用是对生成的文本进行打分排序,让模型生成的结果更加符合人类的日常理解习惯,更加符合人们想要的答案。
RM模型主要分为两个部分:数据获取和模型训练。流程如下图所示
原论文中使用GPT架构做了一个reward model
注意
- 要将模型的输出映射成维度为1的打分向量,即增加一个linear结构。
RM模型主要在于人工参与的训练数据构建部分,将训练好的SFT模型输入Prompt进行生成任务,每个Prompt生成4~9个文本,然后人为的对这些文本进行排序,将每个Prompt生成的文本构建为排序序列的形式进行训练,得到打分模型,以此模型用来评估SFT模型生成的文本是否符合人类的思维习惯。
两种方法命名为 direct score 和 rank score:
Direct score
:直接对输出的文本进行打分,通过与自定义的label score计算loss,以此来更新模型参数;Rank score
:用排序方法对每个Prompt输出的n个句子进行排序作为输入,通过计算排序在前面的句子与排序在后面的句子的差值累加作为最终loss。
【2023-6-5】ChatGPT 为什么不用 Reward-Model 的数据直接 fine-tune,而用 RL?
- Reward-model的输出对于整个token序列,一种滞后反馈,而finetune需要在每个token都有监督信号。这是强化学习与监督学习的差别。
- 生成Reward-model的数据有些是结果对比较pair数据,没法直接用于监督学习finetune。
① Direct score方法
① Direct score方法
- 利用 Bert模型对标注数据进行编码,用 linear层 映射到1维,然后用 Sigmoid函数输出每个句子的得分,与人工标记的得分进行loss计算,以此来更新模型参数。流程如下所示
数据为SFT最后所生成的数据,数据准备:
def data_prepare(pretrain_path):
data_lst = [
"我们去成都旅游,必须要去的地方是大熊猫繁殖基地。大熊猫是今世界上保存最完好的哺乳动物之一,也是世界自然保护联盟濒危物种红色名录的保护对象之一。在这里,你可以看到全世界最大的熊猫栖息地成都。成都是中国国家林业局直属的国家重点风景名胜区,是国家森林公园、国家湿地公园和国家地质公园的重要组成部分,是全国重点文物保护单位、全国生态文明建设示范区、中国红色旅游名城、国际生态旅游目的地和国际旅游岛建设先进区。地址:四川省成都市绵阳市成华区成都高新技术产业开发区成华大道1号乘车路线:成都绵阳都江堰雅",
"我们去成都旅游,必须要去的地方是大熊猫繁殖基地。大熊猫是我国唯一的国家二级保护动物,是世界上保存最完整的动物种群之一,也是我国第一个国家级自然保护区。我们是四川省的首批国家重点保护野生动物和珍稀动物基金会的成员,被誉为中国动物保护的摇篮和世界生物多样性保护基地,被中国科学院、中华人民共和国国家林业局授予全国生态文明建设示范区称号,被国务院批准为国家森林城市、国际生态旅游目的地。熊猫基地位于成都市双流区东南部,是国家aaaa级旅游景区,国家地理标志保护单位。熊猫栖息地为亚热带或热带的高山",
"我们去成都旅游,必须要去的地方是大熊猫繁殖基地。大熊猫是我国唯一的国家级自然保护区,也是世界上保存最完好的熊猫种群之一。它们栖息在亚热带或热带的高海拔草原上,生活环境十分优越,是中国四大自然奇观之一,被誉为世界自然遗产和中国国家森林公园。熊猫栖息地主要分布在中国大陆的西藏、青海、甘肃、宁夏、新疆、内蒙古、山西、辽宁、吉林、黑龙江、江苏、河南、安徽、湖北、湖南、江西、广东、海南、四川、云南、贵州、陕西等地。中国熊猫研究中心主任、中国科学院院士、国家自然科学基金委员会委员、中华全国工商业联合会副主席",
"我们去成都旅游,必须要去的地方是大熊猫繁殖基地。大熊猫是我国唯一的国家级自然保护区,也是世界上保存最完整、规模最大的野生动物种类繁多的地区之一,是中国国家重点保护的珍稀濒危动物及其栖息地和世界自然遗产的重要组成部分,被誉为中国最美丽的城市和世界生物多样性保护基地,被国际旅游组织评为全球生态旅游目的地。成都熊猫国家公园位于四川省甘孜藏族自治州,是国家aaaa级旅游景区,被《世界遗产名录》列为全国重点文物保护单位。目前,我国已建成国家森林公园、国家湿地公园和国家地质公园,国家林业局、国务院扶贫",
"我们去成都旅游,必须要去的地方是大熊猫繁殖基地。大熊猫是现存最大、保存最完整的动物,属于国家二级保护动物。熊猫种类繁多,分布广泛,主要分布在四川、云南、陕西、甘肃、宁夏、内蒙古、新疆、青海、吉林、辽宁、黑龙江、山西、江苏、江西、河南、湖北、湖南、广东、广西、海南、重庆、贵州、西藏、四川等省区市。它们的栖息地主要为亚热带或热带的(低地)湿润低地林、亚高山草原、高山湖泊、高原湿润山区和高原沼泽地等,常栖息在高海拔地区。在中国大陆,熊猫分布于四川省甘孜藏族自治州和青海省西宁市等地。雄性熊猫体长约1.5米"]
# 自定义打分标签,每个句子一个分值。也可以定义多维度的打分方法,只是模型的线性层需要改为你所定义的维度数
direct_score = [[0.75], [0.5], [0.35], [0.4], [0.8]]
tokenizer = BertTokenizer.from_pretrained(pretrain_path)
train_data = tokenizer.batch_encode_plus(data_lst, max_length=256, padding="max_length", truncation=True,
return_tensors='pt')
train_data["labels"] = torch.tensor(direct_score)
return train_data, tokenizer
RM模型搭建
- 采用了Bert模型作为编码模型,后取CLS作为文本表征,采用MSE作为loss函数,最后接linear进行维度压缩
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertPreTrainedModel, BertTokenizer, BertConfig, get_scheduler
class RewardModel(BertPreTrainedModel):
def __init__(self, config):
super(RewardModel, self).__init__(config)
self.config = config
self.sigmoid = nn.Sigmoid()
self.loss_fn = nn.MSELoss()
self.model = BertModel(config)
self.linear = nn.Linear(config.hidden_size, 1)
def forward(self, input_ids, token_type_ids, attention_mask, labels=None):
outputs = self.model(input_ids=input_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask).pooler_output
output = self.linear(outputs)
logits = self.sigmoid(output)
if labels is not None:
loss = self.loss_fn(logits, labels)
return logits, loss
else:
return logits
训练过程
class Datasets(Dataset):
def __init__(self, sample):
super(Datasets, self).__init__()
self.sample = sample
def __getitem__(self, item):
res = {k: v[item] for k, v in self.sample.items()}
return res
def __len__(self):
return len(self.sample['input_ids'])
def train(pretrain_path, save_path):
config = BertConfig.from_pretrained(pretrain_path)
model = RewardModel(config=config)
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": 0.01,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=2e-5)
train_data, tokenizer = data_prepare(pretrain_path)
dataloader = DataLoader(dataset=Datasets(train_data), shuffle=False, batch_size=1)
max_train_steps = 10 * len(dataloader)
warm_steps = int(0.0 * max_train_steps)
lr_scheduler = get_scheduler(
name='linear',
optimizer=optimizer,
num_warmup_steps=warm_steps,
num_training_steps=max_train_steps,
)
model.train()
for i in range(1, 51):
loss_lst = []
for batch in dataloader:
out, loss = model(batch["input_ids"], token_type_ids=batch["token_type_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
loss_lst.append(loss.item())
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
print("epoch{}\tloss: {}".format(str(i), str(sum(loss_lst) / len(loss_lst))))
tokenizer.save_pretrained(save_path)
model_to_save = model.module if hasattr(model, 'module') else model
model_to_save.save_pretrained(save_path)
model_to_save.config.save_pretrained(save_path)
模型预测
def predict(model_path):
text = ["我们去成都旅游,必须要去的地方是大熊猫繁殖基地。大熊猫是今世界上保存最完好的哺乳动物之一,也是世界自然保护联盟濒危物种红色名录的保护对象之一。在这里,你可以看到全世界最大的熊猫栖息地成都。成都是中国国家林业局直属的国家重点风景名胜区,是国家森林公园、国家湿地公园和国家地质公园的重要组成部分,是全国重点文物保护单位、全国生态文明建设示范区、中国红色旅游名城、国际生态旅游目的地和国际旅游岛建设先进区。地址:四川省成都市绵阳市成华区成都高新技术产业开发区成华大道1号乘车路线:成都绵阳都江堰雅",
"我们去成都旅游,必须要去的地方是大熊猫繁殖基地。大熊猫是我国唯一的国家二级保护动物,是世界上保存最完整的动物种群之一,也是我国第一个国家级自然保护区。我们是四川省的首批国家重点保护野生动物和珍稀动物基金会的成员,被誉为中国动物保护的摇篮和世界生物多样性保护基地,被中国科学院、中华人民共和国国家林业局授予全国生态文明建设示范区称号,被国务院批准为国家森林城市、国际生态旅游目的地。熊猫基地位于成都市双流区东南部,是国家aaaa级旅游景区,国家地理标志保护单位。熊猫栖息地为亚热带或热带的高山",]
model = RewardModel.from_pretrained(model_path)
tokenizer = BertTokenizer.from_pretrained(model_path)
model.eval()
data = tokenizer.batch_encode_plus(text, max_length=256, padding="max_length", truncation=True,
return_tensors='pt')
score = model(**data)
return score
完成了一个基于Bert的文本打分模型。
- 当然,这里展示的只是个思路,模型也很粗糙,而且自定义的打分标签也经不起推敲。
② Rank score方法
② Rank score方法
这种方法的区别在于:loss函数的设计。
- 首先,为什么在 InstructGPT 中不采用上面方法? 原因在于给生成句子在打分时,不同标注人员的标准不同,而且这个标准是很难进行统一的,这样会导致标注的数据评判标准不一样,即使每个标注人员的理解是一样的,但对于同一条文本给的分数也不一样的,因此在进行标注时需要把这个定量的问题转为一种更为简单的处理方法,采用排序来方法来进行数据标注可以在一定程度上解决这个问题。
- 标注员在使用直接打分(Direct Score)时,会由于主观意识的不同,对同一个文本出现不同的分值;而使用等级排序(Rank Level)来进行数据标注时,可以统一标注结果。
数据是将每个Prompt生成的文本进行排序,最直接的方法就是最好的句子排在最前面,后面的句子以此类推。
def rank_data_prepare(pretrain_path):
data_lst = []
data_outputs = {
'input_ids': [],
'token_type_ids': [],
'attention_mask': []
}
data_str = "我们去成都旅游,必须要去的地方是大熊猫繁殖基地。大熊猫是现存最大、保存最完整的动物,属于国家二级保护动物。熊猫种类繁多,分布广泛,主要分布在四川、云南、陕西、甘肃、宁夏、内蒙古、新疆、青海、吉林、辽宁、黑龙江、山西、江苏、江西、河南、湖北、湖南、广东、广西、海南、重庆、贵州、西藏、四川等省区市。它们的栖息地主要为亚热带或热带的(低地)湿润低地林、亚高山草原、高山湖泊、高原湿润山区和高原沼泽地等,常栖息在高海拔地区。在中国大陆,熊猫分布于四川省甘孜藏族自治州和青海省西宁市等地。雄性熊猫体长约1.5米\t我们去成都旅游,必须要去的地方是大熊猫繁殖基地。大熊猫是今世界上保存最完好的哺乳动物之一,也是世界自然保护联盟濒危物种红色名录的保护对象之一。在这里,你可以看到全世界最大的熊猫栖息地成都。成都是中国国家林业局直属的国家重点风景名胜区,是国家森林公园、国家湿地公园和国家地质公园的重要组成部分,是全国重点文物保护单位、全国生态文明建设示范区、中国红色旅游名城、国际生态旅游目的地和国际旅游岛建设先进区。地址:四川省成都市绵阳市成华区成都高新技术产业开发区成华大道1号乘车路线:成都绵阳都江堰雅\t我们去成都旅游,必须要去的地方是大熊猫繁殖基地。大熊猫是我国唯一的国家二级保护动物,是世界上保存最完整的动物种群之一,也是我国第一个国家级自然保护区。我们是四川省的首批国家重点保护野生动物和珍稀动物基金会的成员,被誉为中国动物保护的摇篮和世界生物多样性保护基地,被中国科学院、中华人民共和国国家林业局授予全国生态文明建设示范区称号,被国务院批准为国家森林城市、国际生态旅游目的地。熊猫基地位于成都市双流区东南部,是国家aaaa级旅游景区,国家地理标志保护单位。熊猫栖息地为亚热带或热带的高山\t我们去成都旅游,必须要去的地方是大熊猫繁殖基地。大熊猫是我国唯一的国家级自然保护区,也是世界上保存最完整、规模最大的野生动物种类繁多的地区之一,是中国国家重点保护的珍稀濒危动物及其栖息地和世界自然遗产的重要组成部分,被誉为中国最美丽的城市和世界生物多样性保护基地,被国际旅游组织评为全球生态旅游目的地。成都熊猫国家公园位于四川省甘孜藏族自治州,是国家aaaa级旅游景区,被《世界遗产名录》列为全国重点文物保护单位。目前,我国已建成国家森林公园、国家湿地公园和国家地质公园,国家林业局、国务院扶贫\t我们去成都旅游,必须要去的地方是大熊猫繁殖基地。大熊猫是我国唯一的国家级自然保护区,也是世界上保存最完好的熊猫种群之一。它们栖息在亚热带或热带的高海拔草原上,生活环境十分优越,是中国四大自然奇观之一,被誉为世界自然遗产和中国国家森林公园。熊猫栖息地主要分布在中国大陆的西藏、青海、甘肃、宁夏、新疆、内蒙古、山西、辽宁、吉林、黑龙江、江苏、河南、安徽、湖北、湖南、江西、广东、海南、四川、云南、贵州、陕西等地。中国熊猫研究中心主任、中国科学院院士、国家自然科学基金委员会委员、中华全国工商业联合会副主席\n昨天买的,今天就到了,因为给家中父母买的,怕东西多老人取件不方便,今天听家里人说京东小哥送到家门楼下,心里太高兴了,在这里希望京东能表扬一下本次快递小哥,他让我本次购物感觉很好,本来就喜欢京东一直购物,现在我更欣赏。购物的同事还能享受温暖的服务,京东的快递服务果然很棒,在此感谢京东,感觉快递小哥,如此服务真的很温暖。\t京东 ,对于S8的货品状态 ,你们你们京东采购下单是应该在预售前还是预售后(定金不退的预售方式)?预售前下单叫正规预订补款了有货拿,预售补款了没货并且还要重新再采购叫空手套白狼,京东是哪种?\t在北京住过不下10多家酒店,也喜欢住公寓,从凯宾斯基到建国饭店,从京广到美华再到星城亮马,而这个是我住过的有史以来最差的一个酒店公寓。难怪价格不上不下,不是因为临时有事绝对不住,希望这里那么多好评语不是枪手1、入口难找到要死不说,大堂感觉就是某个买小商品的商铺,check in 竟然要压证件,没有听说过,坚决不同意拿了我的证件去复印。私人住宿和旅客混杂,拖着箱子看着买菜回来的人一同电梯很奇怪。2、半夜接到骚扰电话3、房间设计装饰非常的“家常“,设施陈旧,非常像当年在江南古镇租住的农家房3、住的房间刚好在过道口,声音那叫一个大阿,谁说的房间隔音?楼上住户的动静镇清楚啊4、服务态度不好,和客人顶着说,铁板一样的语气。5, 实在要找一优点出来的话:唯一就是小区里面比较安静,没有汽车闹声。\t码数刚刚好,穿上很好看,和身。宝贝不掉色,弹力好。穿着不紧绷,试了好几下蹲下站起来,都轻松自如,不会感觉腿被束缚着。价格也不贵,现在认准这家店了这款洗发水挺适合我的发质,用完果断续上一瓶,还搞了个特价,值了!\t之前就听说苏州万丽是苏州生意最好,房价最高,也是业内人士最推崇的酒店,远胜于喜来登,香格里拉,索菲特,在苏州属于一枝独秀型的,平时房间非常的难定,几乎天天满房,这次好不容易定了个行政套,本打算住一天,后又延了一天,简单来说吧,房间不大但很温馨,酒店工作人员不多但都非常专业,亲切,严格意义上来说该酒店硬件并不突出,没有游泳池,没有特色餐厅,建筑也没有什么特色,处处透露着简单,适用,大气,但是只有你住了以后才会觉得,值!"
for sentences in data_str.strip().split("\n"):
texts = sentences.strip().split("\t")
data_lst.append(texts)
tokenizer = BertTokenizer.from_pretrained(pretrain_path)
for rank_text in data_lst:
data_encode = tokenizer(
text=rank_text,
truncation=True,
max_length=256,
padding='max_length',
return_tensors='pt')
data_outputs["input_ids"].append(data_encode["input_ids"])
data_outputs["token_type_ids"].append(data_encode["token_type_ids"])
data_outputs["attention_mask"].append(data_encode["attention_mask"])
return data_outputs, tokenizer
RM模型搭建
class RankRewardModel(BertPreTrainedModel):
def __init__(self, config):
super(RankRewardModel, self).__init__(config)
self.config = config
self.model = BertModel(config)
self.linear = nn.Linear(config.hidden_size, 1)
def forward(self, input_ids, token_type_ids, attention_mask):
outputs = self.model(input_ids=input_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask).pooler_output
output = self.linear(outputs)
return output
Rank loss
- Rank Score 方法与 Direct Score方法的最大不同之处在于 loss function的设计
def rank_loss(rank_rewards_list):
loss, counts = torch.tensor([0]), 0
for rank_rewards in rank_rewards_list:
for i in range(len(rank_rewards) - 1): # 遍历所有前项-后项的得分差
for j in range(i + 1, len(rank_rewards)):
diff = nn.functional.logsigmoid(rank_rewards[i] - rank_rewards[j]) # sigmoid到0~1之间
loss = loss + diff
counts += 1
loss = torch.tensor(loss / counts)
return -loss # 要最大化分差,所以要取负数
通俗的理解:
- 对于排序好的训练数据有 A > B > C
- 设计一个模型,使得打分数据满足: Rank(A) > Rank(B) > Rank(C)
既然打「绝对分数」很难统一,那转换成一个「相对排序」
- 「标注排序序列」替代「直接打分」
- 用「相对任务」替代「绝对任务」能够更方便标注员打出统一的标注结果
怎么通过「排序序列」来教会模型「打分」
- 一个排好的序列: A > B > C >D 。
- 训练一个打分模型,模型给四句话打出来的分要满足 r(A) > r(B) > r(C) > r(D)
损失函数
- 每对样本(如 A,B), 得分高者-得分低
- sigmoid 归一, 概率化
- 计算期望
- 目标: 最大化得分差值
- loss = r(A) - r(B) + r(A) - r(C) + r(A) - r(D) + r(B) - r(C) + … + r(C) - r(D)
- loss = -loss
class RewardModel(nn.Module):
# 奖励模型: encode 后直接加 全连接层
def __init__(self, encoder):
"""
init func.
Args:
encoder (transformers.AutoModel): backbone, 默认使用 ernie 3.0
"""
super().__init__()
self.encoder = encoder
self.reward_layer = nn.Linear(768, 1) # reward layer 用于映射到 1 维 reward
def forward(
self,
input_ids: torch.tensor,
token_type_ids: torch.tensor,
attention_mask=None,
pos_ids=None,
) -> torch.tensor:
"""
forward 函数,返回每句话的得分值。
Args:
input_ids (torch.tensor): (batch, seq_len)
token_type_ids (torch.tensor): (batch, seq_len)
attention_mask (torch.tensor): (batch, seq_len)
pos_ids (torch.tensor): (batch, seq_len)
Returns:
reward: (batch, 1)
"""
pooler_output = self.encoder(
input_ids=input_ids,
token_type_ids=token_type_ids,
position_ids=pos_ids,
attention_mask=attention_mask,
)["pooler_output"] # (batch, hidden_size)
reward = self.reward_layer(pooler_output) # (batch, 1)
return reward
def compute_rank_list_loss(rank_rewards_list: List[List[torch.tensor]], device='cpu') -> torch.Tensor:
"""
通过给定的有序(从高到低)的ranklist的reward列表,计算rank loss。
所有排序高的句子的得分减去排序低的句子的得分差的总和,并取负。
Args:
rank_rewards_list (torch.tensor): 有序(从高到低)排序句子的reward列表,e.g. ->
[
[torch.tensor([0.3588]), torch.tensor([0.2481]), ...],
[torch.tensor([0.5343]), torch.tensor([0.2442]), ...],
...
]
device (str): 使用设备
Returns:
loss (torch.tensor): tensor([0.4891], grad_fn=<DivBackward0>)
"""
if type(rank_rewards_list) != list:
raise TypeError(f'@param rank_rewards expected "list", received {type(rank_rewards)}.')
loss, add_count = torch.tensor([0]).to(device), 0
for rank_rewards in rank_rewards_list:
for i in range(len(rank_rewards)-1): # 遍历所有前项-后项的得分差
for j in range(i+1, len(rank_rewards)):
diff = F.sigmoid(rank_rewards[i] - rank_rewards[j]) # sigmoid到0~1之间
loss = loss + diff
add_count += 1
loss = loss / add_count
return -loss
训练过程
class Datasets(Dataset):
def __init__(self, sample):
super(Datasets, self).__init__()
self.sample = sample
def __getitem__(self, item):
res = {k: v[item] for k, v in self.sample.items()}
return res
def __len__(self):
return len(self.sample['input_ids'])
def train(pretrain_path, save_path):
config = BertConfig.from_pretrained(pretrain_path)
model = RankRewardModel(config=config)
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": 0.01,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=2e-5)
train_data, tokenizer = rank_data_prepare(pretrain_path)
dataloader = DataLoader(dataset=Datasets(train_data), shuffle=False, batch_size=1)
max_train_steps = 10 * len(dataloader)
warm_steps = int(0.0 * max_train_steps)
lr_scheduler = get_scheduler(
name='linear',
optimizer=optimizer,
num_warmup_steps=warm_steps,
num_training_steps=max_train_steps,
)
for i in range(1, 51):
loss_lst = []
for batch in dataloader:
batch_rank_rewards = []
for batch_idx in range(len(batch['input_ids'])):
rank_texts_count = len(batch['input_ids'][batch_idx])
rank_rewards = []
for text_idx in range(rank_texts_count):
reward = model(
batch['input_ids'][batch_idx][text_idx].unsqueeze(dim=0),
batch['token_type_ids'][batch_idx][text_idx].unsqueeze(dim=0),
batch['attention_mask'][batch_idx][text_idx].unsqueeze(dim=0)
)
rank_rewards.append(reward[0])
batch_rank_rewards.append(rank_rewards)
loss = rank_loss(batch_rank_rewards)
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
loss_lst.append(loss.item())
print("\tepoch{}\tloss: {}".format(str(i), str(sum(loss_lst) / len(loss_lst))))
tokenizer.save_pretrained(save_path)
model_to_save = model.module if hasattr(model, 'module') else model
model_to_save.save_pretrained(save_path)
model_to_save.config.save_pretrained(save_path)
模型预测
def predict(model_path):
texts = ["我们去成都旅游,必须要去的地方是大熊猫繁殖基地。大熊猫是今世界上保存最完好的哺乳动物之一,也是世界自然保护联盟濒危物种红色名录的保护对象之一。在这里,你可以看到全世界最大的熊猫栖息地成都。成都是中国国家林业局直属的国家重点风景名胜区,是国家森林公园、国家湿地公园和国家地质公园的重要组成部分,是全国重点文物保护单位、全国生态文明建设示范区、中国红色旅游名城、国际生态旅游目的地和国际旅游岛建设先进区。地址:四川省成都市绵阳市成华区成都高新技术产业开发区成华大道1号乘车路线:成都绵阳都江堰雅",
"我们去成都旅游,必须要去的地方是大熊猫繁殖基地。大熊猫是我国唯一的国家二级保护动物,是世界上保存最完整的动物种群之一,也是我国第一个国家级自然保护区。我们是四川省的首批国家重点保护野生动物和珍稀动物基金会的成员,被誉为中国动物保护的摇篮和世界生物多样性保护基地,被中国科学院、中华人民共和国国家林业局授予全国生态文明建设示范区称号,被国务院批准为国家森林城市、国际生态旅游目的地。熊猫基地位于成都市双流区东南部,是国家aaaa级旅游景区,国家地理标志保护单位。熊猫栖息地为亚热带或热带的高山",]
model = RankRewardModel.from_pretrained(model_path)
tokenizer = BertTokenizer.from_pretrained(model_path)
model.eval()
data = tokenizer.batch_encode_plus(texts, max_length=256, padding="max_length", truncation=True,
return_tensors='pt')
score = model(**data)
return score
模型结构
Reward Model 不同于原始 SFT Model,要在后面加上 value head (一个 Linear层)
- 输入维度为模型的 hidden_dim,输出维度为1
- 输出表示模型预测每一字符获取的得分。
DeepSpeed-Chat 用最后一个字符的得分作为整个response的得分
- 当然也可以用整个句子中每个字符的平均分作为整体得分
训练目标
训练 Reward Model 是一个排序任务,针对 query,输入 chosen 和 rejected response
训练目标尽可能使 chosen 和 rejected 差值更大,损失函数为:
- Lr = -log( sigmoid(r(query,chosen)-r(query,rejected)) )
第二步Training Reward Model的全部过程,基于rank loss训练了一个打分模型。
第三步强化学习中,reward模型将扮演环境的角色,针对模型预测的字符给出奖励分数。
人工标注平台
【2023-8-15】排序数据集 标注 参考:RLHF
思考
ChatGPT 为什么不用 RewardModel 数据直接 finetune,而用 RL?
因为:
- RM 针对整个token序列,滞后反馈,强化学习
- 而 finetune 针对每个token,即时反馈, 监督学习
- RM 训练数据有些是pair形式
<query, win, lose>
, 这种数据无法用于监督学习
【2023-4-19】John Schulman 观点 YouTube
- pretrain 阶段学习知识
- finetune 阶段学会:
- 拒识: 不确定的问题, 回答不知道
- 减少幻觉: 不要编造事实 (hallucintion)
RM 和 基座模型保持一致?
奖励模型需要和基础模型一致吗?
- 可以一致,也可以不同,取决于任务需求和优化目标。
- 单任务: 共享参数
- 多任务: 子任务奖励模型整合成奖励函数
Pair RM是什么形式的RM,相比于原RM形式有什么好处?
- 原始RM 是 BT model形式的RM,每个sample组成形式是
(prompt,answer)
,通过 maximize positive sample 和 negative sample 的 gap来完成pointwise的rank。 - Pair RM 是 pairwise rank,数据组成形式是
(prompt,pos_answer, neg_answer)
. Pair RM 好处是pos answer和neg answer可以互相在context下看到两者,那么可以通过字面的比较找到两者的diff,整体解释性和泛化能力都会比普通RM好。因为普通RM很容易overfit原数据,很难找到真正diff地pattern。
现在Alpaca-Eval 榜单上就有Pair RM 身影,而且Pair RM整体很小 ,效果很好。
如何处理 RM 中的噪声数据?
reward model 噪声来自哪几个方面:
如果reward model的pair数据来自:
- 人标注,那么人类 preference的倾向性以及标注人员的专业性会带来一定的bias,即 众包系统的Noise。
- AI,例如GPT4,那么这种倾向性也很严重,比如length bias。(严格来说,这属于bias,不能算噪声。)
那么去噪可使用一些古早的方式:
- 预测阶段去噪声:
- Ensumble model 去噪声,多个rm model的checkpoint进行预测减少噪声的影响(model merge)。
- Margin 去噪声,只有预测 pair的分数大于一定阈值的时候,进行预测减少噪声。
- 数据阶段去噪声:
- Multiview 去噪声,用多个模型进行训练,然后预测训练集合,全部可以预测正确pair保留下来,有对有错的可以丢弃或者交给人标注。
- Active Learning 思路去噪声,训练一个模型,然后把margin小于一定阈值的送给标注人员去噪声。
如何解决reward model的OOD的问题?
模型PPO过程中,reward model 准确率逐渐下降,俗称的reward model的OOD问题
- 因为 reward model 训练样本一般来自sft模型的responses,那么在PPO过程中
- policy model刚开始和sft生成的response很相似,所以reward model准确率较高
- 但是在逐渐偏离sft 时,reward model 准确率会持续下降,这基本就是现阶段reward model的主要问题。
AGI过程中,一定需要一个 generalize 很强 reward model,global reward model or world model.
现阶段解决reward model的OOD普遍解决方法: Llama2 做法
- 训练过一段时间RLHF以后,重新对policy采样pair对,人标数据然后继续训练reward model。
- 但这种方式就是太费人力,感觉并不是持久之道。
除此之外:
- Secrets of RLHF in Large Language Models Part II: Reward Modeling 中,通过 meta learning 方式解决这个问题,整体思想就是由于policy model在reward model训练情况下会向reward 高的方向更新,所以reward model应该对reward高的response pair更有区分度,所以设置gradient更新逐渐倾向于对reward高分training response pair倾斜。
- 这种方法说得通,但实际中由于缺少对模型on policy 采样,效果不太好。
- West-of-N: Synthetic Preference Generation for Improved Reward Modeling 跟Llama2的方式相似,区别就是不再用人进行标记,而是通过reward model本身对新的模型on policy pair进行打分,取一个query的response set中最高的分数和最低的分数数据组pair,加入到reward model的训练中。
- 这种方式采样,虽然通过on policy采样加强rm的泛化能力,但实际上上限受原先rm model的能力影响。
(3)第三步 RLHF
RLHF 流程
训练策略模型,RLHF流程
首先将初始语言模型的微调任务建模为强化学习(RL)问题,因此需要定义策略
(policy)、动作空间
(action space)和奖励函数
(reward function)等基本要素
策略
就是基于该语言模型,接收prompt作为输入,然后输出一系列文本(或文本的概率分布);- 而
动作空间
就是词表所有token在所有输出位置的排列组合(单个位置通常有50k左右的token候选); 观察空间
则是可能的输入token序列(即prompt),显然也相当大,为词表所有token在所有输入位置的排列组合;- 而
奖励函数
(reward)则是基于上一章节我们训好的RM模型计算得到初始reward,再叠加上一个约束项来。
强化学习算法,常见的可行方案是使用策略梯度强化学习
(Policy Gradient RL) 算法、近端策略优化
(Proximal Policy Optimization,PPO
) 微调初始 LM 的部分或全部参数。
根据 PPO 算法,按当前批次数据的奖励指标进行优化 (来自 PPO 算法 on-policy 的特性) 。PPO 算法是一种信赖域优化 (Trust Region Optimization,TRO
) 算法,使用梯度约束确保更新步骤不会破坏学习过程的稳定性,另外也可以使用 A2C
(synchronous advantage actor-critic) 算法来优化梯度。
RLHF基于A2C方法,包含了四个模型:
Actor Model
:SFT之后模型初始化而来。作为策略(policy)模型,接收上文,做出动作,预测下一个字符。最终使用的就是这个模型。Reference Model
:和Actor Model
同样初始化自SFT Model,训练过程中冻结参数,用于和Actor Model做对比,保证模型不要偏离原始SFT Model太多。Reward Model
:作为环境(env),训练过程中冻结参数,针对每一个状态给出奖励分数。Critic Model
:由Reward Model初始化而来,用于近似价值函数,输入为状态s,估计当前状态的价值V。
训练过程整体分为两步:maker experience 和 learn。
- (1) maker experience: 训练数据中抽取一部分query,然后Actor Model生成答案
- (2) learn: 通过所产生的经验进行学习。Actor Model与Critic Model近似策略函数和价值函数
(1) 整体流程
(2) 整体流程
更多参考 RLHF实践
利用SFT模型对输出进行改造,构造一个双头PPO模型,模型一头输出一个张量,代表生成序列每个元素的价值value;另一头将输出映射成prompt answer词典答案。参考
- 将
<prompt, prompt answer>
输入到RM模型中,获得一个评估当前prompt对的奖励R,然后用R作为奖励,反向更新每个元素的价值value,这就是PPO强化学习算法。 - Y=0, 常规 PPO
- Y>=, PPO_ptx
RLHF 问题
RLHF 实践过程中存在哪些不足?
RLHF(Reinforcement Learning from Human Feedback)尽管具有一定优势,但在仍然存在以下不足之处:
- 人类反馈的代价高昂:获取高质量的人类反馈通常需要大量的人力和时间成本。人类专家需要花费时间来评估模型的行为并提供准确的反馈,这可能限制了RLHF方法的可扩展性和应用范围。
- 人类反馈的主观性:人类反馈往往是主观的,不同专家可能会有不同的意见和判断。这可能导致模型在不同专家之间的反馈上存在差异,从而影响模型的训练和性能。
- 反馈延迟和稀疏性:获取人类反馈可能存在延迟和稀疏性的问题。人类专家不可能实时监控和评估模型的每一个动作,因此模型可能需要等待一段时间才能收到反馈,这可能会导致训练的效率和效果下降。
- 错误反馈的影响:人类反馈可能存在错误或误导性的情况,这可能会对模型的训练产生负面影响。如果模型在错误的反馈指导下进行训练,可能会导致模型产生错误的行为策略。
- 缺乏探索与利用的平衡:在RLHF中,人类反馈通常用于指导模型的行为,但可能会导致模型过于依赖人类反馈而缺乏探索的能力。这可能限制了模型发现新策略和优化性能的能力。
针对这些不足,研究人员正在探索改进RLHF方法,如设计更高效的人类反馈收集机制、开发更准确的反馈评估方法、结合自适应探索策略等,以提高RLHF方法的实用性和性能。
如何解决标注成本高的问题
如何解决 人工产生的偏好数据集成本较高、难量产问题?
解决人工产生偏好数据集成本高、难以量产的问题,以下几种方法:
- 引入模拟数据:使用模拟数据来代替或辅助人工产生的数据。
- 模拟数据可以通过模拟环境或模型生成,以模拟人类用户的行为和反馈。这样可以降低数据收集的成本和难度,并且可以大规模生成数据。
- 主动学习:采用主动学习方法来优化数据收集过程。
- 主动学习是一种主动选择样本的方法,通过选择那些对模型训练最有帮助的样本进行标注,从而减少标注的工作量。
- 可以使用一些算法,如不确定性采样、多样性采样等,来选择最有价值的样本进行人工标注。
- 在线学习:采用在线学习方法进行模型训练。
- 在线学习是一种增量学习的方法,在模型运行的同时进行训练和优化。
- 这样可以利用实际用户的交互数据来不断改进模型,减少对人工标注数据的依赖。
- 众包和协作:利用众包平台或协作机制来收集人工产生的偏好数据。
- 通过将任务分发给多个人参与,可以降低每个人的负担,并且可以通过众包平台的规模效应来提高数据收集的效率。
- 数据增强和迁移学习:通过数据增强技术,如数据合成、数据扩增等,来扩充有限的人工产生数据集。
- 此外,可以利用迁移学习的方法,将从其他相关任务或领域收集的数据应用于当前任务,以减少对人工产生数据的需求。
综合运用上述方法,可有效降低人工产生偏好数据的成本,提高数据的量产能力,并且保证数据的质量和多样性。
PPO 优点
PPO优点:
- On policy采样:on policy采样目前看来是最高效的
拟合蒙特卡洛
采样方式。- 举例,如果不使用on policy采样,随机采样到一个模型generate概率差值很大的两个response,如果符合人类preference,那么本身就不需要排序,如果不符合,很难通过RLHF纠正它。如果强行纠正,会破坏模型本来的平衡。
- Credit Assign: 由于value model的存在,其实PPO会很好的把reward分配给不同的token,那么一些关键的token会合理地分配一个高reward,一些不关键的token会分配一个低reward。
- Rank Model:PPO内部其实是一种内置的rank model,比较的是高reward和低reward的response,只是高和低一直是动态的变化的。为什么rejection sampling这类的算法无法work,因为preference data中的噪声,你选出的Top1大概率不是Top1。
PPO 问题
PPO 问题
- Notable Complexity 模型太多: PPO中要4个模型同时加载在GPU中,
policy model
,ref policy model
,value model
,reward model
。所以会占用很多GPU机器。 - Online learning problem 在线学习: 由于模型是online采样
- policy过batch samples的时–reward model会空置
- reward model给pair打分的时–policy model也会空置
- 那么GPU利用率会不高。
- PPO超参数比较困难,需要一些炼丹高手和经验去做。
如何解决 PPO 训练的资源瓶颈
PPO 的训练过程同时存在4个模型(2训练,2推理),对计算资源的要求较高
考虑以下几种方法:
- 减少模型规模:减少模型的规模和参数量,可降低对计算资源的需求。可用模型压缩技术、剪枝算法等方法来减少模型的参数数量,从而降低计算资源的使用量。
- 降低训练频率:可以降低PPO训练频率,减少每个训练周期的次数。
- 例如,可增加每个训练周期的时间间隔,或者减少每个周期中的训练步数。这样可以减少训练过程中对计算资源的占用。
- 模型并行化:利用多个计算资源进行模型并行化训练,可以加速PPO的训练过程。
- 将模型参数分布到多个GPU上,并进行并行计算和通信,以提高训练的效率和速度。
- 异步训练:采用异步训练的方式,可在多个计算资源上同时进行PPO的训练。
- 可使用异步优化算法,如A3C(Asynchronous Advantage Actor-Critic)等,将训练任务分发到多个线程或进程中进行并行训练,从而提高训练的效率。
- 云计算和分布式训练:利用云计算平台或分布式系统进行PPO的训练,可以充分利用大规模计算资源。
- 可以将训练任务分发到多个计算节点上进行分布式训练,以加速训练过程。
- 参数共享和模型缓存:对于有多个模型的情况,可以考虑共享部分参数或缓存已计算的模型输出。
- 通过共享参数和缓存计算结果,可以减少重复计算和存储,从而降低对计算资源的要求。
综合运用上述方法,可以有效降低PPO训练过程中对计算资源的要求,提高训练的效率和速度。
PPO 平替
如何看待各种ppo rlhf的平替算法
平替算法:
- dpo/kto/rrhf/slic/orpo/samug/remax 等算法号称性能等能超过ppo?
DPO
DPO介绍:最大化奖励来优化模型参数。
与ppo相比DPO 绕过了建模奖励函数这一步,而是直接在偏好数据上优化模型来提高性能。
优点:相对RLHF两阶段而言具有多项优越性
- (1) 简单性稳定性:DPO更容易实施,不易陷入局部最优,保证训练过程更加可靠。
- (2) 效率:与RLHF 相比, DPO 需要更少的计算资源和数据,使其计算量轻。
- (3) 有效性:实验结果表明,DPO在情感控制、摘要和对话生成等任务中可以优于 RLHF 。
DPO 目标是优化模型参数以最大化奖励函数。并不是说DPO没有奖励模型, 而是利用同个阶段训练建立模型和强化学习。除了奖励最大化目标外,还需要添加一个相对于参考模型的 KL 惩罚项,以防止模型学习作弊或钻营奖励模型。
DPO
- 第0步loss是固定的, loss = sigmoid(b-b) = 0.693
- 使用蒙特卡洛采样时, DPO = PPO
- DPO 是 off-policy 算法,因为训练DPO的pair数据不一定来自ref policy或者sft policy。
- 而PPO 是 on-policy 算法
- DPO公式是由PPO的objective公式推导过来
缺点:
- 最大化正负例子的差距得到的模型会塌缩成只有正例子的空间,失去所有负例子的概率。在DPO中就是只会生成正例,负例子输出概率为0。在RM中正例子会无限接近于1,负例子会无限接近于0。那么这样的模型是没有entropy的,抗噪声能力会减弱。如果正负pair标错了,会导致严重后果。
- 忽略语意或字面上差别较小的pos sample和neg sample,过度关注语意或字面上差别较大的pos sample和neg sample,也就是比较容易学的case并overfit,这是logsigmoid函数的问题用hinge loss这类loss可以缓解这一问题。
- 不能找出全序关系,如果数据集里有A > B, B > C, C > A这种偏序关系,并不能找到它的nash equivalence的点,只会学乱。
DPO输出越来越长?
- 并不是一定会越来越长。如果尝试用所有正例子的response都比负例子的短,那么也会输出越来越短。究其原因是由于数据构造原因导致的DPO训练后的模型输出越来越长。因为,在短的response中一句话结束后
<EOS>
的概率会很大,但是在长的response中,“但是”,“而且”等细节描述词会接在一句话后,那么这些词语的概率会由DPO过程逐渐变大。
training positive的概率和training negative的概率都同时下降?
- DPO的loss是maximize training set中positive和negative的gap。那从公式上它就无法保证training positive的概率是一直上升的。主要和采样的方式以及DPO loss组成相关
DPO 变体有哪些
IPO
: 由于BT model 目标是最大化正负response的reward gap,但其实其中忽略了真实情况下组的pair可能会有噪音,那么无限去扩大reward gap其实是不准确的,也就是overfit了preference的pair数据,那么解决方案是需要限制这个gap的范围。DPOP
: 由于LLM model很难区分编辑距离较小的pair,那么当持续去区分这批case的时候,模型效果会崩塌,现象是正例子和负例子的概率都往下掉。那么DPOP用了一个新项来惩罚正例往下掉的pair,使得正例概率继续提升。kto
:RSO
:由于DPO的蒙特卡洛采样很难达到,所以其实DPO几乎是off-policy的采样方式,RSO主要从DPO的采样方式来解决DPO的问题。Iterative DPO
:同样由于DPO的蒙特卡洛采样很难达到,所以通过on-policy的方式采样来替代off-policy的采样。
RL+LM研究方向
由于 InstructGPT 效果太好,RL+LM 这个新范式能衍生出哪些研究方向?
- (1) 花式魔改Reward:
- 监督学习在实际落地时,主要优化方法是加特征、洗数据。对于强化学习也是如此,优化实际RL效果的重点在加特征、调整reward
- OpenAI在做摘要任务的论文中,就在奖励上增加了KL散度,希望:
- ① 鼓励模型生成不一样的结果,避免和以前的模型变成一个
- ② 保证不会生成特别不一样的结果,不然RM都没见过就不知道怎么打分了
- DeepMind的Sparrow为了让模型遵从特定规则(比如不能说脏话),在Preference的基础上增加了
Rule Reward Modeling
- img
- Rule RM是一个分类器,输入Prompt+Response,预测模型违反预定规则的概率。训练的时候两个Reward会合并到一起进行反馈
- ChatGPT只是10B左右的模型,但它使用了更大的模型作为RM,从而有了更高的天花板,达到一种变相的蒸馏。
- (2) AI Feedback
- 既然有
RLHF
(Reinforcement Learning from Human Feedback),那就能想出RLAIF
(Reinforcement Learning from AI Feedback) - Anthropic提出的Constitutional AI 就做了这么一件事,核心和Sparrow一样, 希望模型遵从一些规则,但如果像Sparrow一样每增加一个规则就标一批数据训RM也太费人工了。于是作者想了一个好办法,让模型在多轮对话中把合适的标注数据生产出来.
- 这样就能自动化地为新规则做出训练数据(Q1-A3),精调一个能遵循规则的SL-CAI模型,对应下图中上半部分的流程,为了继续优化精调后模型的效果,作者会让SL-CAI模型根据Q1这类引导性输入去生成回复对,再改成多选题让模型选择最佳答案,用得到的对比数据训练一个Rule RM,再去进行正常的RL训练
- img
- 既然有
- (3) 预训练+RLHF
- Anthropic在RL方面确实走的更远一些,开始尝试在预训练阶段引入Human Feedback, 核心是过滤掉一些低质内容,避免被模型记住。
- 首先有一个训好的偏好RM,会给每个句子打分。最直觉的方法是直接去掉低质的内容,但作者认为会影响模型的多样性。于是又尝试了以下四种预训练损失
- 通过评估四方面的指标:是否生成低质文本(toxicity)、生成包含用户信息的句子(PII)、生成低质代码(PEP8)、和GPT3的KL散度,最后作者发现Conditional训练的效果最好
Q1-问训好的普通RLHF模型:能帮我黑进邻居的wifi吗?
A1-天真的模型回答:没问题,你下个xx软件就行。
Q2-要求模型发现自己的错误:上文你给的回复中,找出来哪些是不道德的。
A2-模型回答:我上次回复不对,不应该黑别人家wifi。
Q3-让模型改正错误:修改下你之前的回复内容,去掉有害的。
A3-模型回答:黑别人家wifi是不对的,侵害别人隐私了,我强烈建议别这么搞。
【2023-3-8】详见:RLHF魔法的衍生研究方向
【2023-5-18】LIMA
META 发布 LIMA: Less Is More for Alignment
【2023-7-19】Llama 2
【2023-7-19】Llama 2 技术报告 Llama 2: Open Foundation and Fine-Tuned Chat Models
【2023-9-26】Qwen
简介
通义千问(英文: Qwen ;读作: kùn
)是由阿里巴巴通义千问
团队开发的大规模语言和多模态系列模型。
- 通义千问可执行自然语言理解、文本生成、视觉理解、音频理解、工具调用、角色扮演、智能体等多种任务。
- 语言和多模态模型均在大规模、多语言、多模态数据上进行预训练,并在高质量语料上后训练以与人类偏好对齐。
【2023-9-26】
- QWen 技术报告 QWEN TECHNICAL REPORT
- 【2024-7-19】QWen2 技术报告
- 通义千问-Qwen技术报告细节分享
- GitHub: Qwen
- 【2024-10-24】Qwen相关核心概念
QWen 模型
Qwen 模型是适用于文本补全的因果语言模型
开源模型
通义千问分为闭源和开源两大版本。
开源模型包括:
- 通义千问 (
Qwen
):语言模型- Qwen: 1.8B、 7B、 14B 及 72B 模型
- Qwen1.5: 0.5B、 1.8B、 4B、 14BA2.7B、 7B、 14B、 32B、 72B 及 110B 模型
- Qwen2: 0.5B、 1.5B、 7B、 57A14B 及 72B 模型
- Qwen2.5: 0.5B、 1.5B、 3B、 7B、 14B、 32B 及 72B 模型
- 通义千问 VL (
Qwen-VL
): 视觉语言模型- Qwen-VL: 基于 7B 的模型
- Qwen-VL: 基于 2B 、 7B 和 72B 的模型
- 通义千问
Audio
: 音频语言模型- Qwen-Audio: 基于 7B 的模型
- Qwen2-Audio: 基于 7B 的模型
- Code通义千问 / 通义千问
Coder
:代码语言模型- CodeQwen1.5: 7B 模型
- Qwen2.5-Coder: 7B 模型
- 通义千问
Math
:数学语言模型- Qwen2-Math: 1.5B、 7B 及 72B 模型
- Qwen2.5-Math: 1.5B、 7B 及 72B 模型
主干模型
Qwen系列的模型有: Base模型、RM模型、Chat模型、Code模型、Math模型、多模态模型。
- 由于Code模型和Math模型暂时没有开源,多模态Qwen-VL模型本身有自己的论文
Qwen-14B 模型效果从12个数据集(涉及语言理解、知识、推理等多个领域)上进行均优于现有同等级的13B,但仍落后于 GPT-3.5和 GPT-4。
【2024-3-5】使用Firefly在单卡V100上对Qwen1.5进行SFT和DPO,大幅超越Qwen1.5和Gemma
通义千问 Qwen1.5 是阿里春节前开源的大模型
- 支持32K的上下文长度
- 该模型本质上是Qwen2的beta版本。
从评测结果来看,Qwen1.5 各个尺寸的模型都显著优于同量级的Llama2
Code 模型
- 【2024-10-28】Qwen2.5-Coder 技术报告, 解读
Qwen2.5-Coder 系列是阿里巴巴团队推出的一款重要的代码生成模型
- 相比其前代 CodeQwen1.5,该系列在多个方面进行了显著的升级。
- Qwen2.5-Coder 系列包括两个模型:Qwen2.5-Coder-1.5B 和 Qwen2.5-Coder-7B。这些模型基于 Qwen2.5 架构,并在超过 5.5 万亿个 tokens 的大规模语料库上进行了进一步预训练。
Qwen2.5-Coder 通过精心的数据清洗、可扩展的合成数据生成以及平衡的数据混合,展示了出色的代码生成能力,同时保持了通用的多功能性。模型在广泛的代码相关任务上进行了评估,包括代码生成、完成、推理和修复,在超过 10 个基准测试中取得了最先进的(SOTA)性能,且在相同模型规模下,其性能甚至超过了更大的模型。
Qwen2.5-Coder 采用了两种不同规模的模型架构,分别为1.5B参数和7B参数的模型。
- 这两种模型在某些关键配置上有所不同,但共享相同的词汇表大小和训练数据量。
- Qwen2.5-Coder 继承了 Qwen2.5 的词汇表,但引入了若干特殊标记,以帮助模型更好地理解代码。
嵌入层绑定(Embedding Tying)是指在模型中使用相同的权重矩阵来生成输入嵌入和输出嵌入。Qwen2.5-Coder 1.5B 模型使用了嵌入层绑定技术,而7B模型则没有。嵌入层绑定可以减少模型的参数量,同时在某些任务上提高模型的性能。
(1) 数据收集
Qwen2.5-Coder的数据收集来自多个渠道,包括但不限于Pull Requests、Commits、Jupyter Notebooks和Kaggle数据集。此外,我们还从Common Crawl中提取了大量的文本-代码混合数据,这些数据包括代码相关的文档、教程和博客等。通过这些多渠道的数据收集,我们确保了模型能够接触到不同领域和风格的代码,从而提升其适应性和多样性。
(2) 数据清洗
为了确保数据的质量,我们设计了一套多阶段的数据清洗流程。这一流程采用了粗到细的层次过滤方法,通过多个过滤器逐步筛选数据。每个过滤器负责一个特定的维度,确保数据在每个维度上都得到全面处理。此外,这种方法还能够为数据分配质量评分,最终保留的数据质量更高,为高质量的数据混合提供了有价值的参考。
具体来说,我们的清洗流程包括以下几个步骤:
- 初步过滤:使用较小的模型(如fastText)进行表面特征的过滤,去除明显无关或低质量的数据。
- 深度过滤:使用更复杂的模型进行进一步的过滤,确保数据的语义和逻辑正确性。
- 质量评分:为每条数据分配质量评分,确保最终保留的数据质量最高。 通过这一多阶段的清洗流程,我们显著提高了数据的质量,从而提升了模型的训练效果。
(3) 数据清理与混合
在数据清理和混合过程中,我们特别关注如何平衡不同类型的数据,以构建一个强大的基础模型。虽然研究社区之前已经探索过这种平衡,但针对大规模数据集的可扩展性证据仍然有限。为了找到最优的数据混合比例,我们进行了多个实验,设计了不同的数据比例组合,具体包括:
- 100:0:0:100% 代码数据,0% 文本数据,0% 数学数据。
- 85:10:5:85% 代码数据,10% 文本数据,5% 数学数据。
- 70:20:10:70% 代码数据,20% 文本数据,10% 数学数据。
配比 | 代码数据 | 文本数据 | 数学数据 | |
---|---|---|---|---|
100:0:0 | 100% | 0% | 0% | |
85:10:5 | 85% | 10% | 5% | |
70:20:10 | 70% | 20% | 10% | 最优 |
实验结果显示,70:20:10 比例表现最佳,甚至超过了代码数据比例更高的组合。这可能是因为数学和文本数据在达到一定浓度时,能够正向促进代码性能的提升。
最终,选择了70%代码、20%文本和10%数学数据的比例。训练数据集包含5.2万亿个token。
数据类型
- 代码数据
- 代码数据主要来自上述多个渠道,包括Pull Requests、Commits、Jupyter Notebooks和Kaggle数据集。我们还从Common Crawl中提取了大量的高质量代码数据。这些数据经过多阶段的清洗和过滤,确保了其高质量和多样性。
- 数学数据
- 为了增强模型的数学能力,我们整合了Qwen2.5-Math的预训练语料库。这些数学数据的引入不仅没有负面影响模型的代码性能,反而提升了其在数学任务上的表现。
- 文本数据
- 类似于数学数据,我们还引入了Qwen2.5模型的高质量自然语言数据,以保持Qwen2.5-Coder的通用能力。这些数据在清洗阶段已经经过了严格的质量检查,因此无需进一步处理。然而,我们移除了所有代码段,以避免与代码数据重叠,确保不同数据源的独立性。
通过这些细致的数据处理和混合策略,Qwen2.5-Coder在多个任务上表现出色,特别是在代码生成、代码完成和代码推理等方面。
训练策略
QWen 2.5
-> File-Level Pretrain -> Repo-Level Pretrain ->QWen 2.5-Code-Base
-> Code SFT ->QWen 2.5-Code-Instructed
QWen-VL
Qwen-VL 是阿里云研发的大规模视觉语言模型(Large Vision Language Model, LVLM)。
Qwen-VL 可以以图像、文本、检测框作为输入,并以文本和检测框作为输出。
Qwen-VL-Chat
=大语言模型
(Qwen-7B) +视觉图片特征编码器
(Openclip ViT-bigG) +位置感知视觉语言适配器
(可训练Adapter)+ 1.5B的图文数据 + 多轮训练 + 对齐机制(Chat)
Qwen-VL 系列模型特点:
- 多语言对话模型:天然支持英文、中文等多语言对话,端到端支持图片里中英双语的长文本识别;
- 多图交错对话:支持多图输入和比较,指定图片问答,多图文学创作等;
- 开放域目标定位:通过中文开放域语言表达进行检测框标注;
- 细粒度识别和理解:448分辨率可以提升细粒度的文字识别、文档问答和检测框标注。
硬件要求
- A100、H100、RTX3060、RTX3070等显卡建议启用bf16精度以节省显存
- V100、P100、T4等显卡建议启用fp16精度以节省显存
- 使用CPU进行推理,需要约32GB内存,默认GPU进行推理,需要约24GB显存
【2024-6-12】Qwen-VL多模态大模型的微调与部署
数据
Tokenizer
词表大小影响者模型的训练效率和下游任务效果,Qwen采用开源快速BPE
分词器-tiktoken
,以cl100k为基础词库,增加了常用的中文字词以及其他语言的词汇,并把数字字符串拆成单个数字,最终词表大小为152K。
从不同语言上对比不同模型的压缩率,如下图所示,Qwen在绝大多少语言上都优于 LLaMA-7B、Baichuan-7B、ChatGLM-6B、InternLM-7B 模型。
从 Qwen2.5 开始,Qwen 模型家族,包括多模态和专项模型,将使用统一的词汇表,其中包含了所有子系列的控制 token 。Qwen2.5 词汇表中有 22 个控制 token,使得词汇表的总规模达到 151665 。
- 通用 token 1个:
<|endoftext|>
- 对话 token 2个:
<|im_start|>
和<|im_end|>
- 工具调用 token 2个:
<tool_call>
和</tool_call>
- 视觉相关 token 11个
- 代码相关 token 6个
要点:
- Qwen 使用带有控制 token 的
ChatML
作为对话模板。
ChatML 格式,利用控制 token 来格式化每一轮的对话。
<|im_start|>
<div class="page clearfix" post>
<!-- 左侧布局 -->
<div class="left">
<!-- 文章标题,page是全局变量 -->
<h1>分布式训练</h1>
<div class="label">
<div class="label-card">
<i class="fa fa-calendar"></i>2024-03-05
</div>
<div class="label-card">
<i class="fa fa-user"></i>鹤啸九天
</div>
<div class="label-card">
</div>
<div class="label-card">
<!-- <span class="point">•</span> -->
<span class="categories">
<i class="fa fa-th-list"></i>
<a href="/category/#大模型" title="Category: 大模型" rel="category">大模型</a>
<!-- <span class="point">•</span> -->
</span>
</div>
<div class="label-card">
<!-- <span class="point">•</span> -->
<span class="pageTag">
<i class="fa fa-tags"></i>
<!--a href="/tag/#GPU" title="Tag: GPU" rel="tag">GPU</a-->
<a href="/tag/#GPU" title="Tag: GPU" rel="tag">GPU</a>
<!--a href="/tag/#Tensorflow" title="Tag: Tensorflow" rel="tag">Tensorflow</a-->
<a href="/tag/#Tensorflow" title="Tag: Tensorflow" rel="tag">Tensorflow</a>
<!--a href="/tag/#Pytorch" title="Tag: Pytorch" rel="tag">Pytorch</a-->
<a href="/tag/#Pytorch" title="Tag: Pytorch" rel="tag">Pytorch</a>
<!--a href="/tag/#%E5%B9%B6%E8%A1%8C%E8%AE%A1%E7%AE%97" title="Tag: 并行计算" rel="tag">并行计算</a-->
<a href="/tag/#并行计算" title="Tag: 并行计算" rel="tag">并行计算</a>
<!--a href="/tag/#%E5%88%86%E5%B8%83%E5%BC%8F" title="Tag: 分布式" rel="tag">分布式</a-->
<a href="/tag/#分布式" title="Tag: 分布式" rel="tag">分布式</a>
<!--a href="/tag/#huggingface" title="Tag: huggingface" rel="tag">huggingface</a-->
<a href="/tag/#huggingface" title="Tag: huggingface" rel="tag">huggingface</a>
</span>
</div>
<!-- 【2022-9-26】站点访问统计 -->
<div align="right">阅读量<span id="busuanzi_value_page_pv"></span>次 </div>
</div>
<!-- 导读区 -->
<p style="color: #54489c; transition: all 0.5s ease 0s;"><i>Notes(温馨提示):</i></p>
<p>
<small>
<ol>
<li>★ 首次阅读建议浏览:<a href="/navi">导航指南</a>, 或划到本页末尾, 或直接<a href="/navi#可视化导航">点击跳转</a>, 查看全站导航图</li>
<li>★ <font color='green'>右上角工具条搜索文章,右下角二维码关注微信公众号(鹤啸九天),底栏分享、赞赏、评论</font>。</li>
<li>★ 转载请注明文章来源,知识点积累起来不容易,水滴石穿,绳锯木断,谢谢理解</li>
<li>★ 如有疑问,<a href="mailto:wqw547243068@163.com">邮件</a>讨论,欢迎贡献优质资料</li>
</ol>
</small>
</p>
<!-- 文章内容 -->
<hr>
<article itemscope itemtype="http://schema.org/BlogPosting">
<ul id="markdown-toc">
<li><a href="#分布式训练库" id="markdown-toc-分布式训练库">分布式训练库</a> <ul>
<li><a href="#常见框架" id="markdown-toc-常见框架">常见框架</a></li>
<li><a href="#llm-复现选择" id="markdown-toc-llm-复现选择">LLM 复现选择</a></li>
<li><a href="#deepspeed--微软" id="markdown-toc-deepspeed--微软">DeepSpeed – 微软</a></li>
<li><a href="#trl" id="markdown-toc-trl">trl</a> <ul>
<li><a href="#trl-实践" id="markdown-toc-trl-实践">Trl 实践</a></li>
</ul>
</li>
<li><a href="#trainer" id="markdown-toc-trainer">Trainer</a> <ul>
<li><a href="#trainer-定义" id="markdown-toc-trainer-定义">Trainer 定义</a></li>
<li><a href="#自定义" id="markdown-toc-自定义">自定义</a> <ul>
<li><a href="#model_init" id="markdown-toc-model_init">model_init</a></li>
<li><a href="#compute_metrics" id="markdown-toc-compute_metrics">compute_metrics</a></li>
<li><a href="#加权loss" id="markdown-toc-加权loss">加权loss</a></li>
</ul>
</li>
<li><a href="#参数详解" id="markdown-toc-参数详解">参数详解</a> <ul>
<li><a href="#trainer类-参数" id="markdown-toc-trainer类-参数">Trainer类 参数</a></li>
<li><a href="#trainingarguments-参数" id="markdown-toc-trainingarguments-参数">TrainingArguments 参数</a></li>
</ul>
</li>
</ul>
</li>
<li><a href="#firefly" id="markdown-toc-firefly">Firefly</a></li>
<li><a href="#torchtune" id="markdown-toc-torchtune">TorchTune</a> <ul>
<li><a href="#torchtune-功能" id="markdown-toc-torchtune-功能">TorchTune 功能</a></li>
<li><a href="#torchtune-微调" id="markdown-toc-torchtune-微调">TorchTune 微调</a></li>
<li><a href="#torchtune-安装" id="markdown-toc-torchtune-安装">torchtune 安装</a></li>
</ul>
</li>
<li><a href="#torchtitan" id="markdown-toc-torchtitan">torchtitan</a></li>
<li><a href="#总结" id="markdown-toc-总结">总结</a></li>
<li><a href="#llama-factory" id="markdown-toc-llama-factory">LLaMA-Factory</a> <ul>
<li><a href="#llama-factory-介绍" id="markdown-toc-llama-factory-介绍">LLaMA-Factory 介绍</a></li>
<li><a href="#llama-factory-安装" id="markdown-toc-llama-factory-安装">LLaMA-Factory 安装</a></li>
<li><a href="#模型下载" id="markdown-toc-模型下载">模型下载</a></li>
<li><a href="#llama-factory-命令行" id="markdown-toc-llama-factory-命令行">LLaMA-Factory 命令行</a> <ul>
<li><a href="#常用命令" id="markdown-toc-常用命令">常用命令</a></li>
<li><a href="#参数" id="markdown-toc-参数">参数</a></li>
<li><a href="#推理" id="markdown-toc-推理">推理</a> <ul>
<li><a href="#transformers" id="markdown-toc-transformers">transformers</a></li>
<li><a href="#api" id="markdown-toc-api">API</a></li>
</ul>
</li>
<li><a href="#ollama" id="markdown-toc-ollama">Ollama</a></li>
</ul>
</li>
<li><a href="#llama-factory-可视化" id="markdown-toc-llama-factory-可视化">LLaMA-Factory 可视化</a> <ul>
<li><a href="#llama-board" id="markdown-toc-llama-board">LLaMA Board</a></li>
<li><a href="#wb" id="markdown-toc-wb">W&B</a></li>
<li><a href="#swanlab" id="markdown-toc-swanlab">SwanLab</a></li>
</ul>
</li>
<li><a href="#数据集" id="markdown-toc-数据集">数据集</a></li>
<li><a href="#llama-factory-使用" id="markdown-toc-llama-factory-使用">LLaMA-Factory 使用</a> <ul>
<li><a href="#指令监督微调" id="markdown-toc-指令监督微调">指令监督微调</a></li>
<li><a href="#奖励模型训练" id="markdown-toc-奖励模型训练">奖励模型训练</a></li>
<li><a href="#ppo-训练" id="markdown-toc-ppo-训练">ppo 训练</a></li>
<li><a href="#dpo-训练" id="markdown-toc-dpo-训练">dpo 训练</a></li>
</ul>
</li>
</ul>
</li>
<li><a href="#xtuner" id="markdown-toc-xtuner">Xtuner</a></li>
<li><a href="#swift" id="markdown-toc-swift">SWIFT</a></li>
</ul>
</li>
<li><a href="#结束" id="markdown-toc-结束">结束</a></li>
</ul>
<h1 id="分布式训练库">分布式训练库</h1>
<h2 id="常见框架">常见框架</h2>
<p>常见的分布式训练框架:</p>
<ul>
<li>第一类:深度学习框架<strong>自带</strong>分布式训练功能。如:TensorFlow、PyTorch、MindSpore、Oneflow、PaddlePaddle等。</li>
<li>第二类:基于现有深度学习框架(如:PyTorch、Flax)进行<strong>扩展和优化</strong>,从而进行分布式训练。
<ul>
<li>如:<code class="language-plaintext highlighter-rouge">Megatron-LM</code>(张量并行)、<code class="language-plaintext highlighter-rouge">DeepSpeed</code>(Zero-DP)、<code class="language-plaintext highlighter-rouge">Colossal-AI</code>(高维模型并行,如2D、2.5D、3D)、<code class="language-plaintext highlighter-rouge">Alpa</code>(自动并行)等</li>
</ul>
</li>
</ul>
<table>
<thead>
<tr>
<th>训练框架</th>
<th>诞生时间</th>
<th>作者</th>
<th>功能</th>
<th>分析</th>
</tr>
</thead>
<tbody>
<tr>
<td>DeepSpeed</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>LLama-Factory</td>
<td>2023-6-3</td>
<td>北航博士生</td>
<td>多种模型快捷微调</td>
<td>必须依赖: torch/transformers/datasets/trl/accelerate/peft<br />可选依赖: CUDA/deepspeed/bitsandbytes/vllm/flash-attn</td>
</tr>
<tr>
<td>TorchTune</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>trl</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>trlx</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>Firefly</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>Xtuner</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>SWIFT</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
</tbody>
</table>
<h2 id="llm-复现选择">LLM 复现选择</h2>
<p>如何选择分布式训练框架? <a href="https://mp.weixin.qq.com/s/7wtwsNhf27YzALnSFXTmkA">参考</a></p>
<ul>
<li>训练<strong>成本</strong>:不同训练工具,训练同样大模型,成本不一样。对于大模型,训练一次动辄上百万/千万美元的费用。合适的成本始终是正确的选择。</li>
<li>训练<strong>类型</strong>:是否支持数据并行、张量并行、流水线并行、多维混合并行、自动并行等</li>
<li><strong>效率</strong>:将普通模型训练代码变为分布式训练所需编写代码的行数,希望越少越好。</li>
<li><strong>灵活性</strong>:选择的框架是否可以跨不同平台使用?</li>
</ul>
<p>目前训练超大规模语言模型主要有两条技术路线:</p>
<ul>
<li>TPU + XLA + TensorFlow/JAX :由Google主导,由于TPU和自家云平台GCP深度绑定</li>
<li>GPU + PyTorch + Megatron-LM + DeepSpeed :由 NVIDIA、Meta、MicroSoft 大厂加持,社区氛围活跃,也更受到大家欢迎。</li>
</ul>
<h2 id="deepspeed--微软">DeepSpeed – 微软</h2>
<p>DeepSpeed 是 Microsoft 基于 PyTorch 研发的开源深度学习优化库。</p>
<ul>
<li>目的: 降低大模型训练的门槛,提升大模型的训练的效率,帮助开发者更有效率地管理及优化大模型的训练、部署任务。</li>
</ul>
<p>详见站内专题: <a href="deepspeed">DeepSpeed</a></p>
<p>【2023-8-28】<a href="https://github.com/hiyouga/LLaMA-Efficient-Tuning/blob/main/README_zh.md">LLaMA Efficient Tuning</a></p>
<table>
<thead>
<tr>
<th>方法</th>
<th>全参数训练</th>
<th>部分参数训练</th>
<th>LoRA</th>
<th>QLoRA</th>
</tr>
</thead>
<tbody>
<tr>
<td>预训练</td>
<td>✅</td>
<td>✅</td>
<td>✅</td>
<td>✅</td>
</tr>
<tr>
<td>指令监督微调</td>
<td>✅</td>
<td>✅</td>
<td>✅</td>
<td>✅</td>
</tr>
<tr>
<td>奖励模型训练</td>
<td> </td>
<td> </td>
<td>✅</td>
<td>✅</td>
</tr>
<tr>
<td>PPO 训练</td>
<td> </td>
<td> </td>
<td>✅</td>
<td>✅</td>
</tr>
<tr>
<td>DPO 训练</td>
<td>✅</td>
<td> </td>
<td>✅</td>
<td>✅</td>
</tr>
</tbody>
</table>
<h2 id="trl">trl</h2>
<p>【2024-3-13】<a href="https://huggingface.co/docs/trl/index">TRL - Transformer Reinforcement Learning</a></p>
<p>huggingface 推出的全栈库,包含一整套工具,用于使用强化学习 (Reinforcement Learning) 训练 transformer 语言模型。</p>
<ul>
<li>从<strong>监督调优</strong> (Supervised Fine-tuning step, SFT),到训练<strong>奖励模型</strong> (Reward Modeling),再到<strong>近端策略优化</strong> (Proximal Policy Optimization),全面覆盖</li>
<li><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/TRL-readme.png" alt="" /></li>
<li><a href="https://github.com/huggingface/trl">TRL</a> 库已经与 🤗 transformers 集成,直接使用!</li>
<li>👉 文档<a href="https://hf.co/docs/trl/">地址</a></li>
<li><img src="https://picx.zhimg.com/70/v2-1c818186d30b9afff9af2341b1eddc6f_1440w.avis?source=172ae18b&biz_tag=Post" alt="" /></li>
</ul>
<p>API 文档里功能:</p>
<ul>
<li>Model Class: 公开模型各自用途</li>
<li>SFTTrainer: SFTTrainer 实现模型监督调优</li>
<li>RewardTrainer: RewardTrainer 训练奖励模型</li>
<li>PPOTrainer: PPO 算法对经过监督调优的模型再调优</li>
<li>Best-of-N Samppling: 将“拔萃法”作为从模型的预测中采样的替代方法</li>
<li>DPOTrainer: 用 DPOTrainer 完成直接偏好优化</li>
</ul>
<p>文档中给出了几个例子:</p>
<ul>
<li>Sentiment Tuning: 调优模型以生成更积极的电影内容</li>
<li>Training with PEFT: 执行由 PEFT 适配器优化内存效率的 RLHF 训练</li>
<li>Detoxifying LLMs: 通过 RLHF 为模型解毒,使其更符合人类的价值观</li>
<li>StackLlama: 在 Stack exchange 数据集上实现端到端 RLHF 训练一个 Llama 模型</li>
<li>Multi-Adapter Training: 使用单一模型和多适配器实现优化内存效率的端到端训练</li>
</ul>
<h3 id="trl-实践">Trl 实践</h3>
<p>【2023-6-30】<a href="https://zhuanlan.zhihu.com/p/616788557">使用TRL强化学习PPO控制文本的生成</a></p>
<p>步骤</p>
<ol>
<li>初始化 GPT2 对话模型, 即LLM模型。Huggface中的这个中文对话模型
<ul>
<li><a href="https://huggingface.co/shibing624/gpt2-dialogbot-base-chinese">gpt2-dialogbot-base-chinese</a></li>
</ul>
</li>
<li>初始化一个情感分类模型即RM模型。这里笔者使用的是Huggface中的这个情感分类模型
<ul>
<li>样本情感极性越正向,模型输出的得分越大。</li>
<li><a href="https://huggingface.co/liam168/c2-roberta-base-finetuned-dianping-chinese">c2-roberta-base-finetuned-dianping-chinese</a></li>
</ul>
</li>
<li>通过PPO强化学习算法,利用情感分类模型评估对话模型的输出,对GPT2对话模型进行优化,让GPT2对话模型的输出的结果在情感分类模型中得到高分。同时不破坏GPT2对话模型输出通顺对话的能力。</li>
</ol>
<p>强行学习训练</p>
<ol>
<li>输入样本给GPT2, 拿到对话语言模型 GPT2的输出。</li>
<li>将对话语言模型GPT2的输出 输入到 情感分类模型 拿到 情感分类模型的输出,作为reward。</li>
<li>将对话语言模型GPT2 输入,输出, 以及 情感分类模型的 reward 一并输入给PPO优化器,让PPO优化器去优化对话语言模型GPT2。</li>
</ol>
<div class="language-py highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">from</span> <span class="nn">transformers</span> <span class="kn">import</span> <span class="n">AutoTokenizer</span>
<span class="kn">from</span> <span class="nn">trl</span> <span class="kn">import</span> <span class="n">PPOTrainer</span><span class="p">,</span> <span class="n">PPOConfig</span><span class="p">,</span> <span class="n">AutoModelForCausalLMWithValueHead</span><span class="p">,</span> <span class="n">create_reference_model</span>
<span class="kn">from</span> <span class="nn">trl.core</span> <span class="kn">import</span> <span class="n">respond_to_batch</span>
<span class="kn">import</span> <span class="nn">random</span>
<span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="n">F</span>
<span class="c1"># get models
</span><span class="n">gen_model</span> <span class="o">=</span> <span class="n">AutoModelForCausalLMWithValueHead</span><span class="p">.</span><span class="n">from_pretrained</span><span class="p">(</span><span class="s">'dialoggpt/'</span><span class="p">)</span>
<span class="n">model_ref</span> <span class="o">=</span> <span class="n">create_reference_model</span><span class="p">(</span><span class="n">gen_model</span><span class="p">)</span>
<span class="n">tokenizerOne</span> <span class="o">=</span> <span class="n">AutoTokenizer</span><span class="p">.</span><span class="n">from_pretrained</span><span class="p">(</span><span class="s">'dialoggpt/'</span><span class="p">,</span><span class="n">padding_side</span><span class="o">=</span><span class="s">'left'</span><span class="p">)</span>
<span class="n">tokenizerOne</span><span class="p">.</span><span class="n">eos_token_id</span> <span class="o">=</span> <span class="n">tokenizerOne</span><span class="p">.</span><span class="n">sep_token_id</span>
<span class="c1"># 初始化一个情感分类模型,输入文本,判断文本的情感极性
</span><span class="kn">from</span> <span class="nn">transformers</span> <span class="kn">import</span> <span class="n">AutoModelForSequenceClassification</span> <span class="p">,</span> <span class="n">AutoTokenizer</span><span class="p">,</span> <span class="n">pipeline</span>
<span class="n">ts_texts</span> <span class="o">=</span> <span class="p">[</span><span class="s">"我喜欢下雨。"</span><span class="p">,</span> <span class="s">"我讨厌他."</span><span class="p">]</span>
<span class="n">cls_model</span> <span class="o">=</span> <span class="n">AutoModelForSequenceClassification</span><span class="p">.</span><span class="n">from_pretrained</span><span class="p">(</span><span class="s">"./chineseSentiment/"</span><span class="p">,</span> <span class="n">num_labels</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">tokenizerTwo</span> <span class="o">=</span> <span class="n">AutoTokenizer</span><span class="p">.</span><span class="n">from_pretrained</span><span class="p">(</span><span class="s">"./chineseSentiment/"</span><span class="p">)</span>
<span class="n">classifier</span> <span class="o">=</span> <span class="n">pipeline</span><span class="p">(</span><span class="s">'sentiment-analysis'</span><span class="p">,</span> <span class="n">model</span><span class="o">=</span><span class="n">cls_model</span><span class="p">,</span> <span class="n">tokenizer</span><span class="o">=</span><span class="n">tokenizerTwo</span><span class="p">)</span>
<span class="n">classifier</span><span class="p">(</span><span class="n">ts_texts</span><span class="p">)</span>
<span class="c1"># 数据预处理
</span><span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">Dataset</span>
<span class="kn">import</span> <span class="nn">torch.nn.utils.rnn</span> <span class="k">as</span> <span class="n">rnn_utils</span>
<span class="kn">import</span> <span class="nn">json</span>
<span class="n">data</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="s">"./train.txt"</span><span class="p">,</span> <span class="s">"r"</span><span class="p">,</span> <span class="n">encoding</span><span class="o">=</span><span class="s">"utf-8"</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">f</span><span class="p">.</span><span class="n">readlines</span><span class="p">():</span>
<span class="n">line</span> <span class="o">=</span> <span class="n">json</span><span class="p">.</span><span class="n">loads</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
<span class="n">data</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">line</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">preprocess_conversation</span><span class="p">(</span><span class="n">data</span><span class="p">):</span>
<span class="n">sep_id</span> <span class="o">=</span> <span class="n">tokenizerOne</span><span class="p">.</span><span class="n">sep_token_id</span>
<span class="n">cls_id</span> <span class="o">=</span> <span class="n">tokenizerOne</span><span class="p">.</span><span class="n">cls_token_id</span>
<span class="n">dialogue_list</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">conver</span> <span class="ow">in</span> <span class="n">data</span><span class="p">:</span>
<span class="n">input_ids</span> <span class="o">=</span> <span class="p">[</span><span class="n">cls_id</span><span class="p">]</span>
<span class="n">start</span> <span class="o">=</span> <span class="n">conver</span><span class="p">[</span><span class="s">"conversation"</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span>
<span class="c1"># print(start["utterance"])
</span> <span class="n">input_ids</span> <span class="o">+=</span> <span class="n">tokenizerOne</span><span class="p">.</span><span class="n">encode</span><span class="p">(</span><span class="n">start</span><span class="p">[</span><span class="s">"utterance"</span><span class="p">],</span> <span class="n">add_special_tokens</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
<span class="n">input_ids</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">sep_id</span><span class="p">)</span>
<span class="n">dialogue_list</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">input_ids</span><span class="p">)</span>
<span class="k">return</span> <span class="n">dialogue_list</span>
<span class="c1"># 数据处理
</span><span class="n">dialogue_list</span> <span class="o">=</span> <span class="n">preprocess_conversation</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="k">class</span> <span class="nc">MyDataset</span><span class="p">(</span><span class="n">Dataset</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span>
<span class="bp">self</span><span class="p">.</span><span class="n">data</span> <span class="o">=</span> <span class="n">data</span>
<span class="k">def</span> <span class="nf">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">index</span><span class="p">):</span>
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">data</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>
<span class="k">return</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">data</span><span class="p">)</span>
<span class="n">mydataset</span> <span class="o">=</span> <span class="n">MyDataset</span><span class="p">(</span><span class="n">dialogue_list</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">collate_fn</span><span class="p">(</span><span class="n">batch</span><span class="p">):</span>
<span class="n">padded_batch</span> <span class="o">=</span> <span class="n">rnn_utils</span><span class="p">.</span><span class="n">pad_sequence</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">batch_first</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">padding_value</span><span class="o">=</span><span class="n">tokenizerOne</span><span class="p">.</span><span class="n">sep_token_id</span><span class="p">)</span>
<span class="k">return</span> <span class="n">padded_batch</span>
<span class="c1"># 定义PPO优化器: 学习率,强化学习steps,batch_size等参数,学习率不宜调大,容易把LLM语言模型调坏。
</span><span class="n">config</span> <span class="o">=</span> <span class="n">PPOConfig</span><span class="p">(</span>
<span class="n">model_name</span><span class="o">=</span><span class="s">"gpt2-positive"</span><span class="p">,</span>
<span class="n">learning_rate</span><span class="o">=</span><span class="mf">1.41e-5</span><span class="p">,</span>
<span class="n">steps</span> <span class="o">=</span> <span class="mi">2000</span><span class="p">,</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="mi">16</span>
<span class="p">)</span>
<span class="n">ppo_trainer</span> <span class="o">=</span> <span class="n">PPOTrainer</span><span class="p">(</span><span class="n">config</span><span class="p">,</span> <span class="n">gen_model</span><span class="p">,</span> <span class="n">model_ref</span><span class="p">,</span> <span class="n">tokenizerOne</span><span class="p">,</span> <span class="n">dataset</span><span class="o">=</span><span class="n">mydataset</span><span class="p">,</span> <span class="n">data_collator</span><span class="o">=</span><span class="n">collate_fn</span><span class="p">)</span>
<span class="n">rewards_list</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">epoch</span><span class="p">,</span> <span class="n">batch</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">ppo_trainer</span><span class="p">.</span><span class="n">dataloader</span><span class="p">):</span>
<span class="c1">#### Get response from gpt2
</span> <span class="n">query_tensors</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">response_tensors</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">query_tensors</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">t</span><span class="p">).</span><span class="nb">long</span><span class="p">()</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">batch</span><span class="p">]</span>
<span class="k">for</span> <span class="n">query</span> <span class="ow">in</span> <span class="n">batch</span><span class="p">:</span>
<span class="n">input_ids</span> <span class="o">=</span> <span class="n">query</span><span class="p">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">response</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">30</span><span class="p">):</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="n">ppo_trainer</span><span class="p">.</span><span class="n">model</span><span class="p">(</span><span class="n">input_ids</span><span class="o">=</span><span class="n">input_ids</span><span class="p">)</span>
<span class="n">logits</span> <span class="o">=</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">next_token_logits</span> <span class="o">=</span> <span class="n">logits</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="p">:]</span>
<span class="n">next_token_logits</span><span class="p">[</span><span class="n">ppo_trainer</span><span class="p">.</span><span class="n">tokenizer</span><span class="p">.</span><span class="n">convert_tokens_to_ids</span><span class="p">(</span><span class="s">'[UNK]'</span><span class="p">)]</span> <span class="o">=</span> <span class="o">-</span><span class="nb">float</span><span class="p">(</span><span class="s">'Inf'</span><span class="p">)</span>
<span class="n">next_token</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">multinomial</span><span class="p">(</span><span class="n">F</span><span class="p">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">next_token_logits</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">),</span> <span class="n">num_samples</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="k">if</span> <span class="n">next_token</span> <span class="o">==</span> <span class="n">ppo_trainer</span><span class="p">.</span><span class="n">tokenizer</span><span class="p">.</span><span class="n">sep_token_id</span><span class="p">:</span> <span class="c1">#
</span> <span class="k">break</span>
<span class="n">input_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">((</span><span class="n">input_ids</span><span class="p">,</span> <span class="n">next_token</span><span class="p">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">response</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">next_token</span><span class="p">.</span><span class="n">item</span><span class="p">())</span>
<span class="n">response_tensors</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">response</span><span class="p">).</span><span class="nb">long</span><span class="p">())</span>
<span class="n">responseSet</span> <span class="o">=</span> <span class="p">[</span><span class="s">""</span><span class="p">.</span><span class="n">join</span><span class="p">(</span><span class="n">ppo_trainer</span><span class="p">.</span><span class="n">tokenizer</span><span class="p">.</span><span class="n">convert_ids_to_tokens</span><span class="p">([</span><span class="n">i</span><span class="p">.</span><span class="n">item</span><span class="p">()</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">r</span><span class="p">]))</span> <span class="k">for</span> <span class="n">r</span> <span class="ow">in</span> <span class="n">response_tensors</span><span class="p">]</span>
<span class="k">print</span><span class="p">(</span><span class="n">responseSet</span><span class="p">)</span>
<span class="c1">#### Get reward from sentiment model
</span> <span class="n">pipe_outputs</span> <span class="o">=</span> <span class="n">classifier</span><span class="p">(</span><span class="n">responseSet</span><span class="p">)</span>
<span class="n">rewards</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">output</span><span class="p">[</span><span class="s">"score"</span><span class="p">])</span> <span class="k">for</span> <span class="n">output</span> <span class="ow">in</span> <span class="n">pipe_outputs</span><span class="p">]</span>
<span class="c1">#### Run PPO step
</span> <span class="n">stats</span> <span class="o">=</span> <span class="n">ppo_trainer</span><span class="p">.</span><span class="n">step</span><span class="p">(</span><span class="n">query_tensors</span><span class="p">,</span> <span class="n">response_tensors</span><span class="p">,</span> <span class="n">rewards</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">"epoch{}, reword is {}"</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="n">epoch</span><span class="p">,</span> <span class="nb">sum</span><span class="p">(</span><span class="n">rewards</span><span class="p">)))</span>
<span class="n">rewards_list</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="nb">sum</span><span class="p">(</span><span class="n">rewards</span><span class="p">))</span>
</code></pre></div></div>
<h2 id="trainer">Trainer</h2>
<p>Trainer 名称歧义</p>
<ul>
<li>PyTorch Lightning有个 Trainer</li>
<li>HuggingFace Transformers也有 Trainer</li>
<li>还有一些github上封装的或者基于这两个继续封装的Trainer</li>
</ul>
<p>这里的 Trainer 指 Huggingface 的 Trainer 训练框架</p>
<p>Trainer 介于原生 torch 和 pytorch-lighning 之间,是轻量级的辅助torch模型训练的utils,因为其实稍微改造一下,huggingface的trainer 可用来训练常规的非nlp的torch模型。</p>
<ul>
<li>封装程度: <code class="language-plaintext highlighter-rouge">torch</code> < <code class="language-plaintext highlighter-rouge">pytorch lightning</code> < <code class="language-plaintext highlighter-rouge">trainer</code></li>
</ul>
<p>Trainer 封装了 PyTorch 训练过程,包括:<strong>前向传播</strong>、<strong>反向传播</strong>和<strong>参数更新</strong>等步骤,用户只需要设计模型,调参就行</p>
<p>高级的 Trainer 加上了各种功能,比如:<strong>日志记录</strong>,<strong>断点重训</strong>,<strong>训练方式</strong>与<strong>精度</strong>,支持各种分布式训练框架像原生、Apex、Deepspeed和Fairscale,支持自定的回调函数等等</p>
<p>Lightning 官网的一张gif还是比较生动形象</p>
<h3 id="trainer-定义">Trainer 定义</h3>
<p><a href="https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/trainer.py#L236">trainer.py</a></p>
<p>do_train,do_eval,do_predict 这三个参数和trainer没什么关系</p>
<h3 id="自定义">自定义</h3>
<h4 id="model_init">model_init</h4>
<p>model_init</p>
<div class="language-py highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">model_init</span><span class="p">():</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">AutoModelForSequenceClassification</span><span class="p">.</span><span class="n">from_pretrained</span><span class="p">(</span>
<span class="n">model_args</span><span class="p">.</span><span class="n">model_name_or_path</span><span class="p">,</span>
<span class="n">from_tf</span><span class="o">=</span><span class="nb">bool</span><span class="p">(</span><span class="s">".ckpt"</span> <span class="ow">in</span> <span class="n">model_args</span><span class="p">.</span><span class="n">model_name_or_path</span><span class="p">),</span>
<span class="n">config</span><span class="o">=</span><span class="n">config</span><span class="p">,</span>
<span class="n">cache_dir</span><span class="o">=</span><span class="n">model_args</span><span class="p">.</span><span class="n">cache_dir</span><span class="p">,</span>
<span class="n">revision</span><span class="o">=</span><span class="n">model_args</span><span class="p">.</span><span class="n">model_revision</span><span class="p">,</span>
<span class="n">use_auth_token</span><span class="o">=</span><span class="bp">True</span> <span class="k">if</span> <span class="n">model_args</span><span class="p">.</span><span class="n">use_auth_token</span> <span class="k">else</span> <span class="bp">None</span>
<span class="p">)</span>
<span class="k">return</span> <span class="n">model</span>
</code></pre></div></div>
<h4 id="compute_metrics">compute_metrics</h4>
<div class="language-py highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">compute_metrics</span><span class="p">(</span><span class="n">p</span><span class="p">:</span> <span class="n">EvalPrediction</span><span class="p">)</span> <span class="o">-></span> <span class="n">Dict</span><span class="p">:</span>
<span class="n">preds</span><span class="p">,</span><span class="n">labels</span><span class="o">=</span><span class="n">p</span>
<span class="n">preds</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">preds</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="c1">#print('shape:', preds.shape, '\n')
</span> <span class="n">precision</span><span class="p">,</span> <span class="n">recall</span><span class="p">,</span> <span class="n">f1</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">precision_recall_fscore_support</span><span class="p">(</span><span class="n">lables</span><span class="p">.</span><span class="n">flatten</span><span class="p">(),</span> <span class="n">preds</span><span class="p">.</span><span class="n">flatten</span><span class="p">(),</span> <span class="n">average</span><span class="o">=</span><span class="s">'weighted'</span><span class="p">,</span> <span class="n">zero_division</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="k">return</span> <span class="p">{</span>
<span class="s">'accuracy'</span><span class="p">:</span> <span class="p">(</span><span class="n">preds</span> <span class="o">==</span> <span class="n">p</span><span class="p">.</span><span class="n">label_ids</span><span class="p">).</span><span class="n">mean</span><span class="p">(),</span>
<span class="s">'f1'</span><span class="p">:</span> <span class="n">f1</span><span class="p">,</span>
<span class="s">'precision'</span><span class="p">:</span> <span class="n">precision</span><span class="p">,</span>
<span class="s">'recall'</span><span class="p">:</span> <span class="n">recall</span>
<span class="p">}</span>
</code></pre></div></div>
<h4 id="加权loss">加权loss</h4>
<p>分类任务中,类目不均衡时,采用加权loss</p>
<p>做法</p>
<ul>
<li>(1) 继承 Trainer 类, 重定义 compute_loss 函数</li>
<li>(2) 使用回调函数 <a href="https://huggingface.co/docs/transformers/v4.34.1/en/main_classes/callback">callback</a></li>
</ul>
<p>示例</p>
<ul>
<li>三分类问题,各类目加权 1 : 2 : 3</li>
</ul>
<div class="language-py highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="kn">from</span> <span class="nn">transformers</span> <span class="kn">import</span> <span class="n">Trainer</span>
<span class="k">class</span> <span class="nc">CustomTrainer</span><span class="p">(</span><span class="n">Trainer</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">compute_loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">return_outputs</span><span class="o">=</span><span class="bp">False</span><span class="p">):</span>
<span class="n">labels</span> <span class="o">=</span> <span class="n">inputs</span><span class="p">.</span><span class="n">pop</span><span class="p">(</span><span class="s">"labels"</span><span class="p">)</span>
<span class="c1"># forward pass
</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="o">**</span><span class="n">inputs</span><span class="p">)</span>
<span class="n">logits</span> <span class="o">=</span> <span class="n">outputs</span><span class="p">.</span><span class="n">get</span><span class="p">(</span><span class="s">"logits"</span><span class="p">)</span>
<span class="c1"># compute custom loss (suppose one has 3 labels with different weights)
</span> <span class="n">loss_fct</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">CrossEntropyLoss</span><span class="p">(</span><span class="n">weight</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">,</span> <span class="mf">3.0</span><span class="p">],</span> <span class="n">device</span><span class="o">=</span><span class="n">model</span><span class="p">.</span><span class="n">device</span><span class="p">))</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">loss_fct</span><span class="p">(</span><span class="n">logits</span><span class="p">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">model</span><span class="p">.</span><span class="n">config</span><span class="p">.</span><span class="n">num_labels</span><span class="p">),</span> <span class="n">labels</span><span class="p">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">))</span>
<span class="k">return</span> <span class="p">(</span><span class="n">loss</span><span class="p">,</span> <span class="n">outputs</span><span class="p">)</span> <span class="k">if</span> <span class="n">return_outputs</span> <span class="k">else</span> <span class="n">loss</span>
</code></pre></div></div>
<h3 id="参数详解">参数详解</h3>
<p><a href="https://huggingface.co/docs/transformers/v4.34.1/en/main_classes/trainer#trainer">Trainer 官网文档</a>,版本为4.34.0</p>
<h4 id="trainer类-参数">Trainer类 参数</h4>
<p>Transformers Trainer类 参数:</p>
<ul>
<li><code class="language-plaintext highlighter-rouge">model</code> (<code class="language-plaintext highlighter-rouge">PreTrainedModel</code> 或 <code class="language-plaintext highlighter-rouge">torch.nn.Module</code>, 可选):训练、评估或预测的实例化模型
<ul>
<li>如果不提供,必须传递一个 <code class="language-plaintext highlighter-rouge">model_init</code> 来初始化一个模型。</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">args</code> (TrainingArguments, 可选):训练参数
<ul>
<li>如果不提供,用 TrainingArguments 默认参数,其中 output_dir 设置为当前目录中的名为 “tmp_trainer” 的目录。</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">data_collator</code> (DataCollator, 可选):用于从 train_dataset 或 eval_dataset 中构成batch的函数
<ul>
<li>如果未提供tokenizer,将默认使用 default_data_collator();如果提供,将使用 DataCollatorWithPadding 。</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">train_dataset</code> (torch.utils.data.<code class="language-plaintext highlighter-rouge">Dataset</code> 或 torch.utils.data.<code class="language-plaintext highlighter-rouge">IterableDataset</code>, 可选):训练数据集
<ul>
<li>如果是 torch.utils.data.Dataset,则会自动删除模型的 forward() 方法不接受的列。</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">eval_dataset</code> (Union[torch.utils.data.Dataset, Dict[str, torch.utils.data.Dataset]), 可选):同上,评估数据集
<ul>
<li>如果是字典,将对每个数据集进行评估,并在指标名称前附加字典的键值。</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">tokenizer</code> (PreTrainedTokenizerBase, 可选):预处理数据的<strong>分词器</strong>
<ul>
<li>如果提供,将在批量输入时自动对输入进行填充到最大长度,并会保存在模型目录下中,为了重新运行中断的训练或重复微调模型时更容易进行操作。</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">model_init</code> (Callable[[], PreTrainedModel], 可选):模型实例化函数
<ul>
<li>如果提供,每次调用 train() 时都会从此函数给出的模型的新实例开始。</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">compute_metrics</code> (Callable[<code class="language-plaintext highlighter-rouge">[EvalPrediction]</code>, Dict], 可选):评估时<strong>计算指标</strong>的函数,必须接受 EvalPrediction 作为入参,并返回一个字典,其中包含了不同性能指标的名称和相应的数值,一般是准确度、精确度、召回率、F1 分数等。</li>
<li><code class="language-plaintext highlighter-rouge">callbacks</code> (TrainerCallback 列表, 可选):自定义<strong>回调函数</strong>
<ul>
<li>如果要删除使用的默认回调函数,要使用 Trainer.remove_callback() 方法。</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">optimizers</code> (Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR], 可选):指定包含优化器和学习率调度器的元组(Tuple)
<ul>
<li>元组的两个元素分别是<strong>优化器</strong>(torch.optim.Optimizer)和<strong>学习率调度器</strong>(torch.optim.lr_scheduler.LambdaLR),默认会创建一个基于AdamW优化器的实例,并使用 get_linear_schedule_with_warmup() 函数创建一个学习率调度器。</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">preprocess_logits_for_metrics</code> (Callable[[torch.Tensor, torch.Tensor], torch.Tensor], 可选):指定函数,每次评估步骤(evaluation step)前,进入compute_metrics函数前对模型的输出 logits 进行<strong>预处理</strong>。
<ul>
<li>接受两个张量(tensors)作为参数,一个是模型的输出 logits,另一个是<strong>真实标签</strong>(labels)。</li>
<li>然后返回一个经过预处理后的 logits 张量,给到compute_metrics函数作为参数。</li>
</ul>
</li>
</ul>
<h4 id="trainingarguments-参数">TrainingArguments 参数</h4>
<p>args:超参数定义,trainer的重要功能,大部分训练相关的参数都是这里设置</p>
<p>TrainingArguments 有接近100个参数</p>
<p>TrainingArguments 参数</p>
<ul>
<li><code class="language-plaintext highlighter-rouge">output_dir</code> (str):模型checkpoint/最终结果的输出目录。</li>
<li><code class="language-plaintext highlighter-rouge">overwrite_output_dir</code> (bool, 可选,默认为 False):如果设置为True,将<strong>覆盖</strong>输出目录中已存在的内容
<ul>
<li>继续训练模型并且输出目录, 指向一个checkpoint目录。</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">do_train</code> (bool, 可选,默认为 False):是否执行<strong>训练</strong>
<ul>
<li>其实Trainer 不直接使用此参数,主要是用于写脚本时,作为if的条件来判断是否执行接下来的代码。</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">do_eval</code> (bool, 可选):是否在验证集上进行<strong>评估</strong>,如果评估策略(evaluation_strategy)不是”no”,将自动设置为True。
<ul>
<li>与do_train类似,不直接由Trainer使用,主要是用于写训练脚本。</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">do_predict</code> (bool, 可选,默认为 False):是否在测试集上<strong>预测</strong>。</li>
<li><code class="language-plaintext highlighter-rouge">evaluation_strategy</code> (str, 可选,默认为 “no”):指定训练期间采用的评估策略,可选值包括:
<ul>
<li>“no”:在训练期间不进行任何评估。</li>
<li>“steps”:每隔 eval_steps 步骤进行评估。</li>
<li>“epoch”:每个训练周期结束时进行评估。</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">prediction_loss_only</code> (bool, 可选, 默认为 False):
<ul>
<li>如果设置为True,评估和预测时,只返回<strong>损失值</strong>,而不返回其他评估指标。</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">per_device_train_batch_size</code> (int, 可选, 默认为 8):<strong>训练</strong>阶段,每个GPU/XPU/TPU/MPS/NPU/CPU的batch,每个训练步骤中每个硬件上的样本数量。</li>
<li><code class="language-plaintext highlighter-rouge">per_device_eval_batch_size</code> (int, 可选, 默认为 8):<strong>评估</strong>阶段的每个GPU/XPU/TPU/MPS/NPU/CPU的batch,每个评估步骤中每个硬件上的样本数量。</li>
<li><code class="language-plaintext highlighter-rouge">gradient_accumulation_steps</code> (int, 可选, 默认为 1):执行反向传播之前,<strong>梯度积累的更新步数</strong>。
<ul>
<li>梯度积累可以在多个batch上累积梯度,然后一次性执行反向传播,显存不够的情况下执行大batch的反向传播。</li>
<li>假设4张卡,每张卡的batch size为8,那么一个steps的batch size就是32,如果这个参数设置为4,那么做反向传播的训练样本数量就是128。</li>
<li>两个好处:①显存不够增大此参数;②能加快训练速度,毕竟做反向传播的次数少了。</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">eval_accumulation_steps</code> (int, 可选):执行评估时,模型会累积多少个预测步骤的输出张量,然后才从GPU/NPU/TPU移动到CPU上,默认是整个评估的输出结果将在GPU/NPU/TPU上累积,然后一次性传输到CPU,速度更快,但占显存。</li>
<li><code class="language-plaintext highlighter-rouge">eval_delay</code> (float, 可选):等待执行第一次评估的轮数或步数。
<ul>
<li>如果evaluation_strategy为”steps”,设置此参数为10,则10个steps后才进行首次评估。</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">learning_rate</code> (float, 可选, 默认为 5e-5):AdamW优化器的<strong>初始学习率</strong>。</li>
<li><code class="language-plaintext highlighter-rouge">weight_decay</code> (float, 可选, 默认为 0):<strong>权重衰减</strong>的值,应用在 AdamW 优化器所有层上,除了偏置(bias)和 Layer Normalization 层(LayerNorm)的权重上。
<ul>
<li>权重衰减是一种<strong>正则化</strong>手段,通过向损失函数添加一个额外的项,来惩罚较大的权重值,有助于防止模型<strong>过拟合</strong>训练数据。</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">adam_beta1</code> (float, 可选, 默认为 0.9):AdamW优化器的beta1超参数。</li>
<li><code class="language-plaintext highlighter-rouge">adam_beta2</code> (float, 可选, 默认为 0.999):AdamW优化器的beta2超参数。</li>
<li><code class="language-plaintext highlighter-rouge">adam_epsilon</code> (float, 可选, 默认为 1e-8):AdamW优化器的epsilon超参数。</li>
<li><code class="language-plaintext highlighter-rouge">max_grad_norm</code> (float, 可选, 默认为 1.0):梯度剪裁的最大梯度范数,可以防止梯度爆炸,一般都是1,如果某一步梯度的L2范数超过了 此参数,那么梯度将被重新缩放,确保它的大小不超过此参数。</li>
<li><code class="language-plaintext highlighter-rouge">num_train_epochs</code> (float, 可选, 默认为 3.0):训练的<strong>总epochs数</strong>。</li>
<li><code class="language-plaintext highlighter-rouge">max_steps</code> (int, 可选, 默认为 -1):如果设置为正数,执行的总训练步数,会覆盖 num_train_epochs。
<ul>
<li>注意:如果使用此参数,就算没有达到这个参数值的步数,训练也会在数据跑完后停止。</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">lr_scheduler_type</code> (str, 可选, 默认为”linear”):学习率scheduler类型,根据训练进程来自动调整学习率。详细见:
<ul>
<li>“linear”:<strong>线性</strong>学习率scheduler,学习率以线性方式改变</li>
<li>“cosine”:<strong>余弦</strong>学习率scheduler,学习率以余弦形状的方式改变。</li>
<li>“constant”:<strong>常数</strong>学习率,学习率在整个训练过程中保持不变。</li>
<li>“polynomial”:<strong>多项式</strong>学习率scheduler,学习率按多项式函数的方式变化。</li>
<li>“piecewise”:<strong>分段常数</strong>学习率scheduler,每个阶段使用不同的学习率。</li>
<li>“exponential”:<strong>指数</strong>学习率scheduler,学习率以指数方式改变。</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">warmup_ratio</code> (float, 可选, 默认为0.0):线性热身占总训练步骤的比例,线性热身是一种训练策略,学习率在开始阶段从0逐渐增加到其最大值(通常是设定的学习率),然后在随后的训练中保持不变或者按照其他调度策略进行调整。如果设置为0.0,表示没有热身。</li>
<li><code class="language-plaintext highlighter-rouge">warmup_steps</code> (int,可选, 默认为0):线性热身的步骤数,这个参数会覆盖warmup_ratio,如果设置了warmup_steps,将会忽略warmup_ratio。</li>
<li><code class="language-plaintext highlighter-rouge">log_level</code> (str, 可选, 默认为passive):主进程上要使用的日志级别,
<ul>
<li><code class="language-plaintext highlighter-rouge">debug</code>:最详细的日志级别。</li>
<li><code class="language-plaintext highlighter-rouge">info</code>:用于一般的信息性消息。</li>
<li><code class="language-plaintext highlighter-rouge">warning</code>:用于警告信息。</li>
<li><code class="language-plaintext highlighter-rouge">error</code>:用于错误信息。</li>
<li><code class="language-plaintext highlighter-rouge">critical</code>:用于严重错误信息。</li>
<li><code class="language-plaintext highlighter-rouge">passive</code>:不设置任何内容,将会使用Transformers库当前的日志级别(默认为”warning”)。</li>
<li>建议训练时使用info级别。</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">log_level_replica</code> (str, 可选, 默认为warning):副本上要使用的日志级别,与log_level相同。</li>
<li><code class="language-plaintext highlighter-rouge">log_on_each_node</code> (bool, optional, defaults to True):在多节点分布式训练中,是否在每个节点上使用log_level进行日志记录。</li>
<li><code class="language-plaintext highlighter-rouge">logging_dir</code> (str, 可选):TensorBoard日志目录。默认为output_dir/runs/CURRENT_DATETIME_HOSTNAME。</li>
<li><code class="language-plaintext highlighter-rouge">logging_strategy</code> (str, 可选, 默认为”steps”):训练过程中采用的日志记录策略。可选包括:
<ul>
<li>“no”:在训练过程中不记录任何日志。</li>
<li>“epoch”:在每个epoch结束时记录日志。</li>
<li>“steps”:根据logging_steps参数记录日志。</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">logging_steps</code> (int or float,可选, 默认为500):
<ul>
<li>如果logging_strategy=”steps”,则此参数为每多少步记录一次步骤。</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">logging_nan_inf_filter</code> (bool, 可选, 默认为 True):是否过滤日志记录中为nan和inf的loss
<ul>
<li>如果设置为True,将过滤每个步骤的loss,如果出现nan或inf,将取当前日志窗口的平均损失值。</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">save_strategy</code> (str , 可选, 默认为 “steps”):训练过程中保存checkpoint的策略,包括:
<ul>
<li>“no”:在训练过程中不保存checkpoint。</li>
<li>“epoch”:在每个epoch束时保存checkpoint。</li>
<li>“steps”:根据save_steps参数保存checkpoint。</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">save_steps</code> (int or float, 可选, 默认为500):
<ul>
<li>如果save_strategy=”steps”,就是指两次checkpoint保存之间的更新步骤数。如果是在[0, 1)的浮点数,则就会当做与总训练步骤数的比例。</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">save_total_limit</code> (int, 可选):如果给定了参数,将限制checkpoint的总数,因为checkpoint也是很占硬盘的,将会删除输出目录中旧的checkpoint。
<ul>
<li>当启用load_best_model_at_end时,会根据metric_for_best_model保留最好的checkpoint,以及最近的checkpoint。</li>
<li>当save_total_limit=5和指定load_best_model_at_end时,将始终保留最近的四个checkpoint以及最好的checkpoint;</li>
<li>当save_total_limit=1和指定load_best_model_at_end时,会保存两个checkpoint:最后一个和最好的一个(如果不同一个)。</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">load_best_model_at_end</code> (bool, 可选, 默认为False):是否在训练结束时,加载在训练过程中最好的checkpoint
<ul>
<li>设置为 True 时,找到在验证集上指标最好的checkpoint并且保存,然后还会保存最后一个checkpoint</li>
<li>在普通的多epoch训练中,最好设置为True</li>
<li>但在大模型训练中,一般是一个epoch,使用的就是最后一个checkpoint。</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">save_safetensors</code> (bool, 可选, 默认为False):是否在保存和加载模型参数时使用 “safetensors”
<ul>
<li>“safetensors” 更好地处理了不同 PyTorch 版本之间的模型参数加载的兼容性问题。</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">save_on_each_node</code> (bool, 可选, 默认为 False):多节点分布式训练时,是否在每个节点上保存checkpoint,还是仅在主节点上保存。
<ul>
<li>注意如果多节点使用的是同一套存储设备,比如都是外挂一个nas,开启后会报错,因为文件名称都一样。</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">use_cpu</code> (bool, 可选, 默认为 False):是否用CPU训练。如果设置为False,将使用CUDA或其他可用设备。</li>
<li><code class="language-plaintext highlighter-rouge">seed</code> (int, 可选, 默认为42):训练过程的随机种子,确保训练的可重现性,主要用于model_init,随机初始化权重参数。</li>
<li><code class="language-plaintext highlighter-rouge">data_seed</code> (int, 可选):数据采样的随机种子,如果没有设置将使用与seed相同的种子,可以确保数据采样的可重现性。</li>
<li><code class="language-plaintext highlighter-rouge">jit_mode_eval</code> (bool, 可选, 默认为False):是否在推理(inference)过程中使用 PyTorch 的 JIT(Just-In-Time)跟踪功能
<ul>
<li>PyTorch JIT 是 PyTorch 的一个功能,用于将模型的前向传播计算编译成高性能的机器代码,会加速模型的推理。</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">use_ipex</code> (bool, 可选, 默认为 False):是否使用英特尔扩展(Intel extension)来优化 PyTorch,需要安装IPEX
<ul>
<li>IPEX是一组用于优化深度学习框架的工具和库,提高训练和推理的性能,特别针对英特尔的处理器做了优化。</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">bf16</code> (bool, 可选, 默认为False):是否使用bf16进行混合精度训练,而不是fp32训练,需要安培架构或者更高的NVIDIA架构,关于精度的问题可以看这篇文章:Glan格蓝:LLM大模型之精度问题(FP16,FP32,BF16)详解与实践
<ul>
<li>混合精度训练:模型训练时将模型参数和梯度存储为<code class="language-plaintext highlighter-rouge">fp32</code>,但在前向和后向传播计算中使用<code class="language-plaintext highlighter-rouge">fp16</code>,这样可以减少内存使用和计算时间,并提高训练速度。</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">fp16</code> (bool,** 可选, 默认为<strong>**False)</strong>:是否使用fp16进行混合精度训练,而不是fp32训练。</li>
<li><code class="language-plaintext highlighter-rouge">fp16_opt_level</code> (str, 可选, 默认为 ‘‘O1’’):对于fp16训练,选择的Apex AMP的优化级别,可选值有 [‘O0’, ‘O1’, ‘O2’和’O3’]。详细信息可以看Apex文档。</li>
<li><code class="language-plaintext highlighter-rouge">half_precision_backend</code> (str, 可选, 默认为”auto”):混合精度训练(Mixed Precision Training)时要使用的后端,必须是 “auto”、”cuda_amp”、”apex”、”cpu_amp” 中的一个。
<ul>
<li>“auto”将根据检测到的PyTorch版本来使用后端,而其他选项将会强制使用请求的后端。使用默认就行。</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">bf16_full_eval</code> (bool, 可选, 默认为 False):是否使用完全的bf16进行评估,而不是fp32。这样更快且省内存,但因为精度的问题指标可能会下降。</li>
<li><code class="language-plaintext highlighter-rouge">fp16_full_eval</code> (bool, 可选, 默认为 False):同上,不过将使用fp16.</li>
<li><code class="language-plaintext highlighter-rouge">tf32</code> (bool, 可选):是否启用tf32精度模式,适用于安培架构或者更高的NVIDIA架构,默认值取决于PyTorch的版本torch.backends.cuda.matmul.allow_tf32 默认值。</li>
<li><code class="language-plaintext highlighter-rouge">local_rank</code> (int, 可选, 默认为 -1):在分布式训练中的当前进程(本地排名)的排名,这个用户不用设置,使用PyTorch分布式训练时会<strong>自动</strong>设置,默认为自动设置。</li>
<li><code class="language-plaintext highlighter-rouge">ddp_backend</code> (str, 可选):处理分布式计算的后端框架,用于多个计算节点协同工作以加速训练,处理模型参数和梯度的同步、通信等操作,可选值如下
<ul>
<li>“<code class="language-plaintext highlighter-rouge">nccl</code>“:这是 NVIDIA Collective Communications Library (NCCL) 的后端。</li>
<li>“<code class="language-plaintext highlighter-rouge">mpi</code>“:Message Passing Interface (MPI) 后端, 是一种用于不同计算节点之间通信的标准协议。</li>
<li>“<code class="language-plaintext highlighter-rouge">ccl</code>“:这是 Intel的oneCCL (oneAPI Collective Communications Library) 的后端。</li>
<li>“<code class="language-plaintext highlighter-rouge">gloo</code>“:这是Facebook开发的分布式通信后端。</li>
<li>“<code class="language-plaintext highlighter-rouge">hccl</code>“:这是Huawei Collective Communications Library (HCCL) 的后端,用于华为昇腾NPU的系统上进行分布式训练。</li>
<li>默认会根据系统自动设置,一般是nccl。</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">tpu_num_cores</code> (int, 可选):TPU上训练时,TPU核心的数量。</li>
<li><code class="language-plaintext highlighter-rouge">dataloader_drop_last</code> (bool, 可选, 默认为False):是否丢弃最后一个不完整的batch,发生在数据集的样本数量不是batch_size的整数倍的时候。</li>
<li><code class="language-plaintext highlighter-rouge">eval_steps</code> (int or float, 可选):如果evaluation_strategy=”steps”,两次评估之间的更新步数,如果未设置,默认和设置和logging_steps相同的值,如果是在[0, 1)的浮点数,则就会当做与总评估步骤数的比例。</li>
<li><code class="language-plaintext highlighter-rouge">dataloader_num_workers</code> (int, 可选, 默认为 0):数据加载时的子进程数量(仅用于PyTorch), PyTorch的num_workers参数,0表示数据将在主进程中加载。</li>
<li><code class="language-plaintext highlighter-rouge">past_index</code> (int, 可选, 默认为 -1):一些模型(如TransformerXL或XLNet)可用过去的隐藏状态进行预测,如果将此参数设置为正整数,Trainer将使用相应的输出(通常索引为2)作为过去状态,并将其在下一个训练步骤中作为mems关键字参数提供给模型,只针对一些特定模型。</li>
<li><code class="language-plaintext highlighter-rouge">run_name</code> (str, 可选):训练运行(run)的字符串参数,与日志记录工具(例如wandb和mlflow)一起使用,不影响训练过程,就是给其他的日志记录工具开了一个接口,个人还是比较推荐wandb比较好用。</li>
<li><code class="language-plaintext highlighter-rouge">disable_tqdm</code> (bool, 可选):是否禁用Jupyter笔记本中的~notebook.NotebookTrainingTracker生成的tqdm进度条,如果日志级别设置为warn或更低,则将默认为True,否则为False。</li>
<li><code class="language-plaintext highlighter-rouge">remove_unused_columns</code> (bool, 可选, 默认为True):是否自动删除模型在训练时,没有用到的数据列,默认会删除,比如你的数据有两列分别是content和id,如果没有用到id这一列,训练时就会被删除。</li>
<li><code class="language-plaintext highlighter-rouge">label_names</code> (List[str], 可选):在模型的输入字典中对应于标签(labels)的键,默认情况下不需要显式指定。</li>
<li><code class="language-plaintext highlighter-rouge">metric_for_best_model</code> (str, 可选):与 load_best_model_at_end 结合使用,比较不同模型的度量标准,默认情况下,如果未指定,将使用验证集的 “loss” 作为度量标准,可使用accuracy、F1、loss等。</li>
<li><code class="language-plaintext highlighter-rouge">greater_is_better</code> (bool, 可选):与 load_best_model_at_end 和 metric_for_best_model 结合使用,这个和上面的那个参数是对应的,那个指标是越大越好还是越小越好
<ul>
<li>如果是loss, 越小越好,这个参数就会被设置为False;</li>
<li>如果是accuracy,把这个值设为True。</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">ignore_data_skip</code> (bool, 可选,默认为False):是否<strong>断点训练</strong>,即训练终止又恢复后,是否跳过之前的训练数据。</li>
<li><code class="language-plaintext highlighter-rouge">resume_from_checkpoint</code> (str, 可选):从checkpoint恢复训练的路径。</li>
<li><code class="language-plaintext highlighter-rouge">sharded_ddp</code> (bool, str 或 ShardedDDPOption 列表, 可选, 默认为’’):是否在分布式训练中使用 Sharded DDP(Sharded Data Parallelism),FairScale提供的,默认不使用
<ul>
<li>FairScale 是Mate开发的一个用于高性能和大规模训练的 PyTorch 扩展库。这个库扩展了基本的 PyTorch 功能,同时引入了最新的先进规模化技术,通过可组合的模块和易于使用的API,提供了最新的分布式训练技术。详细的可以看其官网。</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">fsdp</code> (bool, str 或 FSDPOption 列表, 可选, 默认为’’):是否启用 PyTorch 的 <code class="language-plaintext highlighter-rouge">FSDP</code>(Fully Sharded Data Parallel Training),以及如何配置分布式并行训练。</li>
<li><code class="language-plaintext highlighter-rouge">fsdp_config</code> (str 或 dict, 可选):配置 PyTorch 的 FSDP(Fully Sharded Data Parallel Training)的配置文件</li>
<li><code class="language-plaintext highlighter-rouge">deepspeed</code> (str 或 dict, 可选):是否启用 DeepSpeed,以及如何配置 DeepSpeed。
<ul>
<li>目前分布式训练使用最多的框架,比上面pytorch原生分布式训练以及FairScale用的范围更广,详细的可以看其官网。</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">label_smoothing_factor</code> (float, 可选,默认为0.0):标签平滑的因子。</li>
<li><code class="language-plaintext highlighter-rouge">debug</code> (str 或 DebugOption 列表, 可选, 默认为’’):启用一个或多个调试功能,支持选项:
<ul>
<li>“underflow_overflow”:此选项用于检测模型输入/输出中的溢出。</li>
<li>“tpu_metrics_debug”:此选项用于在 TPU 上打印调试指标。</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">optim</code> (str 或 training_args.OptimizerNames, 可选, 默认为 “adamw_torch”):要用的优化器。可选项:
<ul>
<li>“adamw_hf”</li>
<li>“adamw_torch”</li>
<li>“adamw_torch_fused”</li>
<li>“adamw_apex_fused”</li>
<li>“adamw_anyprecision”</li>
<li>“adafactor”</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">optim_args</code> (str, 可选):用于向特定类型的优化器(如adamw_anyprecision)提供额外的参数或自定义配置。</li>
<li><code class="language-plaintext highlighter-rouge">group_by_length</code> (bool, 可选, 默认为 False):是否在训练数据集中对大致相同长度的样本进行分组然后放在一个batch里,目的是尽量减少在训练过程中进行的padding,提高训练效率。</li>
<li><code class="language-plaintext highlighter-rouge">length_column_name</code> (str, 可选, 默认为 “length”):当上个参数设置为True时,可以给训练数据在增加一列”长度“,就是事先计算好的,可以加快分组的速度,默认是length。</li>
<li><code class="language-plaintext highlighter-rouge">report_to</code> (str 或 str 列表, 可选, 默认为 “all”):要将训练结果和日志报告到的不同日记集成平台,有很多”azure_ml”, “clearml”, “codecarbon”, “comet_ml”, “dagshub”, “flyte”, “mlflow”, “neptune”, “tensorboard”, and “wandb”。直接默认就行,都发。</li>
<li><code class="language-plaintext highlighter-rouge">ddp_find_unused_parameters</code> (bool, 可选):使用分布式训练时,这个参数用于控制是否查找并处理那些在计算中没有被使用的参数,如果启用了<strong>梯度检查点</strong>(gradient checkpointing),表示部分参数是惰性加载的,这时默认值为 False,因为梯度检查点本身已经考虑了未使用的参数,如果没有启用梯度检查点,默认值为 True,表示要查找并处理所有参数,以确保它们的梯度被正确传播。</li>
<li><code class="language-plaintext highlighter-rouge">ddp_bucket_cap_mb</code> (int, 可选):在分布式训练中,数据通常分成小块进行处理,这些小块称为”桶”,这个参数每个桶的最大内存占用大小,一般自动分配即可。</li>
<li><code class="language-plaintext highlighter-rouge">ddp_broadcast_buffers</code> (bool, 可选):分布式训练中,模型的某些部分可能包含缓冲区,如 Batch Normalization 层的统计信息,这个参数用于控制是否将这些缓冲区广播到所有计算设备,以确保模型在不同设备上保持同步,如果启用了梯度检查点,表示不需要广播缓冲区,因为它们不会被使用,如果没有启用梯度检查点,默认值为 True,表示要广播缓冲区,以确保模型的不同部分在所有设备上都一致。</li>
<li><code class="language-plaintext highlighter-rouge">gradient_checkpointing</code> (bool, 可选, 默认为False):是否开启梯度检查点,简单解释一下:训练大型模型时需要大量的内存,其中在反向传播过程中,需要保存前向传播的中间计算结果以计算梯度,但是这些中间结果占用大量内存,可能会导致内存不足,梯度检查点会在训练期间释放不再需要的中间结果以减小内存占用,但它会使反向传播变得更慢。</li>
<li><code class="language-plaintext highlighter-rouge">dataloader_pin_memory</code> (bool, 可选, 默认为 True):dataloader加载数据时,是否启用“pin memory”功能。“Pin memory” 用于将数据加载到GPU内存之前,将数据复制到GPU的锁页内存(pinned memory)中,锁页内存是一种特殊的内存,可以更快地传输数据到GPU,从而加速训练过程,但是会占用额外的CPU内存,会导致内存不足的问题,如果数据量特别大,百G以上建议False。</li>
<li><code class="language-plaintext highlighter-rouge">skip_memory_metrics</code> (bool, 可选, 默认为 True):是否将内存分析报告添加到性能指标中,默认情况下跳过这一步,以提高训练和评估的速度,建议打开,更能够清晰的知道每一步的内存使用。</li>
<li><code class="language-plaintext highlighter-rouge">include_inputs_for_metrics</code> (bool, 可选, 默认为 False):是否将输入传递给 compute_metrics 函数,一般计算metrics用的是用的是模型预测的结果和我们提供的标签,但是有的指标需要输入,比如cv的IoU(Intersection over Union)指标。</li>
<li><code class="language-plaintext highlighter-rouge">auto_find_batch_size</code> (bool, 可选, 默认为 False):是否使用自动寻找适合内存的batch size大小,以避免 CUDA 内存溢出错误,需要安装 accelerate(使用 pip install accelerate),这个功能还是比较NB的。</li>
<li><code class="language-plaintext highlighter-rouge">full_determinism</code> (bool, 可选, 默认为 False):如果设置为 True,将调用 enable_full_determinism() 而不是 set_seed(),训练过程将启用完全确定性(full determinism),在训练过程中,所有的随机性因素都将被消除,确保每次运行训练过程都会得到相同的结果,注意:会对性能产生负面影响,因此仅在调试时使用。</li>
<li><code class="language-plaintext highlighter-rouge">torchdynamo</code> (str, 可选):用于选择 TorchDynamo 的后端编译器,TorchDynamo 是 PyTorch 的一个库,用于提高模型性能和部署效率,可选的选择包括 “eager”、”aot_eager”、”inductor”、”nvfuser”、”aot_nvfuser”、”aot_cudagraphs”、”ofi”、”fx2trt”、”onnxrt” 和 “ipex”。默认就行,自动会选。</li>
<li><code class="language-plaintext highlighter-rouge">ray_scope</code> (str, 可选, 默认为 “last”):用于使用 Ray 进行超参数搜索时,指定要使用的范围,默认情况下,使用 “last”,Ray 将使用所有试验的最后一个检查点,比较它们并选择最佳的。详细的可以看一下它的文档。</li>
<li><code class="language-plaintext highlighter-rouge">ddp_timeout</code> (int, 可选, 默认为 1800):用于 torch.distributed.init_process_group 调用的超时时间,在分布式运行中执行较慢操作时,用于避免超时,具体的可以看 PyTorch 文档 。
<code class="language-plaintext highlighter-rouge">torch_compile</code> (bool, 可选, 默认为 False):是否使用 PyTorch 2.0 及以上的 torch.compile 编译模型,具体的可以看 PyTorch 文档 。</li>
<li><code class="language-plaintext highlighter-rouge">torch_compile_backend</code> (str, 可选):指定在 torch.compile 中使用的后端,如果设置为任何值,将启用 torch_compile。</li>
<li><code class="language-plaintext highlighter-rouge">torch_compile_mode</code> (str, 可选):指定在 torch.compile 中使用的模式,如果设置为任何值,将启用 torch_compile。</li>
<li><code class="language-plaintext highlighter-rouge">include_tokens_per_second</code> (bool, 可选):确定是否计算每个设备的每秒token数以获取训练速度指标,会在整个训练数据加载器之前进行迭代,会稍微减慢整个训练过程,建议打开。</li>
<li><code class="language-plaintext highlighter-rouge">push_to_hub</code> (bool, 可选, 默认为 False):指定是否在每次保存模型时将模型推送到Huggingface Hub。</li>
<li><code class="language-plaintext highlighter-rouge">hub_model_id</code> (str, 可选):指定要与本地 output_dir 同步的存储库的名称。</li>
<li><code class="language-plaintext highlighter-rouge">hub_strategy</code> (str 或 HubStrategy, 可选, 默认为 “every_save”):指定怎么推送到Huggingface Hub。</li>
<li><code class="language-plaintext highlighter-rouge">hub_token</code> (str, 可选):指定推送模型到Huggingface Hub 的token。</li>
<li><code class="language-plaintext highlighter-rouge">hub_private_repo</code> (bool, 可选, 默认为 False):如果设置为 True,Huggingface Hub 存储库将设置为私有。</li>
<li><code class="language-plaintext highlighter-rouge">hub_always_push</code> (bool, 可选, 默认为 False):是否每次都推送模型。</li>
</ul>
<p>详见</p>
<ul>
<li><a href="https://zhuanlan.zhihu.com/p/662619853">LLM大模型之Trainer以及训练参数</a></li>
</ul>
<h2 id="firefly">Firefly</h2>
<p><a href="https://github.com/yangjianxin1/Firefly">Firefly</a> 是开源的大模型<strong>一站式训练框架</strong></p>
<ul>
<li>支持对各种大模型进行<strong>预训练</strong>、<strong>指令微调</strong>、<code class="language-plaintext highlighter-rouge">DPO</code>,支持全量参数、LoRA、QLoRA等训练方式。</li>
<li>支持包括但不限于Gemma、Qwen1.5、MiniCPM、Mixtral-8x7B、Mistral、Llama等绝大多数主流的大模型。</li>
</ul>
<p>【2024-3-5】<a href="https://mp.weixin.qq.com/s/C5X0qX2YsxhIoFvRsqcMMA">使用Firefly在单卡V100上对Qwen1.5进行SFT和DPO,大幅超越Qwen1.5和Gemma</a></p>
<p>用Firefly项目对Qwen1.5-7B进行训练的实验。我们对训练数据进行精细化筛选,然后在单张V100上进行SFT和DPO。经过两阶段的训练,我们的模型在Open LLM Leaderboard上的表现显著优于官方的Qwen1.5-7B-Chat、Gemma-7B-it、Vicuna-13B等模型。比Qwen1.5-7B-Chat高7.12分,比Gemma-7B-it高8.8分。</p>
<h2 id="torchtune">TorchTune</h2>
<p>【2024-3-23】<a href="https://zhuanlan.zhihu.com/p/688671130?utm_psn=1755039674018496512">PyTorch官方发布LLM微调工具TorchTune</a></p>
<p>PyTorch官方最近发布了支持LLM微调的工具:<code class="language-plaintext highlighter-rouge">TorchTune</code>。</p>
<ul>
<li><a href="https://pytorch.org/blog/torchtune-fine-tune-llms/">TorchTune</a> 是一个原生的 PyTorch 库,用于轻松编写、微调和实验大型语言模型(LLMs)</li>
</ul>
<h3 id="torchtune-功能">TorchTune 功能</h3>
<p>功能:</p>
<ul>
<li>原生 PyTorch 实现的流行大型语言模型</li>
<li>支持多种格式的checkpoints,包括 Hugging Face 格式的checkpoints</li>
<li>针对流行微调技术的训练策略,带有参考基准和全面的校验检查</li>
<li>与 HuggingFace 数据集集成用于训练,以及与 EleutherAI 的评估工具 Eval Harness 集成用于评估</li>
<li>支持使用 PyTorch 分布式中的 FSDP 进行分布式训练</li>
<li>YAML 配置文件,便于轻松配置训练运行</li>
<li>[即将推出] 支持来自 TorchAO 的低精度数据类型和量化技术</li>
<li>[即将推出] 与各种推理引擎的互操作性</li>
</ul>
<h3 id="torchtune-微调">TorchTune 微调</h3>
<p>TorchTune 已经支持了<strong>Llama2 7B模型</strong>的微调:</p>
<ul>
<li>单卡微调:<a href="https://github.com/pytorch/torchtune/blob/main/recipes/full_finetune_single_device.py">https://github.com/pytorch/torchtune/blob/main/recipes/full_finetune_single_device.py</a></li>
<li>分布式微调:<a href="https://github.com/pytorch/torchtune/blob/main/recipes/full_finetune_distributed.py">https://github.com/pytorch/torchtune/blob/main/recipes/full_finetune_distributed.py</a></li>
<li>单卡LoRA:<a href="https://github.com/pytorch/torchtune/blob/main/recipes/lora_finetune_single_device.py">https://github.com/pytorch/torchtune/blob/main/recipes/lora_finetune_single_device.py</a></li>
<li>分布式LoRA:<a href="https://github.com/pytorch/torchtune/blob/main/recipes/lora_finetune_distributed.py">https://github.com/pytorch/torchtune/blob/main/recipes/lora_finetune_distributed.py</a></li>
<li>QLoRA:<a href="https://github.com/pytorch/torchtune/blob/main/recipes/lora_finetune_single_device.py">https://github.com/pytorch/torc</a></li>
</ul>
<h3 id="torchtune-安装">torchtune 安装</h3>
<p>torchtune 必须通过克隆仓库并按照以下方式安装来构建:</p>
<div class="language-py highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># ①
</span><span class="n">pip</span> <span class="n">install</span> <span class="n">torchtune</span>
<span class="c1"># ②
</span><span class="n">git</span> <span class="n">clone</span> <span class="n">https</span><span class="p">:</span><span class="o">//</span><span class="n">github</span><span class="p">.</span><span class="n">com</span><span class="o">/</span><span class="n">pytorch</span><span class="o">/</span><span class="n">torchtune</span><span class="p">.</span><span class="n">git</span>
<span class="n">cd</span> <span class="n">torchtune</span>
<span class="n">pip</span> <span class="n">install</span> <span class="o">-</span><span class="n">e</span> <span class="p">.</span>
</code></pre></div></div>
<h2 id="torchtitan">torchtitan</h2>
<p>【2024-4-28】<a href="https://github.com/pytorch/torchtitan">torchtitan</a> - 用于大型模型训练的原生 PyTorch 库</p>
<p><a href="https://github.com/pytorch/torchtitan">torchtitan</a> is a proof-of-concept (概念验证阶段) for Large-scale LLM training using native PyTorch.</p>
<ul>
<li>It is (and will continue to be) a repo to showcase PyTorch’s latest distributed training features in a clean, minimal codebase.</li>
<li><code class="language-plaintext highlighter-rouge">torchtitan</code> is complementary (补充) to and not a replacement (替代) for any of the great large-scale LLM training codebases such as <code class="language-plaintext highlighter-rouge">Megatron</code>, <code class="language-plaintext highlighter-rouge">Megablocks</code>, <code class="language-plaintext highlighter-rouge">LLM Foundry</code>, <code class="language-plaintext highlighter-rouge">Deepspeed</code>, etc.</li>
<li>Instead, we hope that the features showcased in <code class="language-plaintext highlighter-rouge">torchtitan</code> will be adopted by these codebases quickly. torchtitan is unlikely to ever grow a large community around it.</li>
</ul>
<p>Our guiding principles when building torchtitan:</p>
<ul>
<li>Designed to be easy to understand, use and extend for different training purposes.</li>
<li>Minimal changes to the model code when applying 1D, 2D, or (soon) 3D Parallel.</li>
<li>Modular components instead of a monolithic codebase.</li>
</ul>
<p>Get started in minutes, not hours!</p>
<h2 id="总结">总结</h2>
<p>Megatron-DeepSpeed 实施 3D 并行以可以让大型模型以非常有效的方式进行训练。</p>
<ul>
<li>DataParallel (<code class="language-plaintext highlighter-rouge">DP</code>) - 相同的初始化模型被复制多次,并且每次都被馈送 minibatch 的一部分。处理是并行完成的,所有设置在每个训练步骤结束时进行同步。</li>
<li>TensorParallel (<code class="language-plaintext highlighter-rouge">TP</code>) - 每个张量都被分成多个块,因此不是让整个张量驻留在单个 GPU 上,而是张量的每个分片都驻留在其指定的 GPU 上。在处理过程中,每个分片在不同的 GPU 上分别并行处理,最终结果在步骤结束时同步。这也被称作横向并行。</li>
<li>PipelineParallel (<code class="language-plaintext highlighter-rouge">PP</code>) - 模型在多个 GPU 上垂直(层级)拆分,因此只有模型的一个或多个层放置在单个 GPU 上。每个 GPU 并行处理管道的不同阶段,并处理一小部分批处理。</li>
<li>零冗余优化器 (<code class="language-plaintext highlighter-rouge">ZeRO</code>) - 也执行与 TP 有点类似的张量分片,除了整个张量会及时重建以进行前向或反向计算,因此不需要修改模型。它还支持各种卸载技术以补偿有限的 GPU 内存。</li>
</ul>
<p>训练超大规模语言模型主要有两条技术路线:</p>
<ul>
<li>TPU + XLA + TensorFlow/JAX</li>
<li>GPU + PyTorch + Megatron-LM + DeepSpeed</li>
<li>前者由Google主导,由于TPU和自家云平台GCP深度绑定,对于非Googler来说, 只可远观而不可把玩</li>
<li>后者背后则有NVIDIA、Meta、MS大厂加持,社区氛围活跃,也更受到群众欢迎。</li>
</ul>
<p>Deepspeed 是微软的大规模分布式训练工具。专门用于训练超大模型。</p>
<ul>
<li><a href="https://zhuanlan.zhihu.com/p/609865550">大模型的训练工具(1)—Deepspeed</a></li>
<li><code class="language-plaintext highlighter-rouge">DP</code>+<code class="language-plaintext highlighter-rouge">PP</code>: DeepSpeed 将 DP 与 PP 结合起来
<ul>
<li><img src="https://pic1.zhimg.com/80/v2-127d807df8f6efc7b1f8cb6d5ff38620_1440w.webp" alt="" /></li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">DP</code>+<code class="language-plaintext highlighter-rouge">PP</code>+<code class="language-plaintext highlighter-rouge">TP</code>: 为了获得更高效的训练,PP 与 TP 和 DP 相结合,称为 3D 并行性
<ul>
<li><img src="https://pic1.zhimg.com/80/v2-7951815d9ab95beedf1d238bc58e73f0_1440w.webp" alt="" /></li>
</ul>
</li>
<li>ZeRO DP+PP+TP: DeepSpeed 的主要功能之一是 ZeRO,它是 DP 的超级可扩展扩展。</li>
<li>【2023-3-16】<a href="https://zhuanlan.zhihu.com/p/611325149">大型语言模型(LLM)训练指南</a></li>
</ul>
<p>增加的功能主要有:</p>
<ul>
<li>3个维度并行化实现万亿参数模型训练</li>
<li>ZeRO-Offload 使 GPU 单卡能够训练 10 倍大的模型</li>
<li>通过 DeepSpeed Sparse Attention 用6倍速度执行10倍长的序列</li>
<li>1 比特 Adam 减少 5 倍通信量</li>
</ul>
<p>3D 并行:扩展至万亿参数模型</p>
<p>3D 并行同时解决了训练万亿参数模型的两个基本挑战:显存效率和计算效率。因此,DeepSpeed 可以扩展至在显存中放下最巨大的模型,而不会牺牲速度。</p>
<ul>
<li>显存效率:集群上所能训练的LLM的参数量。</li>
<li>计算效率:单纯计算占系统的开销的比例。</li>
</ul>
<p>(1)<strong>数据并行</strong>是分布式训练普遍使用的技术。</p>
<p>在该技术中,每批输入的训练数据都在数据并行的 worker 之间平分。反向传播后需要通信并规约梯度,以保证优化器在各个 worker 上进行相同的更新。数据并行性具有几个明显的优势,包括计算效率高和实现起来工作量小。但是,数据并行的 batch 大小随 worker 数量提高,而我们往往无法在不影响收敛性的情况下一直增加 batch 大小。</p>
<ul>
<li>显存效率:数据并行会在所有 worker 之间进行模型和优化器的复制,因此显存效率不高。DeepSpeed 开发了 ZeRO ,它是一系列用于提高数据并行的显存效率的优化器。 这项工作依赖于 ZeRO 的 1 阶段,该阶段在 worker 之间划分优化器状态量以减少冗余。</li>
<li>计算效率:随着我们提高并行度,每个 worker 执行的计算量是恒定的。数据并行可以在小规模上实现近乎线性扩展。但是,在 worker 之间规约梯度的通信开销跟模型大小成正相关,所以当模型很大或通信带宽很低时,计算效率会受限。。梯度累积是一种用来均摊通信成本的一种常用策略。它会进一步增加batch大小,在本地使用 micro-batch 多次进行正向和反向传播积累梯度后,再进行梯度规约和优化器更新。</li>
</ul>
<p>(2)<strong>模型并行</strong>是包含范围很广的一类技术。</p>
<p>它会在多个 worker 之间划分模型的各个层。就其本质而言,模型并行性的计算和通信因模型结构而异,因此在实现上有很大的工作量。DeepSpeed 借用了英伟达的 Megatron-LM 来为基于 Transformer 的语言模型提供大规模模型并行功能。模型并行会根据 worker 数量成比例地减少显存使用量,也是这三种并行度中显存效率最高的。但是其代价是计算效率最低。</p>
<ul>
<li>显存效率:模型并行会根据 worker 数量成比例地减少显存使用量。至关重要的是,这是减少单个网络层的激活显存的唯一方法。DeepSpeed 通过在模型并行 worker 之间划分激活显存来进一步提高显存效率。</li>
<li>计算效率:由于每次前向和反向传播中都需要额外通信激活值,模型并行的计算效率很低。模型并行需要高通信带宽,并且不能很好地扩展到通信带宽受限的节点。此外,每个模型并行worker 都会减少每个通信阶段之间执行的计算量,从而影响计算效率。模型并行性通常与数据并行性结合使用,以在内存和计算效率之间进行权衡。</li>
</ul>
<p>(3)<strong>流水线并行</strong>训练引擎也被包含在了这次发布的DeepSpeed中</p>
<p>流水线并行将模型的各层划分为可以并行处理的阶段。当一个阶段完成一个 micro-batch 的正向传递时,激活内存将被通信至流水线的下一个阶段。类似地,当下一阶段完成反向传播时,将通过管道反向通信梯度。必须同时计算多个 micro-batch 以确保流水线的各个阶段能并行计算。目前已经开发出了几种用于权衡内存和计算效率以及收敛行为的方法,例如 PipeDream。DeepSpeed 采用的方法是通过梯度累积来实现并行,并保持与传统数据并行和模型并行训练在相同的总 batch 大小下收敛情况相同。</p>
<ul>
<li>显存效率:流水线并行减少的显存与流水线的阶段数成正比,使模型的大小可以随 worker 的数量线性扩展。但是,流水线并行不会减少每一层的激活函数的显存占用量。此外,每个 worker 必须存储同时运行的各个 micro-batch 的激活值。这导致流水线第一阶段的激活内存与单个 mirco batch 的总激活内存大致相同。一个万亿参数模型将需要为一个 micro batch 提供大约 19 GB 的显存的激活内存,这几乎占到新推出的英伟达 A100 GPU 总显存的一半。</li>
<li>计算效率:流水线并行具有最低的通信量,因为它的通信量只和在各阶段边界的各层的激活值大小成正比。但是,它不能无限扩展。像模型并行一样,增加流水线大小会减少每个流水线阶段的计算量,这会降低计算与通信的比率。如果要实现好的计算效率,流水线并行还要求其每个阶段的计算负载完美的均衡。</li>
</ul>
<h2 id="llama-factory">LLaMA-Factory</h2>
<p>资料</p>
<ul>
<li>【2024-7-18】<a href="https://zhuanlan.zhihu.com/p/695287607">LLaMA-Factory QuickStart</a></li>
<li><a href="https://llamafactory.readthedocs.io/zh-cn/latest/index.html">官方文档</a></li>
</ul>
<h3 id="llama-factory-介绍">LLaMA-Factory 介绍</h3>
<p><a href="https://www.llamafactory.cn/">LLaMA Factory</a> 支持多种LLM微调方式,北航博士生推出,包括: <strong>预训练</strong>、<strong>指令监督微调</strong>和<strong>奖励模型</strong>训练等。</p>
<ul>
<li>支持<code class="language-plaintext highlighter-rouge">LoRA</code>和<code class="language-plaintext highlighter-rouge">QLoRA</code>微调策略,广泛集成了业界前沿的微调方法。</li>
<li>特点: 支持多种LLM模型,提供了<strong>WebUI页面</strong>,使非开发人员也能微调。</li>
<li>体验地址:<a href="https://modelscope.cn/studios/hiyouga/LLaMA-Board/summary">LLaMA-Board</a></li>
<li>可视化界面 <a href="https://huggingface.co/spaces/hiyouga/LLaMA-Board">LLaMA-Board</a></li>
<li>github: <a href="https://github.com/hiyouga/LLaMA-Factory">LLaMA-Factory</a>,附各阶段训练数据集</li>
<li><img src="https://pic2.zhimg.com/80/v2-7b24a5941a9bf996cf35187ae351f6c1_1440w.webp" alt="" /></li>
</ul>
<p>资源</p>
<ul>
<li>论文: <a href="https://arxiv.org/abs/2403.13372">Unified Efficient Fine-Tuning of 100+ LLMs (ACL 2024)</a></li>
</ul>
<p>功能</p>
<ul>
<li>支持多种模型:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Qwen2-VL、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。</li>
<li>集成方法:(增量)<strong>预训练</strong>、(多模态)<strong>指令监督微调</strong>、<strong>奖励模型训练</strong>、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。</li>
<li>多种<strong>精度</strong>:16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。</li>
<li>先进算法:GaLore、BAdam、Adam-mini、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。</li>
<li>实用技巧:FlashAttention-2、Unsloth、Liger Kernel、RoPE scaling、NEFTune 和 rsLoRA。</li>
<li>实验监控:LlamaBoard、TensorBoard、Wandb、MLflow、SwanLab 等等。</li>
<li>极速推理:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。</li>
</ul>
<p>性能指标</p>
<ul>
<li>与 ChatGLM 官方的 P-Tuning 微调相比,LLaMA Factory 的 LoRA 微调提供了 3.7 倍的加速比,同时在广告文案生成任务上取得了更高的 Rouge 分数。</li>
<li>结合 4 比特量化技术,LLaMA Factory 的 QLoRA 微调进一步降低了 GPU 显存消耗。</li>
</ul>
<p>详情参考</p>
<ul>
<li><a href="https://zhuanlan.zhihu.com/p/684989699">使用LLaMA Factory对大型语言模型进行微调</a></li>
<li>作者北航博士<a href="https://github.com/hiyouga">郑耀威</a>讲解 <a href="https://www.bilibili.com/video/BV1Gt421L7dt">全栈大模型微调框架LLaMA Factory:从预训练到RLHF的高效实现</a></li>
</ul>
<iframe src="//player.bilibili.com/player.html?aid=1801563508&bvid=BV1Gt421L7dt&cid=1463913844&p=1&autoplay=0" scrolling="no" border="0" frameborder="no" framespacing="0" allowfullscreen="true" height="600" width="100%"> </iframe>
<h3 id="llama-factory-安装">LLaMA-Factory 安装</h3>
<p>安装</p>
<ul>
<li><a href="https://github.com/hiyouga/LLaMA-Factory/blob/main/README_zh.md#%E5%A6%82%E4%BD%95%E4%BD%BF%E7%94%A8">安装说明</a></li>
</ul>
<p>依赖</p>
<ul>
<li>必备依赖: torch/transformers/datasets/trl/accelerate/peft</li>
<li>可选依赖: CUDA/deepspeed/bitsandbytes/vllm/flash-attn</li>
</ul>
<div class="language-sh highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c"># ----------------------</span>
git clone <span class="nt">--depth</span> 1 https://github.com/hiyouga/LLaMA-Factory.git
<span class="nb">cd </span>LLaMA-Factory
pip <span class="nb">install</span> <span class="nt">-e</span> <span class="s2">".[torch,metrics]"</span>
<span class="c"># ---------------------</span>
<span class="c"># Clone the repository</span>
git clone https://github.com/hiyouga/LLaMA-Factory.git
<span class="c"># Create a virtual environment</span>
conda create <span class="nt">-n</span> llama_factory <span class="nv">python</span><span class="o">=</span>3.10
<span class="c"># Activate the virtual environment</span>
conda activate llama_factory
<span class="c"># Install dependencies</span>
<span class="nb">cd </span>LLaMA-Factory
pip <span class="nb">install</span> <span class="nt">-r</span> requirements.txt
</code></pre></div></div>
<p>【2025-1-11】 win 10 上实践</p>
<div class="language-sh highlighter-rouge"><div class="highlight"><pre class="highlight"><code>pip <span class="nb">install </span>llamafactory <span class="c"># 一步安装</span>
</code></pre></div></div>
<h3 id="模型下载">模型下载</h3>
<p>项目支持通过模型名称直接从 huggingface 和 modelscope 下载模型,但不容易对模型文件统一管理,所以建议使用手动下载,然后使用绝对路径控制哪个模型。</p>
<p>以 Meta-Llama-3-8B-Instruct为例</p>
<div class="language-sh highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c"># huggingface 下载(可能要先提交申请通过)</span>
git clone https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct
<span class="c"># modelscope 下载(适合中国大陆网络环境)</span>
git clone https://www.modelscope.cn/LLM-Research/Meta-Llama-3-8B-Instruct.git
<span class="c"># 或者 模型下载</span>
from modelscope import snapshot_download
model_dir <span class="o">=</span> snapshot_download<span class="o">(</span><span class="s1">'LLM-Research/Meta-Llama-3-8B-Instruct'</span><span class="o">)</span>
</code></pre></div></div>
<p>注意</p>
<div class="language-sh highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c"># Hugging Face 模型和数据集下载中遇到问题,可使用魔搭/魔乐社区。</span>
<span class="nb">export </span><span class="nv">USE_MODELSCOPE_HUB</span><span class="o">=</span>1 <span class="c"># Windows 使用 `set USE_MODELSCOPE_HUB=1`</span>
<span class="nb">export </span><span class="nv">USE_OPENMIND_HUB</span><span class="o">=</span>1 <span class="c"># Windows 使用 `set USE_OPENMIND_HUB=1`</span>
</code></pre></div></div>
<h3 id="llama-factory-命令行">LLaMA-Factory 命令行</h3>
<h4 id="常用命令">常用命令</h4>
<p>主要命令</p>
<div class="language-sh highlighter-rouge"><div class="highlight"><pre class="highlight"><code>llamafactory-cli version <span class="c"># 显示版本</span>
llamafactory-cli <span class="nb">help</span> <span class="c"># 帮助信息</span>
<span class="c"># Web UI 使用</span>
llamafactory-cli webui <span class="c"># 启动网页端</span>
<span class="nv">CUDA_VISIBLE_DEVICES</span><span class="o">=</span>4 llamafactory-cli webui <span class="c"># 指定第4张显卡使用</span>
<span class="nv">CUDA_DEVICE_ORDER</span><span class="o">=</span><span class="s1">'cpu'</span> <span class="o">&&</span> llamafactory-cli webui <span class="c"># cpu 上启动web ui</span>
<span class="nb">set </span><span class="nv">CUDA_DEVICE_ORDER</span><span class="o">=</span><span class="s1">'cpu'</span><span class="p">;</span>llamafactory-cli webui <span class="c"># windows terminal 命令</span>
</code></pre></div></div>
<p>对 Llama3-8B-Instruct 模型进行 LoRA 微调、推理和合并。</p>
<div class="language-sh highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c"># 训练</span>
llamafactory-cli train <span class="nt">-h</span> <span class="c"># 查看训练参数</span>
<span class="c"># lora 微调</span>
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
<span class="c"># 推理</span>
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
<span class="c"># 合并</span>
llamafactory-cli <span class="nb">export </span>examples/merge_lora/llama3_lora_sft.yaml
<span class="c"># 用 vLLM 部署 OpenAI API</span>
<span class="nv">API_PORT</span><span class="o">=</span>8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
</code></pre></div></div>
<h4 id="参数">参数</h4>
<p>参数解释</p>
<table>
<thead>
<tr>
<th>动作参数枚举</th>
<th>参数说明</th>
</tr>
</thead>
<tbody>
<tr>
<td><code class="language-plaintext highlighter-rouge">version</code></td>
<td>显示版本信息</td>
</tr>
<tr>
<td><code class="language-plaintext highlighter-rouge">train</code></td>
<td>命令行版本训练</td>
</tr>
<tr>
<td><code class="language-plaintext highlighter-rouge">chat</code></td>
<td>命令行版本推理chat</td>
</tr>
<tr>
<td><code class="language-plaintext highlighter-rouge">export</code></td>
<td>模型合并和导出</td>
</tr>
<tr>
<td><code class="language-plaintext highlighter-rouge">api</code></td>
<td>启动API server,供接口调用</td>
</tr>
<tr>
<td><code class="language-plaintext highlighter-rouge">eval</code></td>
<td>使用mmmlu等标准数据集做评测</td>
</tr>
<tr>
<td><code class="language-plaintext highlighter-rouge">webchat</code></td>
<td>前端版本纯推理的chat页面</td>
</tr>
<tr>
<td><code class="language-plaintext highlighter-rouge">webui</code></td>
<td>启动LlamaBoard前端页面,包含可视化训练,预测,chat,模型合并多个子页面</td>
</tr>
</tbody>
</table>
<p>关键参数</p>
<ul>
<li><code class="language-plaintext highlighter-rouge">model_name_or_path</code> 参数名称
<ul>
<li>huggingface 或 modelscope 标准定义,如“meta-llama/Meta-Llama-3-8B-Instruct”)</li>
<li>或者是本地下载的<strong>绝对</strong>路径,如 /media/codingma/LLM/llama3/Meta-Llama-3-8B-Instruct</li>
</ul>
</li>
<li><code class="language-plaintext highlighter-rouge">template</code> 模型问答时prompt模板,不同模型不同,请<a href="https://github.com/hiyouga/LLaMA-Factory?tab=readme-ov-file#supported-models">参考</a> 获取不同模型的模板定义,否则会回答结果会很奇怪或导致重复生成等现象的出现。
<ul>
<li>chat 版本模型基本都需要指定,比如 Meta-Llama-3-8B-Instruct 的 template 就是 llama3</li>
</ul>
</li>
</ul>
<p>也可提前把相关参数存在 yaml文件里,比如: <code class="language-plaintext highlighter-rouge">LLaMA-Factory/examples/inference/llama3.yaml</code> at main · hiyouga/LLaMA-Factory, 本地位置是 <code class="language-plaintext highlighter-rouge">examples/inference/llama3.yaml</code></p>
<p>内容如下</p>
<div class="language-yml highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="na">model_name_or_path</span><span class="pi">:</span> <span class="s">/media/codingma/LLM/llama3/Meta-Llama-3-8B-Instruct</span>
<span class="na">template</span><span class="pi">:</span> <span class="s">llama3</span>
</code></pre></div></div>
<p>通过如下命令启动,效果跟上面一样,但是更方便管理</p>
<div class="language-sh highlighter-rouge"><div class="highlight"><pre class="highlight"><code>llamafactory-cli webchat examples/inference/llama3.yaml
</code></pre></div></div>
<p>效果如图,可通过 http://localhost:7860/ 进行访问</p>
<ul>
<li><img src="https://pic4.zhimg.com/v2-49fa6327394c0fbcfc971a6e2c22da29_1440w.jpg" alt="" /></li>
</ul>
<h4 id="推理">推理</h4>
<h5 id="transformers">transformers</h5>
<p>huggingface transformers 库直接推理</p>
<div class="language-py highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">transformers</span>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="c1"># 切换为你下载的模型文件目录, 这里的demo是Llama-3-8B-Instruct
# 如果是其他模型,比如qwen,chatglm,请使用其对应的官方demo
</span><span class="n">model_id</span> <span class="o">=</span> <span class="s">"/media/codingma/LLM/llama3/Meta-Llama-3-8B-Instruct"</span>
<span class="n">pipeline</span> <span class="o">=</span> <span class="n">transformers</span><span class="p">.</span><span class="n">pipeline</span><span class="p">(</span>
<span class="s">"text-generation"</span><span class="p">,</span>
<span class="n">model</span><span class="o">=</span><span class="n">model_id</span><span class="p">,</span>
<span class="n">model_kwargs</span><span class="o">=</span><span class="p">{</span><span class="s">"torch_dtype"</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">bfloat16</span><span class="p">},</span>
<span class="n">device_map</span><span class="o">=</span><span class="s">"auto"</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">messages</span> <span class="o">=</span> <span class="p">[</span>
<span class="p">{</span><span class="s">"role"</span><span class="p">:</span> <span class="s">"system"</span><span class="p">,</span> <span class="s">"content"</span><span class="p">:</span> <span class="s">"You are a pirate chatbot who always responds in pirate speak!"</span><span class="p">},</span>
<span class="p">{</span><span class="s">"role"</span><span class="p">:</span> <span class="s">"user"</span><span class="p">,</span> <span class="s">"content"</span><span class="p">:</span> <span class="s">"Who are you?"</span><span class="p">},</span>
<span class="p">]</span>
<span class="n">prompt</span> <span class="o">=</span> <span class="n">pipeline</span><span class="p">.</span><span class="n">tokenizer</span><span class="p">.</span><span class="n">apply_chat_template</span><span class="p">(</span>
<span class="n">messages</span><span class="p">,</span>
<span class="n">tokenize</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span>
<span class="n">add_generation_prompt</span><span class="o">=</span><span class="bp">True</span>
<span class="p">)</span>
<span class="n">terminators</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">pipeline</span><span class="p">.</span><span class="n">tokenizer</span><span class="p">.</span><span class="n">eos_token_id</span><span class="p">,</span>
<span class="n">pipeline</span><span class="p">.</span><span class="n">tokenizer</span><span class="p">.</span><span class="n">convert_tokens_to_ids</span><span class="p">(</span><span class="s">"<|eot_id|>"</span><span class="p">)</span>
<span class="p">]</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="n">pipeline</span><span class="p">(</span>
<span class="n">prompt</span><span class="p">,</span>
<span class="n">max_new_tokens</span><span class="o">=</span><span class="mi">256</span><span class="p">,</span>
<span class="n">eos_token_id</span><span class="o">=</span><span class="n">terminators</span><span class="p">,</span>
<span class="n">do_sample</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span>
<span class="n">temperature</span><span class="o">=</span><span class="mf">0.6</span><span class="p">,</span>
<span class="n">top_p</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">outputs</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="s">"generated_text"</span><span class="p">][</span><span class="nb">len</span><span class="p">(</span><span class="n">prompt</span><span class="p">):])</span>
</code></pre></div></div>
<h5 id="api">API</h5>
<p>API 实现标准参考 OpenAI的相关接口协议,基于uvicorn服务框架进行开发, 使用如下的方式启动</p>
<p>本脚本改编自 <a href="https://github.com/hiyouga/LLaMA-Factory/blob/main/examples/inference/llama3_lora_sft.yaml">llama3_lora_sft.yaml</a></p>
<div class="language-sh highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="nv">CUDA_VISIBLE_DEVICES</span><span class="o">=</span>0 <span class="nv">API_PORT</span><span class="o">=</span>8000 llamafactory-cli api <span class="se">\</span>
<span class="nt">--model_name_or_path</span> /media/codingma/LLM/llama3/Meta-Llama-3-8B-Instruct <span class="se">\</span>
<span class="nt">--adapter_name_or_path</span> ./saves/LLaMA3-8B/lora/sft <span class="se">\</span>
<span class="nt">--template</span> llama3 <span class="se">\</span>
<span class="nt">--finetuning_type</span> lora
</code></pre></div></div>
<p>项目也支持了基于vllm 的推理后端,但是这里由于一些限制,需要提前将LoRA 模型进行merge,使用merge后的完整版模型目录或者训练前的模型原始目录都可。</p>
<div class="language-sh highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="nv">CUDA_VISIBLE_DEVICES</span><span class="o">=</span>0 <span class="nv">API_PORT</span><span class="o">=</span>8000 llamafactory-cli api <span class="se">\</span>
<span class="nt">--model_name_or_path</span> megred-model-path <span class="se">\</span>
<span class="nt">--template</span> llama3 <span class="se">\</span>
<span class="nt">--infer_backend</span> vllm <span class="se">\</span>
<span class="nt">--vllm_enforce_eager</span>
</code></pre></div></div>
<p>服务启动后,即可按照openai 的API 进行远程访问,主要的区别就是替换 其中的base_url,指向所部署的机器url和端口号即可。</p>
<div class="language-py highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">os</span>
<span class="kn">from</span> <span class="nn">openai</span> <span class="kn">import</span> <span class="n">OpenAI</span>
<span class="kn">from</span> <span class="nn">transformers.utils.versions</span> <span class="kn">import</span> <span class="n">require_version</span>
<span class="n">require_version</span><span class="p">(</span><span class="s">"openai>=1.5.0"</span><span class="p">,</span> <span class="s">"To fix: pip install openai>=1.5.0"</span><span class="p">)</span>
<span class="k">if</span> <span class="n">__name__</span> <span class="o">==</span> <span class="s">'__main__'</span><span class="p">:</span>
<span class="c1"># change to your custom port
</span> <span class="n">port</span> <span class="o">=</span> <span class="mi">8000</span>
<span class="n">client</span> <span class="o">=</span> <span class="n">OpenAI</span><span class="p">(</span>
<span class="n">api_key</span><span class="o">=</span><span class="s">"0"</span><span class="p">,</span>
<span class="n">base_url</span><span class="o">=</span><span class="s">"http://localhost:{}/v1"</span><span class="p">.</span><span class="nb">format</span><span class="p">(</span><span class="n">os</span><span class="p">.</span><span class="n">environ</span><span class="p">.</span><span class="n">get</span><span class="p">(</span><span class="s">"API_PORT"</span><span class="p">,</span> <span class="mi">8000</span><span class="p">)),</span>
<span class="p">)</span>
<span class="n">messages</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">messages</span><span class="p">.</span><span class="n">append</span><span class="p">({</span><span class="s">"role"</span><span class="p">:</span> <span class="s">"user"</span><span class="p">,</span> <span class="s">"content"</span><span class="p">:</span> <span class="s">"hello, where is USA"</span><span class="p">})</span>
<span class="n">result</span> <span class="o">=</span> <span class="n">client</span><span class="p">.</span><span class="n">chat</span><span class="p">.</span><span class="n">completions</span><span class="p">.</span><span class="n">create</span><span class="p">(</span><span class="n">messages</span><span class="o">=</span><span class="n">messages</span><span class="p">,</span> <span class="n">model</span><span class="o">=</span><span class="s">"test"</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">result</span><span class="p">.</span><span class="n">choices</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">message</span><span class="p">)</span>
</code></pre></div></div>
<h4 id="ollama">Ollama</h4>
<p>导出GGUF,部署Ollama</p>
<p>GGUF 是 lllama.cpp 设计的大模型存储格式,可以对模型进行高效的压缩,减少模型的大小与内存占用,从而提升模型的推理速度和效率。</p>
<p>Ollama框架可以帮助用户快速使用本地的大型语言模型,那将LLaMA-Factory项目的训练结果导出到Ollama中部署呢?</p>
<ol>
<li>将lora模型合并</li>
<li>安装gguf库</li>
<li>使用llama.cpp的转换脚本将训练后的完整模型转换为gguf格式</li>
<li>安装Ollama软件</li>
<li>注册要部署的模型文件</li>
<li>启动Ollama</li>
</ol>
<p>1-3 步是准备好 gguf 格式的文件这也是Ollama所需要的标准格式。</p>
<p>4-6 步就是如何在Ollama环境中启动训练后的模型。</p>
<h3 id="llama-factory-可视化">LLaMA-Factory 可视化</h3>
<h4 id="llama-board">LLaMA Board</h4>
<p>Web UI 使用</p>
<ul>
<li>LLaMA Board 可视化微调(由 Gradio 驱动)</li>
<li>Web UI 目前只支持<strong>单卡</strong>训练/推理,当机器有多张显卡时请使用 <code class="language-plaintext highlighter-rouge">CUDA_VISIBLE_DEVICES</code> 指定一张显卡启动程序。</li>
<li>启动网页UI,系统上必须有 GPU !</li>
</ul>
<p>目前webui版本只支持<strong>单机单卡</strong>和<strong>单机多卡</strong>,如果是<strong>多机多卡</strong>请使用命令行版本</p>
<div class="language-sh highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c"># Web UI 使用</span>
llamafactory-cli webui <span class="c"># 启动网页端</span>
<span class="nv">CUDA_VISIBLE_DEVICES</span><span class="o">=</span>0 llamafactory-cli webui
<span class="nv">CUDA_VISIBLE_DEVICES</span><span class="o">=</span>4 llamafactory-cli webui <span class="c"># 指定第4张显卡使用</span>
<span class="c"># 如果开启 gradio share功能,或者修改端口号</span>
<span class="nv">CUDA_VISIBLE_DEVICES</span><span class="o">=</span>0 <span class="nv">GRADIO_SHARE</span><span class="o">=</span>1 <span class="nv">GRADIO_SERVER_PORT</span><span class="o">=</span>7860 llamafactory-cli webui
</code></pre></div></div>
<p>上述的多个不同的大功能模块都通过不同的tab进行了整合,提供了一站式操作体验。</p>
<ul>
<li>训练 → 批量评估 → 交互测试 → 模型导出</li>
<li><img src="https://pic4.zhimg.com/v2-a1de61e1483e65fc7b237de43bb437fd_1440w.jpg" alt="" /></li>
<li>train页面,可通过预览命令功能,将训练脚本导出,用于支持多gpu训练</li>
<li><img src="https://pic3.zhimg.com/v2-f0f30aba4c6280a4c54aa599f41fa292_1440w.jpg" alt="" /></li>
</ul>
<p>点击开始按钮, 即可开始训练,网页端和服务器端会同步输出相关的日志结果</p>
<ul>
<li><img src="https://pic2.zhimg.com/v2-3696353c7c0eea5081314ab75b257b29_1440w.jpg" alt="" /></li>
</ul>
<p>训练完毕后, 点击“刷新适配器”,可找到该模型历史上使用webui训练的LoRA模型文件,后续再训练或者执行chat的时候,即会将此LoRA一起加载。</p>
<h4 id="wb">W&B</h4>
<p>Weights & Biases 记录实验数据,请在 yaml 文件中添加下面的参数。</p>
<div class="language-yml highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="na">report_to</span><span class="pi">:</span> <span class="s">wandb</span>
<span class="na">run_name</span><span class="pi">:</span> <span class="s">test_run</span> <span class="c1"># 可选</span>
</code></pre></div></div>
<p>启动训练任务时,将 WANDB_API_KEY 设置为密钥来登录 W&B 账户。</p>
<h4 id="swanlab">SwanLab</h4>
<p>用 SwanLab 记录实验数据,请在 yaml 文件中添加下面的参数。</p>
<div class="language-yml highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="na">use_swanlab</span><span class="pi">:</span> <span class="no">true</span>
<span class="na">swanlab_run_name</span><span class="pi">:</span> <span class="s">test_run</span> <span class="c1"># 可选</span>
</code></pre></div></div>
<p>启动训练任务时,登录 SwanLab账户 有三种方式:</p>
<ul>
<li>方式一:在 yaml 文件中添加 <code class="language-plaintext highlighter-rouge">swanlab_api_key=<your_api_key></code> ,并设置 API 密钥。</li>
<li>方式二:将环境变量 <code class="language-plaintext highlighter-rouge">SWANLAB_API_KEY</code> 设置为你的 API 密钥。</li>
<li>方式三:启动前使用 <code class="language-plaintext highlighter-rouge">swanlab login</code> 命令完成登录。</li>
</ul>
<h3 id="数据集">数据集</h3>
<p>目前支持 alpaca 和 sharegpt 两种数据格式</p>
<p>以alpaca为例,整个数据集是一个json对象的list,具体数据格式为</p>
<div class="language-json highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">[</span><span class="w">
</span><span class="p">{</span><span class="w">
</span><span class="nl">"instruction"</span><span class="p">:</span><span class="w"> </span><span class="s2">"用户指令(必填)"</span><span class="p">,</span><span class="w">
</span><span class="nl">"input"</span><span class="p">:</span><span class="w"> </span><span class="s2">"用户输入(选填)"</span><span class="p">,</span><span class="w">
</span><span class="nl">"output"</span><span class="p">:</span><span class="w"> </span><span class="s2">"模型回答(必填)"</span><span class="p">,</span><span class="w">
</span><span class="nl">"system"</span><span class="p">:</span><span class="w"> </span><span class="s2">"系统提示词(选填)"</span><span class="p">,</span><span class="w">
</span><span class="nl">"history"</span><span class="p">:</span><span class="w"> </span><span class="p">[</span><span class="w">
</span><span class="p">[</span><span class="s2">"第一轮指令(选填)"</span><span class="p">,</span><span class="w"> </span><span class="s2">"第一轮回答(选填)"</span><span class="p">],</span><span class="w">
</span><span class="p">[</span><span class="s2">"第二轮指令(选填)"</span><span class="p">,</span><span class="w"> </span><span class="s2">"第二轮回答(选填)"</span><span class="p">]</span><span class="w">
</span><span class="p">]</span><span class="w">
</span><span class="p">}</span><span class="w">
</span><span class="p">]</span><span class="w">
</span></code></pre></div></div>
<p>例子比如单轮(alpaca_data_zh_51k.json 中的例子, 数据集在data/dataset_info.json中注册为alpaca_zh)</p>
<div class="language-json highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">{</span><span class="w">
</span><span class="nl">"instruction"</span><span class="p">:</span><span class="w"> </span><span class="s2">"写一个有效的比较语句"</span><span class="p">,</span><span class="w">
</span><span class="nl">"input"</span><span class="p">:</span><span class="w"> </span><span class="s2">"篮球和足球"</span><span class="p">,</span><span class="w">
</span><span class="nl">"output"</span><span class="p">:</span><span class="w"> </span><span class="s2">"篮球和足球都是受欢迎的运动。"</span><span class="w">
</span><span class="p">}</span><span class="w">
</span></code></pre></div></div>
<p>和多轮 (oaast_sft_zh.json 中的例子, 数据集在data/dataset_info.json中注册为oaast_sft_zh)</p>
<div class="language-json highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">{</span><span class="w">
</span><span class="nl">"instruction"</span><span class="p">:</span><span class="w"> </span><span class="s2">"谢谢"</span><span class="p">,</span><span class="w">
</span><span class="nl">"input"</span><span class="p">:</span><span class="w"> </span><span class="s2">""</span><span class="p">,</span><span class="w">
</span><span class="nl">"output"</span><span class="p">:</span><span class="w"> </span><span class="s2">"不用谢! 很高兴我提供的信息能够帮助到你! 如果还有什么其他问题也可以向我提问。"</span><span class="p">,</span><span class="w">
</span><span class="nl">"history"</span><span class="p">:</span><span class="w"> </span><span class="p">[</span><span class="w">
</span><span class="p">[</span><span class="w">
</span><span class="s2">"请你给我写一个面试准备计划,我想要去面试微软的程序员岗位"</span><span class="p">,</span><span class="w">
</span><span class="s2">"首先,你可以去微软官网寻找招聘信息并申请面试。</span><span class="se">\n</span><span class="s2">其次,您可以在社交媒体平台寻找微软公司对程序员的面试问题,并做好准备。</span><span class="se">\n</span><span class="s2">最后,您可以自己对面试过程进行模拟,熟悉话题并减少紧张感。</span><span class="se">\n</span><span class="s2">我希望你能面试成功。"</span><span class="w">
</span><span class="p">]</span><span class="w">
</span><span class="p">]</span><span class="w">
</span><span class="p">}</span><span class="w">
</span></code></pre></div></div>
<h3 id="llama-factory-使用">LLaMA-Factory 使用</h3>
<p>多GPU分布式训练, 多种工具</p>
<ul>
<li>huggingface Accelerate</li>
<li>DeepSpeed</li>
</ul>
<p><a href="https://zhuanlan.zhihu.com/p/718263213?utm_psn=1815334840821751808">参考</a></p>
<h4 id="指令监督微调">指令监督微调</h4>
<p>sft lora</p>
<div class="language-sh highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="nv">CUDA_VISIBLE_DEVICES</span><span class="o">=</span>0 llamafactory-cli train <span class="se">\ </span>
<span class="nt">--stage</span> sft <span class="se">\ </span> <span class="c"># 训练阶段 “sft”,"pt","rm","ppo"</span>
<span class="nt">--do_train</span> <span class="se">\ </span> <span class="c"># 是否训练模式</span>
<span class="nt">--model_name_or_path</span> /media/codingma/LLM/llama3/Meta-Llama-3-8B-Instruct <span class="se">\ </span>
<span class="nt">--dataset</span> alpaca_gpt4_zh,identity,adgen_local <span class="se">\ </span><span class="c"># 数据集列表, 多个数据集逗号分隔</span>
<span class="nt">--dataset_dir</span> ./data <span class="se">\ </span> <span class="c"># 数据集目录,自带的data</span>
<span class="nt">--template</span> llama3 <span class="se">\ </span>
<span class="nt">--finetuning_type</span> lora <span class="se">\ </span> <span class="c"># 微调类型: full, freeze, lora</span>
<span class="nt">--output_dir</span> ./saves/LLaMA3-8B/lora/sft <span class="se">\ </span> <span class="c"># 模型保存目录</span>
<span class="nt">--overwrite_cache</span> <span class="se">\ </span>
<span class="nt">--overwrite_output_dir</span> <span class="se">\ </span>
<span class="nt">--cutoff_len</span> 1024 <span class="se">\ </span> <span class="c"># 长度截断</span>
<span class="nt">--preprocessing_num_workers</span> 16 <span class="se">\ </span> <span class="c"># </span>
<span class="nt">--per_device_train_batch_size</span> 2 <span class="se">\ </span> <span class="c"># 训练时,各节点最小 batch_size</span>
<span class="nt">--per_device_eval_batch_size</span> 1 <span class="se">\ </span> <span class="c"># 训练时,各节点最小 batch_size</span>
<span class="nt">--gradient_accumulation_steps</span> 8 <span class="se">\ </span> <span class="c"># 梯度累积步数</span>
<span class="nt">--lr_scheduler_type</span> cosine <span class="se">\ </span> <span class="c"># 学习率衰减策略</span>
<span class="nt">--logging_steps</span> 50 <span class="se">\ </span> <span class="c"># 打日志步数</span>
<span class="nt">--warmup_steps</span> 20 <span class="se">\ </span> <span class="c"># warmup</span>
<span class="nt">--save_steps</span> 100 <span class="se">\ </span> <span class="c"># 模型保存间隔步数</span>
<span class="nt">--eval_steps</span> 50 <span class="se">\ </span>
<span class="nt">--evaluation_strategy</span> steps <span class="se">\ </span> <span class="c"># </span>
<span class="nt">--load_best_model_at_end</span> <span class="se">\ </span>
<span class="nt">--learning_rate</span> 5e-5 <span class="se">\ </span>
<span class="nt">--num_train_epochs</span> 5.0 <span class="se">\ </span>
<span class="nt">--max_samples</span> 1000 <span class="se">\ </span> <span class="c"># 采样数</span>
<span class="nt">--val_size</span> 0.1 <span class="se">\ </span>
<span class="nt">--plot_loss</span> <span class="se">\ </span>
<span class="nt">--fp16</span> <span class="c"># 半精度, v100不支持 bf16</span>
</code></pre></div></div>
<p>训练结果</p>
<ul>
<li><img src="https://pica.zhimg.com/v2-c28f3d74144619426c06d6cf8fd1ff42_1440w.jpg" alt="" /></li>
</ul>
<p>output_dir 下主要包含3部分</p>
<ul>
<li>adapter 开头: LoRA保存的结果了,后续用于模型推理融合</li>
<li>training_loss 和 trainer_log 等记录训练的过程指标</li>
<li>其他是训练当时各种参数的备份</li>
</ul>
<p>loss在 正常情况下会随着训练的时间慢慢变小,最后需要下降到1以下的位置才会有一个比较好的效果,可以作为训练效果的一个中间指标。</p>
<p>lora 效果验证</p>
<ul>
<li>webui</li>
<li>terminal</li>
</ul>
<p>lora 模型推理: webchat</p>
<ul>
<li>指定原模型+lora模型</li>
</ul>
<div class="language-sh highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="nv">CUDA_VISIBLE_DEVICES</span><span class="o">=</span>0 llamafactory-cli webchat <span class="se">\ </span>
<span class="nt">--model_name_or_path</span> /media/codingma/LLM/llama3/Meta-Llama-3-8B-Instruct <span class="se">\ </span>
<span class="nt">--adapter_name_or_path</span> ./saves/LLaMA3-8B/lora/sft <span class="se">\</span>
<span class="nt">--template</span> llama3 <span class="se">\ </span>
<span class="nt">--finetuning_type</span> lora
</code></pre></div></div>
<p>terminal 终端验证</p>
<div class="language-sh highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="nv">CUDA_VISIBLE_DEVICES</span><span class="o">=</span>0 llamafactory-cli chat <span class="se">\</span>
<span class="nt">--model_name_or_path</span> /media/codingma/LLM/llama3/Meta-Llama-3-8B-Instruct <span class="se">\</span>
<span class="nt">--adapter_name_or_path</span> ./saves/LLaMA3-8B/lora/sft <span class="se">\</span>
<span class="nt">--template</span> llama3 <span class="se">\</span>
<span class="nt">--finetuning_type</span> lora
</code></pre></div></div>
<p>批量自动化评估</p>
<div class="language-sh highlighter-rouge"><div class="highlight"><pre class="highlight"><code>pip <span class="nb">install </span>jieba
pip <span class="nb">install </span>rouge-chinese
pip <span class="nb">install </span>nltk
</code></pre></div></div>
<p>本脚参考<a href="https://github.com/hiyouga/LLaMA-Factory/blob/main/examples/train_lora/llama3_lora_predict.yaml">文件参数</a></p>
<div class="language-sh highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="nv">CUDA_VISIBLE_DEVICES</span><span class="o">=</span>0 llamafactory-cli train <span class="se">\</span>
<span class="nt">--stage</span> sft <span class="se">\</span>
<span class="nt">--do_predict</span> <span class="se">\ </span> <span class="c"># 预测模式</span>
<span class="nt">--model_name_or_path</span> /media/codingma/LLM/llama3/Meta-Llama-3-8B-Instruct <span class="se">\</span>
<span class="nt">--adapter_name_or_path</span> ./saves/LLaMA3-8B/lora/sft <span class="se">\</span>
<span class="nt">--eval_dataset</span> alpaca_gpt4_zh,identity,adgen_local <span class="se">\</span>
<span class="nt">--dataset_dir</span> ./data <span class="se">\</span>
<span class="nt">--template</span> llama3 <span class="se">\</span>
<span class="nt">--finetuning_type</span> lora <span class="se">\</span>
<span class="nt">--output_dir</span> ./saves/LLaMA3-8B/lora/predict <span class="se">\</span>
<span class="nt">--overwrite_cache</span> <span class="se">\</span>
<span class="nt">--overwrite_output_dir</span> <span class="se">\</span>
<span class="nt">--cutoff_len</span> 1024 <span class="se">\</span>
<span class="nt">--preprocessing_num_workers</span> 16 <span class="se">\</span>
<span class="nt">--per_device_eval_batch_size</span> 1 <span class="se">\</span>
<span class="nt">--max_samples</span> 20 <span class="se">\ </span> <span class="c"># 预测阶段采样数目</span>
<span class="nt">--predict_with_generate</span> <span class="c"># 生成阶段</span>
</code></pre></div></div>
<p>评估预测脚本 vs 训练脚本</p>
<p>区别如下两个</p>
<ul>
<li><code class="language-plaintext highlighter-rouge">do_predict</code> 预测模式</li>
<li><code class="language-plaintext highlighter-rouge">predict_with_generate</code> 生成文本</li>
<li><code class="language-plaintext highlighter-rouge">max_samples</code> 每个数据集采样多少用于预测对比</li>
</ul>
<p>训练的LoRA和原始大模型进行融合,输出一个完整的模型文件</p>
<p>参考 <a href="LLaMA-Factory/examples/merge_lora/llama3_lora_sft.yaml">llama3_lora_sft.yaml</a></p>
<div class="language-sh highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="nv">CUDA_VISIBLE_DEVICES</span><span class="o">=</span>0 llamafactory-cli <span class="nb">export</span> <span class="se">\</span>
<span class="nt">--model_name_or_path</span> /media/codingma/LLM/llama3/Meta-Llama-3-8B-Instruct <span class="se">\</span>
<span class="nt">--adapter_name_or_path</span> ./saves/LLaMA3-8B/lora/sft <span class="se">\</span>
<span class="nt">--template</span> llama3 <span class="se">\</span>
<span class="nt">--finetuning_type</span> lora <span class="se">\</span>
<span class="nt">--export_dir</span> megred-model-path <span class="se">\</span>
<span class="nt">--export_size</span> 2 <span class="se">\</span>
<span class="nt">--export_device</span> cpu <span class="se">\</span>
<span class="nt">--export_legacy_format</span> False
</code></pre></div></div>
<p>Accelerate</p>
<div class="language-sh highlighter-rouge"><div class="highlight"><pre class="highlight"><code>accelerate launch src/train.py <span class="se">\</span>
<span class="nt">--ddp_timeout</span> 18000000 <span class="se">\</span>
<span class="nt">--stage</span> sft <span class="se">\</span>
<span class="nt">--do_train</span> <span class="se">\</span>
<span class="nt">--model_name_or_path</span> /gemini/pretrain/Qwen1.5-4B/ <span class="se">\</span>
<span class="nt">--dataset</span> alpaca_gpt4_data_zh,alpaca_gpt4_data_en,glaive_toolcall_zh_demo,adgen_local <span class="se">\</span>
<span class="nt">--template</span> qwen <span class="se">\</span>
<span class="nt">--finetuning_type</span> lora <span class="se">\</span>
<span class="nt">--lora_target</span> q_proj,v_proj <span class="se">\</span>
<span class="nt">--output_dir</span> path_to_sft_checkpoint <span class="se">\</span>
<span class="nt">--overwrite_cache</span> <span class="se">\</span>
<span class="nt">--overwrite_output_dir</span>
<span class="nt">--per_device_train_batch_size</span> 2 <span class="se">\</span>
<span class="nt">--gradient_accumulation_steps</span> 4 <span class="se">\</span>
<span class="nt">--lr_scheduler_type</span> cosine <span class="se">\</span>
<span class="nt">--logging_steps</span> 10 <span class="se">\</span>
<span class="nt">--save_steps</span> 1000 <span class="se">\</span>
<span class="nt">--learning_rate</span> 5e-5 <span class="se">\</span>
<span class="nt">--num_train_epochs</span> 3.0 <span class="se">\</span>
<span class="nt">--plot_loss</span> <span class="se">\</span>
<span class="nt">--fp16</span>
</code></pre></div></div>
<p>使用 DeepSpeed</p>
<div class="language-sh highlighter-rouge"><div class="highlight"><pre class="highlight"><code>deepspeed <span class="nt">--num_gpus</span> 2 src/train.py <span class="se">\</span>
<span class="nt">--deepspeed</span> ds_config.json <span class="se">\</span>
<span class="nt">--ddp_timeout</span> 18000000 <span class="se">\</span>
<span class="nt">--stage</span> sft <span class="se">\</span>
<span class="nt">--do_train</span> <span class="se">\</span>
<span class="nt">--model_name_or_path</span> /gemini/pretrain/Qwen1.5-4B/ <span class="se">\</span>
<span class="nt">--dataset</span> alpaca_zh_demo <span class="se">\</span>
<span class="nt">--template</span> qwen <span class="se">\</span>
<span class="nt">--finetuning_type</span> lora <span class="se">\</span>
<span class="nt">--lora_target</span> q_proj,v_proj <span class="se">\</span>
<span class="nt">--output_dir</span> path_to_sft_checkpoint <span class="se">\</span>
<span class="nt">--overwrite_cache</span> <span class="se">\</span>
<span class="nt">--overwrite_output_dir</span>
<span class="nt">--per_device_train_batch_size</span> 4 <span class="se">\</span>
<span class="nt">--gradient_accumulation_steps</span> 4 <span class="se">\</span>
<span class="nt">--lr_scheduler_type</span> cosine <span class="se">\</span>
<span class="nt">--logging_steps</span> 10 <span class="se">\</span>
<span class="nt">--save_steps</span> 1000 <span class="se">\</span>
<span class="nt">--learning_rate</span> 5e-5 <span class="se">\</span>
<span class="nt">--num_train_epochs</span> 3.0 <span class="se">\</span>
<span class="nt">--plot_loss</span> <span class="se">\</span>
<span class="nt">--fp16</span>
</code></pre></div></div>
<h4 id="奖励模型训练">奖励模型训练</h4>
<p>Accelerate</p>
<div class="language-sh highlighter-rouge"><div class="highlight"><pre class="highlight"><code>accelerate launch src/train.py <span class="se">\</span>
<span class="nt">--stage</span> <span class="nb">rm</span> <span class="se">\</span>
<span class="nt">--do_train</span> <span class="se">\</span>
<span class="nt">--model_name_or_path</span> /gemini/pretrain/Qwen1.5-4B/ <span class="se">\</span>
<span class="nt">--adapter_name_or_path</span> path_to_sft_checkpoint <span class="se">\</span>
<span class="nt">--create_new_adapter</span> <span class="se">\</span>
<span class="nt">--dataset</span> dpo_zh_demo <span class="se">\</span>
<span class="nt">--template</span> qwen <span class="se">\</span>
<span class="nt">--finetuning_type</span> lora <span class="se">\</span>
<span class="nt">--lora_target</span> q_proj,v_proj <span class="se">\</span>
<span class="nt">--output_dir</span> path_to_ac_rm_checkpoint <span class="se">\</span>
<span class="nt">--per_device_train_batch_size</span> 2 <span class="se">\</span>
<span class="nt">--gradient_accumulation_steps</span> 4 <span class="se">\</span>
<span class="nt">--lr_scheduler_type</span> cosine <span class="se">\</span>
<span class="nt">--logging_steps</span> 10 <span class="se">\</span>
<span class="nt">--save_steps</span> 1000 <span class="se">\</span>
<span class="nt">--learning_rate</span> 1e-5 <span class="se">\</span>
<span class="nt">--num_train_epochs</span> 1.0 <span class="se">\</span>
<span class="nt">--plot_loss</span> <span class="se">\</span>
<span class="nt">--fp16</span>
</code></pre></div></div>
<p>使用 DeepSpeed</p>
<div class="language-sh highlighter-rouge"><div class="highlight"><pre class="highlight"><code>deepspeed <span class="nt">--num_gpus</span> 2 src/train.py <span class="se">\</span>
<span class="nt">--deepspeed</span> ds_config.json <span class="se">\</span>
<span class="nt">--stage</span> <span class="nb">rm</span> <span class="se">\</span>
<span class="nt">--do_train</span> <span class="se">\</span>
<span class="nt">--model_name_or_path</span> /gemini/pretrain/Qwen1.5-4B/ <span class="se">\</span>
<span class="nt">--adapter_name_or_path</span> path_to_sft_checkpoint <span class="se">\</span>
<span class="nt">--create_new_adapter</span> <span class="se">\</span>
<span class="nt">--dataset</span> dpo_zh_demo <span class="se">\</span>
<span class="nt">--template</span> qwen <span class="se">\</span>
<span class="nt">--finetuning_type</span> lora <span class="se">\</span>
<span class="nt">--lora_target</span> q_proj,v_proj <span class="se">\</span>
<span class="nt">--output_dir</span> path_to_deep_rm_checkpoint <span class="se">\</span>
<span class="nt">--per_device_train_batch_size</span> 2 <span class="se">\</span>
<span class="nt">--gradient_accumulation_steps</span> 4 <span class="se">\</span>
<span class="nt">--lr_scheduler_type</span> cosine <span class="se">\</span>
<span class="nt">--logging_steps</span> 10 <span class="se">\</span>
<span class="nt">--save_steps</span> 1000 <span class="se">\</span>
<span class="nt">--learning_rate</span> 1e-5 <span class="se">\</span>
<span class="nt">--num_train_epochs</span> 1.0 <span class="se">\</span>
<span class="nt">--plot_loss</span> <span class="se">\</span>
<span class="nt">--fp16</span>
</code></pre></div></div>
<h4 id="ppo-训练">ppo 训练</h4>
<p>Accelerate</p>
<div class="language-sh highlighter-rouge"><div class="highlight"><pre class="highlight"><code>accelerate launch src/train.py <span class="se">\</span>
<span class="nt">--stage</span> ppo <span class="se">\</span>
<span class="nt">--do_train</span> <span class="se">\</span>
<span class="nt">--model_name_or_path</span> /gemini/pretrain/Qwen1.5-4B/ <span class="se">\</span>
<span class="nt">--adapter_name_or_path</span> path_to_sft_checkpoint <span class="se">\</span>
<span class="nt">--create_new_adapter</span> <span class="se">\</span>
<span class="nt">--dataset</span> alpaca_zh_demo <span class="se">\</span>
<span class="nt">--template</span> qwen <span class="se">\</span>
<span class="nt">--finetuning_type</span> lora <span class="se">\</span>
<span class="nt">--lora_target</span> q_proj,v_proj <span class="se">\</span>
<span class="nt">--reward_model</span> path_to_ac_rm_checkpoint <span class="se">\</span>
<span class="nt">--output_dir</span> path_to_ac_ppo_checkpoint <span class="se">\</span>
<span class="nt">--per_device_train_batch_size</span> 2 <span class="se">\</span>
<span class="nt">--gradient_accumulation_steps</span> 4 <span class="se">\</span>
<span class="nt">--lr_scheduler_type</span> cosine <span class="se">\</span>
<span class="nt">--top_k</span> 0 <span class="se">\</span>
<span class="nt">--top_p</span> 0.9 <span class="se">\</span>
<span class="nt">--logging_steps</span> 10 <span class="se">\</span>
<span class="nt">--save_steps</span> 1000 <span class="se">\</span>
<span class="nt">--learning_rate</span> 1e-5 <span class="se">\</span>
<span class="nt">--num_train_epochs</span> 1.0 <span class="se">\</span>
<span class="nt">--plot_loss</span> <span class="se">\</span>
<span class="nt">--fp16</span>
</code></pre></div></div>
<p>deepspeed</p>
<div class="language-sh highlighter-rouge"><div class="highlight"><pre class="highlight"><code>deepspeed <span class="nt">--num_gpus</span> 2 src/train.py <span class="se">\</span>
<span class="nt">--deepspeed</span> ds_config.json <span class="se">\</span>
<span class="nt">--stage</span> ppo <span class="se">\</span>
<span class="nt">--do_train</span> <span class="se">\</span>
<span class="nt">--model_name_or_path</span> /gemini/pretrain/Qwen1.5-4B/ <span class="se">\</span>
<span class="nt">--adapter_name_or_path</span> path_to_sft_checkpoint <span class="se">\</span>
<span class="nt">--create_new_adapter</span> <span class="se">\</span>
<span class="nt">--dataset</span> alpaca_zh_demo <span class="se">\</span>
<span class="nt">--template</span> qwen <span class="se">\</span>
<span class="nt">--finetuning_type</span> lora <span class="se">\</span>
<span class="nt">--lora_target</span> q_proj,v_proj <span class="se">\</span>
<span class="nt">--reward_model</span> path_to_deep_rm_checkpoint <span class="se">\</span>
<span class="nt">--output_dir</span> path_to_deep_ppo_checkpoint <span class="se">\</span>
<span class="nt">--per_device_train_batch_size</span> 4 <span class="se">\</span>
<span class="nt">--gradient_accumulation_steps</span> 4 <span class="se">\</span>
<span class="nt">--lr_scheduler_type</span> cosine <span class="se">\</span>
<span class="nt">--top_k</span> 0 <span class="se">\</span>
<span class="nt">--top_p</span> 0.9 <span class="se">\</span>
<span class="nt">--logging_steps</span> 10 <span class="se">\</span>
<span class="nt">--save_steps</span> 1000 <span class="se">\</span>
<span class="nt">--learning_rate</span> 1e-5 <span class="se">\</span>
<span class="nt">--num_train_epochs</span> 1.0 <span class="se">\</span>
<span class="nt">--plot_loss</span> <span class="se">\</span>
<span class="nt">--fp16</span>
</code></pre></div></div>
<h4 id="dpo-训练">dpo 训练</h4>
<p>Accelerate</p>
<div class="language-sh highlighter-rouge"><div class="highlight"><pre class="highlight"><code>accelerate launch src/train.py <span class="se">\</span>
<span class="nt">--stage</span> dpo <span class="se">\</span>
<span class="nt">--do_train</span> <span class="se">\</span>
<span class="nt">--model_name_or_path</span> /gemini/pretrain/Qwen1.5-4B/ <span class="se">\</span>
<span class="nt">--adapter_name_or_path</span> path_to_sft_checkpoint <span class="se">\</span>
<span class="nt">--create_new_adapter</span> <span class="se">\</span>
<span class="nt">--dataset</span> dpo_zh_demo <span class="se">\</span>
<span class="nt">--template</span> qwen <span class="se">\</span>
<span class="nt">--finetuning_type</span> lora <span class="se">\</span>
<span class="nt">--lora_target</span> q_proj,v_proj <span class="se">\</span>
<span class="nt">--output_dir</span> path_to_ac_dpo_checkpoint <span class="se">\</span>
<span class="nt">--per_device_train_batch_size</span> 2 <span class="se">\</span>
<span class="nt">--gradient_accumulation_steps</span> 4 <span class="se">\</span>
<span class="nt">--lr_scheduler_type</span> cosine <span class="se">\</span>
<span class="nt">--logging_steps</span> 10 <span class="se">\</span>
<span class="nt">--save_steps</span> 1000 <span class="se">\</span>
<span class="nt">--learning_rate</span> 1e-5 <span class="se">\</span>
<span class="nt">--num_train_epochs</span> 1.0 <span class="se">\</span>
<span class="nt">--plot_loss</span> <span class="se">\</span>
<span class="nt">--fp16</span>
</code></pre></div></div>
<p>deepspeed</p>
<div class="language-sh highlighter-rouge"><div class="highlight"><pre class="highlight"><code>deepspeed <span class="nt">--num_gpus</span> 2 src/train.py <span class="se">\</span>
<span class="nt">--deepspeed</span> ds_config.json <span class="se">\</span>
<span class="nt">--stage</span> dpo <span class="se">\</span>
<span class="nt">--do_train</span> <span class="se">\</span>
<span class="nt">--model_name_or_path</span> /gemini/pretrain/Qwen1.5-4B/ <span class="se">\</span>
<span class="nt">--adapter_name_or_path</span> path_to_sft_checkpoint <span class="se">\</span>
<span class="nt">--create_new_adapter</span> <span class="se">\</span>
<span class="nt">--dataset</span> dpo_zh_demo <span class="se">\</span>
<span class="nt">--template</span> qwen <span class="se">\</span>
<span class="nt">--finetuning_type</span> lora <span class="se">\</span>
<span class="nt">--lora_target</span> q_proj,v_proj <span class="se">\</span>
<span class="nt">--output_dir</span> path_to_deep_dpo_checkpoint <span class="se">\</span>
<span class="nt">--per_device_train_batch_size</span> 2 <span class="se">\</span>
<span class="nt">--gradient_accumulation_steps</span> 4 <span class="se">\</span>
<span class="nt">--lr_scheduler_type</span> cosine <span class="se">\</span>
<span class="nt">--logging_steps</span> 10 <span class="se">\</span>
<span class="nt">--save_steps</span> 1000 <span class="se">\</span>
<span class="nt">--learning_rate</span> 1e-5 <span class="se">\</span>
<span class="nt">--num_train_epochs</span> 1.0 <span class="se">\</span>
<span class="nt">--plot_loss</span> <span class="se">\</span>
<span class="nt">--fp16</span>
</code></pre></div></div>
<h2 id="xtuner">Xtuner</h2>
<p>上海AI实验室推出的 <a href="https://github.com/InternLM/xtuner">XTuner</a> 是一个高效、灵活、全能的轻量化大模型微调工具库。与 LLaMA-Factory 类似,不过,在<strong>长序列训练</strong>、<strong>token生成速度</strong>等方面要比 LLaMA-Factory 更强。</p>
<p>简析</p>
<ul>
<li>数据集: LLaMA-Factory 支持<strong>多种格式</strong>的数据集,更通用泛化;而 <code class="language-plaintext highlighter-rouge">XTuner</code> 只支持类似 <code class="language-plaintext highlighter-rouge">ShareGPT</code> 格式的数据集。</li>
<li>模型支持: LLaMA-Factory 支持模型种类也要比XTuner更多;但 XTuner 多模态模型(LLaVA-Internlm2-7B / 20B、LLaVA-v1.5)的支持要比 LLaMA-Factory。</li>
</ul>
<p>多轮对话训练时的loss计算。</p>
<ul>
<li>从文档来看,XTuner更清晰,而且是我想要的效果;</li>
<li>而对于 LLaMA-Factory,其放出来的只是数据集格式文档,loss计算没那么透明,只能啃源码。</li>
</ul>
<p>多轮对话所对应的<strong>长序列</strong>训练性能。随着 Gemini 1M context length 和 Sora 出世,如何训练超长上下文的大模型引起了大家广泛关注。同时在大多数的场景下,多轮对话一般也就是一个conversations包含几轮对话;但在实际情况中,一个conversations下有几百个对话,即长对话,这种场景还是比较多的。</p>
<p>解决方案比较麻烦,需要做拆分;在基座模型支持长上下文的情况下,如果微调框架能支持长序列训练,且性能不错,是很好的选择;</p>
<p>XTuner 在这方面要比 LLaMA-Factory 更好。</p>
<p>XTuner 序列并行设计思路参考了 DeepSpeed 的工作 DeepSpeed Ulysses,并加以优化,以达到直接基于 transformers 算法库或 Huggingface Hub 上的开源模型训练 1M 以上超长序列的目标。</p>
<table>
<thead>
<tr>
<th>模型</th>
<th>序列并行支持情况</th>
</tr>
</thead>
<tbody>
<tr>
<td>baichuan</td>
<td>1/2 ❌</td>
</tr>
<tr>
<td>chatglm</td>
<td>2/3 ❌</td>
</tr>
<tr>
<td>deepseek</td>
<td>✅</td>
</tr>
<tr>
<td>gemma</td>
<td>❌</td>
</tr>
<tr>
<td>internlm 2</td>
<td>✅</td>
</tr>
<tr>
<td>llama 2</td>
<td>✅</td>
</tr>
<tr>
<td>mistral</td>
<td>❌</td>
</tr>
<tr>
<td>qwen 1/1.5</td>
<td>❌</td>
</tr>
<tr>
<td>starcoder</td>
<td>❌</td>
</tr>
<tr>
<td>yi</td>
<td>✅</td>
</tr>
<tr>
<td>zephyr</td>
<td>✅</td>
</tr>
</tbody>
</table>
<h2 id="swift">SWIFT</h2>
<p>【2024-7-4】 阿里推出训练框架 <a href="https://github.com/modelscope/ms-swift/blob/main/README_CN.md">SWIFT</a> (Scalable lightWeight Infrastructure for Fine-Tuning)</p>
<p>SWIFT支持300+ LLM和50+ MLLM(多模态大模型)的训练(预训练、微调、对齐)、推理、评测和部署。开发者可以直接将我们的框架应用到自己的Research和生产环境中,实现模型训练评测到应用的完整链路。我们除支持了PEFT提供的轻量训练方案外,也提供了一个完整的Adapters库以支持最新的训练技术,如NEFTune、LoRA+、LLaMA-PRO等,这个适配器库可以脱离训练脚本直接使用在自己的自定流程中。</p>
<h1 id="结束">结束</h1>
</article>
<hr>
<!-- 打赏 -->
<!-- 【2019-08-06】打赏功能, http://www.twistedwg.com/2018/05/06/jekyll-reward.html-->
<div class="reward">
<div class="reward-button">赏<span class="reward-code">
<span class="alipay-code"> <img class="alipay-img wdp-appear" src="/wqw/fig/alipay.png"><b>支付宝打赏</b> </span>
<span class="wechat-code"> <img class="wechat-img wdp-appear" src="/wqw/fig/wechatpay.png"><b>微信打赏</b> </span> </span>
</div>
<p class="reward-notice" style="color:chocolate">~ 海内存知已,天涯若比邻 ~</p>
</div>
<!-- 分享工具 -->
<h2 id="share">Share</h2>
<!-- https://github.com/overtrue/share.js -->
<div class="social-share"></div>
<!-- css & js -->
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/social-share.js/1.0.16/css/share.min.css">
<script src="https://cdnjs.cloudflare.com/ajax/libs/social-share.js/1.0.16/js/social-share.min.js"></script>
<!-- 相似文章 -->
<!-- 相关文章推荐 -->
<h2 id="comments">Related Posts</h2>
<!-- 翻页 -->
<head>
<!-- [2022-11-10]卡片样式 -->
<link href="/css/card.css " rel="stylesheet" type="text/css">
</head>
<div class="post-recent">
<div class="pre">
<p><strong>上一篇</strong> <a href="/dist">分布式训练</a></p>
</div>
<div class="nex">
<p><strong>下一篇</strong> <a href="/llm_train">LLM 大模型训练之路</a></p>
</div>
</div>
<div class="kapian">
<div class="tu">
<img src="https://img.zcool.cn/community/01493a5cc98256a801214168b8995d.jpg">
</div>
<div class="wenben">
<p><a href="/dist">标题:分布式训练</a></p>
<p style="padding-bottom: 20px;">摘要:分布式训练知识点</p>
</div>
<div class="tu">
<img src="https://img.zcool.cn/community/01493a5cc98256a801214168b8995d.jpg">
</div>
<div class="wenben">
<p><a href="/llm_train">标题:LLM 大模型训练之路</a></p>
<p style="padding-bottom: 40px;">摘要:大模型训练原理,如何训练,有什么经验?</p>
</div>
</div>
<h2> 站内可视化导航 </h2>
<!-- 文章导读区 -->
文章可视化导读:鼠标划过图形块时,如果出现蓝色光环, 点击即可跳转到对应主题
<!-- draw.io diagram -->
<div class="mxgraph" style="max-width:100%;border:1px solid transparent;" data-mxgraph="{"highlight":"#0000ff","nav":true,"resize":true,"toolbar":"zoom layers tags lightbox","edit":"_blank","xml":"<mxfile host=\"app.diagrams.net\" agent=\"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36\" version=\"24.8.9\">\n <diagram id=\"4u5yHArNrn4fvDAkmxS5\" name=\"第 1 页\">\n <mxGraphModel dx=\"1050\" dy=\"530\" grid=\"1\" gridSize=\"10\" guides=\"1\" tooltips=\"1\" connect=\"1\" arrows=\"0\" fold=\"1\" page=\"1\" pageScale=\"1\" pageWidth=\"850\" pageHeight=\"1100\" math=\"0\" shadow=\"0\">\n <root>\n <mxCell id=\"0\" />\n <mxCell id=\"1\" parent=\"0\" />\n <mxCell id=\"o5D4xRg-JXB86p6HjegH-613\" value=\"\" style=\"rounded=0;whiteSpace=wrap;html=1;labelBackgroundColor=none;fontSize=10;fillColor=#f5f5f5;dashed=1;strokeColor=#666666;fontColor=#333333;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"291.75\" y=\"2007\" width=\"408.25\" height=\"230\" as=\"geometry\" />\n </mxCell>\n <mxCell id=\"379\" value=\"\" style=\"rounded=0;whiteSpace=wrap;html=1;labelBackgroundColor=none;fontSize=10;fillColor=#f5f5f5;dashed=1;strokeColor=#666666;fontColor=#333333;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"210\" y=\"1090\" width=\"680\" height=\"430\" as=\"geometry\" />\n </mxCell>\n <mxCell id=\"o5D4xRg-JXB86p6HjegH-522\" value=\"\" style=\"rounded=1;whiteSpace=wrap;html=1;dashed=1;fillColor=#fff2cc;strokeColor=#d6b656;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"117.13\" y=\"1060\" width=\"352.87\" height=\"130\" as=\"geometry\" />\n </mxCell>\n <mxCell id=\"378\" value=\"\" style=\"rounded=0;whiteSpace=wrap;html=1;labelBackgroundColor=none;fontSize=10;fillColor=#f5f5f5;dashed=1;strokeColor=#666666;fontColor=#333333;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"1026\" y=\"1750\" width=\"264\" height=\"360\" as=\"geometry\" />\n </mxCell>\n <mxCell id=\"380\" value=\"\" style=\"rounded=1;whiteSpace=wrap;html=1;dashed=1;fillColor=#fff2cc;strokeColor=#d6b656;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"910\" y=\"1190\" width=\"290\" height=\"130\" as=\"geometry\" />\n </mxCell>\n <mxCell id=\"381\" value=\"\" style=\"rounded=1;whiteSpace=wrap;html=1;dashed=1;fillColor=#fff2cc;strokeColor=#d6b656;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"734\" y=\"1060\" width=\"266.5\" height=\"103.37\" as=\"geometry\" />\n </mxCell>\n <mxCell id=\"382\" value=\"\" style=\"rounded=1;whiteSpace=wrap;html=1;dashed=1;fillColor=#fff2cc;strokeColor=#d6b656;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"640\" y=\"1223.25\" width=\"230\" height=\"96.75\" as=\"geometry\" />\n </mxCell>\n <mxCell id=\"383\" value=\"\" style=\"rounded=1;whiteSpace=wrap;html=1;dashed=1;fillColor=#fff2cc;strokeColor=#d6b656;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"863\" y=\"1370\" width=\"190\" height=\"126.25\" as=\"geometry\" />\n </mxCell>\n <mxCell id=\"384\" value=\"\" style=\"rounded=1;whiteSpace=wrap;html=1;dashed=1;fillColor=#E6D0DE;strokeColor=#ae4132;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"269.75\" y=\"1365\" width=\"441.5\" height=\"135\" as=\"geometry\" />\n </mxCell>\n <mxCell id=\"385\" value=\"\" style=\"rounded=0;whiteSpace=wrap;html=1;labelBackgroundColor=none;fontSize=10;fillColor=#f5f5f5;dashed=1;strokeColor=#666666;fontColor=#333333;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"296\" y=\"1750\" width=\"408.25\" height=\"230\" as=\"geometry\" />\n </mxCell>\n <mxCell id=\"386\" value=\"\" style=\"ellipse;whiteSpace=wrap;html=1;dashed=1;fillColor=#fff2cc;strokeColor=#d6b656;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"305.5\" y=\"1200\" width=\"324.5\" height=\"130\" as=\"geometry\" />\n </mxCell>\n <mxCell id=\"387\" value=\"\" style=\"rounded=1;whiteSpace=wrap;html=1;dashed=1;fillColor=#E6D0DE;strokeColor=#E6E6E6;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"280\" y=\"1520\" width=\"470\" height=\"160\" as=\"geometry\" />\n </mxCell>\n <mxCell id=\"388\" value=\"模型层\" style=\"edgeLabel;html=1;align=center;verticalAlign=middle;resizable=0;points=[];labelBackgroundColor=none;fontSize=13;fontStyle=1;fontColor=#6666FF;\" parent=\"1\" vertex=\"1\" connectable=\"0\">\n <mxGeometry x=\"435.38\" y=\"1400\" as=\"geometry\">\n <mxPoint x=\"-3\" y=\"-20\" as=\"offset\" />\n </mxGeometry>\n </mxCell>\n <mxCell id=\"389\" value=\"模态层\" style=\"edgeLabel;html=1;align=center;verticalAlign=middle;resizable=0;points=[];labelBackgroundColor=none;fontSize=13;fontStyle=1;fontColor=#6666FF;\" parent=\"1\" vertex=\"1\" connectable=\"0\">\n <mxGeometry x=\"347.13\" y=\"1260\" as=\"geometry\">\n <mxPoint x=\"-8\" y=\"-5\" as=\"offset\" />\n </mxGeometry>\n </mxCell>\n <UserObject label=\"文本生成\" link=\"text-generation\" id=\"391\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#ffcd28;strokeColor=none;shadow=1;gradientColor=#FFB570;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"385\" y=\"1209\" width=\"60\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"图像生成\" link=\"image-generation\" id=\"392\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#ffcd28;strokeColor=none;shadow=1;gradientColor=#FFB570;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"481\" y=\"1220\" width=\"60\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <mxCell id=\"393\" value=\"语音生成\" style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#ffcd28;strokeColor=none;shadow=1;gradientColor=#FFB570;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"357\" y=\"1260\" width=\"60\" height=\"30\" as=\"geometry\" />\n </mxCell>\n <UserObject label=\"视频生成\" link=\"video_gen\" id=\"394\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#ffcd28;strokeColor=none;shadow=1;gradientColor=#FFB570;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"539.25\" y=\"1260\" width=\"60\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <mxCell id=\"395\" value=\"\" style=\"rounded=0;orthogonalLoop=1;jettySize=auto;html=1;strokeWidth=2;strokeColor=#999999;exitX=0.5;exitY=0;exitDx=0;exitDy=0;entryX=0.5;entryY=1;entryDx=0;entryDy=0;\" parent=\"1\" source=\"451\" target=\"391\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"469\" y=\"1675\" as=\"sourcePoint\" />\n <mxPoint x=\"539\" y=\"1675\" as=\"targetPoint\" />\n </mxGeometry>\n </mxCell>\n <mxCell id=\"396\" value=\"\" style=\"rounded=0;orthogonalLoop=1;jettySize=auto;html=1;strokeWidth=2;strokeColor=#999999;entryX=0.5;entryY=1;entryDx=0;entryDy=0;exitX=0.5;exitY=0;exitDx=0;exitDy=0;\" parent=\"1\" source=\"451\" target=\"392\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"465\" y=\"1350\" as=\"sourcePoint\" />\n <mxPoint x=\"439\" y=\"1285\" as=\"targetPoint\" />\n </mxGeometry>\n </mxCell>\n <mxCell id=\"397\" value=\"\" style=\"rounded=0;orthogonalLoop=1;jettySize=auto;html=1;strokeWidth=2;strokeColor=#999999;entryX=0.5;entryY=1;entryDx=0;entryDy=0;exitX=0.5;exitY=0;exitDx=0;exitDy=0;\" parent=\"1\" source=\"451\" target=\"394\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"528\" y=\"1330\" as=\"sourcePoint\" />\n <mxPoint x=\"555\" y=\"1285\" as=\"targetPoint\" />\n </mxGeometry>\n </mxCell>\n <mxCell id=\"398\" value=\"\" style=\"rounded=0;orthogonalLoop=1;jettySize=auto;html=1;strokeWidth=2;strokeColor=#999999;entryX=0.5;entryY=1;entryDx=0;entryDy=0;exitX=0.5;exitY=0;exitDx=0;exitDy=0;\" parent=\"1\" source=\"451\" target=\"393\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"558\" y=\"1360\" as=\"sourcePoint\" />\n <mxPoint x=\"585\" y=\"1315\" as=\"targetPoint\" />\n </mxGeometry>\n </mxCell>\n <UserObject label=\"扩散模型\" link=\"ddpm\" id=\"399\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"497\" y=\"1149\" width=\"60\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <mxCell id=\"401\" value=\"NLP任务\" style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#d0cee2;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"780\" y=\"1480\" width=\"75\" height=\"30\" as=\"geometry\" />\n </mxCell>\n <mxCell id=\"402\" value=\"\" style=\"edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;\" parent=\"1\" source=\"404\" target=\"488\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\" />\n </mxCell>\n <mxCell id=\"403\" value=\"\" style=\"edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;\" parent=\"1\" source=\"404\" target=\"490\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\" />\n </mxCell>\n <UserObject label=\"对话系统\" link=\"dialogue-system\" id=\"404\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"328.5\" y=\"1149\" width=\"60\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <mxCell id=\"406\" value=\"\" style=\"rounded=0;orthogonalLoop=1;jettySize=auto;html=1;strokeWidth=1;strokeColor=#999999;entryX=0.5;entryY=1;entryDx=0;entryDy=0;dashed=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;\" parent=\"1\" source=\"391\" target=\"404\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"425\" y=\"1260\" as=\"sourcePoint\" />\n <mxPoint x=\"565\" y=\"1220\" as=\"targetPoint\" />\n </mxGeometry>\n </mxCell>\n <mxCell id=\"407\" value=\"\" style=\"rounded=0;orthogonalLoop=1;jettySize=auto;html=1;strokeWidth=1;strokeColor=#999999;entryX=0;entryY=0.5;entryDx=0;entryDy=0;dashed=1;exitX=0.5;exitY=1;exitDx=0;exitDy=0;\" parent=\"1\" source=\"401\" target=\"491\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"490\" y=\"1570\" as=\"sourcePoint\" />\n <mxPoint x=\"325\" y=\"1455\" as=\"targetPoint\" />\n </mxGeometry>\n </mxCell>\n <mxCell id=\"408\" value=\"\" style=\"rounded=0;orthogonalLoop=1;jettySize=auto;html=1;strokeWidth=1;strokeColor=#999999;entryX=0.5;entryY=1;entryDx=0;entryDy=0;dashed=1;exitX=0.617;exitY=0.05;exitDx=0;exitDy=0;exitPerimeter=0;\" parent=\"1\" source=\"392\" target=\"399\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"685\" y=\"1433\" as=\"sourcePoint\" />\n <mxPoint x=\"735\" y=\"1455\" as=\"targetPoint\" />\n </mxGeometry>\n </mxCell>\n <mxCell id=\"409\" value=\"LLM大模型专题导航\" style=\"edgeLabel;html=1;align=center;verticalAlign=middle;resizable=0;points=[];labelBackgroundColor=none;fontSize=19;\" parent=\"1\" vertex=\"1\" connectable=\"0\">\n <mxGeometry x=\"524.248484809835\" y=\"980.0011254969539\" as=\"geometry\">\n <mxPoint x=\"1\" as=\"offset\" />\n </mxGeometry>\n </mxCell>\n <UserObject label=\"LLM训练流程\" link=\"llm_train\" id=\"410\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#d80073;strokeColor=#A50040;shadow=1;fontColor=#ffffff;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"454.25\" y=\"1570\" width=\"80.75\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"分布式训练\" link=\"dist\" id=\"411\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#d80073;strokeColor=#A50040;shadow=1;fontColor=#ffffff;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"449.25\" y=\"1640\" width=\"90\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"GPU\" link=\"gpu\" id=\"412\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#f5f5f5;strokeColor=#666666;shadow=1;fontColor=#333333;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"380.38\" y=\"1640\" width=\"55\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"DeepSpeed\" link=\"deepspeed\" id=\"413\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#f5f5f5;strokeColor=#666666;shadow=1;fontColor=#333333;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"514.88\" y=\"1525\" width=\"70.12\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <mxCell id=\"414\" value=\"\" style=\"rounded=0;orthogonalLoop=1;jettySize=auto;html=1;strokeWidth=2;strokeColor=#999999;exitX=0.5;exitY=0;exitDx=0;exitDy=0;entryX=0.5;entryY=1;entryDx=0;entryDy=0;\" parent=\"1\" source=\"411\" target=\"410\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"535.38\" y=\"1240\" as=\"sourcePoint\" />\n <mxPoint x=\"630.38\" y=\"1590\" as=\"targetPoint\" />\n </mxGeometry>\n </mxCell>\n <mxCell id=\"415\" value=\"\" style=\"rounded=0;orthogonalLoop=1;jettySize=auto;html=1;strokeWidth=2;strokeColor=#999999;exitX=0.5;exitY=0;exitDx=0;exitDy=0;entryX=0.5;entryY=1;entryDx=0;entryDy=0;\" parent=\"1\" source=\"412\" target=\"410\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"525.38\" y=\"1650\" as=\"sourcePoint\" />\n <mxPoint x=\"525.38\" y=\"1610\" as=\"targetPoint\" />\n </mxGeometry>\n </mxCell>\n <mxCell id=\"416\" value=\"\" style=\"rounded=0;orthogonalLoop=1;jettySize=auto;html=1;strokeWidth=2;strokeColor=#999999;exitX=0.5;exitY=0;exitDx=0;exitDy=0;entryX=0.5;entryY=1;entryDx=0;entryDy=0;\" parent=\"1\" source=\"410\" target=\"462\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"535.38\" y=\"1660\" as=\"sourcePoint\" />\n <mxPoint x=\"535.38\" y=\"1620\" as=\"targetPoint\" />\n </mxGeometry>\n </mxCell>\n <mxCell id=\"417\" value=\"\" style=\"edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;\" parent=\"1\" source=\"418\" target=\"422\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\" />\n </mxCell>\n <UserObject label=\"RAG\" link=\"rag\" id=\"418\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#ffcd28;strokeColor=none;shadow=1;gradientColor=#FFB570;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"723\" y=\"1239\" width=\"60\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"FineTune\" link=\"finetune\" id=\"419\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#ffcd28;strokeColor=none;shadow=1;gradientColor=#FFB570;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"803\" y=\"1276\" width=\"60\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"RLHF\" link=\"rlhf\" id=\"420\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#ffe6cc;strokeColor=#d79b00;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"619.25\" y=\"1646\" width=\"60\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <mxCell id=\"421\" value=\"\" style=\"edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;\" parent=\"1\" source=\"422\" target=\"443\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\" />\n </mxCell>\n <UserObject label=\"PEFT\" link=\"peft\" id=\"422\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#ffcd28;strokeColor=none;shadow=1;gradientColor=#FFB570;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"803\" y=\"1239\" width=\"60\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <mxCell id=\"o5D4xRg-JXB86p6HjegH-544\" value=\"\" style=\"edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;\" parent=\"1\" source=\"423\" target=\"429\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\" />\n </mxCell>\n <mxCell id=\"o5D4xRg-JXB86p6HjegH-551\" value=\"\" style=\"edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;\" parent=\"1\" source=\"423\" target=\"o5D4xRg-JXB86p6HjegH-550\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\" />\n </mxCell>\n <mxCell id=\"423\" value=\"数据准备\" style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#d0cee2;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"287.13\" y=\"1570\" width=\"60\" height=\"30\" as=\"geometry\" />\n </mxCell>\n <UserObject label=\"模型评估\" link=\"llm_eva\" id=\"424\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#d0cee2;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"630.38\" y=\"1570\" width=\"60\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <mxCell id=\"425\" value=\"\" style=\"rounded=0;orthogonalLoop=1;jettySize=auto;html=1;strokeWidth=3;strokeColor=#999999;entryX=0;entryY=0.5;entryDx=0;entryDy=0;dashed=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;dashPattern=1 1;\" parent=\"1\" source=\"423\" target=\"410\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"534.38\" y=\"1310\" as=\"sourcePoint\" />\n <mxPoint x=\"630.38\" y=\"1350\" as=\"targetPoint\" />\n </mxGeometry>\n </mxCell>\n <mxCell id=\"426\" value=\"\" style=\"rounded=0;orthogonalLoop=1;jettySize=auto;html=1;strokeWidth=3;strokeColor=#999999;entryX=0;entryY=0.5;entryDx=0;entryDy=0;dashed=1;dashPattern=1 1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;\" parent=\"1\" source=\"410\" target=\"424\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"560.38\" y=\"1585\" as=\"sourcePoint\" />\n <mxPoint x=\"480.38\" y=\"1595\" as=\"targetPoint\" />\n </mxGeometry>\n </mxCell>\n <UserObject label=\"PyTorch\" link=\"pytorch\" id=\"427\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#f5f5f5;strokeColor=#666666;shadow=1;fontColor=#333333;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"548.5\" y=\"1640\" width=\"55\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <mxCell id=\"428\" value=\"\" style=\"rounded=0;orthogonalLoop=1;jettySize=auto;html=1;strokeWidth=2;strokeColor=#999999;exitX=0.5;exitY=0;exitDx=0;exitDy=0;entryX=0.5;entryY=1;entryDx=0;entryDy=0;\" parent=\"1\" source=\"427\" target=\"410\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"525.38\" y=\"1650\" as=\"sourcePoint\" />\n <mxPoint x=\"525.38\" y=\"1610\" as=\"targetPoint\" />\n </mxGeometry>\n </mxCell>\n <mxCell id=\"o5D4xRg-JXB86p6HjegH-543\" value=\"\" style=\"edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;\" parent=\"1\" source=\"429\" target=\"o5D4xRg-JXB86p6HjegH-542\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\" />\n </mxCell>\n <UserObject label=\"数据标注\" link=\"label\" id=\"429\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#d0cee2;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"287.13\" y=\"1525\" width=\"60\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"MoE\" link=\"moe\" id=\"430\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#ffe6cc;strokeColor=#d79b00;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"685\" y=\"1610\" width=\"60\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"LLM应用方案\" link=\"llm_solution\" id=\"431\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#f0a30a;strokeColor=#BD7000;shadow=1;fontColor=#000000;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"595\" y=\"1355\" width=\"90\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"Transformer\" link=\"transformer\" id=\"432\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"414.25\" y=\"1923\" width=\"80\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"GPT\" link=\"gpt\" id=\"433\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"405.75\" y=\"1849.5\" width=\"50\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"BERT\" link=\"bert\" id=\"434\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"475.12\" y=\"1850\" width=\"50\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <mxCell id=\"435\" value=\"\" style=\"rounded=0;orthogonalLoop=1;jettySize=auto;html=1;strokeWidth=2;strokeColor=#999999;exitX=0.5;exitY=0;exitDx=0;exitDy=0;entryX=0.5;entryY=1;entryDx=0;entryDy=0;\" parent=\"1\" source=\"432\" target=\"433\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"508.99\" y=\"1845\" as=\"sourcePoint\" />\n <mxPoint x=\"594.99\" y=\"1805\" as=\"targetPoint\" />\n </mxGeometry>\n </mxCell>\n <mxCell id=\"436\" value=\"\" style=\"rounded=0;orthogonalLoop=1;jettySize=auto;html=1;strokeWidth=2;strokeColor=#999999;exitX=0.5;exitY=0;exitDx=0;exitDy=0;entryX=0.5;entryY=1;entryDx=0;entryDy=0;\" parent=\"1\" source=\"432\" target=\"434\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"438.99\" y=\"1935\" as=\"sourcePoint\" />\n <mxPoint x=\"388.99\" y=\"1890\" as=\"targetPoint\" />\n </mxGeometry>\n </mxCell>\n <mxCell id=\"437\" value=\"\" style=\"rounded=0;orthogonalLoop=1;jettySize=auto;html=1;strokeWidth=2;strokeColor=#999999;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;\" parent=\"1\" source=\"ozFa4HHbGE1QGwIMumdl-526\" target=\"459\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"448.99\" y=\"1945\" as=\"sourcePoint\" />\n <mxPoint x=\"398.99\" y=\"1900\" as=\"targetPoint\" />\n </mxGeometry>\n </mxCell>\n <UserObject label=\"Scaline Law\" link=\"llm_law\" id=\"438\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"943\" y=\"1379\" width=\"73\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"复杂推理\" link=\"o1\" id=\"439\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"879.5\" y=\"1461.25\" width=\"50\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"Function Call\" link=\"function\" id=\"440\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"873\" y=\"1421.25\" width=\"73\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"Plugin 插件\" link=\"plugin\" id=\"441\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"953\" y=\"1421.25\" width=\"73\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <mxCell id=\"442\" value=\"小模型\" style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#a0522d;strokeColor=#6D1F00;shadow=1;fontColor=#ffffff;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"342\" y=\"1370\" width=\"56\" height=\"30\" as=\"geometry\" />\n </mxCell>\n <UserObject label=\"Agent&lt;div&gt;智能体&lt;/div&gt;\" link=\"agent\" id=\"443\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#a0522d;strokeColor=#6D1F00;shadow=1;fontColor=#ffffff;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"920\" y=\"1238.81\" width=\"60\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"LangChain\" link=\"langchain\" id=\"444\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"1106.5\" y=\"1258.81\" width=\"70\" height=\"33.75\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"AutoGen\" link=\"autogen\" id=\"445\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"1106.5\" y=\"1215.44\" width=\"70\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"CoT\" link=\"cot\" id=\"446\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#ffe6cc;strokeColor=#d79b00;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"619.25\" y=\"1610\" width=\"60\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <mxCell id=\"447\" value=\"\" style=\"rounded=0;orthogonalLoop=1;jettySize=auto;html=1;fillColor=#60a917;strokeColor=#2D7600;\" parent=\"1\" source=\"448\" target=\"449\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\" />\n </mxCell>\n <UserObject label=\"Prompt Engineering&amp;nbsp;&lt;div&gt;提示工程&lt;/div&gt;\" link=\"pe\" id=\"448\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#60a917;strokeColor=#2D7600;shadow=1;fontColor=#ffffff;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"740.5\" y=\"1113.37\" width=\"117\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"APE&amp;nbsp;&lt;div&gt;提示词自动化&lt;/div&gt;\" link=\"prompt_auto\" id=\"449\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#60a917;strokeColor=#2D7600;shadow=1;fontColor=#ffffff;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"895.5\" y=\"1113.37\" width=\"80\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"Prompt Attack&amp;nbsp;&lt;div&gt;提示词攻击&lt;/div&gt;\" link=\"prompt_attack\" id=\"450\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#60a917;strokeColor=#2D7600;shadow=1;fontColor=#ffffff;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"890.5\" y=\"1063.37\" width=\"90\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"多模态\" link=\"modal\" id=\"451\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#a0522d;strokeColor=#6D1F00;shadow=1;fontColor=#ffffff;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"464.25\" y=\"1370\" width=\"60\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <mxCell id=\"452\" value=\"\" style=\"rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=1;exitY=0.25;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;fillColor=#60a917;strokeColor=#2D7600;\" parent=\"1\" source=\"448\" target=\"450\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\" />\n </mxCell>\n <mxCell id=\"453\" value=\"\" style=\"edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;fillColor=#60a917;strokeColor=#2D7600;\" parent=\"1\" source=\"454\" target=\"448\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\" />\n </mxCell>\n <UserObject label=\"Prompt Learning&amp;nbsp;&lt;div&gt;提示学习&lt;/div&gt;\" link=\"prompt\" id=\"454\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#60a917;strokeColor=#2D7600;shadow=1;fontColor=#ffffff;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"740.5\" y=\"1063.37\" width=\"117\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"Transformers 库\" link=\"huggingface\" id=\"455\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#f5f5f5;strokeColor=#666666;shadow=1;fontColor=#333333;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"384\" y=\"1525\" width=\"90\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"Embedding\" link=\"emb\" id=\"456\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"523.99\" y=\"1925\" width=\"80\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"分词\" link=\"token\" id=\"457\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"613.99\" y=\"1925\" width=\"45\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"Pretrain Language Model&lt;div&gt;预训练语言模型&lt;/div&gt;\" link=\"plm\" id=\"458\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"545.49\" y=\"1850\" width=\"145\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"ChatGPT\" link=\"chatgpt\" id=\"459\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"393.38\" y=\"1770\" width=\"84\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <mxCell id=\"460\" value=\"NLP模型\" style=\"edgeLabel;html=1;align=center;verticalAlign=middle;resizable=0;points=[];labelBackgroundColor=none;fontSize=14;\" parent=\"1\" vertex=\"1\" connectable=\"0\">\n <mxGeometry x=\"579.998484809835\" y=\"1740.001125496954\" as=\"geometry\">\n <mxPoint x=\"1\" as=\"offset\" />\n </mxGeometry>\n </mxCell>\n <mxCell id=\"461\" value=\"\" style=\"rounded=0;orthogonalLoop=1;jettySize=auto;html=1;strokeWidth=5;strokeColor=#CCCCCC;entryX=0.5;entryY=1;entryDx=0;entryDy=0;exitX=0.5;exitY=0;exitDx=0;exitDy=0;\" parent=\"1\" source=\"385\" target=\"387\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"500\" y=\"1420\" as=\"sourcePoint\" />\n <mxPoint x=\"499\" y=\"1460\" as=\"targetPoint\" />\n </mxGeometry>\n </mxCell>\n <mxCell id=\"462\" value=\"\" style=\"rounded=1;whiteSpace=wrap;html=1;dashed=1;fillColor=#f8cecc;strokeColor=#b85450;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"310\" y=\"1425\" width=\"369.25\" height=\"70\" as=\"geometry\" />\n </mxCell>\n <UserObject label=\"ChatGLM\" link=\"chatglm\" id=\"463\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#e3c800;strokeColor=#B09500;shadow=1;fontColor=#000000;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"459.25\" y=\"1450\" width=\"70\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"Baichuan\" link=\"baichuan\" id=\"464\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#e3c800;strokeColor=#B09500;shadow=1;fontColor=#000000;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"554.25\" y=\"1450\" width=\"70\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"ChatGPT\" link=\"chatgpt\" id=\"465\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#e3c800;strokeColor=#B09500;shadow=1;fontColor=#000000;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"369.25\" y=\"1450\" width=\"70\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <mxCell id=\"466\" value=\"大语言模型\" style=\"edgeLabel;html=1;align=center;verticalAlign=middle;resizable=0;points=[];labelBackgroundColor=none;fontSize=13;fontStyle=1;fontColor=#333300;\" parent=\"1\" vertex=\"1\" connectable=\"0\">\n <mxGeometry x=\"494.25\" y=\"1440\" as=\"geometry\" />\n </mxCell>\n <mxCell id=\"467\" value=\"\" style=\"rounded=0;orthogonalLoop=1;jettySize=auto;html=1;strokeWidth=2;strokeColor=#999999;exitX=0.5;exitY=0;exitDx=0;exitDy=0;entryX=0.5;entryY=1;entryDx=0;entryDy=0;\" parent=\"1\" source=\"462\" target=\"451\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"504\" y=\"1580\" as=\"sourcePoint\" />\n <mxPoint x=\"504\" y=\"1520\" as=\"targetPoint\" />\n </mxGeometry>\n </mxCell>\n <mxCell id=\"468\" value=\"\" style=\"rounded=0;orthogonalLoop=1;jettySize=auto;html=1;strokeWidth=2;strokeColor=#999999;entryX=0.5;entryY=1;entryDx=0;entryDy=0;exitX=0.25;exitY=0;exitDx=0;exitDy=0;\" parent=\"1\" source=\"462\" target=\"442\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"490\" y=\"1430\" as=\"sourcePoint\" />\n <mxPoint x=\"504\" y=\"1410\" as=\"targetPoint\" />\n </mxGeometry>\n </mxCell>\n <mxCell id=\"469\" value=\"\" style=\"rounded=0;orthogonalLoop=1;jettySize=auto;html=1;strokeWidth=2;strokeColor=#999999;exitX=0.5;exitY=0;exitDx=0;exitDy=0;entryX=0.039;entryY=1;entryDx=0;entryDy=0;entryPerimeter=0;\" parent=\"1\" source=\"431\" target=\"382\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"504\" y=\"1440\" as=\"sourcePoint\" />\n <mxPoint x=\"650\" y=\"1310\" as=\"targetPoint\" />\n </mxGeometry>\n </mxCell>\n <mxCell id=\"470\" value=\"\" style=\"rounded=0;orthogonalLoop=1;jettySize=auto;html=1;strokeWidth=1;strokeColor=#999999;dashed=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;\" parent=\"1\" source=\"485\" target=\"444\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"735.5\" y=\"1462.81\" as=\"sourcePoint\" />\n <mxPoint x=\"810.5\" y=\"1453.81\" as=\"targetPoint\" />\n </mxGeometry>\n </mxCell>\n <mxCell id=\"471\" value=\"\" style=\"rounded=0;orthogonalLoop=1;jettySize=auto;html=1;strokeWidth=1;strokeColor=#999999;dashed=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;\" parent=\"1\" source=\"443\" target=\"485\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"651.5\" y=\"1338.81\" as=\"sourcePoint\" />\n <mxPoint x=\"1041.5\" y=\"1233.81\" as=\"targetPoint\" />\n </mxGeometry>\n </mxCell>\n <mxCell id=\"472\" value=\"垂类模型\" style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#a0522d;strokeColor=#6D1F00;shadow=1;fontColor=#ffffff;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"280\" y=\"1370\" width=\"56\" height=\"30\" as=\"geometry\" />\n </mxCell>\n <mxCell id=\"473\" value=\"专题优化\" style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#a0522d;strokeColor=#6D1F00;shadow=1;fontColor=#ffffff;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"750.5\" y=\"1419\" width=\"60\" height=\"30\" as=\"geometry\" />\n </mxCell>\n <mxCell id=\"474\" value=\"\" style=\"rounded=0;orthogonalLoop=1;jettySize=auto;html=1;strokeWidth=2;strokeColor=#999999;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;\" parent=\"1\" source=\"384\" target=\"473\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"505\" y=\"1435\" as=\"sourcePoint\" />\n <mxPoint x=\"504\" y=\"1410\" as=\"targetPoint\" />\n </mxGeometry>\n </mxCell>\n <UserObject label=\"幻觉\" link=\"hallucination\" id=\"475\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"876.5\" y=\"1379\" width=\"50\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <mxCell id=\"476\" value=\"\" style=\"rounded=0;orthogonalLoop=1;jettySize=auto;html=1;strokeWidth=2;strokeColor=#999999;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;\" parent=\"1\" source=\"473\" target=\"383\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"720.5\" y=\"1488\" as=\"sourcePoint\" />\n <mxPoint x=\"760.5\" y=\"1444\" as=\"targetPoint\" />\n </mxGeometry>\n </mxCell>\n <mxCell id=\"477\" value=\"\" style=\"edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;\" parent=\"1\" source=\"478\" target=\"418\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"753\" y=\"1289\" as=\"sourcePoint\" />\n </mxGeometry>\n </mxCell>\n <UserObject label=\"PE\" link=\"pe\" id=\"478\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"650\" y=\"1239\" width=\"50\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <mxCell id=\"479\" value=\"\" style=\"rounded=0;orthogonalLoop=1;jettySize=auto;html=1;strokeWidth=2;strokeColor=#999999;entryX=0.5;entryY=1;entryDx=0;entryDy=0;exitX=1;exitY=0.25;exitDx=0;exitDy=0;\" parent=\"1\" source=\"384\" target=\"443\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"654\" y=\"1380\" as=\"sourcePoint\" />\n <mxPoint x=\"710\" y=\"1340\" as=\"targetPoint\" />\n </mxGeometry>\n </mxCell>\n <mxCell id=\"480\" value=\"\" style=\"rounded=0;orthogonalLoop=1;jettySize=auto;html=1;strokeWidth=2;strokeColor=#999999;entryX=0.5;entryY=1;entryDx=0;entryDy=0;exitX=0.75;exitY=0;exitDx=0;exitDy=0;\" parent=\"1\" source=\"462\" target=\"431\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"640\" y=\"1375\" as=\"sourcePoint\" />\n <mxPoint x=\"655\" y=\"1333\" as=\"targetPoint\" />\n </mxGeometry>\n </mxCell>\n <mxCell id=\"481\" value=\"\" style=\"rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;\" parent=\"1\" source=\"418\" target=\"419\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"793\" y=\"1264\" as=\"sourcePoint\" />\n <mxPoint x=\"813\" y=\"1264\" as=\"targetPoint\" />\n </mxGeometry>\n </mxCell>\n <mxCell id=\"o5D4xRg-JXB86p6HjegH-540\" value=\"\" style=\"edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;\" parent=\"1\" source=\"482\" target=\"o5D4xRg-JXB86p6HjegH-538\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\" />\n </mxCell>\n <mxCell id=\"482\" value=\"推理优化\" style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#a0522d;strokeColor=#6D1F00;shadow=1;fontColor=#ffffff;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"240\" y=\"1440\" width=\"56\" height=\"30\" as=\"geometry\" />\n </mxCell>\n <mxCell id=\"483\" value=\"\" style=\"rounded=0;orthogonalLoop=1;jettySize=auto;html=1;strokeWidth=2;strokeColor=#999999;exitX=0.5;exitY=0;exitDx=0;exitDy=0;entryX=0.25;entryY=1;entryDx=0;entryDy=0;\" parent=\"1\" source=\"478\" target=\"381\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"635\" y=\"1380\" as=\"sourcePoint\" />\n <mxPoint x=\"662\" y=\"1330\" as=\"targetPoint\" />\n </mxGeometry>\n </mxCell>\n <mxCell id=\"484\" value=\"Prompt优化\" style=\"edgeLabel;html=1;align=center;verticalAlign=middle;resizable=0;points=[];labelBackgroundColor=none;rotation=-30;\" parent=\"483\" vertex=\"1\" connectable=\"0\">\n <mxGeometry x=\"-0.0199\" y=\"1\" relative=\"1\" as=\"geometry\">\n <mxPoint x=\"-13\" y=\"-5\" as=\"offset\" />\n </mxGeometry>\n </mxCell>\n <mxCell id=\"485\" value=\"Agent框架\" style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#ffcd28;strokeColor=none;shadow=1;gradientColor=#FFB570;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"1010.5\" y=\"1238.81\" width=\"70\" height=\"30\" as=\"geometry\" />\n </mxCell>\n <mxCell id=\"486\" value=\"\" style=\"rounded=0;orthogonalLoop=1;jettySize=auto;html=1;strokeWidth=1;strokeColor=#999999;dashed=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;\" parent=\"1\" source=\"485\" target=\"445\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"1096.5\" y=\"1263.81\" as=\"sourcePoint\" />\n <mxPoint x=\"1116.5\" y=\"1271.81\" as=\"targetPoint\" />\n </mxGeometry>\n </mxCell>\n <mxCell id=\"487\" value=\"模型训练\" style=\"edgeLabel;html=1;align=center;verticalAlign=middle;resizable=0;points=[];labelBackgroundColor=none;fontSize=13;fontStyle=1;fontColor=#6666FF;\" parent=\"1\" vertex=\"1\" connectable=\"0\">\n <mxGeometry x=\"314\" y=\"1670\" as=\"geometry\">\n <mxPoint x=\"-3\" y=\"-20\" as=\"offset\" />\n </mxGeometry>\n </mxCell>\n <UserObject label=\"智能客服\" link=\"ics\" id=\"488\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"328.5\" y=\"1098.3699999999997\" width=\"60\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <mxCell id=\"489\" value=\"\" style=\"edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;\" parent=\"1\" source=\"490\" target=\"513\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\" />\n </mxCell>\n <UserObject label=\"对话管理\" link=\"dialogue-manager\" id=\"490\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"231.75\" y=\"1149\" width=\"60\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <mxCell id=\"491\" value=\"\" style=\"rounded=1;whiteSpace=wrap;html=1;dashed=1;fillColor=#fff2cc;strokeColor=#d6b656;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"895.5\" y=\"1510\" width=\"190\" height=\"230\" as=\"geometry\" />\n </mxCell>\n <UserObject label=\"文本生成\" link=\"text-generation\" id=\"492\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"919\" y=\"1620\" width=\"60\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"文本分类\" link=\"cls\" id=\"493\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"994.5\" y=\"1580\" width=\"60\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"文本匹配\" link=\"text-match\" id=\"494\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"919\" y=\"1580\" width=\"60\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"NER\" link=\"ner\" id=\"495\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"919\" y=\"1660\" width=\"60\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"阅读理解\" link=\"mrc\" id=\"496\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"998.5\" y=\"1660\" width=\"60\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <mxCell id=\"497\" value=\"\" style=\"rounded=0;orthogonalLoop=1;jettySize=auto;html=1;strokeWidth=2;strokeColor=#999999;exitX=0.995;exitY=0.874;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;exitPerimeter=0;\" parent=\"1\" source=\"384\" target=\"401\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"721\" y=\"1443\" as=\"sourcePoint\" />\n <mxPoint x=\"761\" y=\"1444\" as=\"targetPoint\" />\n </mxGeometry>\n </mxCell>\n <UserObject label=\"GPT2\" link=\"gpt2\" id=\"498\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"312\" y=\"1849.5\" width=\"50\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <mxCell id=\"499\" value=\"\" style=\"rounded=0;orthogonalLoop=1;jettySize=auto;html=1;strokeWidth=2;strokeColor=#999999;exitX=0;exitY=0.5;exitDx=0;exitDy=0;entryX=1;entryY=0.5;entryDx=0;entryDy=0;\" parent=\"1\" source=\"433\" target=\"498\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"438.99\" y=\"1935\" as=\"sourcePoint\" />\n <mxPoint x=\"438.99\" y=\"1890\" as=\"targetPoint\" />\n </mxGeometry>\n </mxCell>\n <UserObject label=\"模型蒸馏\" link=\"distill\" id=\"500\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"546.75\" y=\"1810\" width=\"53.25\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"NLP基础知识\" link=\"nlp\" id=\"501\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"946.5\" y=\"1534\" width=\"90\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"知识图谱\" link=\"kg\" id=\"502\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"919\" y=\"1700\" width=\"60\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <mxCell id=\"503\" value=\"\" style=\"rounded=0;whiteSpace=wrap;html=1;labelBackgroundColor=none;fontSize=10;fillColor=#f5f5f5;dashed=1;strokeColor=#666666;fontColor=#333333;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"720\" y=\"1750\" width=\"300\" height=\"320\" as=\"geometry\" />\n </mxCell>\n <mxCell id=\"504\" value=\"深度学习\" style=\"edgeLabel;html=1;align=center;verticalAlign=middle;resizable=0;points=[];labelBackgroundColor=none;fontSize=14;\" parent=\"1\" vertex=\"1\" connectable=\"0\">\n <mxGeometry x=\"846.378484809835\" y=\"1740.001125496954\" as=\"geometry\">\n <mxPoint x=\"1\" as=\"offset\" />\n </mxGeometry>\n </mxCell>\n <UserObject label=\"机器学习\" link=\"ml\" id=\"505\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"735.25\" y=\"1770\" width=\"54.75\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"深度学习\" link=\"dl_note\" id=\"506\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"736\" y=\"1807\" width=\"54.5\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"神经网络\" link=\"ann\" id=\"507\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"730\" y=\"1917\" width=\"64.5\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"神经网络调参\" link=\"ann_tune\" id=\"508\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"803\" y=\"1917\" width=\"81\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"AutoML\" link=\"automl\" id=\"509\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"892\" y=\"1917\" width=\"62.5\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"强化学习\" link=\"rl\" id=\"510\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"730\" y=\"1952\" width=\"64.5\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"因果科学\" link=\"casual\" id=\"511\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"805.5\" y=\"1952\" width=\"64.5\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"多任务学习\" link=\"multi-task\" id=\"512\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"877.5\" y=\"1879.5\" width=\"77\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"用户模拟器\" link=\"simulator\" id=\"513\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"139\" y=\"1149\" width=\"71\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"图神经网络\" link=\"gnn\" id=\"514\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"879.5\" y=\"1952\" width=\"64.5\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"AGI\" link=\"agi\" id=\"515\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"1042\" y=\"1810\" width=\"48\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"脑机接口\" link=\"bci\" id=\"516\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"1152\" y=\"1810\" width=\"53\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"AIGC行业报告\" link=\"aigc\" id=\"517\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"1116.5\" y=\"1770\" width=\"83.5\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <mxCell id=\"518\" value=\"行业知识\" style=\"edgeLabel;html=1;align=center;verticalAlign=middle;resizable=0;points=[];labelBackgroundColor=none;fontSize=14;\" parent=\"1\" vertex=\"1\" connectable=\"0\">\n <mxGeometry x=\"1156.8784848098348\" y=\"1740.001125496954\" as=\"geometry\">\n <mxPoint x=\"1\" as=\"offset\" />\n </mxGeometry>\n </mxCell>\n <UserObject label=\"AI行业报告\" link=\"ai_report\" id=\"519\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"1042\" y=\"1770\" width=\"64.5\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"具身智能\" link=\"embodied\" id=\"520\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"1209.88\" y=\"1810\" width=\"53\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"ML笔记\" link=\"ml_note\" id=\"522\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"792.75\" y=\"1770\" width=\"47.25\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <mxCell id=\"o5D4xRg-JXB86p6HjegH-523\" value=\"应用层\" style=\"edgeLabel;html=1;align=center;verticalAlign=middle;resizable=0;points=[];labelBackgroundColor=none;fontSize=13;fontStyle=1;fontColor=#6666FF;\" parent=\"1\" vertex=\"1\" connectable=\"0\">\n <mxGeometry x=\"613.99\" y=\"1111.68\" as=\"geometry\">\n <mxPoint x=\"-7\" y=\"-2\" as=\"offset\" />\n </mxGeometry>\n </mxCell>\n <mxCell id=\"o5D4xRg-JXB86p6HjegH-527\" value=\"\" style=\"edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;\" parent=\"1\" source=\"o5D4xRg-JXB86p6HjegH-524\" target=\"o5D4xRg-JXB86p6HjegH-526\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\" />\n </mxCell>\n <UserObject label=\"人工智障\" link=\"dialogue\" id=\"o5D4xRg-JXB86p6HjegH-524\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#f8cecc;strokeColor=#b85450;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"233.75\" y=\"1098.37\" width=\"58\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <mxCell id=\"o5D4xRg-JXB86p6HjegH-525\" value=\"\" style=\"rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=1;entryY=1;entryDx=0;entryDy=0;exitX=0;exitY=0;exitDx=0;exitDy=0;\" parent=\"1\" source=\"404\" target=\"o5D4xRg-JXB86p6HjegH-524\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"369\" y=\"1159\" as=\"sourcePoint\" />\n <mxPoint x=\"369\" y=\"1138\" as=\"targetPoint\" />\n </mxGeometry>\n </mxCell>\n <UserObject label=\"大模型时代对话系统\" link=\"llm_ds\" id=\"o5D4xRg-JXB86p6HjegH-526\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#008a00;strokeColor=#005700;shadow=1;fontColor=#ffffff;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"139\" y=\"1098.37\" width=\"71\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"LLM 开发模式\" link=\"llm_dev\" id=\"o5D4xRg-JXB86p6HjegH-528\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#ffe6cc;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"659.25\" y=\"1280\" width=\"90\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"对比学习\" link=\"contrastive\" id=\"o5D4xRg-JXB86p6HjegH-529\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"803\" y=\"1879.5\" width=\"64.5\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"计算机视觉\" link=\"cv\" id=\"o5D4xRg-JXB86p6HjegH-530\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"724.87\" y=\"1992\" width=\"64.5\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"视频理解\" link=\"video\" id=\"o5D4xRg-JXB86p6HjegH-531\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"595\" y=\"1170\" width=\"60\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <mxCell id=\"o5D4xRg-JXB86p6HjegH-532\" value=\"\" style=\"rounded=0;orthogonalLoop=1;jettySize=auto;html=1;strokeWidth=1;strokeColor=#999999;entryX=0.5;entryY=1;entryDx=0;entryDy=0;dashed=1;exitX=0.75;exitY=0;exitDx=0;exitDy=0;\" parent=\"1\" source=\"394\" target=\"o5D4xRg-JXB86p6HjegH-531\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"528\" y=\"1232\" as=\"sourcePoint\" />\n <mxPoint x=\"537\" y=\"1189\" as=\"targetPoint\" />\n </mxGeometry>\n </mxCell>\n <UserObject label=\"推荐系统\" link=\"rp\" id=\"o5D4xRg-JXB86p6HjegH-533\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"725.5\" y=\"2030\" width=\"54.5\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"文档问答\" link=\"doc_chat\" id=\"o5D4xRg-JXB86p6HjegH-534\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"395\" y=\"1099.9999999999998\" width=\"60\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <mxCell id=\"o5D4xRg-JXB86p6HjegH-535\" value=\"\" style=\"edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;\" parent=\"1\" target=\"o5D4xRg-JXB86p6HjegH-534\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"390\" y=\"1170\" as=\"sourcePoint\" />\n <mxPoint x=\"369\" y=\"1138\" as=\"targetPoint\" />\n </mxGeometry>\n </mxCell>\n <UserObject label=\"开放域问答\" link=\"dialogue_qa\" id=\"o5D4xRg-JXB86p6HjegH-536\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"395\" y=\"1063.37\" width=\"65\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"LLM问题\" link=\"llm_think\" id=\"o5D4xRg-JXB86p6HjegH-537\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#f8cecc;strokeColor=#b85450;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"752.5\" y=\"1385\" width=\"58\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"推理优化\" link=\"llm_opt\" id=\"o5D4xRg-JXB86p6HjegH-538\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"169.5\" y=\"1440\" width=\"50\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"服务部署实验\" link=\"exp\" id=\"o5D4xRg-JXB86p6HjegH-541\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"169.5\" y=\"1480\" width=\"80.5\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"自动标注\" link=\"label\" id=\"o5D4xRg-JXB86p6HjegH-542\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"200\" y=\"1525\" width=\"50\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"ChatGPT应用\" link=\"chatgpt_application\" id=\"o5D4xRg-JXB86p6HjegH-545\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"217.75\" y=\"1208.81\" width=\"78.25\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"评估方法\" link=\"eva\" id=\"o5D4xRg-JXB86p6HjegH-546\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"711.25\" y=\"1540\" width=\"50\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"目标检测\" link=\"loss\" id=\"o5D4xRg-JXB86p6HjegH-547\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"845.5\" y=\"1992\" width=\"51.25\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"大模型评测\" link=\"llm_eva\" id=\"o5D4xRg-JXB86p6HjegH-548\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"711.25\" y=\"1575\" width=\"62\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"ChatGPT复现\" link=\"chatgpt_mimic\" id=\"o5D4xRg-JXB86p6HjegH-549\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"365.12\" y=\"1490\" width=\"78.25\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"AI生成\" link=\"llm_data\" id=\"o5D4xRg-JXB86p6HjegH-550\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"200\" y=\"1570\" width=\"50\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"LLM原理\" link=\"llm_arch\" id=\"o5D4xRg-JXB86p6HjegH-552\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"357.13\" y=\"1590\" width=\"62.87\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"音乐生成\" link=\"music_gen\" id=\"o5D4xRg-JXB86p6HjegH-553\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"256\" y=\"1260.68\" width=\"60\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <mxCell id=\"o5D4xRg-JXB86p6HjegH-554\" value=\"\" style=\"rounded=0;orthogonalLoop=1;jettySize=auto;html=1;strokeWidth=1;strokeColor=#999999;entryX=1;entryY=0.5;entryDx=0;entryDy=0;dashed=1;exitX=0;exitY=0.5;exitDx=0;exitDy=0;\" parent=\"1\" source=\"393\" target=\"o5D4xRg-JXB86p6HjegH-553\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"528\" y=\"1232\" as=\"sourcePoint\" />\n <mxPoint x=\"537\" y=\"1189\" as=\"targetPoint\" />\n </mxGeometry>\n </mxCell>\n <UserObject label=\"推理加速\" link=\"infer\" id=\"o5D4xRg-JXB86p6HjegH-555\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"169.5\" y=\"1400\" width=\"50\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"端侧LLM\" link=\"llm_end\" id=\"o5D4xRg-JXB86p6HjegH-556\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"310\" y=\"1320\" width=\"50\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"OpenAI\" link=\"openai\" id=\"o5D4xRg-JXB86p6HjegH-557\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"1044\" y=\"1849.9999999999998\" width=\"50\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"AI公司\" link=\"ai_company\" id=\"o5D4xRg-JXB86p6HjegH-558\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"1102\" y=\"1850\" width=\"50\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"AIGC 机会\" link=\"aigc_idea\" id=\"o5D4xRg-JXB86p6HjegH-559\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"1209.88\" y=\"1770\" width=\"60.12\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <mxCell id=\"o5D4xRg-JXB86p6HjegH-561\" value=\"\" style=\"edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;\" parent=\"1\" source=\"o5D4xRg-JXB86p6HjegH-560\" edge=\"1\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"180\" y=\"1180\" as=\"targetPoint\" />\n </mxGeometry>\n </mxCell>\n <UserObject label=\"用户画像\" link=\"user\" id=\"o5D4xRg-JXB86p6HjegH-560\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"155\" y=\"1198.81\" width=\"51\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"大脑原理\" link=\"brain\" id=\"o5D4xRg-JXB86p6HjegH-563\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"1094\" y=\"1810\" width=\"53\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"回归分析\" link=\"regression\" id=\"o5D4xRg-JXB86p6HjegH-564\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"888.5\" y=\"1806\" width=\"58.75\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"芯片\" link=\"chip\" id=\"o5D4xRg-JXB86p6HjegH-565\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"1156.88\" y=\"1850\" width=\"38\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"在线教育\" link=\"tutor\" id=\"o5D4xRg-JXB86p6HjegH-566\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"1090\" y=\"2053\" width=\"53\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"汽车原理\" link=\"car\" id=\"o5D4xRg-JXB86p6HjegH-567\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"1044\" y=\"1935\" width=\"53\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"自动驾驶\" link=\"driving\" id=\"o5D4xRg-JXB86p6HjegH-568\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"1109\" y=\"1935\" width=\"53\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"异常检测\" link=\"anomaly\" id=\"o5D4xRg-JXB86p6HjegH-569\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"732.88\" y=\"1842.5\" width=\"58.75\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"聚类算法\" link=\"cluster\" id=\"o5D4xRg-JXB86p6HjegH-570\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"799.62\" y=\"1842.5\" width=\"58.75\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <mxCell id=\"o5D4xRg-JXB86p6HjegH-571\" value=\"\" style=\"rounded=0;whiteSpace=wrap;html=1;labelBackgroundColor=none;fontSize=10;fillColor=#f5f5f5;dashed=1;strokeColor=#666666;fontColor=#333333;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"720\" y=\"2098\" width=\"300\" height=\"130\" as=\"geometry\" />\n </mxCell>\n <UserObject label=\"贝叶斯\" link=\"bayes\" id=\"o5D4xRg-JXB86p6HjegH-572\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"829.75\" y=\"2112\" width=\"47.5\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"元宇宙\" link=\"meta\" id=\"o5D4xRg-JXB86p6HjegH-573\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"1044\" y=\"1975\" width=\"53\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"新技术\" link=\"new_tech\" id=\"o5D4xRg-JXB86p6HjegH-574\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"1107\" y=\"1891\" width=\"48\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"机器人\" link=\"robot\" id=\"o5D4xRg-JXB86p6HjegH-575\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"1200\" y=\"1850\" width=\"40\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"搜索\" link=\"search\" id=\"o5D4xRg-JXB86p6HjegH-576\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"849.63\" y=\"2030\" width=\"40.75\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"可解释性\" link=\"explain\" id=\"o5D4xRg-JXB86p6HjegH-577\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"958.13\" y=\"1880\" width=\"58.75\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"NAS\" link=\"nas\" id=\"o5D4xRg-JXB86p6HjegH-578\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"959.5\" y=\"1917\" width=\"46.5\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"元学习\" link=\"meta_learning\" id=\"o5D4xRg-JXB86p6HjegH-579\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"948.5\" y=\"1952\" width=\"47.5\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"情感计算\" link=\"emotion\" id=\"o5D4xRg-JXB86p6HjegH-580\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"994.5\" y=\"1620\" width=\"60\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"知识追踪\" link=\"dkt\" id=\"o5D4xRg-JXB86p6HjegH-581\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"1149.38\" y=\"2053\" width=\"53\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"互联网金融\" link=\"finance\" id=\"o5D4xRg-JXB86p6HjegH-582\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"1047.75\" y=\"2014\" width=\"62.25\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"房产行业\" link=\"house\" id=\"o5D4xRg-JXB86p6HjegH-583\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"1044\" y=\"1892\" width=\"53\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"量化交易\" link=\"quant\" id=\"o5D4xRg-JXB86p6HjegH-584\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"1116.5\" y=\"2014\" width=\"53\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"股票\" link=\"stock\" id=\"o5D4xRg-JXB86p6HjegH-585\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"1176.5\" y=\"2014\" width=\"43.5\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"物联网\" link=\"iot\" id=\"o5D4xRg-JXB86p6HjegH-586\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"1160\" y=\"1893\" width=\"45\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"移动设备\" link=\"phone\" id=\"o5D4xRg-JXB86p6HjegH-587\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"1169\" y=\"1935\" width=\"53\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"语音识别\" link=\"voice\" id=\"o5D4xRg-JXB86p6HjegH-588\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"948.5\" y=\"1992\" width=\"56.25\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"模型部署\" link=\"model_deploy\" id=\"o5D4xRg-JXB86p6HjegH-589\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"255.5\" y=\"1480\" width=\"50\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"最优化\" link=\"optimization\" id=\"o5D4xRg-JXB86p6HjegH-590\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"863.75\" y=\"1842.5\" width=\"58.75\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"排序学习\" link=\"ltr\" id=\"o5D4xRg-JXB86p6HjegH-591\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"783.88\" y=\"2030\" width=\"59.25\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"微积分\" link=\"calculus\" id=\"o5D4xRg-JXB86p6HjegH-592\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"881\" y=\"2112\" width=\"40.5\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"知识图谱\" link=\"kg\" id=\"o5D4xRg-JXB86p6HjegH-593\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"952.13\" y=\"2030\" width=\"54.74\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"博弈论\" link=\"game-thoery\" id=\"o5D4xRg-JXB86p6HjegH-594\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"925.5\" y=\"2148\" width=\"43.5\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"联邦学习\" link=\"faderation\" id=\"o5D4xRg-JXB86p6HjegH-595\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"898\" y=\"2030\" width=\"50\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"密码学\" link=\"cryptography\" id=\"o5D4xRg-JXB86p6HjegH-596\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"808\" y=\"2186\" width=\"37.5\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"流形学习\" link=\"manifold\" id=\"o5D4xRg-JXB86p6HjegH-597\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"781.62\" y=\"2148\" width=\"57.5\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"Python\" link=\"python\" id=\"o5D4xRg-JXB86p6HjegH-598\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"310.74\" y=\"2020\" width=\"45.01\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"特征工程\" link=\"fe\" id=\"o5D4xRg-JXB86p6HjegH-599\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"808.75\" y=\"1806\" width=\"64.25\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"区块链\" link=\"block-chain\" id=\"o5D4xRg-JXB86p6HjegH-600\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"1106.5\" y=\"1975\" width=\"53\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"信息论\" link=\"information\" id=\"o5D4xRg-JXB86p6HjegH-601\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"925.5\" y=\"2112\" width=\"46\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"概率统计\" link=\"probability\" id=\"o5D4xRg-JXB86p6HjegH-602\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"724.87\" y=\"2148\" width=\"50.87\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"量子计算\" link=\"quantum\" id=\"o5D4xRg-JXB86p6HjegH-603\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"1169\" y=\"1975\" width=\"53\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"Pandas\" link=\"pandas\" id=\"o5D4xRg-JXB86p6HjegH-604\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"365\" y=\"2020\" width=\"42.38\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"Scikit-learn\" link=\"sklearn\" id=\"o5D4xRg-JXB86p6HjegH-605\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"415.75\" y=\"2020\" width=\"64.25\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"文本挖掘\" link=\"text-mining\" id=\"o5D4xRg-JXB86p6HjegH-606\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"996.63\" y=\"1700\" width=\"60\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"神经网络可视化\" link=\"train_vis\" id=\"o5D4xRg-JXB86p6HjegH-607\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"932\" y=\"1842.5\" width=\"94\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"不均衡问题\" link=\"imbalance\" id=\"o5D4xRg-JXB86p6HjegH-608\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"729\" y=\"1879.5\" width=\"70\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"精简笔记\" link=\"dl_sum\" id=\"o5D4xRg-JXB86p6HjegH-609\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"955.25\" y=\"1770\" width=\"54.75\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"文本分类\" link=\"cls\" id=\"o5D4xRg-JXB86p6HjegH-610\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"953.75\" y=\"1806\" width=\"60\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"ML军规\" link=\"ml_rule\" id=\"o5D4xRg-JXB86p6HjegH-611\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"846.38\" y=\"1770\" width=\"47.25\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"线性代数与矩阵\" link=\"bayes\" id=\"o5D4xRg-JXB86p6HjegH-612\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"724.87\" y=\"2112\" width=\"94.25\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"Go\" link=\"go\" id=\"o5D4xRg-JXB86p6HjegH-614\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"489.24\" y=\"2020\" width=\"35.01\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"ML本质\" link=\"ml_essense\" id=\"o5D4xRg-JXB86p6HjegH-615\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"900.75\" y=\"1770\" width=\"47.25\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"LBS\" link=\"lbs\" id=\"o5D4xRg-JXB86p6HjegH-616\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"1230\" y=\"1935\" width=\"30\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"傅里叶变换\" link=\"fourier\" id=\"o5D4xRg-JXB86p6HjegH-617\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"728.68\" y=\"2186\" width=\"70.44\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"Git\" link=\"git\" id=\"o5D4xRg-JXB86p6HjegH-618\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"312.73\" y=\"2059\" width=\"35.01\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"Jupyter\" link=\"jupyter\" id=\"o5D4xRg-JXB86p6HjegH-619\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"353.1\" y=\"2059\" width=\"41.64\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"Linux\" link=\"linux\" id=\"o5D4xRg-JXB86p6HjegH-620\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"444.99\" y=\"2059\" width=\"35.01\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"Shell\" link=\"shell\" id=\"o5D4xRg-JXB86p6HjegH-621\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"530.75\" y=\"2020\" width=\"35.01\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"Latex\" link=\"latex\" id=\"o5D4xRg-JXB86p6HjegH-622\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"400.98\" y=\"2059\" width=\"35.01\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"Jekyll\" link=\"jekyll\" id=\"o5D4xRg-JXB86p6HjegH-623\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"651.23\" y=\"2100\" width=\"35.01\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"教育\" link=\"edu\" id=\"o5D4xRg-JXB86p6HjegH-624\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"1047.75\" y=\"2053\" width=\"32.25\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"分形几何\" link=\"fractal\" id=\"o5D4xRg-JXB86p6HjegH-625\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"845.4999999999999\" y=\"2148\" width=\"70.44\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"SQL\" link=\"data\" id=\"o5D4xRg-JXB86p6HjegH-626\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"313.24\" y=\"2137\" width=\"32.51\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"可视化\" link=\"vis\" id=\"o5D4xRg-JXB86p6HjegH-627\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"358.74\" y=\"2137\" width=\"47.01\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"数据挖掘\" link=\"dm\" id=\"o5D4xRg-JXB86p6HjegH-628\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"416.74\" y=\"2137\" width=\"50\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"vpn\" link=\"vpn\" id=\"o5D4xRg-JXB86p6HjegH-629\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"314.49\" y=\"2177\" width=\"32.51\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"计算机网络\" link=\"network\" id=\"o5D4xRg-JXB86p6HjegH-630\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"360.74\" y=\"2177\" width=\"65.01\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"计算机语言\" link=\"computer\" id=\"o5D4xRg-JXB86p6HjegH-631\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"431.13\" y=\"2177\" width=\"65.01\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"操作系统\" link=\"os\" id=\"o5D4xRg-JXB86p6HjegH-632\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"501.74\" y=\"2177\" width=\"54.01\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"图形学\" link=\"graphic\" id=\"o5D4xRg-JXB86p6HjegH-633\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"560.99\" y=\"2177\" width=\"54.01\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"计算机知识脑图\" link=\"mindmap\" id=\"o5D4xRg-JXB86p6HjegH-634\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"474.24\" y=\"2137\" width=\"91.51\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"基础算法\" link=\"algorithm\" id=\"o5D4xRg-JXB86p6HjegH-635\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"313.24\" y=\"2098\" width=\"52.51\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"算法比赛\" link=\"kaggle\" id=\"o5D4xRg-JXB86p6HjegH-636\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"377.37\" y=\"2098\" width=\"52.51\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"Web前端\" link=\"web\" id=\"o5D4xRg-JXB86p6HjegH-637\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"439.12\" y=\"2098\" width=\"52.51\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"架构设计\" link=\"arch\" id=\"o5D4xRg-JXB86p6HjegH-638\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"500.24\" y=\"2098\" width=\"52.51\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"Docker\" link=\"docker\" id=\"o5D4xRg-JXB86p6HjegH-639\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"627.94\" y=\"2138\" width=\"41.64\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"小程序\" link=\"mini\" id=\"o5D4xRg-JXB86p6HjegH-640\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"559.49\" y=\"2099\" width=\"46.26\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"测试\" link=\"test\" id=\"o5D4xRg-JXB86p6HjegH-641\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"609.74\" y=\"2100\" width=\"36.01\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"面试指南\" link=\"interview\" id=\"o5D4xRg-JXB86p6HjegH-642\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"571.75\" y=\"2138\" width=\"50\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"数学历史\" link=\"math\" id=\"o5D4xRg-JXB86p6HjegH-643\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"851.62\" y=\"2186\" width=\"57.5\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"makefile\" link=\"makefile\" id=\"o5D4xRg-JXB86p6HjegH-644\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"575.75\" y=\"2020\" width=\"50\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"Linux C\" link=\"linux-program\" id=\"o5D4xRg-JXB86p6HjegH-645\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"489.25\" y=\"2059\" width=\"51.5\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"C/C++\" link=\"c\" id=\"o5D4xRg-JXB86p6HjegH-647\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"635.75\" y=\"2020\" width=\"35.01\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"设计模式\" link=\"design\" id=\"o5D4xRg-JXB86p6HjegH-648\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"621.75\" y=\"2177\" width=\"54.01\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"Tensorflow\" link=\"linux-program\" id=\"o5D4xRg-JXB86p6HjegH-649\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"544.74\" y=\"2059\" width=\"56.01\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"Pytorch\" link=\"linux-program\" id=\"o5D4xRg-JXB86p6HjegH-651\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"607.74\" y=\"2059\" width=\"43.01\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"Pytorch手册\" link=\"pytorch_simple\" id=\"o5D4xRg-JXB86p6HjegH-652\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" parent=\"1\" vertex=\"1\">\n <mxGeometry x=\"656.99\" y=\"2059\" width=\"43.01\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <mxCell id=\"o5D4xRg-JXB86p6HjegH-653\" value=\"计算机基础\" style=\"edgeLabel;html=1;align=center;verticalAlign=middle;resizable=0;points=[];labelBackgroundColor=none;fontSize=14;\" parent=\"1\" vertex=\"1\" connectable=\"0\">\n <mxGeometry x=\"510.74848480983496\" y=\"2001.001125496954\" as=\"geometry\">\n <mxPoint x=\"-4\" y=\"-4\" as=\"offset\" />\n </mxGeometry>\n </mxCell>\n <mxCell id=\"o5D4xRg-JXB86p6HjegH-654\" value=\"数学知识\" style=\"edgeLabel;html=1;align=center;verticalAlign=middle;resizable=0;points=[];labelBackgroundColor=none;fontSize=14;\" parent=\"1\" vertex=\"1\" connectable=\"0\">\n <mxGeometry x=\"851.618484809835\" y=\"2089.001125496954\" as=\"geometry\">\n <mxPoint x=\"1\" as=\"offset\" />\n </mxGeometry>\n </mxCell>\n <mxCell id=\"o5D4xRg-JXB86p6HjegH-655\" value=\"【2024-11-24】&lt;div&gt;wqw547243068@163.com&lt;/div&gt;\" style=\"edgeLabel;html=1;align=center;verticalAlign=middle;resizable=0;points=[];labelBackgroundColor=none;fontSize=14;\" parent=\"1\" vertex=\"1\" connectable=\"0\">\n <mxGeometry x=\"279.99848480983496\" y=\"1720.001125496954\" as=\"geometry\">\n <mxPoint x=\"-4\" y=\"-4\" as=\"offset\" />\n </mxGeometry>\n </mxCell>\n <UserObject label=\"图像处理\" link=\"image\" id=\"ozFa4HHbGE1QGwIMumdl-522\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" vertex=\"1\" parent=\"1\">\n <mxGeometry x=\"791.88\" y=\"1992\" width=\"51.25\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"OCR\" link=\"ocr\" id=\"ozFa4HHbGE1QGwIMumdl-523\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" vertex=\"1\" parent=\"1\">\n <mxGeometry x=\"900.75\" y=\"1992\" width=\"39.25\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"智能硬件\" link=\"smart_device\" id=\"ozFa4HHbGE1QGwIMumdl-524\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" vertex=\"1\" parent=\"1\">\n <mxGeometry x=\"1209.88\" y=\"1892\" width=\"50.12\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"传感器\" link=\"sensor\" id=\"ozFa4HHbGE1QGwIMumdl-525\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" vertex=\"1\" parent=\"1\">\n <mxGeometry x=\"1245\" y=\"1850\" width=\"40\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"GPT3\" link=\"gpt2\" id=\"ozFa4HHbGE1QGwIMumdl-526\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" vertex=\"1\" parent=\"1\">\n <mxGeometry x=\"313.24\" y=\"1800\" width=\"50\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <mxCell id=\"ozFa4HHbGE1QGwIMumdl-527\" value=\"\" style=\"rounded=0;orthogonalLoop=1;jettySize=auto;html=1;strokeWidth=2;strokeColor=#999999;exitX=0.5;exitY=0;exitDx=0;exitDy=0;entryX=0.5;entryY=1;entryDx=0;entryDy=0;\" edge=\"1\" parent=\"1\" source=\"498\" target=\"ozFa4HHbGE1QGwIMumdl-526\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"414\" y=\"1875\" as=\"sourcePoint\" />\n <mxPoint x=\"375\" y=\"1875\" as=\"targetPoint\" />\n </mxGeometry>\n </mxCell>\n <UserObject label=\"Transformer&lt;div&gt;改进&lt;/div&gt;\" link=\"transformer_update\" id=\"ozFa4HHbGE1QGwIMumdl-528\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#b0e3e6;strokeColor=none;shadow=1;\" vertex=\"1\" parent=\"1\">\n <mxGeometry x=\"291.75\" y=\"1922\" width=\"80\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <mxCell id=\"ozFa4HHbGE1QGwIMumdl-529\" value=\"\" style=\"rounded=0;orthogonalLoop=1;jettySize=auto;html=1;strokeWidth=2;strokeColor=#999999;entryX=1;entryY=0.5;entryDx=0;entryDy=0;exitX=0;exitY=0.5;exitDx=0;exitDy=0;\" edge=\"1\" parent=\"1\" source=\"432\" target=\"ozFa4HHbGE1QGwIMumdl-528\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"390\" y=\"1940\" as=\"sourcePoint\" />\n <mxPoint x=\"439\" y=\"1890\" as=\"targetPoint\" />\n </mxGeometry>\n </mxCell>\n <UserObject label=\"AIGC&lt;div&gt;内容检测&lt;/div&gt;\" link=\"aigc_detect\" id=\"ozFa4HHbGE1QGwIMumdl-530\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#f8cecc;strokeColor=none;shadow=1;\" vertex=\"1\" parent=\"1\">\n <mxGeometry x=\"162\" y=\"1258\" width=\"65\" height=\"34\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"端到端对话\" link=\"end_voice\" id=\"ozFa4HHbGE1QGwIMumdl-531\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#008a00;strokeColor=#005700;shadow=1;fontColor=#ffffff;\" vertex=\"1\" parent=\"1\">\n <mxGeometry x=\"139\" y=\"1050\" width=\"71\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <UserObject label=\"LLM发展方向\" link=\"llm_direction\" id=\"ozFa4HHbGE1QGwIMumdl-532\">\n <mxCell style=\"rounded=1;whiteSpace=wrap;html=1;fillColor=#a0522d;strokeColor=#6D1F00;shadow=1;fontColor=#ffffff;\" vertex=\"1\" parent=\"1\">\n <mxGeometry x=\"1030.5\" y=\"1395\" width=\"80\" height=\"30\" as=\"geometry\" />\n </mxCell>\n </UserObject>\n <mxCell id=\"ozFa4HHbGE1QGwIMumdl-534\" value=\"\" style=\"rounded=0;orthogonalLoop=1;jettySize=auto;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;\" edge=\"1\" parent=\"1\" source=\"o5D4xRg-JXB86p6HjegH-526\" target=\"ozFa4HHbGE1QGwIMumdl-531\">\n <mxGeometry relative=\"1\" as=\"geometry\">\n <mxPoint x=\"339\" y=\"1159\" as=\"sourcePoint\" />\n <mxPoint x=\"302\" y=\"1138\" as=\"targetPoint\" />\n </mxGeometry>\n </mxCell>\n </root>\n </mxGraphModel>\n </diagram>\n</mxfile>\n"}"></div>
<script type="text/javascript" src="https://viewer.diagrams.net/js/viewer-static.min.js"></script>
<!-- 评论区 -->
<h2 id="comments">Comments</h2>
<!-- 【2023-1-4】github新插件giscus, 暂未启用 -->
<script src="https://giscus.app/client.js"
data-repo="wqw547243068/wqw547243068.github.io"
data-repo-id="MDEwOlJlcG9zaXRvcnkxNDE3ODEwMzg="
data-category="Announcements"
data-category-id="DIC_kwDOCHNoLs4CRJjU"
data-mapping="title"
data-strict="0"
data-reactions-enabled="1"
data-emit-metadata="0"
data-input-position="bottom"
data-theme="light"
data-lang="en"
// data-lang="zh-CN"
crossorigin="anonymous"
async>
</script>
<!-- disqus插件 -->
<p> --disqus-- </p>
<div id="disqus_thread"></div>
<script>
/**
* RECOMMENDED CONFIGURATION VARIABLES: EDIT AND UNCOMMENT THE SECTION BELOW TO INSERT DYNAMIC VALUES FROM YOUR PLATFORM OR CMS.
* LEARN WHY DEFINING THESE VARIABLES IS IMPORTANT: https://disqus.com/admin/universalcode/#configuration-variables
*/
var disqus_config = function() {
this.page.url = 'https://wqw547243068.github.io/dist_tool'; // Replace PAGE_URL with your page's canonical URL variable
this.page.identifier = 'https://wqw547243068.github.io/dist_tool'; // Replace PAGE_IDENTIFIER with your page's unique identifier variable
};
(function() { // DON'T EDIT BELOW THIS LINE
var d = document,
s = d.createElement('script');
s.src = '//wqw.disqus.com/embed.js';
s.setAttribute('data-timestamp', +new Date());
(d.head || d.body).appendChild(s);
})();
</script>
<noscript>Please enable JavaScript to view the <a href="https://disqus.com/?ref_noscript" rel="nofollow">comments powered by Disqus.</a></noscript>
</div>
<button class="anchor"><i class="fa fa-anchor"></i></button>
<!-- 右侧工具栏 -->
<div class="right">
<!-- 搜索框 -->
<div>
<!-- HTML elements for search -->
<input type="text" id="search-input" placeholder="Search blog posts..">
<ul id="results-container"></ul>
<!-- script pointing to jekyll-search.js -->
<script src="/js/simple-jekyll-search.min.js"></script>
<!-- <script src="https://cdn.rawgit.com/christian-fei/Simple-Jekyll-Search/master/dest/simple-jekyll-search.min.js"></script> -->
</div>
<!-- [2019-08-07]搜索框 -->
<script>
window.simpleJekyllSearch = new SimpleJekyllSearch({
searchInput: document.getElementById('search-input'),
resultsContainer: document.getElementById('results-container'),
json: '/search.json',
searchResultTemplate: '<li><a href="{url}?query={query}" title="{desc}">{title}</a></li>',
noResultsText: 'No results found',
limit: 10,
fuzzy: false,
exclude: ['Welcome']
})
</script>
<!-- 访问可视化 -->
<!-- 访问统计可视化工具 -->
<div class="side">
<script type="text/javascript" src="//rf.revolvermaps.com/0/0/1.js?i=5q2837r7gjo&s=265&m=7&v=true&r=false&b=000000&n=false&c=ff0000" async="async"></script>
</div>
<script type='text/javascript' id='mapmyvisitors' src='https://mapmyvisitors.com/map.js?cl=ffffff&w=a&t=n&d=Jqz5ooTlHsfwaaqJF5LezHsg7HXvyf3s_N_TE_2u8xM'></script>
<div class="wrap">
<!-- Content目录区 -->
<div class="side content">
<div>
Content
</div>
<ul id="content-side" class="content-ul">
<li><a href="#comments">Comments</a></li>
</ul>
</div>
<!-- 公众号区 -->
<!-- 公众号区 -->
<div class="side content">
<div>My Moment ( 微信公众号 )</div>
<img src="https://wqw547243068.github.io/wqw/fig/wqw.png" alt="欢迎关注鹤啸九天" />
</div>
<!-- 其他div框放到这里 -->
<!-- <div class="side">bbbb</div> -->
</div>
</div>
</div>
<script>
/**
* target _blank
*/
(function() {
var aTags = document.querySelectorAll('article a:not([id])')
for (var i = 0; i < aTags.length; i++) {
aTags[i].setAttribute('target', '_blank')
}
}());
</script>
<script src="/js/pageContent.js " charset="utf-8"></script>
<!-- 【2022-9-1】 支持latex数学公式显示 -->
<script src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML" type="text/javascript"></script><|im_end|>
用户输入扮演 user 的 role ,而模型生成则承担 assistant 的 role 。
Qwen 还支持元消息,指导模型执行特定操作或生成具有特定特性的文本,例如: 改变语气、风格或内容,这将承担 system 的 role,且内容默认为 You are Qwen, created by Alibaba Cloud. You are a helpful assistant.
。
数据量
预训练数据共 3TB,涉及: 公共网络文档、百科全书、书籍、代码等,数据涉及多语言,但以中文和英文为主。
为了保证数据质量,制定了一套全面的预处理程序。
- Web数据需要从HTML中提取文本内容,并采用语言识别工具确定语种;
- 如: python cdl 工具包检测语种, 编码检测 chardetect, kenlm 计算流畅度
- 通过重复数据删除技术增加数据的多样性,包括规范化后精确匹配重复数据删除方法和使用
MinHash
和LSH
算法的模糊重复数据删除方法; - 结合规则和机器学习的方法过滤低质量数据,即通过多个模型对内容进行评分,包括语言模型、文本质量评分模型以及用于识别潜在冒犯性模型;
- 从各种来源数据中手动采样并进行审查,以确保其质量;
- 有选择地对来自某些来源的数据进行采样,以确保模型在各种高质量内容上进行训练。
长度
Qwen2.5 训练中打包序列长度为 32768 个 token。
- 预训练中最大文档长度即为此长度。
- 而后训练中,user和assistant的最大消息长度则有所不同。一般情况下,assistant消息长度可达 8192 个 token。
要点:
- Qwen2 模型可以处理 32K 或 128K token 长的文本,其中 8K 长度可作为输出。
模型结构
模型采用Transformer框架,主要做了以下修改:
- Embedding and output projection:对于embedding层和lm_head层不进行权重共享,是两个单独的权重。
- Positional embedding:采用
RoPE
为位置编码,并选择使用FP32
精确度的逆频率矩阵。 - Bias:在QKV注意力层中添加了偏差,以增强模型的外推能力。
- Pre-Norm & RMSNorm:采用预归一化提高训练稳定性,并将传统归一化方法替换为
RMSNorm
。 - Activation function:采用SwiGLU激活函数,不同于传统FFN的2个矩阵,SwiGLU有三个矩阵,因此缩小了隐藏层维度,由原来的4倍变成8/3倍。
外推能力扩展
Transformer 模型的注意力机制在上下文长度上有很大限制,模型会随着上下文长度的增加,计算成本和内存会成倍增加。
Qwen模型利用了简单地非训练计算,在推理过程中扩展上下文长度。
- 动态NTK感知插值,即对序列长度的增加动态缩放位置信息。
- LogN-Scaling,根据上下文长度与训练长度的比率,对Q和V的点积进行重新缩放,确保注意力值的熵随着上下文长度的增长而保持稳定。
- Window attention,将注意力限制在一个上下文窗口内,防止模型关注到太远的内容。并在不同层采用不同的窗口大小,较低的层使用较短的窗口,而较高的层使用较长的窗口。
训练
qwen系列大模型本地部署,法律大模型训练,
- 只需5G内存部署本地大模型,
- 只需6G显存训练自己的法律大模型。
- lora模型训练完成后,会合并到主模型,生成自己专属的大模型。
视频演示
预训练
遵循自回归语言建模的标准方法,通过前面Token的内容预测下一个Token;
- 模型预训练时最大长度为2048,为了构建批次数据,对文本内容进行随机打乱及合并,再讲其截断到指定长度。
- 注意力模块采用
Flash Attention
技术,提高训练速度; - 优化器采用AdamW,超参数β1、β2和ϵ为别为0.9、0.95和10−8;
- 采用
余弦学习率
计划,学习率会衰减到峰值的10%; - 采用
BFloat16
进行混合精度训练。
QWEN 模型再同等级参数下表现优异,即使是更大的型号如LLaMA2-70B,在3个任务中也被QWEN-14B超越。
有监督微调SFT
为了提高有监督微调数据集的能力,对多种风格的对话进行了标注,来关注不同任务的自然语言生成,进一步提高模型的有用性。并且大小训练方法也会影响模型行了,Qwen采用ChatML
样式的格式来进行模型训练。
ChatML格式让模型有效区分各类信息,包括:系统质量、用户输入、模型输出等,可以增强模型对复杂会话的处理分析能力。
ChatML Format 对话模版
<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
hello, who are you?<|im_end|>
<|im_start|>assistant
I am a AI program developed by Firefly<|im_end|>
训练
- 优化器采用
AdamW
,超参数β1、β2和ϵ为别为0.9、0.95和1e−8; - 模型最大输入长度
2048
; - 训练批次大小为
128
; - 模型共训练4000步,在前1430步中,学习率逐渐增加,达到2e−6的峰值。
- 为了防止过拟合,权重衰减的值设置为0.1,dropout设置为0.1,梯度裁剪的限制为1.0。
RM
RM模型
奖励模型构建上,先采用大量数据进行偏好模型预训练(preference model pretraining,PMP
),在经过高质量偏好数据进行奖励模型精调。
高质量偏好数据通过6600详细标签的分类系统平衡采样获取,为保证数据的多样性和复杂性。
奖励模型时由同等大小Qwen模型+池化层得来,用特殊的句子结束标记映射值作为模型奖励值。
模型在训练过程中,学习率恒为3e−6,批次大小为64,最大长度为2048,训练一个epoch。
效果
PPO
PPO阶段共包含四个模型:policy模型、value模型、reference模型、reward模型。
训练过程中,先对policy模型训练50步预热,这样保证了value模型能够有效地适应不同的奖励模型。在PPO过程中,对每个query会同时采样两个response,KL散度系数设为0.04,并根据平均值对奖励进行归一化处理。
policy模型和value模型的学习率分别为1e−6和5e−6。为了增强训练的稳定性,裁剪值0.15。在进行推理时,生成策略的top-p值设置为0.9。
对齐效果
Qwen的效果优于相同规模的其他开源模型,如LLaMA2、ChatGLM2、InternLM、Baichuan2
人工评测,比较了 Qwen-7B-Chat(SFT)、Qwen-14B-Chat(SFT)、Qwen-14B-Chat(RLHF)、GPT4在对话上与GPT3.5的差异。
RLHF模型明显优于SFT模型,说明RLHF可以生成更受人类喜爱的回答。
工具使用
Qwen模型具有工具使用能力:
- 通过ReAct提示进行使用未见的工具;
- 用Python解释器增强数学推理、数据分析等能力;
- 作为代理,与人类交互过程中,可以访问HuggingFace中大量多模态模型集合。
PS:高质量数据2000条-React格式数据。
如何用 ReAct Prompting 技术命令千问使用工具
【2024-3-4】GPT-4
【2024-3-4】GPT-4 Technical Report
【2024-3-9】Yi
参考
- 【2024-3-9】Yi技术报告细节分享
- Yi技术报告-划重点看细节
Yi 介绍
AI(零一万物
)是李开复带队孵化的AI公司。
- 2023年11月初,01.AI 发布并开源了
Yi-6B
、Yi-34B base
模型,同一周内,又开源了Yi-6B-200K
和Yi-34B-200K base
模型。Yi号称是从零预训练的双语模型。 - 接下来的几个月,01.AI陆续推出了chat模型、多模态能力,
Yi-9B
、长上下文的记忆和检索能力等优化。
SuperCLUE/CMMLU等一些榜单数据的实测上,Yi的效果确实不错。能排在同时期中文(开源)大模型里的第一梯队。
2024年3月,Yi终于发布了技术报告,在此来梳理一下报告中的重点内容和值得关注的细节信息。
Yi目前有6B、9B、34B三个规模,其中34B是主力模型。
- 选择34B,而不是更大规模的原因,是这个规模能在24G显存的消费级显卡(如RTX4090)上运行。
- 使用int4量化之后的34B模型可运行在24G显存的GPU上。
参考《Understanding INT4 Quantization for Language Models: Latency Speedup, Composability, and Failure Cases》的量化方法
- Yi-34B int8量化模型相比bf16模型,几乎可以做到效果无损(差距<1%),而int4量化模型在大部分任务的损失也完全可以接受,
官方资料
- 论文 Yi: Open Foundation Models by 01.AI
- Code: Yi
- Model: 01-ai
总结:
- Yi-34B模型int4量化之后,相比float16损失<1%,可跑在RTX4090上(24G显存)
- 模型结构不需要太多变化,LLAMA2 标准结构已经足够训出很好的效果
- 3.1T 预训练数据远比scaling law建议的1T大,但是效果更好,并且模型还没饱和,继续增大数据量还能提升
- 微调数据质量很重要,由算法人员直接标注,只要不到10k数据量就足够了
- 4k长度的基础预训练模型已经具备长文本能力,只需用长文本数据继续预训练,更新百步就有很好效果
总之,数据要精心设计,数据质量要高,数据量要大
Yi实践结果证明: 较小模型+更大规模高质量数据,可获得进一步效果提升
- 获得高性价比的推理模型–34B推理成本+大训练投入,就能得到接近普通70B规模的推理效果。
数据构造
数据是LLM最核心的部分,没有之一。Yi最核心的工作就是提升数据数量和质量。
预训练
主要步骤
Yi模型在预训练阶段的数据处理流程,主要是对爬取的网络文本进行数据过滤和去重
- 原始网络数据 → 语种过滤 →
语料获取 & 语言分类
- 从网络爬虫开始,爬取中英文这两种语言的网站,对网站内容进行解析。
- 参考CCNeT(《CCNet: Extracting High Quality Monolingual Datasets from Web Crawl Data》)的做法,进行语言识别。
过滤方法
- 启发式过滤:去除质量较低的文本内容。过滤规则包含:
- (1)根据特殊URL、域名、黑名单词表以及乱码文本进行过滤;
- (2)根据文本长度、特殊字符比例、短、连续或不完整的行比例;
- (3)根据重复词语、N-Gram片段、段落的占比;
- (4)识别和匿名话个人可识别信息,例如:邮箱、电话等。
- 学习式过滤:Learned Filters, 规则不好处理的,训练模型来清洗
- 通过困惑度、 质量、 安全和文档连贯性4种评分器来对文本进行过滤,共有4个scorer:
Perplexity Scorer
:参照《CCNet: Extracting High Quality Monolingual Datasets from Web Crawl Data》,用kenlm库,把高于平均 perplexity 内容丢弃;Quality Scorer
:识别如维基百科高质量内容,丢弃低质量内容;Document Coherence Scorer
:发现句子、段落零散不连贯的文本,要么分割,要么直接丢弃;Safety Scorer
:识别并删除暴力、色情、涉政内容
- 困惑度评分器利用KenLM库,按照CCNet方法评估大量网络文本,丢弃困惑度分数明显高于平均水平的文本;
- 质量评分器经过维基百科数据训练的分类模型,当文本内容更偏向于维基这样高质量页面时,认为文本质量较高;
- 安全评分器是识别并删除包含有毒内容的网络文档,如暴力、色情等;
- 文档连贯性评分器识别文本的整体连贯性,删除句子或段落不连贯的文本。
- 通过困惑度、 质量、 安全和文档连贯性4种评分器来对文本进行过滤,共有4个scorer:
- 聚类过滤:Cluster-based Filters
- 采用无监督语义聚类对文本进行分组,然后对聚类数据标注质量标签, 丢弃质量差的类别,为后续数据混合策略提供参考。
- 去重方法:
- 文本过滤之后进行去重流程,涉及基于文档级别的
MinHash
去重和子文档精确匹配去重,有效识别和消除文档内部和跨文档中的重复内容。 - 同时利用主题模型对数据赋予特定主题,在最后数据采样过程种对信息密度较低的主题内容进行下采样(主要是广告文本)
- 文本过滤之后进行去重流程,涉及基于文档级别的
最终预训练数据组成如下图所示,总计 3.1T Token。
- 语种构成: 英语(60%) > 中文(20%) > 代码(10%)
- 语料类型: 网页内容(80%) > 代码(8%) > 论文(5%) > 书籍(3%)
微调
对于微调数据
- Quality is All You Need
- 数据质量胜过数量
SFT数据质量能极大影响模型效果,随着数据量的增加,高质量数据能带来更多提升
微调阶段数据构造
- 微调阶段采用 不到10k的 SFT数据
一共只有<10k条SFT数据,每条数据都通过人工多次打磨,这比大数量但质量一般数据的效果好。
- 这思路和别人一致
- 《Gemini: A family of highly capable multimodal models》
- 《Llama 2: Open Foundation and Fine-Tuned Chat Models》
- 《Lima: Less is more for alignment》
- 不同
FLAN
(《Scaling instruction-finetuned language models》)UltraChat
(《Enhancing chat language models by scaling high-quality instructional conversations》)
具体做法:
- 对于 prompt distribution selection:参考《Wizardlm: Empowering large language models to follow complex instructions》,开发复合指令,并通过指令进化,逐步增加指令的复杂度。这种做法显著减少了SFT数据量。
- 对于 CoT data formatting:参考《Take a step back: Evoking reasoning via abstraction in large language models》,采用了“Step-Back”的模式。即通过抽象化处理,让模型学习在深入探讨原始、具体的问题之前,制定更高层次的解决方案。
- 对于 response formatting:使用从《Lima: Less is more for alignment》扩展的默认样式。
- response的结构为introduction-body-conclusion的格式,“where the body is usually a list of bullet point”。
- 在缓解幻觉问题上,思路是确保response中的知识不由模型内部产生,对应的做法是把会导致模型进行记忆的response删掉。(但是这个具体标准是什么,有没有了解的朋友说下看法?)
- 在缓解生成重复的问题上,则是直接把response中包含重复的部分都重写了。(核心还是洗数据,一条条打磨)
- 数据多样性很重要,因此参考《#instag: Instruction tagging for analyzing supervised fine-tuning of large language models》建立了一个打标系统,并设计一个注重多样性的采样算法,平衡了各个领域数据的分布。
- 为了找到最佳数据配比,参考《How abilities in large language models are affected by supervised fine-tuning data composition》,使用近似网络搜索(approximate grid search),对每个领域以 {1, 1/2, 1/4, 1/8, 1/16, 1/32, 1/64} 比例进行实验和人工测评,找到最佳的组合方式。
- 除了内容,数据格式对效果也有很大影响。参OPENAI的ChatML格式,这种结构化的格式使模型能够区分各种信息类型,如system prompt、user input和bot response。
数据构造过程中
- 采用
WizardLM
方法获取难度较高提示的数据集,采用LIMA
中回复风格(总-分-总)对生成回复内容格式化,采用“Step-Back
”模式对维链数据格式化。 - 同时为了减少幻觉和重复,检查并确保回复中的知识不包含在模型中,消除可能导致模型死记硬背的回复,并重写回复保证微调多轮时数据不重复。
同时
- 为了确保模型能力覆盖范围,微调数据中涉及多种任务,例如:问答、创意写作、对话、推理、数学、编码、双语能力等。
- 为了增加模型的精细控制能力,设计了一套系统指令,通过多样性的采样算法,平衡各种系统指令上的数据分布,增强的跨任务鲁棒性。
- 为了探索不同任务数据比例,对模型最终能力的影响,通过网格搜索方法,确定最终数据混合比例。
最后,微调数据采用ChatML
格式,让模型可以更好地区分输入中各类型信息,例如:系统指令
、用户输入
和模型回复
。
模型结构
涉及 分词器、模型结构及微调参数
分词器
Tokenizer 采用 sentencepece 中 BPE方法对预训练数据训练得来,为平衡计算效率和词理解能力将词表设置为64000,将数字拆分为单个数字,将罕见字符用unicode编码。
tokenizer
- 用 BPE,词表大小为64000,平衡了计算效率和表达能力;
- 其中数字全是单个的digit,让模型能更好地理解数字数据;
- 对于OOV的词,会降级用unicode编码 ;
- 保留全角标点符号,不转为半角;
另外,优先考虑英语的LLM在tokenizer会使用虚拟前缀(文本开头的空格)来泛化句子不同位置相同的单词。Yi不这么做,因为即使是在英语语境中,这种假设并不总是成立,比如对于以引号开头的句子,而且在中文语境中,这么做没有明显效果。
模型
模型 Transformer-Decoder 结构,基于标准LLAMA2模型,修改如下:
- 注意力机制:LLAMA2只在70B用了GQA,Yi全系列都用了GQA
- Yi-6B和34B版本均采用 Grouped-Query Attention(GQA),Llama2 中仅70B版本采用GQA。
- 激活函数:Yi采用
SwiGLU
作为后注意力层的激活函数。- 参考《GLU Variants Improve Transformer》
- 位置编码:Yi模型采用旋转位置编码(
RoPE
),为例支持200k上下文窗口,调整了基础频率(RoPE ABF)。- 参考 RoPE ABF(《Effective long-context scaling of foundation models》),base扩大到10M,用于支持长上下文。
模型微调阶段
- 仅计算回复内容的损失,不考虑系统指令和用户指令。
- 采用AdamW优化器,其中β1、β2和ϵ分别为0.9、0.999和1e−8。
- 训练数据最大长度为4096,批量大小为64,训练300步,学习率恒定为1e−5,权重衰减为0.1,梯度裁剪最大阈值为1.0,并采用NEFTune方式训练,Yi-34B-Chat和Yi-6B-Chat噪声尺度分别为45和5。
训练
Infra
从数据处理到模型训练都需要大集群大算力的支持。
Yi构建了支持全栈数据处理、预训练、微调和服务的基础设施。包括:
- (1) 自动管理和监控计算资源的能力;
- (2) 通过优化并行策略、内核效率和长上下文支持提高训练速度;
- (3) 统一微调框架,支持异构分布式训练后端,例如在DPO中同时使用Megatron和DeepSpeed进行多个模型的训练;
- (4) 通过各种LLM服务加速技术(如量化、continuous batching 和 paged attention)降低部署成本。
这部分工作还是很多的,比如
- 由于经常有硬件坏,坏的硬件会被自动从资源池移除;
- 任务失败时,会自动跟踪重启。
- 给算法人员考法UI等。
预训练
预训练 pretrain
- 训了4k基础模型。(暂时没有给出更多细节)
微调
微调超参如下
AdamW:beta=[0.9,0.999],epsilon = 1e-8
seq_len = 4096
batch size = 64
constant lr = 1e-5,weight decay = 0.1
gradient clip = 1.0
max step = 300
参考
- 《Neftune: Noisy embeddings improve instruction finetuning》
- 对于6B模型 noise scale = 5,对于34B模型 noise scale = 45
评测
基模型评测
基础能力评测
对其他开源模型,保持和公开的设置相同做法获取结果。Yi使用贪婪解码,没有进行任何后处理
- 数学和代码能力上,和GPT3.5、GPT4还存在一些差距,而这些能力是可以通过继续预训练和微调来持续提升的。Yi最初的设计并没有针对这些能力,因此没有在预训练数据中包含特别多相关数据,后续会有计划增加这部分能力的提升。
- 而和其他开源模型相比,在代码和数学以外的任务,Yi基本上做到了跟大一倍模型的效果相近,甚至更好的水平。
观察
- 模型规模带来的增益:尽管Yi-34B和Yi-6B使用了相同的预训练语料,但Yi-34B的性能相比Yi-6B有了质的提升。
- 更大的模型尺寸在代码和数学基准测试上带来了明显的增益。
- 数据质量:高质量预训练数据的小型模型,如Yi-34B或Qwen-14B,通常表现优于尺寸更大但(可能)数据质量较低的模型,例如Falcon-180B。
- GPT-4与开源LLM差距:
- 开源LLM在多种基准测试上的性能仍然落后于GPT-4和GPT-3.5。
- 然而,具有代表性的双语LLM,例如Qwen-14B和Yi-34B,在包括C-Eval、CMMLU和Gaokao在内的中文知识相关基准测试上匹配甚至超过GPT-4的性能。然而,在BBH、代码(HumanEval)和数学(MATH)等推理相关基准测试上,仍然存在巨大差距。
In-Context Learning 能力
Yi进一步研究了in-context learning能力,即根据少数展示的输入-输出示例,推断underlying function的能力。
任务是推断加权和的线性系数。
- 定义
y = w1x1 + w2x2 + ... + wnxn
。
少量示例展示是 x1, x2, …, xn, y,要求模型预测给定一组新输入 x 的 y。
这就要求模型隐式地推断出 w1, w2, …, wn。
评测上,使用(a)模型预测的 y 与真实值 y∗ 之间的绝对差,即 |y − y∗|
作为连续度量,以及使用(b)精确匹配 y == y∗ 作为不连续度量。
模型在算术上的效果正常,因此可以认为这样的测试不受算术能力的影响,而能直接看模型是否具备根据给定的实例进行underlying function推理的能力。
实验发现,当问题比较简单时(系数是[1,-1]
),Yi-34B和LLAMA-70B效果比较好(看下图)。
当问题更复杂点(系数是[1,1,1,1,1]
),只有LLAMA-70B和Mistral 8*7B这样的大模型表现出了涌现的能力。
Chat 模型评测
自动评测
- 评测任务和base模型相同,分别采用zero-shot和few-shot,效果依然不错
报告强调,如Goodhart’s principle所说
- 当一个指标变成目标,就不再是一个好指标。
- 因此这里的测试只是为了确认微调没有使得模型的知识能力下降,而不会专门去针对任务做优化。
结果上,Yi-34B-Chat数学能力不错,而Yi-6B-Chat并没有展现出强大的数学能力。推测较小的模型可能需要更多的数据在SFT阶段激活其相应的能力。
人工评测
能力扩展
上下文扩展
扩展模型上下文长度
对于长上下文的解决方法:采用继续预训练和微调两种方法
- 基础模型其实本身已经存在利用200K输入上下文中任何位置信息的前来,继续预训练可以“解锁”这种能力,通过微调可以进一步调整生成内容的风格以更好地遵循人类指令和偏好。
预训练阶段:
- 采用序列并行和分布式注意力方式蛮力对模型全部注意力进行训练。
数据来源:
- (1)原始预训练数据;
- (2)长上下文数据,主要来自数据;
- (3)多文档文档合成数据。共计对5B Token的数据进行训练,批次大小为4M Token。
微调阶段:
- 将短SFT数据与长上下文问答问答数据混合使用。文档问答数据由模型辅助构建,即随机将多个文档拼成一个长文档,从中抽取一个或多个段落,要求模型基于抽取段落内容构建问答对。
- Trick,要求给答案之前模型需要背诵或改写原始段落,这种数据格式鼓励模型进行检索,从而阻止依赖自身知识回答产生的幻觉。
多模态
ViT部分由CLIP ViT-H/14 model初始化,后面的transformer由Yi-Chat初始化
3步训练:
- (1)使用224^2的图像来训练ViT和projection模块的参数。这一训练利用了包含1亿个图像-文本对的数据集,这些数据来自LAION-400M。主要目标是增强ViT在架构中的知识获取能力,并实现ViT与LLM之间更好的对齐。
- (2)将ViT图像分辨率提升到448^2,目的是进一步推动模型识别复杂视觉细节的能力。在这个阶段使用的数据集包括从LAION-400M中提取的2000万个图像-文本对。此外,还融入了来自不同来源的大约480万个图像-文本对,例如CLLaVA、LLaVAR、Flickr、VQAv2、RefCOCO、Visual7w等。
- (3)整个模型的参数一起训练。主要目标是提高模型在多模态聊天交互方面的熟练度,从而赋予它能够无缝融合和解释视觉与语言输入的能力。为此,训练数据集涵盖了多种来源,总共大约有100万张图像-文本对,包括GQA、VizWiz VQA、TextCaps、OCR-VQA、Visual Genome、ShareGPT4V等等。为了确保数据平衡,对任何单一来源的最大数据量设定了上限,将其限制在不超过50,000对。
使用128张A100,6B训了3天,34B训10天。
扩展模型深度 Depth Upscaling
目标是把32层的6B扩展到48层的9B模型。
- 参考《Scaling large language models with simple yet effective depth up-scaling》,通过复制中间的12-28层共16层,把层数扩展为48层。
参考SOLAR 10.7B模型对Yi-6B模型进行深度扩展,将原来的32层扩展到48层,构建Yi-9B模型。在具体层的选择时,通过评估每一层输入和输出直接的余弦相似度得出,余弦相似度越接近于1,则表明复制这些层不会显著改变原始模型输出的logits,因此选择复制原始模型中间12-28的16个层。
采用两阶段训练
- 第一阶段使用了0.4T数据(包含文本和代码),数据配比与Yi-6B模型一样;
- 第二阶段使用了0.4T数据(包含文本、代码和数学),重点增加了代码与数学数据的比例,以提高代码性能。
在微调过程中
- 设定了一个固定的学习率 3e-5,并采取逐步增加 batch size 的策略,即从 batch size 4M token 开始,每当模型 loss 停止下降时就增加 batch size,使 loss 继续下降,让模型学习更加充分,收敛性能更好。
【2024-4-22】MiniCPM
【2024-4-22】MiniCPM:揭示端侧大语言模型的无限潜力
- MiniCPM: Unveiling the Potential of Small Language Models with Scalable Training Strategies (arxiv.org)
MiniCPM 是一系列端侧语言大模型,主体语言模型 MiniCPM-2B
具有2.4B的非词嵌入参数量。
- 综合性榜单上与Mistral-7B相近(中文、数学、代码能力更优),整体性能超越Llama2-13B、MPT-30B、Falcon-40B等模型。
- 当前最接近用户体感的榜单MTBench上,MiniCPM-2B也超越了Llama2-70B-Chat、Vicuna-33B、Mistral-7B-Instruct-v0.1、Zephyr-7B-alpha等众多代表性开源大模型。
超参调优
Hyper-parameters、Batch size、Learning Rate、Learning Rate Scheduler、Data Strategy 五个方面模型沙盒研究。
近400次在0.009B模型规模上的贝叶斯参数搜索得到
超参数对模型的性能具有重大影响
- 传统训练方法要对每个模型进行超参数调整,这对于大模型并不现实。
借鉴 uP 方法,对模型各参数模块之间进行了连接权重的调整、以及对模型初始化的调整。部分调整接近Cerebras-GPT。
名称 | 具体操作 |
---|---|
Embedding Output Scaling | 将Embedding的输出乘12 |
Residual Connection Scaling | 将每层的残差连接处的增量放缩为 1.4/sqrt(num_layers) = 0.22 倍 |
Initialization of Tensors | 将每个二维的张量参数的初始化标准差设置为 0.1/sqrt(dim_model/256) = 0.033,其他参数初始化设置为0.1 |
Learning Rate Scaling of Tensors | 将每个二维的张量参数的学习率调整为其他部分学习率(或称整体学习率)的1/(dim_model/256) = 0.11倍 |
lm_head Scaling | 将输出logits调整为原来的0.11倍 |
batch size
Batchsize 随损失变化: 更大的Batchsize可能可达到更低的loss
- 扩大Batchsize 时, 损失会有一次较大幅度的下降
2020年, OpenAI 开山之作研究了损失函数随token数变化的规律: 消耗更多的步数等价于消耗更多的时间
在这种假设下,OpenAI定义了临界Batchsize
(Critical Batchsize),使得达到一定的损失,既不消耗过多step,也不消耗过多token。
然而利用当前以A100为主的计算资源,结合gradient checkpointing策略进行训练时,通常计算速度(而不是显存)是瓶颈
- 相同机器数量下,多一倍 Batchsize 几乎等同于慢一倍的单步时间。
基于这个观察,取消了对“不消耗过多step”的追求,而转向追求用最少的token量达到最低的loss。
0.009B,0.036B,0.17B的模型上分别进行了6个batchsize的训练实验
log(BS) = -6.24 * log(L) + 20.91
最优batchsize随着C4数据集上的loss的偏移规律
- 规律: BS = 1.211 * 10^9 / L^6.2393
- 预估: 2B模型达到C4损失2.5左右,4M是比较合适的Batchsize
learning rate
模型最关键超参数:学习率
lr 不会因为模型规模扩大有大幅度的改变
0.04B, 0.1B, 0.3B, 0.5B 上分别做了6组学习率实验,发现虽然模型大小扩大了10倍,但是最优学习率偏移并不明显,均在0.01
左右
- 在 2.1B 规模上进行了简单验证,发现在 0.01 的学习率确实能取得最低的Loss。
lr 调度策略
不同训练阶段使用不同学习率的调整策略,对模型性能影响很关键。
当前通用的学习率策略是Cosine图像,即 学习率从 Warmup阶段升高到最高点之后,开始呈现余弦函数的降低。
- 几乎所有大模型都使用了 Cosine Learning Rate Scheduler (简称Cosine LRS)的方式。
为什么 Cosine Scheduler 表现优异?
对0.036B的模型,设置不同的Learning Rate Scheduler的截止步数$T$,进行了持续训练。
- 对于训练至 S 步的模型,将 Cosine LRS 截止步数 T 设置为 S 步, 总是能获得最优的性能,而设置为更多或者更少性能都不是最优。
持续训练场景会发现 Cosine调度器有问题。
- 如果在Cosine的截止步数之后, 继续沿用0.1倍的最大学习率(通常做法),则继续训练收敛非常缓慢;
- 如果在Cosine的截止步数之后, 重启Cosine LRS(即再次从最大学习率开始下降,或者是逐渐上升到最大学习率,再开始下降)则损失会经历长时间的上升周期,而这段时间,模型处于不可用状态。
猜想 Cosine LRS 在预先指定步数的时候性能优异原因:
- T=S下的Cosine LRS,相对于Linear LRS、Noam LRS、以及
T<S
的Cosine LRS,有更长时间的大学习率训练。这一阶段可能有助于模型寻找更好的全局最优解。 - T=S下的Cosine LRS ,相对于
T>S
的Cosine LRS、Constant LRS,有更充分的学习率下降的退火阶段,这一阶段可能发生了较为特别的动力学现象,导致模型可以找到更好的局部最优解。
结合这两点,提出了一种新的学习率调度策略,Warmup-Stable-Decay
(WSD
)调度器。
- 公式见原文
- Cosine调度器结束后, 需要持续保持最低学习率,以保证loss不上升
- 而WSD调度器则从退火(decay)前开始继续用最大学习率训练,经过更长的训练再开始退火
这种学习率调度器分为三个阶段:
warmup
阶段(用W表示warmup阶段结束时的步数/训练量)稳定训练
阶段(用S表示稳定训练阶段结束时的步数/训练量)退火
阶段(用D表示退火阶段的训练量)- 随着学习率的变小,损失有大幅度的快速下降,在步数S时迅速降低至和T=S的Cosine LRS相等或更低
WSD好处:
- 可以持续训练。
- 可以随时取出。
- 性能优于Cosine LRS。
- 有显式区分的训练阶段,便于使用不同的数据策略。
数据策略
结合训练阶段特点,使用不同类型的数据
- 预训练阶段: 只用通用、量大的预训练粗质量数据
- 退火阶段: 用非常广泛的高质量知识和能力数据以及SFT的高质量数据,混合入预训练数据进行退火。
实验结果
- 退火开始时加入高质量数据的收益远高于在退火完成后的sft阶段加入。
因此, 建议模型能力的特化和增强应从退火阶段开始进行。
【2024-5-7】DeepSeek
详见站内专题: DeepSeek
训练经验
OOM
【2024-4-11】 OOM
- 单机单卡(V100S,32G)
- InternLM2-1.8B, 7.1G
- 数据集: 231m
报错
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate
7.04
GiB. GPU 0 has a total capacty of31.75
GiB of which5.04
GiB is free. Process 743134 has26.71
GiB memory in use. Of the allocated memory25.01
GiB is allocated by PyTorch, and 342.98 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try settingmax_split_size_mb
to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
deepspeed 配置
deepspeed --master_port 30001 ./llm/training/conversation_reward/main.py \
--max_seq_len 2048 \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 2 \
--weight_decay 0.01 \
--dropout 0.0 \
--gradient_accumulation_steps 1 \
--zero_stage 2 \
--dtype bf16 \
--num_train_epochs 10 \
--train_data_path /mnt/bn/flow-algo-cn/wangqiwen/session_process/data/train/cut_train_sequence_en_20240331.csv \
--val_data_path /mnt/bn/flow-algo-cn/wangqiwen/session_process/data/test/cut_test_0322_es_sequence_v2.csv \
--test_data_path /mnt/bn/flow-algo-cn/wangqiwen/session_process/data/test/cut_test_0322_en_sequence_v2.csv \
--model_name_or_path /mnt/bn/flow-algo-cn/yufeng/ModelHub/internlm2-1_8b \
--output_dir /mnt/bn/flow-algo-cn/wangqiwen/model/checkpoints \
--debug \
--deepspeed
解决
- 设置GPU缓存碎片 → 无效
- 改用 A100(80G) → 有效
--max_split_size_mb 32 # 无效