kplug / demo_sum.py
xu-song's picture
rebuild
4a3c603
# coding=utf-8
# author: xusong <[email protected]>
# time: 2022/8/23 12:58
"""
## TODO:
1. 下拉框,选择类目。 gr.Radio(['服饰','箱包', '鞋靴']
2. 支持输入特效
- 示例:https://huggingface.co./uer/gpt2-chinese-lyric
- 参考 https://github.com/huggingface/hub-docs/blob/main/js/src/lib/components/InferenceWidget/shared/WidgetTextarea/WidgetTextarea.svelte
3. 待开放参数:No Repeat Ngram Size、Length Penalty、Number of Beams。topk-sampling, topp-sampling,
num_beam_groups = return_sequences数吗?
## badcase:
1. 结尾容易出多个句号。为啥?
2. 重复
## 解码demo (能够调整解码参数的demo)
- https://huggingface.co./spaces/THUDM/GLM-130B
## 解码参数示例
**greedy策略**
**sample策略**
- moss: do_sample=True, temperature=0.7, top_p=0.8, top_k=40, repetition_penalty=1.02
- chatglm:do_sample=True, temperature=0.95, top_p=0.7, max_length=2048
- chatglm2:do_sample=True, top_p=0.8, temperature=0.8 https://huggingface.co./THUDM/chatglm2-6b/blob/main/modeling_chatglm.py#L1023
- glm130b:
- temperature=1, top_p=0.7, top_k=0, no_repeat_ngram_size=3, length_penalty=1, num_beams=2
- vicuna: do_sample=True, temperature=0.7, top_p=1, top_k=-1, repetition_penalty=1
- chatgpt
- baichuan-chat: do_sample=True, temperature=0.3, top_p=0.85, top_k=5, repetition_penalty=1.05 https://huggingface.co./baichuan-inc/Baichuan2-13B-Chat/blob/main/generation_config.json
- internlm-chat: do_sample=True, temperature=0.8, top_p=0.8 https://huggingface.co./internlm/internlm-chat-7b-v1_1/blob/main/modeling_internlm.py#L783
- 解决重复问题,需要添加 repetition_penalty=1.05 https://github.com/InternLM/InternLM/issues/28
- llama2-chat: top_p=0.6, temperature=0.9
- qwen: top_p=0.8, top_k= 0, repetition_penalty=1.1 https://huggingface.co./Qwen/Qwen-7B-Chat/blob/main/generation_config.json
- gpt4:
- temperature=1, top_p=1,
https://platform.openai.com/docs/api-reference/chat/object
- claude:
-
**beam_search策略**
- tensor2tensor:
- opennmt:
- transformers:
- asr: num_beams=5, max_length=200 https://github.com/huggingface/transformers/blob/main/examples/pytorch/speech-recognition/README.md
- wmt:
- "num_beams=5:10:15 length_penalty=0.6:0.7:0.8:0.9:1.0:1.1" https://github.com/huggingface/transformers/blob/main/scripts/fsmt/eval-allenai-wmt16.sh
- num_beams=5 length_penalty=0.8:1.2 early_stopping=true:false https://github.com/huggingface/transformers/tree/main/examples/legacy/seq2seq
- tensor2tensor: beam_size=4, alpha=0.6 https://github.com/tensorflow/tensor2tensor/tree/master#walkthrough
## 解码参数
- generate官方文档:
- https://huggingface.co./blog/how-to-generate
- https://github.com/huggingface/transformers/blob/main/docs/source/en/generation_strategies.md
- https://github.com/huggingface/transformers/blob/main/src/transformers/generation/configuration_utils.py
- generate 解码策略介绍:
-
- 去重
- no_repeat_ngram_size
- 源码: [NoRepeatNGramLogitsProcessor](https://github.com/huggingface/transformers/blob/v4.29.2/src/transformers/generation/logits_process.py#L490)
- 逻辑:
- 取值: 默认0, If set to int > 0, all ngrams of that size can only occur once
no_repeat_ngram_size=6 即代表: 6-gram不出现2次
- 兼容:与greedy、sampling、beam_search 兼容
- 缺陷:
- 这个可能把GPT的输入都算进去了。比如商品文案写作场景,输入"雅诗兰黛小棕瓶",加入no_repeat_ngram_size参数可能就不能输出"雅诗兰黛小棕瓶"了
- "only occur once", 需要一个参数 调整成最大允许次数
- encoder_no_repeat_ngram_size
- 源码:[EncoderNoRepeatNGramLogitsProcessor](https://github.com/huggingface/transformers/blob/v4.29.2/src/transformers/generation/logits_process.py#L525)
- 逻辑:
- 兼容:与greedy、sampling、beam_search 兼容
- repetition_penalty:
- 源码:[RepetitionPenaltyLogitsProcessor](https://github.com/huggingface/transformers/blob/v4.29.2/src/transformers/generation/logits_process.py#L206)
- 逻辑:对input_ids 做去重逻辑,其中 input_ids 是随着解码动态变化的。对于 logits>0 会 logits/=penalty,才叫惩罚。
类似 coverage mechanism
- 取值:取值范围(0, inf),>1 才叫惩罚,<1 就叫奖励了,=1 就是 no penalty。论文里说 1.2 能够balance truthful generation and lack of repetition.
- 公式:
- 默认 p=softmax(logits)
- 加 temperature后 p=softmax(logits/T)
- 加 repetition_penalty Θ 后 p=softmax(logits/(T* (Θ if i∈g else 1) ) ,其中 i∈g 表示已经生成过的 token
- 缺陷:未考虑重复次数,也就是 重复2次和重复100次的惩罚是一样的。
- 兼容:与greedy、sampling、beam_search 兼容
- encoder_repetition_penalty
- 源码:[EncoderRepetitionPenaltyLogitsProcessor](https://github.com/huggingface/transformers/blob/v4.29.2/src/transformers/generation/logits_process.py#L228)
- 逻辑:只对 self.encoder_input_ids 做去重逻辑,self.encoder_input_ids 是静态的
- 冲突:
- 多样性:
- do_sample
- temperature:
- 取值范围(0, inf),大于1 则会平均化(inf则相当于均匀采样,更多样化),小于1则会集中化(逼近0则相当于greedy)
- 理解:温度越高,系统越混乱,熵越大(概率越平均化,不确定性越大,生成文本的自由创作空间越大)。温度越低,生成的文本越偏保守。
- 公式: p=softmax(logits) , 加 temperature后 p=softmax(logits/T)
- Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
- diversity_penalty
-
- 长度
- max_length
- max_new_tokens
- min_length
- min_new_tokens
- early_stopping
- max_tim
- length_penalty: 长度惩罚因子
- 取值(-inf, inf),大于0会生成更长的序列,小于0会生成更短的序列。默认值=1.0。
- 应用场景:仅用于 beam search。(sampling策略建议也加上)
- 公式: score = sum_logprobs / (generated_len**self.length_penalty) 即:长度越长,当前生成序列(路径)的得分越低。
- 源码:https://github.com/huggingface/transformers/blob/v4.43.2/src/transformers/generation/beam_search.py#L965
- 参考文档:https://cloud.tencent.com/developer/article/2295947
- exponential_decay_length_penalty
- ss
- 公式
- 源码:
- 截止符
- eos_token_id
- 禁用词黑名单
- 源码:[NoBadWordsLogitsProcessor](https://github.com/huggingface/transformers/blob/v4.29.2/src/transformers/generation/logits_process.py#L590)
- bad_words_ids
- suppress_tokens
- 强制解码词
- force_words_ids
- constraints
- 其他参数
- top_p: only used in sample-based generation
- 又称Nucleus Sampling
- 每个时间步,按照字出现的概率由高到底排序,当概率之和大于top-p的时候,就不取后面的样本了。然后对取到的这些字的概率重新归一化后,进行采样。
- 取值范围:0-1
- 0表示?
- top_k: only used in sample-based generation
- 取值范围:
- top-P采样方法往往与top-K采样方法结合使用,每次选取两者中最小的采样范围进行采样,可以减少预测分布过于平缓时采样到极小概率单词的几率。
## TODO:
- counted_repetition_penalty: 解决 repetition_penalty 不考虑重复次数的问题,重复越多惩罚越大
- no_repeat_ngram_size:
- {"ngram": 3, "max_repeat": 1, "ignore_prefix": False}
"max_allowed_repetition":
"""
import torch
import gradio as gr
from info import article
from kplug import modeling_kplug_s2s_patch
from transformers import BertTokenizer, BartForConditionalGeneration
model = BartForConditionalGeneration.from_pretrained("eson/kplug-base-cepsum-jiadian") # cnn指的是cnn daily mail
tokenizer = BertTokenizer.from_pretrained("eson/kplug-base-cepsum-jiadian")
"""
解码策略
https://zhuanlan.zhihu.com/p/267471193
https://github.com/huggingface/transformers/blob/v4.29.2/src/transformers/generation/utils.py#L473
"""
gen_mode_params = {
"greedy": {
"num_beams": 1,
"do_sample": False,
},
# 核心:next_tokens = torch.multinomial(next_token_probs, num_samples=1)
"sampling": {
"num_beams": 1,
"do_sample": True,
"repetition_penalty": 1.2
# temperature # 大于1 则会平均化(inf则相当于均匀采样,更多样化),小于1则会集中化(0则相当于greedy)
# top_p
# top_k
},
# TODO:
# typical sampling:
# https://github.com/huggingface/transformers/blob/v4.29.2/src/transformers/generation/logits_process.py#L332
# TODO:
# Truncation Sampling: EtaLogitsWarper、EpsilonLogitsWarper
# https://github.com/huggingface/transformers/blob/v4.29.2/src/transformers/generation/logits_process.py#L387
"beam search": {
"num_beams": 10,
"do_sample": False,
},
# 算法? 复杂度?
"contrastive search": {
"top_k": 4,
"penalty_alpha": 0.2,
},
# 算法? 复杂度?
# 网格波束搜索(Hokamp和Liu,2017)和约束波束搜索(Anderson等,2017) https://blog.csdn.net/qq_36533552/article/details/106317720
"diverse beam search": {
"num_beams": 5,
"num_beam_groups": 5,
"num_return_sequences": 5,
"diversity_penalty": 1.0,
}
}
all_decoding_strategys = list(gen_mode_params.keys())
def summarize(text, prefix_text, constrained_text, decoding_strategys):
"""
prefix_text: 能叫 prompt吗?
constrained_text: 受限解码效果怎么这么差.
gen_modes: Search Strategy、Decoding strategy、
"""
# bad_words_ids num_return_sequences=1, no_repeat_ngram_size=1, remove_invalid_values=True,
common_params = {"min_length": 20, "max_length": 100}
inputs = tokenizer([text], max_length=512, return_tensors="pt")
# prompt_text = GPT2里的参数. 这里是 decoder_input_ids。 shape=(batch_size, n)
if prefix_text:
decoder_input_ids = tokenizer([prefix_text], max_length=30, return_tensors="pt")
# decoder_input_ids = tokenizer(["采用优质的"], max_length=30, return_tensors="pt")
decoder_input_ids = decoder_input_ids.input_ids[:, :-1]
decoder_input_ids[:, 0] = model.config.decoder_start_token_id
common_params["decoder_input_ids"] = decoder_input_ids
#
if constrained_text:
common_params["force_words_ids"] = tokenizer(
[constrained_text], add_special_tokens=False, max_length=30).input_ids
result = {}
print(decoding_strategys)
for strategy in decoding_strategys:
if constrained_text and strategy in ["greedy", "sampling", "diverse beam search"]:
# `num_beams` needs to be greater than 1 for constrained generation.
# `num_beam_groups` not supported yet for constrained generation.
result[strategy] = "不支持 constrained text"
continue
summary_ids = model.generate(inputs["input_ids"][:, 1:], **common_params, **gen_mode_params[strategy])
summary = tokenizer.batch_decode(summary_ids, skip_special_tokens=True,
clean_up_tokenization_spaces=False)
print(strategy, summary)
result[strategy] = summary
return result
# return pd.DataFrame([result])
sum_examples = [
[
"美的对开门风冷无霜家用智能电冰箱波光金纤薄机身高颜值助力保鲜,美的家居风,尺寸说明:M以上的距离尤其是左右两侧距离必须保证。关于尺寸的更多问题可,LED冷光源,纤薄机身,风冷无霜,智能操控,远程调温,节能静音,照亮你的视野,535L大容量,系统散热和使用的便利性,建议左右两侧、顶部和背部需要预留10C,电源线和调平脚等。冰箱放置时为保证,菜谱推荐,半开门俯视图,全开门俯视图,预留参考图",
"", "", all_decoding_strategys],
[
"美的对开门风冷无霜家用智能电冰箱波光金纤薄机身高颜值助力保鲜,美的家居风,尺寸说明:M以上的距离尤其是左右两侧距离必须保证。关于尺寸的更多问题可,LED冷光源,纤薄机身,风冷无霜,智能操控,远程调温,节能静音,照亮你的视野,535L大容量,系统散热和使用的便利性,建议左右两侧、顶部和背部需要预留10C,电源线和调平脚等。冰箱放置时为保证,菜谱推荐,半开门俯视图,全开门俯视图,预留参考图",
"智能", "", all_decoding_strategys],
[
"美的对开门风冷无霜家用智能电冰箱波光金纤薄机身高颜值助力保鲜,美的家居风,尺寸说明:M以上的距离尤其是左右两侧距离必须保证。关于尺寸的更多问题可,LED冷光源,纤薄机身,风冷无霜,智能操控,远程调温,节能静音,照亮你的视野,535L大容量,系统散热和使用的便利性,建议左右两侧、顶部和背部需要预留10C,电源线和调平脚等。冰箱放置时为保证,菜谱推荐,半开门俯视图,全开门俯视图,预留参考图",
"", "风冷无霜", all_decoding_strategys],
[
"爱家乐新加坡电风扇静音无叶风扇健康空气循环扇儿童球形风扇落地扇外观,宁静节能,产品结构,现代科技的结晶,品质,气家,未来风新时代,动里,空让,健康,低至13分贝/DC直流马达/低耗24,亲密玩伴,24W功率,/低耗,别加坡国民品牌,气流通道,增强室内空气运动,过尘栅网,1-12档风力调速,涡轮风扇,吸气口,大于6米随心掌控,电源适配暑,装箱明细,摆头角度,手动摇摆轨道,操作方式,与空调同时使用不仅可以让室温快速均衡作,电源插口,适用环境,还可以在短时间内,导引出风口,产品类型,快件重量,电机,暖空气向上冷空气向下,线长,使房间温度均衡,省电环保,定时,功率,将凉风或热风送给到附近的房间,轻松享受生活,左右自动(上下手动)摇摆9度,进风口,能够很快中和空气温度差",
"", "", all_decoding_strategys],
[
"海尔8公斤节能静音高温消毒烫烫净全自动滚筒洗衣机靠实力说话,一掌控时间掌控自由,i-time智能时间洗,8公斤容量全家衣物一次清洗,细节绝不含糊,真正实力派,自动添加洗衣盒,洗羽绒服,就要专属程序,羊毛,牛仔,习绒,海尔洗衣机蓝晶系列滚筒,个性范儿,按照程序需求自动冲入洗衣机内,灵活旋钮,创新下排水洁净不残留,强力筋内筒,AMT防霉窗垫,LED大屏显示,洗衣液,消毒剂分别置放在洗衣盒中,从根本上解决污水残留问题避免,全新LD面板显示,更宽阔更大气操作信息一目了然,宽阔大气操作信息一目了然,右槽:消毒剂,简化洗衣程序,弹力筋中间的凹槽内分布,无残留排水模块,海尔洗衣机具有专业级羽绒洗护程序,为羽绒服营造洗护,一体化环境彻底告别手洗或者机洗,左槽:洗涤剂,我的智慧生活,中槽:柔顺剂,满足各种洗涤需求,告别昂贵洗衣店,自家",
"", "", all_decoding_strategys],
]
sum_iface = gr.Interface(
fn=summarize,
inputs=[
gr.Textbox(
label="商品信息(Product Info)",
value="美的对开门风冷无霜家用智能电冰箱波光金纤薄机身高颜值助力保鲜,美的家居风,尺寸说明:"
"M以上的距离尤其是左右两侧距离必须保证。关于尺寸的更多问题可,LED冷光源,纤薄机身,风冷"
"无霜,智能操控,远程调温,节能静音,照亮你的视野,535L大容量,系统散热和使用的便利性,"
"建议左右两侧、顶部和背部需要预留10C,电源线和调平脚等。冰箱放置时为保证,菜谱推荐,半开"
"门俯视图,全开门俯视图,预留参考图"),
gr.Textbox(
"",
label="前缀词(Prefix Text)"
),
gr.Textbox(
"",
label="限定词(Constrained Text)"
),
gr.Checkboxgroup(
all_decoding_strategys, value=all_decoding_strategys[0:1],
label="解码策略(Decoding Strategy)"
),
],
# outputs=gr.Textbox(
# label="文本摘要(Summarization)",
# lines=4,
# ),
# outputs=gr.DataFrame(
# label="文本摘要(Summarization)",
# ),
outputs=gr.JSON( # TODO:去掉json array的数字标号
label="文本摘要(Summarization)",
),
examples=sum_examples,
title="生成式摘要(Abstractive Summarization)",
description='生成式摘要,用于电商领域的商品营销文案写作。输入商品信息,输出商品的营销文案。',
article=article
)
if __name__ == "__main__":
sum_iface.launch()