鹤啸九天 自律更自由,平凡不平庸 Less is More

ChatGLM 系列

2023-11-06
阅读量

Notes(温馨提示):

  1. ★ 首次阅读建议浏览:导航指南
  2. 右上角工具条搜索文章,右下角二维码关注微信公众号(鹤啸九天),底栏分享、赞赏、评论
  3. ★ 转载请注明文章来源,知识点积累起来不容易,水滴石穿,绳锯木断,谢谢理解
  4. ★ 如有疑问,邮件讨论,欢迎贡献优质资料


ChatGLM

资讯

ChatGLM:开源双语对话语言模型,官方B站

  • 【2023-4-29】清华大学计算机系第一届基础模型前沿研讨会,唐杰教授报告《ChatGLM:从千亿到开源的一点思考》
  • 【2023-6-3】从GLM-130B到ChatGLM:大模型预训练与微调
  • 【2023-8-31】智谱AI正式上线首款AI助手:智谱清言,基于 ChatGLM2,采用监督微调,以通用对话形式提供智能化服务

文心一言和智谱清言(ChatGLM)这次连夜开放

ChatGLM 介绍

GLM 时间线

几个关键的里程碑:

  • 2020年11月,开始大规模预训练
  • 2021年5月,做出百亿的GLM模型
  • 2021年底,开始训练GLM-130B千亿模型
  • 2022年8月,GLM-130B训练完成, 清华智谱AI开源GLM-130B。该大语言模型基于之前提出的GLM(General Language Model),在Norm处理、激活函数、Mask机制等方面进行了调整,目的是训练出开源开放的高精度千亿中英双语稠密模型,能够让更多研发者用上千亿模型
  • 2023年3月, 通过SFT+RLHF训练ChatGLM-130B/ChatGLM-6B(开源), 解决大基座模型在复杂问题、动态知识、人类对齐场景的不足,基于GLM-130B,引入面向对话的用户反馈,进行指令微调后,得到对话机器人。
  • 2023年6月25日,ChatGLM2-6B
  • 2023年11月

图解

多场景涌现能力

ChatGLM 生态

ChatGLM3 推出了可在手机上部署的端测模型 ChatGLM3-1.5BChatGLM3-3B,支持包括vivo、小米、三星在内的多款手机以及车载平台,甚至支持移动平台上CPU芯片的推理,速度可达20tokens每秒(token是语言模型中用来表示单词或短语的符号)。

ChatGLM 第三方扩展

【2023-5-19】基于ChatGLM的扩展模型

  • Chinese-LangChain: 中文langchain项目,基于ChatGLM-6b+langchain实现本地化知识库检索与智能答案生成,增加web search功能、知识库选择功能和支持知识增量更新
  • bibliothecarius: 快速构建服务以集成您的本地数据和AI模型,支持ChatGLM等本地化模型接入。
  • langchain-ChatGLM: 基于 langchain 的 ChatGLM 应用,实现基于可扩展知识库的问答
  • InstructGLM: 基于ChatGLM-6B进行指令学习,汇总开源中英文指令数据,基于Lora进行指令数据微调,开放了Alpaca、Belle微调后的Lora权重,修复web_demo重复问题
  • ChatGLM-Efficient-Tuning: 基于ChatGLM-6B模型进行定制化微调,汇总10余种指令数据集和3种微调方案,实现了4/8比特量化和模型权重融合,提供微调模型快速部署方法。
  • ChatGLM-Finetuning: 基于ChatGLM-6B模型,进行下游具体任务微调,涉及Freeze、Lora、P-tuning等,并进行实验效果对比。
  • ChatGLM-Tuning: 基于 LoRA 对 ChatGLM-6B 进行微调。类似的项目还包括 Humanable ChatGLM/GPT Fine-tuning

GLM 模型

2022年8月,智谱AI基于GLM框架,推出1300亿参数的中英双语稠密模型GLM-130B,综合能力与GPT3相当 内存节省75%,可在单台3090 (4)或单台2080(8)进行无损推理 高速推理,比Pytorch提升7-8倍速度

原有大模型问题

  • 规模过大或精度一般
  • 大都无法支持单机推理
  • 基于NVIDIA为主,缺少国产芯片支持
  • 训练成本高昂
  • 人力投入极大
  • 训练过程不稳定
  • 缺少充分训练、开源的稠密千亿大模型

GLM 特点

相较于自回归模型GPT,自编码模型BERT,以及encoder-decoder模型T5,GLM 模型架构是设计了自回归填空结构,通过双向注意力,对masked字段进行自回归预测。

GLM的出发点是将3种主流预训练模型进行统一:

  • GPT训练目标是从左到右文本生成
  • BERT训练目标是对文本进行随机掩码,然后预测被掩码的词
  • T5则是接受一段文本,从左到右的生成另一段文本
  • img

GLM 模型对比

GLM 既可以做 Encoder 也可以做 Decoder。

主要通过 两种mask方式来实现:

  • [mask]:bert形式,随机mask 文本中的短span
  • [gmask]:gpt 形式,mask末尾的长span 在chatglm里面做生成任务时,是用 [gmask]。chaglm2 完全采用 gmask来进行预训练。

在ChatGLM 的内部结构中的变换,从下到上依次是:

  • 位置编码:从BERT的训练式位置编码转变为旋转位置编码
  • 激活函数:从BERT中的 GeLU 转变为 GLU, 在ChatGLM2 中又变成了SwiGLU
  • LayerNormalization:采用的是DeepNorm,是对post-Normalization 的改进,即在残差之后做Normalization。在ChatGLM中,把 layer-normalization 改为 RMSNormalization。

在ChatGLM 2.0 中还添加了一些其他变化:

  • FlashAttenion:利用显存和内存来做加速
  • Multi-Query Attention:多个头只采用一个 KV对,通过参数共享来降低显存占用

整体结构

模型 发布时间 模型结构 位置编码 激活函数 Layer norm 总结
LLaMA - Casual decoder RoPE SwiGLU Pre RMS Norm  
Bloom - Casual decoder ALiBi GeLU Pre Layer Norm  
ChatGLM-6B 2023.3.14 Prefix decoder RoPE GeGLU Post Deep Norm ChatGLM上中文问答优化,监督微调/RLHF, 推理加速。改动点:
embedding层梯度缩减: 梯度缩小10倍,提升训练稳定性
layer normalization: 基于Deep Norm的 Post layer norm
激活函数: 采用 GeGLU(3个权重矩阵)
位置编码: 绝对位置编码 → 旋转位置编码 RoPE
【劣势】
prefix decoder-only结构导致训练效率低(损失计算不含prefix)
ChatGLM2-6B 2023.6.25 Casual decoder RoPE SwiGLU Post Deep Norm 升级点:
效果大幅提升: GLM 混合目标函数,偏好对齐
上下文扩充: 借助 FlashAttention技术,将 context 2k→32k,对话模型使用8k,多轮会话
推理提速:Multi-Query Attention技术,速度提升42%
可商用
ChatGLM3-6B 2023.11.6 Casual decoder RoPE SwiGLU Post Deep Norm 升级点:
基座模型升级: ChatGLM3-6B-Base
全新的Prompt格式, 支持 工具调用/代码执行/Agent

问题

  • ChatGLM2-6B 架构变化?ChatGLM3-6B 是什么结构

tokenizer

分析

  1. LLaMA 词表最小,LLaMA在中英文上的平均token数都是最多的,这意味着LLaMA对中英文分词都会比较碎,比较细粒度。尤其在中文上平均token数高达1.45,这意味着LLaMA大概率会将中文字符切分为2个以上的token。
  2. ChatGLM-6B 是平衡中英文分词效果最好的tokenizer。由于词表比较,中文处理时间也有增加。
  3. BLOOM 虽然是词表最大,但由于是多语种的,在中英文上分词效率与ChatGLM-6B基本相当。
模型 词表大小 中文平均token数 英文平均token数 中文处理时间(s) 英文处理时间(s)
LLaMA 32000 1.45 0.25 12.6 19.4
ChatGLM-6B 130528 0.55 0.19 15.91 20.84
Bloom 250880 0.53 0.22 9.87 15.6

直观对比不同tokenizer 分词结果。

token数目 男儿何不带吴钩,收取关山五十州 杂申椒与菌桂兮,岂维纫夫蕙茝
LLaMA 24 37
Bloom 13 17
ChatGLM-6B 11 17

(1) “男儿何不带吴钩,收取关山五十州

# LLaMA分词为24个token:
[ '男', '<0xE5>', '<0x84>', '<0xBF>', '何', '不', '<0xE5>', '<0xB8>', '<0xA6>', '<0xE5>', '<0x90>', '<0xB4>', '<0xE9>', '<0x92>', '<0xA9>', ',', '收', '取', '关', '山', '五', '十', '州', '。'] 
# Bloom分词为13个token:
 ['男', '儿', '何不', '带', '吴', '钩', ',', '收取', '关', '山', '五十', '州', '。']
# ChatGLM-6B分词为11个token:
[ '男儿', '何不', '带', '吴', '钩', ',', '收取', '关山', '五十', '州', '。'] 

(2) “杂申椒与菌桂兮,岂维纫夫蕙茝。”

# LLaMA分词为37个token:
[ '<0xE6>', '<0x9D>', '<0x82>', '<0xE7>', '<0x94>', '<0xB3>', '<0xE6>', '<0xA4>', '<0x92>', '与', '<0xE8>', '<0x8F>', '<0x8C>', '<0xE6>', '<0xA1>', '<0x82>', '<0xE5>', '<0x85>', '<0xAE>', ',', '<0xE5>', '<0xB2>', '<0x82>', '<0xE7>', '<0xBB>', '<0xB4>', '<0xE7>', '<0xBA>', '<0xAB>', '夫', '<0xE8>', '<0x95>', '<0x99>', '<0xE8>', '<0x8C>', '<0x9D>', '。']
# Bloom分词为17个token:
 ['杂', '申', '椒', '与', '菌', '桂', '兮', ',', '岂', '维', '�', '�', '夫', '蕙', '�', '�', '。'] 
 
# ChatGLM-6B分词为17个token:
 [ '杂', '申', '椒', '与', '菌', '桂', '兮', ',', '岂', '维', '纫', '夫', '蕙', '<0xE8>', '<0x8C>', '<0x9D>', '。'] 

Mask机制

transformer中mask机制:

  • Transformer中mask机制用于 self-attention,控制不同token之间的注意力交互。
  • Transformer中使用两种类型的mask:padding mask 和 sequence mask。
    • Padding mask(填充掩码):自注意力机制中,句子所有单词都会参与计算。但是,实际句子中,往往会存在填充符,用来填充句子长度不够的情况。Padding mask就是将这些填充符对应的位置标记为0,以便在计算中将这些位置的单词忽略掉。
    • Sequence mask(序列掩码):sequence mask 在Decoder端的self-attention中,以保证在生成序列时不会将未来的信息泄露给当前位置的单词
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("models/chatglm3-6b", trust_remote_code=True)

if __name__ == '__main__':
    promt = ["你好", "今天过得怎么样?", "好"]
    print(tokenizer(promt, return_tensors="pt", padding=True))
""" 
{'input_ids': tensor([[    0,     0, 64790, 64792, 36474, 54591],
                      [64790, 64792, 53456, 37072, 35367, 30987],
                      [    0,     0,     0, 64790, 64792, 46458]]), 
'attention_mask': tensor([[0, 0, 1, 1, 1, 1],
                          [1, 1, 1, 1, 1, 1],
                          [0, 0, 0, 1, 1, 1]]), 
'position_ids': tensor([[0, 0, 0, 1, 2, 3],
                        [0, 1, 2, 3, 4, 5],
                        [0, 0, 0, 0, 1, 2]])}
"""

比较不同LLM架构,其实是在比较sequence mask。

Decoder Only

LLM(Language Model)是一种处理自然语言的模型架构,对输入的序列进行预测和生成。

LLM可以使用 encoder-decoder或者decoder-only架构。

LLM所做的都是seq2seq任务,用户输入一个文本prompt,LLM对应输出一个文本completion。

  • 示例: A Survey of Large Language Models
  • 用户输入 prompt为 A Survey of,期望语言模型输出 Large Language Models
模型架构 代表LLM 注意力机制 是否属于Decoder-Only
Causal Decoder GPT3/ChatGPT/ChatGLM2-6B 纯单向 YES
Encoder-Decoder Flan-T5 输入双向,输出单向 NO
Prefix Decoder GLM-130B/ChatGLM-6B 输入双向,输出单向 YES

训练目标上,ChatGLM-6B 训练任务是自回归文本填空

  • 相比于采用causal decoder-only结构的大语言模型,采用prefix decoder-only结构的ChatGLM-6B存在一个劣势:训练效率低。
  • causal decoder结构会在所有的token上计算损失,而prefix decoder只会在输出上计算损失,而不计算输入上的损失。
  • 相同数量的训练tokens的情况下,prefix decoder要比causal decoder的效果差,因为训练过程中实际用到的tokens数量要更少。

分析

  1. Causal Decoder
    • Causal Decoder 架构典型代表是GPT系列模型,单向注意力掩码,以确保每个输入token只能注意到过去的token和它本身,输入和输出的token通过Decoder以相同的方式进行处理。
    • 灰色代表对应的两个token互相之间看不到,否则就代表可以看到。例如,”Survery”可以看到前面的“A”,但是看不到后面的“of”。
    • Causal Decoder的sequence mask矩阵是一种典型的下三角矩阵。
  2. Encoder-Decoder
    • Transformer最初被提出来时, 采用Encoder-Decoder架构,模型包含两部分Encoder和Decoder,两部分参数独立。
    • 其中Encoder将输入序列处理为一种中间表示,而Decoder则基于中间表示自回归地生成目标序列。
  3. Prefix Decoder

以上三种架构都可以通过MoE技术扩展,针对输入只激活一部分权重,典型代表: Switch Transformer and GLaM

  • All three architecture types can be extended using the mixture-of-experts (MoE) scaling technique, which sparsely activates a subset of neural network weights for each input.
  • This approach has been used in models like Switch Transformer and GLaM, and increasing the number of experts or the total parameter size has shown significant performance improvements.

为什么LLM都是decoder-only架构?

decoder-only架构具有一定优势。

  • 首先,decoder-only模型更简单,因为只需要生成输出,不需要考虑输入。这样可以减少计算负担,提高模型的训练和推理速度。
  • 另外,decoder-only模型更好地解决语言建模问题,如自然语言生成、文本分类等问题。
  • 其次,decoder-only模型更好地利用预训练任务数据。在预训练任务中
    • decoder-only模型只需要通过掩码语言建模(Masked Language Modelling)任务来学习上下文和语言规律。
    • 而encoder-decoder模型需要尝试预测中间编码表示。这对于decoder-only模型来说是一种更容易的任务,因此可以更好地利用数据,使得模型表现更好。
  • 最后,decoder-only模型可以更好地处理长序列
    • encoder-decoder模型需要在解码的时候进行对齐操作,因此当输入序列长度变化时,需要重新对齐,导致计算复杂度的提升。
    • 而decoder-only模型不需要进行对齐操作,可更好地处理长序列。

综上所述,decoder-only架构具有简单、高效、更好地利用预训练数据以及更好地处理长序列等优势

成诚补充

  • Encoder-Decoder 架构没落
    • 代表模型T5最大模型只有 11B,效果不佳
    • T5 很难使用 Pipeline Parallelism, 即便 Megatron 支持 T5 训练(代码),效果非常差,增加一倍 GPU,开启 PP 反而会让总吞吐下降
    • 流水并行是千卡以上分布式训练中最重要的特性
  • 即便 prefix decoder架构 也纷纷转向 decoder-only
    • GLM-130B 不是 Decoder-Only 架构,但是 GLM-3 及以后(含 GLM-4)的 GLM 系列模型都改为 Decoder-Only架构。 就连 GLM 团队都抛弃了原有架构,Follow LLaMa 了。
    • HuggingFace 上可以尝试 GLM-130B 的 Playground,即使仅从 Foundation-Model 的角度评价,效果也很糟糕。
  • Decoder-Only 架构最核心优势是便于 Scale Up,基于 Scaling Laws 的实际训练成本最低
    • LLM 时代,如果新算法结构可能有 5% 效果提升,但是引入了额外 50%训练成本(计算时间 or 通信量) ,那这个新算法是负优化。
    • 因为这 50% 的训练成本,基于 Scaling Laws 在原模型上多训练 50% 的 tokens ,或者训练大一半的模型, 带来的最终提升都远大于新算法的 5%。
  • 至此 2023 年下半年之后的所有 LLM (可以被用户使用的 Chat 模型)均为 Decoder-Only 架构

Encoder-Only、Encoder-Decoder、Decoder-Only 三者架构:

  • 相同参数量的训练效率上:Decoder-Only > Encoder-Only > Encoder-Decoder
  • 现行分布式并行策略下,可以扩展参数量上限 和 分布式集群规模的上限:Decoder-Only » Encoder-Only » Encoder-Decoder

Encoder-Only 并不适合做 NLP 生成式任务(Chat)

  • 目前在 CV 领域应用的比较多(ViT),输入/输出通常是固定像素大小的图片,token 之间没有非常强的先后依赖关系,因此先排除 BERT

总结

  • T5 Scale up 到 100B、500B 的难度很大,训练成本的增加远远高于 GPT。
  • 因此也许 100B 的 T5 训练 10T tokens 的模型能力比 100B 的 GPT 更强,但为此要支付的算力/时间成本远大于 100B GPT 训练 10T tokens,以至于:
  • 没有公司愿意支付这样的代价我还不如支付相同的代价,让 GPT 多训练更多倍的 Tokens;
  • 或者训练一个参数量大很多的 GPT

Layer Normalization

按照 layer normalization 位置不同,可分为 post layer norm 和 pre layer norm。

  • post layer norm。原始transformer中,layer normalization 放在残差连接之后,称为post LN
    • 使用Post LN的深层transformer模型容易出现训练不稳定问题。
    • post LN随着transformer层数的加深,梯度范数逐渐增大,导致了训练的不稳定性。
  • pre layer norm。将layer norm放在残差连接的过程中,self-attention或FFN块之前,称为“Pre LN”。
    • Pre layer norm在每个transformer层的梯度范数近似相等,有利于提升训练稳定性,但缺点是pre LN可能会轻微影响transformer模型的性能,为了提升训练稳定性,GPT3、PaLM、BLOOM、OPT等大语言模型都采用了pre layer norm。
  • LayerNorm:LayerNorm对每一层的所有激活函数进行标准化,使用它们的均值和方差来重新定位和调整激活函数。其公式如下:
  • RMSNorm:RMSNorm通过仅使用激活函数的均方根来重新调整激活,从而提高了训练速度。
  • DeepNorm:为了进一步稳定深度Transformer的训练,Microsoft推出了DeepNorm。这是一个创新的方法,它不仅用作标准化,还作为残差连接。有了DeepNorm的帮助,我们现在可以轻松训练高达1000层的Transformer模型,同时保持稳定性和高性能。其中,GLM-130B和ChatGLM就是采用了这种技术的代表。其公式如下:其中SublayerSublayer是FFN或Self-Attention模块。

自回归填空

GLM 预训练任务是一种自回归填空任务(Autoregressive Blank Infilling),和大多数预训练任务的设计思路一样,采用“先破坏,再重建”的方式,将原始文本的部分进行mask(先破坏),再对mask的部分进行预测(再重建

  • 不同:被mask的输入部分使用和bert相同的双向注意力,在生成预测的一侧使用的是自回归的单向注意力

根据mask的长度不同,可以分为三种方式:单词(MASK)、句子(sMASK)、文档(gMASK)

  • 实际使用中,可以根据不同的任务需要,设置不同mask方式的比例。
  • 例如,如果希望模型有更强的生成能力,可以把文档级别的gMASK的比例设置地比较高。
  • GLM-130B中,采用了70%文档级别的gMASK和30%单词级别的MASK

谷歌的UL2(UL2: UL2: Unifying Language Learning Paradigms),其中的预训练任务和GLM高度相似,但是晚于GLM一年后提出

LayerNorm

LayerNorm 会影响训练的稳定性

通常认为稳定性上: DeepNorm > Sandwich-LN > Pre-LN > Post-LN

130B 规模实验

  • DeepNormSandwich-LN 更稳定

GLM-130B最终采用 DeepNorm(Deepnet: Scaling transformers to 1,000 layers

Positional Embedding

位置编码分为绝对位置编码和相对位置编码。

绝对位置编码代表:

相对位置编码代表性的有:

大模型中应用较多的位置编码:

GLM-130B

GLM-130B 介绍

2022 年 8 月,清华大学联合智谱AI(唐杰、李涓子,公司名:北京智谱华章科技有限公司) 向研究界和工业界开放了拥有 1300 亿参数的中英双语双向稠密模型 GLM-130B

  • 截至2022年7月,它已经训练了超过4000亿个文本标记。
  • 底层架构基于通用语言模型(GLM),在语言理解和语言生成任务上均展示出强大的性能。

官方github对ChatGLM的介绍

该模型有一些独特的优势:

  • 双语:同时支持中文和英文;
  • 高精度(英文):在公开的英文自然语言榜单 LAMBADA、MMLU 和 Big-bench-lite 上优于 GPT-3 175B(API: davinci,基座模型)、OPT-175B 和 BLOOM-176B;
  • 高精度(中文):在 7 个零样本 CLUE 数据集和 5 个零样本 FewCLUE 数据集上明显优于 ERNIE TITAN 3.0 260B 和 YUAN 1.0-245B;
  • 快速推理:首个实现 INT4 量化的千亿模型,支持用一台 4 卡 3090 或 8 卡 2080Ti 服务器进行快速且基本无损推理;
  • 可复现性:所有结果(超过 30 个任务)均可通过我们的开源代码和模型参数复现;
  • 跨平台:支持在国产的海光 DCU、华为昇腾 910 和申威处理器及美国的英伟达芯片上进行训练与推理。

GLM-130B 资料

清华大学曾奥涵报告“从GLM-130B到ChatGLM:大模型预训练与微调”,整个报告分为三个部分

  • 第二段“大规模语言模型系列技术:以GLM-130B为例”, GLM-130B的训练过程

清华官方资料

更多资料

GLM-130B 原理

GLM-130B 将 BERTGPT 目标进行了统一,并与最近提出的一些技术进行结合以提升语言模型的性能表现。

2022年11月,斯坦福大学大模型中心对全球30个主流大模型进行了全方位的评测,GLM-130B是亚洲唯一入选的大模型。在与OpenAI、Google Brain、微软、英伟达、Meta AI的各大模型对比中,评测报告显示GLM-130B在准确性和公平性指标上与GPT-3 175B(davinci)接近或持平,鲁棒性、校准误差和无偏性优于GPT-3 175B。

2022年8月,向研究界和工业界开放了拥有1300亿参数的中英双语稠密模型 GLM-130B1,该模型有一些独特的优势:

  • 双语: 同时支持中文和英文。
  • 高精度(英文): 在公开的英文自然语言榜单 LAMBADA、MMLU 和 Big-bench-lite 上优于 GPT-3 175B(API: davinci,基座模型)、OPT-175B 和 BLOOM-176B。
  • 高精度(中文): 在7个零样本 CLUE 数据集和5个零样本 FewCLUE 数据集上明显优于 ERNIE TITAN 3.0 260B 和 YUAN 1.0-245B。
  • 快速推理: 首个实现 INT4 量化的千亿模型,支持用一台 4 卡 3090 或 8 卡 2080Ti 服务器进行快速且基本无损推理。
  • 可复现性: 所有结果(超过 30 个任务)均可通过我们的开源代码和模型参数复现。
  • 跨平台: 支持在国产的海光 DCU、华为昇腾 910 和申威处理器及美国的英伟达芯片上进行训练与推理。

参考 ChatGPT 的设计思路, ChatGLM 在千亿基座模型 GLM-130B 中注入了代码预训练,通过有监督微调(Supervised Fine-Tuning)等技术实现人类意图对齐。

GLM-130B 训练

【2023-6-12】大规模语言模型系列技术:以GLM-130B为例

训练难点

训练中遇到的难题及解决方案

训练中最大的挑战: 如何平衡训练稳定性(高精度低效)还是训练效率(低精度高效)

  • 训练稳定方面,团队在Attention score层使用了softmax in 32避免上下溢出,并调小了embbeding层梯度,缓解前期的梯度爆炸问题。
  • 训练效率方面,为了实现并行训练策略,采用了多种方案:
    • 采用ZeRO优化器在数据并行组内分摊优化器状态
    • 模型并行:将模型参数分布到多个GPU上
    • 算子融合
    • 流水线平衡
    • 跨平台兼容

训练中的工程挑战

  • 首先,大模型最大的训练挑战就是它的训练成本非常高,体现在训练过程中模型的计算量非常大

  • 其次,是内存的挑战,体现在单个大模型的参数量早已超出单张显卡的显存大小。同时,在训练过程中,除了模型参数外,优化器的状态(例如Adam优化器里面的m、v等等)也需要占用大量内存
  • 接下来围绕上述两个挑战,介绍一下大模型训练的常用技术(并非全部是GLM-130B实际采用的)。

混合精度训练

  • 混合精度训练实际“混”的是16位浮点数和32位浮点数
  • 从右面nvidia提供的表格可以看到FP16 Tensor Core和BFLOAT16 Tensor Core比Tensor Float 32(TF32)快两倍,比FP32快10倍以上
  • 16位浮点数有两种格式:FP16/BF16,FP16比BF16精度更高,但是表示范围更小。在实际训练中,表示范围更加重要(为了防止上溢和下溢),因此更倾向于选择BF16。但是要注意,BF16只支持3080、A100、TPU

  • 这里介绍的是,混合精度训练的执行流程,上半部分对应的是计算损失函数和反向传播计算梯度,这部分采用的是fp16,下半部分的优化器状态则采用fp32格式
  • 想进一步了解的话可以看下原论文: Mixed Precision Training

  • 这里介绍的是采用了混合精度训练后,训练过程中的内存占用
  • 可以看到,大部分的存储占用都在激活函数的保存存储上

激活函数重演

  • 前面提到的激活函数占用存储过多的问题,可以通过激活函数重演的技术来解决

数据并行

  • 这里介绍的数据并行中的参数服务器(Parameter Server)方案

  • 这里介绍的是另一种数据并行方案:All-Reduce

模型并行

  • 需要模型并行解决不了一张显卡放不下单个模型的问题
  • 模型并行最简单的方法是张量并行,这个方法将矩阵切分到不同的显卡上,分别计算,最后再通过all-reduce操作把计算结果同步回来,显然这种方法通信量是比较大的,因此实际更多应用在单机多卡,有nvlink的场景下
  • 对于更大的模型,做模型并行的方案是流水线并行

  • 这里介绍的是流水线并行最朴素的实现
  • 先Forward,再Backward,最后再做更新,每次只有1张卡在运算
  • 其中的Bubble time(也就是图中空白区域)比较大,也就是这种方案下显卡的整体使用率是不高的

  • GPipe是一种改进的流水线并行方案,进一步降低了气泡占比

  • 但是GPipe的峰值内存占用比较大

  • 流水线并行策略:1F1B

  • 流水线并行策略:Interleaved 1F1B

并行策略

GLM-130B同时使用了多种并行策略

稳定性

训练中的稳定性问题

  • 稳定性可能是训练过程中最大的问题了
  • 在Major Issues中可以看到大部分的问题都和disconverge/spike有关

  • 这张图里的经验价值千金

量化

  • 模型量化的目标是降低推理阶段的显存成本
  • 上图中的策略都没有采用

  • 值得注意的是,GLM系列的量化方案只降低了显存的占用,中间的计算量(推理时间)并不会有明显下降,因为仍然使用FP16进行计算。

训练成果

  • 双语:同时支持中文和英文
  • 高精度(英文):在LAMBADA上优于GPT-3 175B(+4.0%)、OPT-175B(+5.5%)和BLOOM-176B(+13.0%),在MMLU上略优于GPT-3 175B(+0.9%)
  • 高精度(中文):在7个零样本CLUE数据集(+24.26%)和5个零样本FewCLUE数据集(+12.75%)上明显优于ERNIE Titan 3.0 260B
  • 高效推理:支持用一台A100(8×40G)/V100(8×32G)服务器基于FasterTransformer进行快速推理(相比Megatron提速最高可达2.5倍)
  • 低门槛推理:最低量化到INT4,则可在4张3090/8张 2080Ti上完成推理
  • 跨平台:支持在NVIDIA、海关DCU、昇腾910和神威处理器上的训练

GLM-130B API

【2023-11-13】火山方舟大模型服务 api 支持多种大模型调用

  • baichuan: 7b
  • ChatGLM:多个版本 6B,130B,ChatGLM2-Pro
  • MiniMax
  • Skylark: 多个版本 lite, plus, pro, chat(豆包)

调用语言

  • Go
  • Java
  • Python

【2023-3-14】ChatGLM-6B

ChatGLM-6B 介绍

【2023-3-14】ChatGLM-6B 是一个开源的、支持中英双语的对话语言模型,基于 General Language Model (GLM) 架构,具有 62 亿参数。

  • 由清华大学知识工程 (KEG) 实验室和智谱AI公司与于2023年共同训练的语言模型
  • 结合模型量化技术,用户可以在消费级显卡上进行本地部署(INT4 量化级别下最低只需 6GB 显存)。
  • ChatGLM-6B 使用了和 ChatGLM 相同的技术,针对中文问答和对话进行了优化。经过约 1T 标识符的中英双语训练,辅以监督微调反馈自助人类反馈强化学习等技术的加持,62 亿参数的 ChatGLM-6B 虽然规模不及千亿模型,但大大降低了推理成本,提升了效率,并且已经能生成相当符合人类偏好的回答。
  • 模型开源地址, huggingface
  • finetune代码:ChatGLM-Tuning
  • API: 调用方法参考智谱AI, ChatGLM 商用 Issue
  • 【2023-3-17】issue: Cannot import name ‘convert_file_size_to_int’ from ‘transformers.utils.hub’

ChatGLM-6B 模型结构

ChatGLM-6B 模型结构:采用了prefix decoder-only的transformer模型框架,在输入上采用双向的注意力机制,在输出上采用单向注意力机制。

模型细节几点改动:

  • embedding层梯度缩减:为了提升训练稳定性,减小了embedding层的梯度。具体地, word_embedding = word_embedding * alpha + word_embedding.detatch() * (1-alpha) ,其中 alpha=0.1 ,这里detach()函数的作用是返回一个新的tensor,并从计算图分离出来。梯度缩减的效果相当于把embedding层的梯度缩小了10倍,减小了梯度的范数。
  • layer normalization:采用了基于Deep Norm的post layer norm。
  • 激活函数:采用了GeGLU激活函数。相比于普通的FFN,使用线形门控单元的GLU新增了一个权重矩阵,共有三个权重矩阵,为了保持参数量一致,中间维度采用了 8d/3 ,而不是4d。
  • 位置编码:去除了绝对位置编码,采用了旋转位置编码RoPE

训练目标: ChatGLM-6B的训练任务是自回归文本填空。相比于采用causal decoder-only结构的大语言模型,采用prefix decoder-only结构的ChatGLM-6B存在一个劣势:训练效率低。

  • causal decoder结构会在所有的token上计算损失,而prefix decoder只会在输出上计算损失,而不计算输入上的损失。
  • 相同数量的训练tokens的情况下,prefix decoder要比causal decoder的效果差,因为训练过程中实际用到的tokens数量要更少。

另外,ChatGPT的成功已经证明了causal decoder结构的大语言模型可以获得非常好的few-shot和zero-shot生成能力,通过指令微调可以进一步激发模型的能力。至于prefix decoder结构的大语言模型能否获得相当的few-shot和zero-shot能力还缺少足够的验证。

关于tokenizer,ChatGLM在25GB的中英双语数据上训练了SentencePiece作为tokenizer,词表大小为130528。

ChatGLM-6B 特点

ChatGLM-6B 具备以下特点:

  • 充分的中英双语预训练:ChatGLM-6B 在 1:1 比例的中英语料上训练了 1T 的 token 量,兼具双语能力。
  • 优化的模型架构和大小:吸取 GLM-130B 训练经验,修正了二维 RoPE 位置编码实现,使用传统 FFN 结构。6B(62 亿)的参数大小,也使得研究者和个人开发者自己微调和部署 ChatGLM-6B 成为可能。
  • 较低的部署门槛:FP16 半精度下,ChatGLM-6B 需要至少 13 GB 的显存进行推理,结合模型量化技术,这一需求可以进一步降低到 10GB(INT8) 和 6GB(INT4),使得 ChatGLM-6B 可以部署在消费级显卡上。
    • 在现代 GPU 和 TPU 上,tensor 计算可以在 16 位浮点上高效完成。 但并非简单将 tensor 的 dtype 设置为 torch.float16。对于某些部分,如 loss,仍然需要 32 位精度。
    • 半精度优化能使内存占用减半,或者说能使有效内存翻倍。
  • 更长的序列长度:相比 GLM-10B(序列长度 1024),ChatGLM-6B 序列长度达 2048,支持更长对话和应用。
  • 人类意图对齐训练:使用了监督微调(Supervised Fine-Tuning)、反馈自助(Feedback Bootstrap)、人类反馈强化学习(Reinforcement Learning from Human Feedback)等方式,使模型初具理解人类指令意图的能力。输出格式为 markdown,方便展示。

ChatGLM-6B 不足

不过由于 ChatGLM-6B 模型的容量较小,不可避免的存在一些局限和不足,目前已知其具有相当多的局限性,如:

  • 事实性/数学逻辑错误,可能生成有害/有偏见内容,较弱的上下文能力,自我认知混乱,以及对英文指示生成与中文指示完全矛盾的内容。
  • 相对较弱的模型记忆和语言能力。在面对许多事实性知识任务时,ChatGLM-6B 可能会生成不正确的信息,也不太擅长逻辑类问题(如数学、编程)的解答。
  • 可能会产生有害说明或有偏见的内容:ChatGLM-6B 只是一个初步与人类意图对齐的语言模型,可能会生成有害、有偏见的内容。
  • 较弱的多轮对话能力:ChatGLM-6B 的上下文理解能力还不够充分,在面对长答案生成和多轮对话的场景时,可能会出现上下文丢失和理解错误的情况。
  • 参考:清华系千亿基座对话模型ChatGLM启动内测,开源单卡版模型

ChatGLM-6B 实践

ChatGLM-6B 接入

from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()

print(type(model), model) # 显示模型结构
model = model.eval() # 是否必须?

response, history = model.chat(tokenizer, "你好", history=[])
print(response)
# 你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。
response, history = model.chat(tokenizer, "晚上睡不着应该怎么办", history=history)
print(response)

ChatGLM-6B 本地部署

【2023-4-13】清华ChatGLM-6B模型本地部署

硬件要求

git clone https://github.com/THUDM/ChatGLM-6B.git
pip install -r requirements.txt

使用

from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().quantize(4).cuda()
model = model.eval()
response, history = model.chat(tokenizer, "你是谁", history=[])
# The dtype of attention mask (torch.int64) is not bool
print(response)
# 我是一个名为 ChatGLM-6B 的人工智能助手,是基于清华大学 KEG 实验室和智谱 AI 公司于 2023 年共同训练的语言模型开发的。我的任务是针对用户的问题和要求提供适当的答复和支持。
response, history = model.chat(tokenizer, "你会什么", history=history)
print(response)
# ---------
# 默认的
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
# 我的 INT4的
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().quantize(4).cuda()

【2023-6-25】ChatGLM2-6B – 升级

【2023-6-25】ChatGLM2-6B:性能大幅提升,8-32k上下文,推理提速42%

  • CEval榜单,ChatGLM2 暂时位居 Rank 0,ChatGLM2-6B 位居 Rank 6
  • 截至6月25日 ChatGLM2 模型以 71.1 的分数位居 Rank 0 ,ChatGLM2-6B 模型以 51.7 的分数位居 Rank 6,是榜单上排名最高的开源模型。

ChatGLM2-6B的安装请参考官方

ChatGLM2-6B 是开源中英双语对话模型 ChatGLM-6B 的第二代版本,在保留了初代模型对话流畅、部署门槛较低等众多优秀特性的基础之上,ChatGLM2-6B 引入了如下新特性:

  • 更强大的性能:基于 ChatGLM 初代模型的开发经验,我们全面升级了 ChatGLM2-6B 的基座模型。ChatGLM2-6B 使用了 GLM 的混合目标函数,经过了 1.4T 中英标识符的预训练与人类偏好对齐训练,评测结果显示,相比于初代模型,ChatGLM2-6B 在 MMLU(+23%)、CEval(+33%)、GSM8K(+571%) 、BBH(+60%)等数据集上的性能取得了大幅度的提升,在同尺寸开源模型中具有较强的竞争力。
  • 更长的上下文:基于 FlashAttention 技术,我们将基座模型的上下文长度(Context Length)由 ChatGLM-6B 的 2K 扩展到了 32K,并在对话阶段使用 8K 的上下文长度训练,允许更多轮次的对话。但当前版本的 ChatGLM2-6B 对单轮超长文档的理解能力有限,我们会在后续迭代升级中着重进行优化。
  • 更高效的推理:基于 Multi-Query Attention 技术,ChatGLM2-6B 有更高效的推理速度和更低的显存占用:在官方的模型实现下,推理速度相比初代提升了 42%,INT4 量化下,6G 显存支持的对话长度由 1K 提升到了 8K。
  • 更开放的协议:ChatGLM2-6B 权重对学术研究完全开放,在获得官方的书面许可后,亦允许商业使用。如果发现开源模型对业务有用,我们欢迎您对下一代模型 ChatGLM3 研发的捐赠。

【2023-7-14】ChatGLM2-6B,免费商用,扫码登记即可

自 3 月 14 日发布 ChatGLM-6B 及 6 月 25 日发布 ChatGLM2-6B 以来,这两个模型在 Huggingface 上的下载量已经先后超过了 300 万和 120 万。非常感谢大家对 ChatGLM 模型的支持。为了更好地支持国产大模型开源生态的繁荣发展,经智谱 AI 及清华 KEG 实验室决定,自即日起 ChatGLM-6B 和 ChatGLM2-6B 权重对学术研究完全开放,并且在完成企业登记获得授权后,允许免费商业使用。

相比于初代模型,ChatGLM2-6B 多个维度的能力都取得了提升,对比示例

  • 数理逻辑
  • 知识推理
  • 长文档理解

部署

(1)下载方式

  • 手动下载: 下载完毕上传到租赁GPU服务器就行,可能比较费流量
  • git lfs 工具: 下载大文件的工具(受网络限制 ,可能需要多次尝试)
git clone https://github.com/THUDM/ChatGLM-6B
# model文件最好像我这样放置,好找一些~
cd ChatGLM-6B
mkdir model
cd model

apt-get update
apt-get install git-lfs 
git-lfs install 
git lfs clone https://huggingface.co/THUDM/chatglm2-6b 
# 下载glm2 代码、和模型文件
# 连接不稳定,可能需要多clone几次,或者直接本机download然后上传(ps 还是自己upload万无一失)

(2)环境部署

conda create -n chatglm2 python=3.10
conda activate chatglm2 
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple #这里配了下载源,更快一些!

(3)修改代码

修改web_demo.py配置信息

  • model_path : 加载本地模型,而不是从huggingface上pull
  • launch : 默认不会生成可访问的公网url链接
from transformers import AutoModel, AutoTokenizer
import gradio as gr
import mdtex2html
from utils import load_model_on_gpus

#model_path = 'THUDM/chatglm2-6b'
model_path = './chatglm2-6b' # 本地模型地址
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_path, trust_remote_code=True).cuda() # cpu -> gpu
# 多显卡支持,使用下面两行代替上面一行,将num_gpus改为你实际的显卡数量
# from utils import load_model_on_gpus
# model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2) # 多卡模式
model = model.eval()

# ....
# demo.queue().launch(share=True, inbrowser=True) # old
demo.queue().launch(share=True, inbrowser=True,server_name='0.0.0.0', server_port=7860)) # new

(4)启动服务

  • gradio 公网url有时失败 → ssh隧道 or 其它平台(如 autoDL
python web_demo.py
  • api 方式: 运行api.py文件,默认部署在本地的 8000 端口,通过 POST 方法进行调用
curl -X POST "http://127.0.0.1:8000" \
     -H 'Content-Type: application/json' \
     -d '{"prompt": "你好", "history": []}'
# {"response":"你好 !我是人工智能助手 ChatGLM2-6B,很高兴见到你,欢迎问我任何问题。","history":[["你好","你好 !我是人工智能助手 ChatGLM2-6B,很高兴见到你,欢迎问我任何问题。"]],"status":200,"time":"2023-09-25 22:23:34"}

OpenAI 格式的流式 API 部署

import openai
if __name__ == "__main__":
    openai.api_base = "http://localhost:8000/v1"
    openai.api_key = "none"
    for chunk in openai.ChatCompletion.create(
        model="chatglm2-6b",
        messages=[
            {"role": "user", "content": "你好"}
        ],
        stream=True
    ):
        if hasattr(chunk.choices[0].delta, "content"):
            print(chunk.choices[0].delta.content, end="", flush=True)

知识注入

【2023-7-11】单样本微调给ChatGLM2注入知识

  • 借助 AdaLoRA算法,使用1条样本对ChatGLM2-6b实施微调。几分钟就成功注入了有关知识
  • AdaLoRA是LoRA方法的一种升级版本,使用方法与LoRA基本一样。主要差异
    • LoRA中不同训练参数矩阵的秩被固定。
    • 但AdaLoRA中不同训练参数矩阵的秩是会在一定范围内自适应调整的,那些更重要的训练参数矩阵会分配到更高的秩。
    • AdaLoRA的效果会好于LoRA。

备注

  • (1) 只需要1条样本,很少的训练时间,就可以通过微调给LLM注入知识。
  • (2) LLM 可以看做类似Key-Value形式的知识数据库,支持增删改查。通过微调可以增删修改知识,通过条件生成可以查询提取知识。
  • (3) LoRA 微调是一种高效的融入学习算法。类似人类把新知识融入现有知识体系的学习过程。学习时无需新知识特别多的样本,学习后原有的庞大知识和能力可以基本不受影响。
from peft import get_peft_model, AdaLoraConfig, TaskType

#训练时节约GPU占用
model.config.use_cache=False
model.supports_gradient_checkpointing = True  #
model.gradient_checkpointing_enable()
model.enable_input_require_grads()

peft_config = AdaLoraConfig(
    task_type=TaskType.CAUSAL_LM, inference_mode=False,
    r=8,
    lora_alpha=32, lora_dropout=0.1,
    target_modules=["query", "value"]
)

peft_model = get_peft_model(model, peft_config)

peft_model.is_parallelizable = True
peft_model.model_parallel = True
peft_model.print_trainable_parameters()

验证模型

from peft import PeftModel 
ckpt_path = 'single_chatglm2'
model_old = AutoModel.from_pretrained("chatglm2-6b",
                                  load_in_8bit=False, 
                                  trust_remote_code=True)
peft_loaded = PeftModel.from_pretrained(model_old,ckpt_path).cuda()
model_new = peft_loaded.merge_and_unload() #合并lora权重
chatglm = ChatGLM(model_new,tokenizer,max_chat_rounds=20) #支持多轮对话,可以从之前对话上下文提取知识。

ChatGLM3

【2023-10-27】ChatGLM3 是智谱AI和清华大学 KEG 实验室联合发布的新一代对话预训练模型。

ChatGLM3 还推出了可在手机上部署的端测模型ChatGLM3-1.5BChatGLM3-3B,支持包括vivo、小米、三星在内的多款手机以及车载平台,甚至支持移动平台上CPU芯片的推理,速度可达20tokens每秒(token是语言模型中用来表示单词或短语的符号)。

【2023-11-7】ChatGLM3-6b

【2023-11-7】chatglm3-6b

ChatGLM3-6B 是 ChatGLM3 系列中的开源模型,在保留了前两代模型对话流畅、部署门槛低等众多优秀特性的基础上,ChatGLM3-6B 引入了如下特性:

  1. 更强大的基础模型: ChatGLM3-6B 的基础模型 ChatGLM3-6B-Base 采用了更多样的训练数据、更充分的训练步数和更合理的训练策略。在语义、数学、推理、代码、知识等不同角度的数据集上测评显示,ChatGLM3-6B-Base 具有在 10B 以下的基础模型中最强的性能
  2. 更完整的功能支持: ChatGLM3-6B 采用了全新设计的 Prompt 格式,除正常的多轮对话外。同时原生支持工具调用(Function Call)、代码执行(Code Interpreter)和 Agent 任务等复杂场景。
  3. 更全面的开源序列:

ChatGLM3 GitHub

全新的 Prompt 格式

  • 为了避免用户输入的注入攻击以及统一 Code Interpreter,Tool & Agent 等任务的输入,ChatGLM3 采用了全新的对话格式。

整体结构

  • ChatGLM3 对话的格式由若干对话组成,其中每个对话包含对话头和内容,一个典型的多轮对话结构如下
<|system|>
You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.
<|user|>
Hello
<|assistant|>
Hello, I'm ChatGLM3. What can I assist you today?

其中 <|role|> 部分使用 special token 表示,无法从文本形式被 tokenizer 编码以防止注入。metadata 部分采用纯文本表示,为可选内容。

  • <|system|>:系统信息,设计上可穿插于对话中,但目前规定仅可以出现在开头
  • <|user|>:用户
    • 不会连续出现多个来自 <|user|> 的信息
  • <|assistant|>:AI 助手
    • 在出现之前必须有一个来自 <|user|> 的信息
  • <|observation|>:外部的返回结果
    • 必须在 <|assistant|> 的信息之后

问题

ChatGLM3-6b 部署

git clone https://github.com/THUDM/ChatGLM3
cd ChatGLM3
# transformers 库版本推荐为 4.30.2,torch 推荐使用 2.0 及以上
pip install -r requirements.txt
# 本地加载模型并启动 demo
# export MODEL_PATH=/path/to/model
streamlit run main.py

ChatGLM3 Demo

ChatGLM3 Demo 拥有三种模式:

  • Chat: 对话模式,在此模式下可以与模型进行对话。
    • 直接在侧边栏修改 top_p, temperature, System Prompt 等参数来调整模型的行为
  • Tool: 工具模式,模型除了对话外,还可以通过工具进行其他操作。
    • 在 tool_registry.py 中注册新工具来增强模型能力,@register_tool 装饰函数
    • 通过 Manual mode 进入手动模式,YAML 直接指定工具列表,但需要手动将工具的输出反馈给模型。
  • Code Interpreter: 代码解释器模式,模型可以在一个 Jupyter 环境中执行代码并获取结果,以完成复杂任务。
    • 执行更为复杂的任务,例如绘制图表、执行符号运算等等。

示例

  • 函数名称为工具名称
  • 函数 docstring 即为工具的说明;
  • 工具参数使用 Annotated[typ: type, description: str, required: bool] 标注参数的类型、描述和是否必须。
@register_tool
def get_weather(
    city_name: Annotated[str, 'The name of the city to be queried', True],
) -> str:
    """
    Get the weather for `city_name` in the following week
    """
    ...

文件结构

# 模型文件
pytorch_model-00001-of-00007.bin

配置

config.json

configuration_chatglm.py

分词

tokenization_chatglm.py

import json
import os
import torch
from typing import List, Optional, Union, Dict
from sentencepiece import SentencePieceProcessor
from transformers import PreTrainedTokenizer
from transformers.utils import logging, PaddingStrategy
from transformers.tokenization_utils_base import EncodedInput, BatchEncoding


class SPTokenizer:
    def __init__(self, model_path: str):
        # reload tokenizer
        assert os.path.isfile(model_path), model_path
        self.sp_model = SentencePieceProcessor(model_file=model_path)

        # BOS / EOS token IDs
        self.n_words: int = self.sp_model.vocab_size()
        self.bos_id: int = self.sp_model.bos_id()
        self.eos_id: int = self.sp_model.eos_id()
        self.pad_id: int = self.sp_model.unk_id()
        assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()

        special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop", "<|system|>", "<|user|>", "<|assistant|>",
                          "<|observation|>"]
        self.special_tokens = {}
        self.index_special_tokens = {}
        for token in special_tokens:
            self.special_tokens[token] = self.n_words
            self.index_special_tokens[self.n_words] = token
            self.n_words += 1

    def tokenize(self, s: str):
        return self.sp_model.EncodeAsPieces(s)

    def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
        assert type(s) is str
        t = self.sp_model.encode(s)
        if bos:
            t = [self.bos_id] + t
        if eos:
            t = t + [self.eos_id]
        return t

    def decode(self, t: List[int]) -> str:
        text, buffer = "", []
        for token in t:
            if token in self.index_special_tokens:
                if buffer:
                    text += self.sp_model.decode(buffer)
                    buffer = []
                text += self.index_special_tokens[token]
            else:
                buffer.append(token)
        if buffer:
            text += self.sp_model.decode(buffer)
        return text

    def decode_tokens(self, tokens: List[str]) -> str:
        text = self.sp_model.DecodePieces(tokens)
        return text

    def convert_token_to_id(self, token):
        """ Converts a token (str) in an id using the vocab. """
        if token in self.special_tokens:
            return self.special_tokens[token]
        return self.sp_model.PieceToId(token)

    def convert_id_to_token(self, index):
        """Converts an index (integer) in a token (str) using the vocab."""
        if index in self.index_special_tokens:
            return self.index_special_tokens[index]
        if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
            return ""
        return self.sp_model.IdToPiece(index)


class ChatGLMTokenizer(PreTrainedTokenizer):
    vocab_files_names = {"vocab_file": "tokenizer.model"}

    model_input_names = ["input_ids", "attention_mask", "position_ids"]

    def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, **kwargs):
        self.name = "GLMTokenizer"

        self.vocab_file = vocab_file
        self.tokenizer = SPTokenizer(vocab_file)
        self.special_tokens = {
            "<bos>": self.tokenizer.bos_id,
            "<eos>": self.tokenizer.eos_id,
            "<pad>": self.tokenizer.pad_id
        }
        super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces, **kwargs)

    def get_command(self, token):
        if token in self.special_tokens:
            return self.special_tokens[token]
        assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
        return self.tokenizer.special_tokens[token]

    @property
    def unk_token(self) -> str:
        return "<unk>"

    @property
    def pad_token(self) -> str:
        return "<unk>"

    @property
    def pad_token_id(self):
        return self.get_command("<pad>")

    @property
    def eos_token(self) -> str:
        return "</s>"

    @property
    def eos_token_id(self):
        return self.get_command("<eos>")

    @property
    def vocab_size(self):
        return self.tokenizer.n_words

    def get_vocab(self):
        """ Returns vocab as a dict """
        vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
        vocab.update(self.added_tokens_encoder)
        return vocab

    def _tokenize(self, text, **kwargs):
        return self.tokenizer.tokenize(text)

    def _convert_token_to_id(self, token):
        """ Converts a token (str) in an id using the vocab. """
        return self.tokenizer.convert_token_to_id(token)

    def _convert_id_to_token(self, index):
        """Converts an index (integer) in a token (str) using the vocab."""
        return self.tokenizer.convert_id_to_token(index)

    def convert_tokens_to_string(self, tokens: List[str]) -> str:
        return self.tokenizer.decode_tokens(tokens)

    def save_vocabulary(self, save_directory, filename_prefix=None):
        """
        Save the vocabulary and special tokens file to a directory.
        Args:
            save_directory (`str`):
                The directory in which to save the vocabulary.
            filename_prefix (`str`, *optional*):
                An optional prefix to add to the named of the saved files.
        Returns:
            `Tuple(str)`: Paths to the files saved.
        """
        if os.path.isdir(save_directory):
            vocab_file = os.path.join(
                save_directory, self.vocab_files_names["vocab_file"]
            )
        else:
            vocab_file = save_directory

        with open(self.vocab_file, 'rb') as fin:
            proto_str = fin.read()

        with open(vocab_file, "wb") as writer:
            writer.write(proto_str)

        return (vocab_file,)

    def get_prefix_tokens(self):
        prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
        return prefix_tokens

    def build_single_message(self, role, metadata, message):
        assert role in ["system", "user", "assistant", "observation"], role
        role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n")
        message_tokens = self.tokenizer.encode(message)
        tokens = role_tokens + message_tokens
        return tokens

    def build_chat_input(self, query, history=None, role="user"):
        if history is None:
            history = []
        input_ids = []
        for item in history:
            content = item["content"]
            if item["role"] == "system" and "tools" in item:
                content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False)
            input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content))
        input_ids.extend(self.build_single_message(role, "", query))
        input_ids.extend([self.get_command("<|assistant|>")])
        return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)

    def build_inputs_with_special_tokens(
            self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
        adding special tokens. A BERT sequence has the following format:
        - single sequence: `[CLS] X [SEP]`
        - pair of sequences: `[CLS] A [SEP] B [SEP]`
        Args:
            token_ids_0 (`List[int]`):
                List of IDs to which the special tokens will be added.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.
        Returns:
            `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
        """
        prefix_tokens = self.get_prefix_tokens()
        token_ids_0 = prefix_tokens + token_ids_0
        if token_ids_1 is not None:
            token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command("<eos>")]
        return token_ids_0

    def _pad(
            self,
            encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
            max_length: Optional[int] = None,
            padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
            pad_to_multiple_of: Optional[int] = None,
            return_attention_mask: Optional[bool] = None,
    ) -> dict:
        """
        Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
        Args:
            encoded_inputs:
                Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
            max_length: maximum length of the returned list and optionally padding length (see below).
                Will truncate by taking into account the special tokens.
            padding_strategy: PaddingStrategy to use for padding.
                - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
                - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
                - PaddingStrategy.DO_NOT_PAD: Do not pad
                The tokenizer padding sides are defined in self.padding_side:
                    - 'left': pads on the left of the sequences
                    - 'right': pads on the right of the sequences
            pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
                This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
                `>= 7.5` (Volta).
            return_attention_mask:
                (optional) Set to False to avoid returning attention mask (default: set to model specifics)
        """
        # Load from model defaults
        assert self.padding_side == "left"

        required_input = encoded_inputs[self.model_input_names[0]]
        seq_length = len(required_input)

        if padding_strategy == PaddingStrategy.LONGEST:
            max_length = len(required_input)

        if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
            max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of

        needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length

        # Initialize attention mask if not present.
        if "attention_mask" not in encoded_inputs:
            encoded_inputs["attention_mask"] = [1] * seq_length

        if "position_ids" not in encoded_inputs:
            encoded_inputs["position_ids"] = list(range(seq_length))

        if needs_to_be_padded:
            difference = max_length - len(required_input)

            if "attention_mask" in encoded_inputs:
                encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
            if "position_ids" in encoded_inputs:
                encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
            encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input

        return encoded_inputs

模型结构

modeling_chatglm.py

""" PyTorch ChatGLM model. """

import math
import copy
import warnings
import re
import sys

import torch
import torch.utils.checkpoint
import torch.nn.functional as F
from torch import nn
from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
from torch.nn.utils import skip_init
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
from copy import deepcopy

from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
    SequenceClassifierOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput

from .configuration_chatglm import ChatGLMConfig

# flags required to enable jit fusion kernels

if sys.platform != 'darwin':
    torch._C._jit_set_profiling_mode(False)
    torch._C._jit_set_profiling_executor(False)
    torch._C._jit_override_can_fuse_on_cpu(True)
    torch._C._jit_override_can_fuse_on_gpu(True)

logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
_CONFIG_FOR_DOC = "ChatGLMConfig"

CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "THUDM/chatglm3-6b",
    # See all ChatGLM models at https://huggingface.co/models?filter=chatglm
]


def default_init(cls, *args, **kwargs):
    return cls(*args, **kwargs)


class InvalidScoreLogitsProcessor(LogitsProcessor):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        if torch.isnan(scores).any() or torch.isinf(scores).any():
            scores.zero_()
            scores[..., 5] = 5e4
        return scores


class PrefixEncoder(torch.nn.Module):
    """
    The torch.nn model to encode the prefix
    Input shape: (batch-size, prefix-length)
    Output shape: (batch-size, prefix-length, 2*layers*hidden)
    """

    def __init__(self, config: ChatGLMConfig):
        super().__init__()
        self.prefix_projection = config.prefix_projection
        if self.prefix_projection:
            # Use a two-layer MLP to encode the prefix
            kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2
            self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size)
            self.trans = torch.nn.Sequential(
                torch.nn.Linear(kv_size, config.hidden_size),
                torch.nn.Tanh(),
                torch.nn.Linear(config.hidden_size, kv_size)
            )
        else:
            self.embedding = torch.nn.Embedding(config.pre_seq_len,
                                                config.num_layers * config.kv_channels * config.multi_query_group_num * 2)

    def forward(self, prefix: torch.Tensor):
        if self.prefix_projection:
            prefix_tokens = self.embedding(prefix)
            past_key_values = self.trans(prefix_tokens)
        else:
            past_key_values = self.embedding(prefix)
        return past_key_values


def split_tensor_along_last_dim(
        tensor: torch.Tensor,
        num_partitions: int,
        contiguous_split_chunks: bool = False,
) -> List[torch.Tensor]:
    """Split a tensor along its last dimension.
    Arguments:
        tensor: input tensor.
        num_partitions: number of partitions to split the tensor
        contiguous_split_chunks: If True, make each chunk contiguous
                                 in memory.
    Returns:
        A list of Tensors
    """
    # Get the size and dimension.
    last_dim = tensor.dim() - 1
    last_dim_size = tensor.size()[last_dim] // num_partitions
    # Split.
    tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
    # Note: torch.split does not create contiguous tensors by default.
    if contiguous_split_chunks:
        return tuple(chunk.contiguous() for chunk in tensor_list)

    return tensor_list


class RotaryEmbedding(nn.Module):
    def __init__(self, dim, original_impl=False, device=None, dtype=None):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
        self.register_buffer("inv_freq", inv_freq)
        self.dim = dim
        self.original_impl = original_impl

    def forward_impl(
            self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
    ):
        """Enhanced Transformer with Rotary Position Embedding.
        Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
        transformers/rope/__init__.py. MIT License:
        https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
        """
        # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
        theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem))

        # Create position indexes `[0, 1, ..., seq_len - 1]`
        seq_idx = torch.arange(seq_len, dtype=torch.float, device=device)

        # Calculate the product of position index and $\theta_i$
        idx_theta = torch.outer(seq_idx, theta).float()

        cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)

        # this is to mimic the behaviour of complex32, else we will get different results
        if dtype in (torch.float16, torch.bfloat16, torch.int8):
            cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
        return cache

    def forward(self, max_seq_len, offset=0):
        return self.forward_impl(
            max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
        )


@torch.jit.script
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
    # x: [sq, b, np, hn]
    sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
    rot_dim = rope_cache.shape[-2] * 2
    x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
    # truncate to support variable sizes
    rope_cache = rope_cache[:sq]
    xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
    rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
    x_out2 = torch.stack(
        [
            xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
            xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
        ],
        -1,
    )
    x_out2 = x_out2.flatten(3)
    return torch.cat((x_out2, x_pass), dim=-1)


class RMSNorm(torch.nn.Module):
    def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
        self.eps = eps

    def forward(self, hidden_states: torch.Tensor):
        input_dtype = hidden_states.dtype
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.eps)

        return (self.weight * hidden_states).to(input_dtype)


class CoreAttention(torch.nn.Module):
    def __init__(self, config: ChatGLMConfig, layer_number):
        super(CoreAttention, self).__init__()

        self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)

        projection_size = config.kv_channels * config.num_attention_heads

        # Per attention head and per partition values.
        self.hidden_size_per_partition = projection_size
        self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
        self.num_attention_heads_per_partition = config.num_attention_heads

        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff
        self.coeff = coeff

        self.attention_dropout = torch.nn.Dropout(config.attention_dropout)

    def forward(self, query_layer, key_layer, value_layer, attention_mask):
        pytorch_major_version = int(torch.__version__.split('.')[0])
        if pytorch_major_version >= 2:
            query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
            if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
                context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
                                                                                 is_causal=True)
            else:
                if attention_mask is not None:
                    attention_mask = ~attention_mask
                context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
                                                                                 attention_mask)
            context_layer = context_layer.permute(2, 0, 1, 3)
            new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
            context_layer = context_layer.reshape(*new_context_layer_shape)
        else:
            # Raw attention scores

            # [b, np, sq, sk]
            output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))

            # [sq, b, np, hn] -> [sq, b * np, hn]
            query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
            # [sk, b, np, hn] -> [sk, b * np, hn]
            key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)

            # preallocting input tensor: [b * np, sq, sk]
            matmul_input_buffer = torch.empty(
                output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,
                device=query_layer.device
            )

            # Raw attention scores. [b * np, sq, sk]
            matmul_result = torch.baddbmm(
                matmul_input_buffer,
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
                beta=0.0,
                alpha=(1.0 / self.norm_factor),
            )

            # change view to [b, np, sq, sk]
            attention_scores = matmul_result.view(*output_size)

            # ===========================
            # Attention probs and dropout
            # ===========================

            # attention scores and attention mask [b, np, sq, sk]
            if self.attention_softmax_in_fp32:
                attention_scores = attention_scores.float()
            if self.coeff is not None:
                attention_scores = attention_scores * self.coeff
            if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
                attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
                                            device=attention_scores.device, dtype=torch.bool)
                attention_mask.tril_()
                attention_mask = ~attention_mask
            if attention_mask is not None:
                attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
            attention_probs = F.softmax(attention_scores, dim=-1)
            attention_probs = attention_probs.type_as(value_layer)

            # This is actually dropping out entire tokens to attend to, which might
            # seem a bit unusual, but is taken from the original Transformer paper.
            attention_probs = self.attention_dropout(attention_probs)
            # =========================
            # Context layer. [sq, b, hp]
            # =========================

            # value_layer -> context layer.
            # [sk, b, np, hn] --> [b, np, sq, hn]

            # context layer shape: [b, np, sq, hn]
            output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
            # change view [sk, b * np, hn]
            value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
            # change view [b * np, sq, sk]
            attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
            # matmul: [b * np, sq, hn]
            context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
            # change view [b, np, sq, hn]
            context_layer = context_layer.view(*output_size)
            # [b, np, sq, hn] --> [sq, b, np, hn]
            context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
            # [sq, b, np, hn] --> [sq, b, hp]
            new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
            context_layer = context_layer.view(*new_context_layer_shape)

        return context_layer


class SelfAttention(torch.nn.Module):
    """Parallel self-attention layer abstract class.
    Self-attention layer takes input with size [s, b, h]
    and returns output of the same size.
    """

    def __init__(self, config: ChatGLMConfig, layer_number, device=None):
        super(SelfAttention, self).__init__()
        self.layer_number = max(1, layer_number)

        self.projection_size = config.kv_channels * config.num_attention_heads

        # Per attention head and per partition values.
        self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
        self.num_attention_heads_per_partition = config.num_attention_heads

        self.multi_query_attention = config.multi_query_attention
        self.qkv_hidden_size = 3 * self.projection_size
        if self.multi_query_attention:
            self.num_multi_query_groups_per_partition = config.multi_query_group_num
            self.qkv_hidden_size = (
                    self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
            )
        self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,
                                         bias=config.add_bias_linear or config.add_qkv_bias,
                                         device=device, **_config_to_kwargs(config)
                                         )

        self.core_attention = CoreAttention(config, self.layer_number)

        # Output.
        self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
                               device=device, **_config_to_kwargs(config)
                               )

    def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
        if self.multi_query_attention:
            num_attention_heads = self.num_multi_query_groups_per_partition
        else:
            num_attention_heads = self.num_attention_heads_per_partition
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
            num_attention_heads,
            self.hidden_size_per_attention_head,
            dtype=dtype,
            device=device,
        )

    def forward(
            self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
    ):
        # hidden_states: [sq, b, h]

        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================
        # =====================
        # Query, Key, and Value
        # =====================

        # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
        mixed_x_layer = self.query_key_value(hidden_states)

        if self.multi_query_attention:
            (query_layer, key_layer, value_layer) = mixed_x_layer.split(
                [
                    self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
                    self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
                    self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
                ],
                dim=-1,
            )
            query_layer = query_layer.view(
                query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
            )
            key_layer = key_layer.view(
                key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
            )
            value_layer = value_layer.view(
                value_layer.size()[:-1]
                + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
            )
        else:
            new_tensor_shape = mixed_x_layer.size()[:-1] + \
                               (self.num_attention_heads_per_partition,
                                3 * self.hidden_size_per_attention_head)
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

            # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
            (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)

        # apply relative positional encoding (rotary embedding)
        if rotary_pos_emb is not None:
            query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
            key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)

        # adjust key and value for inference
        if kv_cache is not None:
            cache_k, cache_v = kv_cache
            key_layer = torch.cat((cache_k, key_layer), dim=0)
            value_layer = torch.cat((cache_v, value_layer), dim=0)
        if use_cache:
            kv_cache = (key_layer, value_layer)
        else:
            kv_cache = None

        if self.multi_query_attention:
            key_layer = key_layer.unsqueeze(-2)
            key_layer = key_layer.expand(
                -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
            )
            key_layer = key_layer.contiguous().view(
                key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
            )
            value_layer = value_layer.unsqueeze(-2)
            value_layer = value_layer.expand(
                -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
            )
            value_layer = value_layer.contiguous().view(
                value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
            )

        # ==================================
        # core attention computation
        # ==================================

        context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)

        # =================
        # Output. [sq, b, h]
        # =================

        output = self.dense(context_layer)

        return output, kv_cache


def _config_to_kwargs(args):
    common_kwargs = {
        "dtype": args.torch_dtype,
    }
    return common_kwargs


class MLP(torch.nn.Module):
    """MLP.
    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
    state back into h hidden dimension.
    """

    def __init__(self, config: ChatGLMConfig, device=None):
        super(MLP, self).__init__()

        self.add_bias = config.add_bias_linear

        # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
        self.dense_h_to_4h = nn.Linear(
            config.hidden_size,
            config.ffn_hidden_size * 2,
            bias=self.add_bias,
            device=device,
            **_config_to_kwargs(config)
        )

        def swiglu(x):
            x = torch.chunk(x, 2, dim=-1)
            return F.silu(x[0]) * x[1]

        self.activation_func = swiglu

        # Project back to h.
        self.dense_4h_to_h = nn.Linear(
            config.ffn_hidden_size,
            config.hidden_size,
            bias=self.add_bias,
            device=device,
            **_config_to_kwargs(config)
        )

    def forward(self, hidden_states):
        # [s, b, 4hp]
        intermediate_parallel = self.dense_h_to_4h(hidden_states)
        intermediate_parallel = self.activation_func(intermediate_parallel)
        # [s, b, h]
        output = self.dense_4h_to_h(intermediate_parallel)
        return output


class GLMBlock(torch.nn.Module):
    """A single transformer layer.
    Transformer layer takes input with size [s, b, h] and returns an
    output of the same size.
    """

    def __init__(self, config: ChatGLMConfig, layer_number, device=None):
        super(GLMBlock, self).__init__()
        self.layer_number = layer_number

        self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm

        self.fp32_residual_connection = config.fp32_residual_connection

        LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
        # Layernorm on the input data.
        self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
                                             dtype=config.torch_dtype)

        # Self attention.
        self.self_attention = SelfAttention(config, layer_number, device=device)
        self.hidden_dropout = config.hidden_dropout

        # Layernorm on the attention output
        self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
                                                      dtype=config.torch_dtype)

        # MLP
        self.mlp = MLP(config, device=device)

    def forward(
            self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True,
    ):
        # hidden_states: [s, b, h]

        # Layer norm at the beginning of the transformer layer.
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
        attention_output, kv_cache = self.self_attention(
            layernorm_output,
            attention_mask,
            rotary_pos_emb,
            kv_cache=kv_cache,
            use_cache=use_cache
        )

        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = hidden_states

        layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
        layernorm_input = residual + layernorm_input

        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

        # MLP.
        mlp_output = self.mlp(layernorm_output)

        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = layernorm_input

        output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
        output = residual + output

        return output, kv_cache


class GLMTransformer(torch.nn.Module):
    """Transformer class."""

    def __init__(self, config: ChatGLMConfig, device=None):
        super(GLMTransformer, self).__init__()

        self.fp32_residual_connection = config.fp32_residual_connection
        self.post_layer_norm = config.post_layer_norm

        # Number of layers.
        self.num_layers = config.num_layers

        # Transformer layers.
        def build_layer(layer_number):
            return GLMBlock(config, layer_number, device=device)

        self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])

        if self.post_layer_norm:
            LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
            # Final layer norm before output.
            self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
                                                 dtype=config.torch_dtype)

        self.gradient_checkpointing = False

    def _get_layer(self, layer_number):
        return self.layers[layer_number]

    def forward(
            self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None,
            use_cache: Optional[bool] = True,
            output_hidden_states: Optional[bool] = False,
    ):
        if not kv_caches:
            kv_caches = [None for _ in range(self.num_layers)]
        presents = () if use_cache else None
        if self.gradient_checkpointing and self.training:
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                )
                use_cache = False

        all_self_attentions = None
        all_hidden_states = () if output_hidden_states else None
        for index in range(self.num_layers):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer = self._get_layer(index)
            if self.gradient_checkpointing and self.training:
                layer_ret = torch.utils.checkpoint.checkpoint(
                    layer,
                    hidden_states,
                    attention_mask,
                    rotary_pos_emb,
                    kv_caches[index],
                    use_cache
                )
            else:
                layer_ret = layer(
                    hidden_states,
                    attention_mask,
                    rotary_pos_emb,
                    kv_cache=kv_caches[index],
                    use_cache=use_cache
                )
            hidden_states, kv_cache = layer_ret
            if use_cache:
                presents = presents + (kv_cache,)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        # Final layer norm.
        if self.post_layer_norm:
            hidden_states = self.final_layernorm(hidden_states)

        return hidden_states, presents, all_hidden_states, all_self_attentions


class ChatGLMPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and
    a simple interface for downloading and loading pretrained models.
    """

    is_parallelizable = False
    supports_gradient_checkpointing = True
    config_class = ChatGLMConfig
    base_model_prefix = "transformer"
    _no_split_modules = ["GLMBlock"]

    def _init_weights(self, module: nn.Module):
        """Initialize the weights."""
        return

    def get_masks(self, input_ids, past_key_values, padding_mask=None):
        batch_size, seq_length = input_ids.shape
        full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
        full_attention_mask.tril_()
        past_length = 0
        if past_key_values:
            past_length = past_key_values[0][0].shape[0]
        if past_length:
            full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
                                                        device=input_ids.device), full_attention_mask), dim=-1)
        if padding_mask is not None:
            full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
        if not past_length and padding_mask is not None:
            full_attention_mask -= padding_mask.unsqueeze(-1) - 1
        full_attention_mask = (full_attention_mask < 0.5).bool()
        full_attention_mask.unsqueeze_(1)
        return full_attention_mask

    def get_position_ids(self, input_ids, device):
        batch_size, seq_length = input_ids.shape
        position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
        return position_ids

    def _set_gradient_checkpointing(self, module, value=False):
        if isinstance(module, GLMTransformer):
            module.gradient_checkpointing = value


class Embedding(torch.nn.Module):
    """Language model embeddings."""

    def __init__(self, config: ChatGLMConfig, device=None):
        super(Embedding, self).__init__()

        self.hidden_size = config.hidden_size
        # Word embeddings (parallel).
        self.word_embeddings = nn.Embedding(
            config.padded_vocab_size,
            self.hidden_size,
            dtype=config.torch_dtype,
            device=device
        )
        self.fp32_residual_connection = config.fp32_residual_connection

    def forward(self, input_ids):
        # Embeddings.
        words_embeddings = self.word_embeddings(input_ids)
        embeddings = words_embeddings
        # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
        embeddings = embeddings.transpose(0, 1).contiguous()
        # If the input flag for fp32 residual connection is set, convert for float.
        if self.fp32_residual_connection:
            embeddings = embeddings.float()
        return embeddings


class ChatGLMModel(ChatGLMPreTrainedModel):
    def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
        super().__init__(config)
        if empty_init:
            init_method = skip_init
        else:
            init_method = default_init
        init_kwargs = {}
        if device is not None:
            init_kwargs["device"] = device
        self.embedding = init_method(Embedding, config, **init_kwargs)
        self.num_layers = config.num_layers
        self.multi_query_group_num = config.multi_query_group_num
        self.kv_channels = config.kv_channels

        # Rotary positional embeddings
        self.seq_length = config.seq_length
        rotary_dim = (
            config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
        )

        self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
                                              dtype=config.torch_dtype)
        self.encoder = init_method(GLMTransformer, config, **init_kwargs)
        self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
                                        dtype=config.torch_dtype, **init_kwargs)
        self.pre_seq_len = config.pre_seq_len
        self.prefix_projection = config.prefix_projection
        if self.pre_seq_len is not None:
            for param in self.parameters():
                param.requires_grad = False
            self.prefix_tokens = torch.arange(self.pre_seq_len).long()
            self.prefix_encoder = PrefixEncoder(config)
            self.dropout = torch.nn.Dropout(0.1)

    def get_input_embeddings(self):
        return self.embedding.word_embeddings

    def get_prompt(self, batch_size, device, dtype=torch.half):
        prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
        past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
        past_key_values = past_key_values.view(
            batch_size,
            self.pre_seq_len,
            self.num_layers * 2,
            self.multi_query_group_num,
            self.kv_channels
        )
        # seq_len, b, nh, hidden_size
        past_key_values = self.dropout(past_key_values)
        past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
        return past_key_values

    def forward(
            self,
            input_ids,
            position_ids: Optional[torch.Tensor] = None,
            attention_mask: Optional[torch.BoolTensor] = None,
            full_attention_mask: Optional[torch.BoolTensor] = None,
            past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
            inputs_embeds: Optional[torch.Tensor] = None,
            use_cache: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
    ):
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        batch_size, seq_length = input_ids.shape

        if inputs_embeds is None:
            inputs_embeds = self.embedding(input_ids)

        if self.pre_seq_len is not None:
            if past_key_values is None:
                past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device,
                                                  dtype=inputs_embeds.dtype)
            if attention_mask is not None:
                attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)),
                                            attention_mask], dim=-1)

        if full_attention_mask is None:
            if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
                full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)

        # Rotary positional embeddings
        rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
        if position_ids is not None:
            rotary_pos_emb = rotary_pos_emb[position_ids]
        else:
            rotary_pos_emb = rotary_pos_emb[None, :seq_length]
        rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()

        # Run encoder.
        hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
            inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
            kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
        )

        if not return_dict:
            return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)

        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=presents,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
        )

    def quantize(self, weight_bit_width: int):
        from .quantization import quantize
        quantize(self.encoder, weight_bit_width)
        return self


class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
    def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
        super().__init__(config)

        self.max_sequence_length = config.max_length
        self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
        self.config = config
        self.quantized = False

        if self.config.quantization_bit:
            self.quantize(self.config.quantization_bit, empty_init=True)

    def _update_model_kwargs_for_generation(
            self,
            outputs: ModelOutput,
            model_kwargs: Dict[str, Any],
            is_encoder_decoder: bool = False,
            standardize_cache_format: bool = False,
    ) -> Dict[str, Any]:
        # update past_key_values
        model_kwargs["past_key_values"] = self._extract_past_from_model_output(
            outputs, standardize_cache_format=standardize_cache_format
        )

        # update attention mask
        if "attention_mask" in model_kwargs:
            attention_mask = model_kwargs["attention_mask"]
            model_kwargs["attention_mask"] = torch.cat(
                [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
            )

        # update position ids
        if "position_ids" in model_kwargs:
            position_ids = model_kwargs["position_ids"]
            new_position_id = position_ids[..., -1:].clone()
            new_position_id += 1
            model_kwargs["position_ids"] = torch.cat(
                [position_ids, new_position_id], dim=-1
            )

        model_kwargs["is_first_forward"] = False
        return model_kwargs

    def prepare_inputs_for_generation(
            self,
            input_ids: torch.LongTensor,
            past_key_values: Optional[torch.Tensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.Tensor] = None,
            use_cache: Optional[bool] = None,
            is_first_forward: bool = True,
            **kwargs
    ) -> dict:
        # only last token for input_ids if past is not None
        if position_ids is None:
            position_ids = self.get_position_ids(input_ids, device=input_ids.device)
        if not is_first_forward:
            if past_key_values is not None:
                position_ids = position_ids[..., -1:]
                input_ids = input_ids[:, -1:]
        return {
            "input_ids": input_ids,
            "past_key_values": past_key_values,
            "position_ids": position_ids,
            "attention_mask": attention_mask,
            "return_last_logit": True,
            "use_cache": use_cache
        }

    def forward(
            self,
            input_ids: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.Tensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
            inputs_embeds: Optional[torch.Tensor] = None,
            labels: Optional[torch.Tensor] = None,
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
            return_last_logit: Optional[bool] = False,
    ):
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        transformer_outputs = self.transformer(
            input_ids=input_ids,
            position_ids=position_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = transformer_outputs[0]
        if return_last_logit:
            hidden_states = hidden_states[-1:]
        lm_logits = self.transformer.output_layer(hidden_states)
        lm_logits = lm_logits.transpose(0, 1).contiguous()

        loss = None
        if labels is not None:
            lm_logits = lm_logits.to(torch.float32)

            # Shift so that tokens < n predict n
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

            lm_logits = lm_logits.to(hidden_states.dtype)
            loss = loss.to(hidden_states.dtype)

        if not return_dict:
            output = (lm_logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=lm_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )

    @staticmethod
    def _reorder_cache(
            past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
    ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
        """
        This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
        [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
        beam_idx at every generation step.
        Output shares the same memory storage as `past`.
        """
        return tuple(
            (
                layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
                layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
            )
            for layer_past in past
        )

    def process_response(self, output, history):
        content = ""
        history = deepcopy(history)
        for response in output.split("<|assistant|>"):
            metadata, content = response.split("\n", maxsplit=1)
            if not metadata.strip():
                content = content.strip()
                history.append({"role": "assistant", "metadata": metadata, "content": content})
                content = content.replace("[[训练时间]]", "2023年")
            else:
                history.append({"role": "assistant", "metadata": metadata, "content": content})
                if history[0]["role"] == "system" and "tools" in history[0]:
                    content = "\n".join(content.split("\n")[1:-1])
                    def tool_call(**kwargs):
                        return kwargs
                    parameters = eval(content)
                    content = {"name": metadata.strip(), "parameters": parameters}
                else:
                    content = {"name": metadata.strip(), "content": content}
        return content, history

    @torch.inference_mode()
    def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
             max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
             **kwargs):
        if history is None:
            history = []
        if logits_processor is None:
            logits_processor = LogitsProcessorList()
        logits_processor.append(InvalidScoreLogitsProcessor())
        gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
                      "temperature": temperature, "logits_processor": logits_processor, **kwargs}
        inputs = tokenizer.build_chat_input(query, history=history, role=role)
        inputs = inputs.to(self.device)
        eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
                        tokenizer.get_command("<|observation|>")]
        outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
        outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
        response = tokenizer.decode(outputs)
        history.append({"role": role, "content": query})
        response, history = self.process_response(response, history)
        return response, history

    @torch.inference_mode()
    def stream_chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
                    past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
                    logits_processor=None, return_past_key_values=False, **kwargs):
        if history is None:
            history = []
        if logits_processor is None:
            logits_processor = LogitsProcessorList()
        logits_processor.append(InvalidScoreLogitsProcessor())
        eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
                        tokenizer.get_command("<|observation|>")]
        gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
                      "temperature": temperature, "logits_processor": logits_processor, **kwargs}
        if past_key_values is None:
            inputs = tokenizer.build_chat_input(query, history=history, role=role)
        else:
            inputs = tokenizer.build_chat_input(query, role=role)
        inputs = inputs.to(self.device)
        if past_key_values is not None:
            past_length = past_key_values[0][0].shape[0]
            if self.transformer.pre_seq_len is not None:
                past_length -= self.transformer.pre_seq_len
            inputs.position_ids += past_length
            attention_mask = inputs.attention_mask
            attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
            inputs['attention_mask'] = attention_mask
        history.append({"role": role, "content": query})
        for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
                                            eos_token_id=eos_token_id, return_past_key_values=return_past_key_values,
                                            **gen_kwargs):
            if return_past_key_values:
                outputs, past_key_values = outputs
            outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
            response = tokenizer.decode(outputs)
            if response and response[-1] != "�":
                response, new_history = self.process_response(response, history)
                if return_past_key_values:
                    yield response, new_history, past_key_values
                else:
                    yield response, new_history

    @torch.inference_mode()
    def stream_generate(
            self,
            input_ids,
            generation_config: Optional[GenerationConfig] = None,
            logits_processor: Optional[LogitsProcessorList] = None,
            stopping_criteria: Optional[StoppingCriteriaList] = None,
            prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
            return_past_key_values=False,
            **kwargs,
    ):
        batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]

        if generation_config is None:
            generation_config = self.generation_config
        generation_config = copy.deepcopy(generation_config)
        model_kwargs = generation_config.update(**kwargs)
        model_kwargs["use_cache"] = generation_config.use_cache
        bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id

        if isinstance(eos_token_id, int):
            eos_token_id = [eos_token_id]
        eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None

        has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
        if has_default_max_length and generation_config.max_new_tokens is None:
            warnings.warn(
                f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
                "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
                " recommend using `max_new_tokens` to control the maximum length of the generation.",
                UserWarning,
            )
        elif generation_config.max_new_tokens is not None:
            generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
            if not has_default_max_length:
                logger.warn(
                    f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
                    f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
                    "Please refer to the documentation for more information. "
                    "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
                    UserWarning,
                )

        if input_ids_seq_length >= generation_config.max_length:
            input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
            logger.warning(
                f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
                f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
                " increasing `max_new_tokens`."
            )

        # 2. Set generation parameters if not already defined
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()

        logits_processor = self._get_logits_processor(
            generation_config=generation_config,
            input_ids_seq_length=input_ids_seq_length,
            encoder_input_ids=input_ids,
            prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
            logits_processor=logits_processor,
        )

        stopping_criteria = self._get_stopping_criteria(
            generation_config=generation_config, stopping_criteria=stopping_criteria
        )
        logits_warper = self._get_logits_warper(generation_config)

        unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
        scores = None
        while True:
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
            # forward pass to get next token
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=False,
                output_hidden_states=False,
            )

            next_token_logits = outputs.logits[:, -1, :]

            # pre-process distribution
            next_token_scores = logits_processor(input_ids, next_token_logits)
            next_token_scores = logits_warper(input_ids, next_token_scores)

            # sample
            probs = nn.functional.softmax(next_token_scores, dim=-1)
            if generation_config.do_sample:
                next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
            else:
                next_tokens = torch.argmax(probs, dim=-1)
            # update generated ids, model inputs, and length for next step
            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
            )
            unfinished_sequences = unfinished_sequences.mul(
                next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
            )
            if return_past_key_values:
                yield input_ids, outputs.past_key_values
            else:
                yield input_ids
            # stop when each sentence is finished, or if we exceed the maximum length
            if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
                break

    def quantize(self, bits: int, empty_init=False, device=None, **kwargs):
        if bits == 0:
            return

        from .quantization import quantize

        if self.quantized:
            logger.info("Already quantized.")
            return self

        self.quantized = True

        self.config.quantization_bit = bits

        self.transformer.encoder = quantize(self.transformer.encoder, bits, empty_init=empty_init, device=device,
                                            **kwargs)
        return self


class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
    def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
        super().__init__(config)

        self.num_labels = config.num_labels
        self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)

        self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half)
        if config.classifier_dropout is not None:
            self.dropout = nn.Dropout(config.classifier_dropout)
        else:
            self.dropout = None
        self.config = config

        if self.config.quantization_bit:
            self.quantize(self.config.quantization_bit, empty_init=True)

    def forward(
            self,
            input_ids: Optional[torch.LongTensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            full_attention_mask: Optional[torch.Tensor] = None,
            past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
            inputs_embeds: Optional[torch.LongTensor] = None,
            labels: Optional[torch.LongTensor] = None,
            use_cache: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        transformer_outputs = self.transformer(
            input_ids=input_ids,
            position_ids=position_ids,
            attention_mask=attention_mask,
            full_attention_mask=full_attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = transformer_outputs[0]
        pooled_hidden_states = hidden_states[-1]
        if self.dropout is not None:
            pooled_hidden_states = self.dropout(pooled_hidden_states)
        logits = self.classifier_head(pooled_hidden_states)

        loss = None
        if labels is not None:
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze().float(), labels.squeeze())
                else:
                    loss = loss_fct(logits.float(), labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels).float(), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits.float(), labels.view(-1, self.num_labels))

        if not return_dict:
            output = (logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )

其它

量化 quantization.py

【2024-1-16】GLM-4

【2024-1-16】智谱AI举办首届技术开放日(Zhipu DevDay),正式发布新一代基座大模型GLM-4。

GLM-4 最大的亮点:

  • 多模态能力:推出了CogView3代,效果超过开源SD模型,逼近 DALLE-3。
  • All Tools能力:GLM-4能自主理解复杂指令,自由调用WebGLM搜索增强、Code Interpreter代码解释器和多模态生成能力,完成复杂任务。
  • GLMs个性化智能体定制:用户可以通过智谱清言官方网站创建属于自己的GLM智能体,无需编程基础。
  • MaaS平台和API:GLM-4登陆了Maas平台,提供API访问,支持开发者内测Assistant API。

GLM-4的整体性能相比上一代大幅提升,逼近GPT-4。它可以支持更长的上下文,具备更强的多模态能力。同时,它的推理速度更快,支持更高的并发,大大降低推理成本。

GLM-4 性能已经超过 Claude 2.1,直接逼近 GPT 4

  • GLM-4 在 MMLU(81.5)达到 GPT-4 94% 水平,GSM8K(87.6) 达到 GPT-4 95% 水平,MATH(47.9)达到 GPT-4 91% 水平,BBH (82.25) 达到 GPT-4 99% 水平,HellaSwag (85.4) 达到 GPT-4 90% 水平,HumanEval(72)达到 GPT-4 100% 水平。
  • 指令跟随方面,GLM-4 也实现了媲美 GPT-4 的水准。根据指令跟随评估基准 IFEval 的结果,GLM-4 在 Prompt 提示词跟随(中文)方面达到了 GPT-4 88% 的水平;在指令跟随(中文)方面,达到了 GPT-4 90% 的水平。
  • 公开数据集 AlignBench 的评估结果,GLM-4 超过了 GPT-4 在 6 月 13 日发布的版本,逼近 GPT-4 最新(11 月 6 日版本)效果,在专业能力、中文理解、角色扮演方面超过了最新 GPT-4 的精度,唯一有待提升的是 GLM-4 在中文推理方面的能力。

GLM-4 All Tools 实现自主根据用户意图,自动理解、规划复杂指令,自由调用网页浏览器、Code Interpreter代码解释器等以完成复杂任务。

代码解释器

网页浏览

Function Call

  • 根据⽤户提供的function描述,⾃动选择所需function并⽣成参数,以及根据function的返回值⽣成回复

CogView: 首个中文全领域文生图预训练模型

  • Diffusion模型训练时不太稳定,智谱提出了中继扩散模型, 图像质量更高

评测

VisualGLM-6B 首个中文开源多模态对话模型

CogVLM 全新视觉专家结构,显著超越其他开源模型

  • 增加 Cross-Attention高分辨率结构,认清GUI每个文字和图标

CogAgent成为开源榜首

1~2年内,多模态模型的视觉理解能力会超过人类

编码模型CodeGeeX

GLMs 个性化智能体定制功能亦同时上线,用户用简单的提示词指令就能创建属于自己的 GLM 智能体。

资料

ChatGLM 微调

【2023-6-2】 chatglm的微调有没有保姆式的教程

怎么finetuning?

  • P-tuning v2
    • 官方示例,单卡3090没问题
    • P-Tuning,设计了一种连续可微的virtual token(同Prefix-Tuning类似)
    • 清华大学发布的两种参数高效Prompt微调方法: P-Tuning、P-Tuning v2,可以简单的将P-Tuning认为是针对Prompt Tuning的改进,P-Tuning v2认为是针对Prefix Tuning的改进。
  • Full parameter
    • 全量参数finetune, 8卡3090 没跑起来
    • 训练资源: 8 x A100 40G 或者 4 x A100 80G.
  • LoRA

ChatGLM-Tuning

【2023-3-25】ChatGLM-Tuning

  • 一种平价的chatgpt实现方案,基于清华的 ChatGLM-6B + LoRA 进行finetune.
  • 数据集: alpaca

应用

金融财报问答系统

【2023-12-4】ChatGLM金融大模型挑战赛赛题总结

以ChatGLM2-6B模型为中心制作一个问答系统,回答用户的金融相关的问题,不允许使用其他的大语言模型。参赛选手可以使用其他公开访问的外部数据来微调模型,也可以使用向量数据库等技术。

优胜方案

结束


支付宝打赏 微信打赏

~ 海内存知已,天涯若比邻 ~

Share

Similar Posts

Related Posts

标题:Baichuan 百川模型系列

摘要:开源大模型百川笔记

标题:混合专家模型(MoE)专题

摘要:混合专家模型,MoE系列

Comments

--disqus--

    My Moment ( 微信公众号 )
    欢迎关注鹤啸九天