ysharma HF staff mikeee commited on
Commit
9fe4d8e
0 Parent(s):

Duplicate from mikeee/chatglm2-6b-4bit

Browse files

Co-authored-by: mikeee <[email protected]>

Files changed (4) hide show
  1. .gitattributes +35 -0
  2. README.md +14 -0
  3. app.py +188 -0
  4. requirements.txt +9 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Chatglm2 6b 4bit
3
+ emoji: 🌖
4
+ colorFrom: purple
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 3.35.2
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ duplicated_from: mikeee/chatglm2-6b-4bit
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import gradio as gr
2
+
3
+ # model_name = "models/THUDM/chatglm2-6b-int4"
4
+ # gr.load(model_name).lauch()
5
+
6
+ # %%writefile demo-4bit.py
7
+
8
+ from textwrap import dedent
9
+
10
+ # credit to https://github.com/THUDM/ChatGLM2-6B/blob/main/web_demo.py
11
+ # while mistakes are mine
12
+ from transformers import AutoModel, AutoTokenizer
13
+ import gradio as gr
14
+ import mdtex2html
15
+
16
+ from loguru import logger
17
+
18
+ model_name = "THUDM/chatglm2-6b"
19
+ model_name = "THUDM/chatglm2-6b-int4"
20
+
21
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
22
+
23
+ # model = AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda()
24
+
25
+ # 4/8 bit
26
+ # model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).quantize(4).cuda()
27
+
28
+ import torch
29
+
30
+ has_cuda = torch.cuda.is_available()
31
+ # has_cuda = False # force cpu
32
+
33
+ if has_cuda:
34
+ model = AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda() # 3.92G
35
+ else:
36
+ model = AutoModel.from_pretrained(model_name, trust_remote_code=True).half() # .float() .half().float()
37
+
38
+ model = model.eval()
39
+
40
+ _ = """Override Chatbot.postprocess"""
41
+
42
+ def postprocess(self, y):
43
+ if y is None:
44
+ return []
45
+ for i, (message, response) in enumerate(y):
46
+ y[i] = (
47
+ None if message is None else mdtex2html.convert((message)),
48
+ None if response is None else mdtex2html.convert(response),
49
+ )
50
+ return y
51
+
52
+
53
+ gr.Chatbot.postprocess = postprocess
54
+
55
+
56
+ def parse_text(text):
57
+ """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
58
+ lines = text.split("\n")
59
+ lines = [line for line in lines if line != ""]
60
+ count = 0
61
+ for i, line in enumerate(lines):
62
+ if "```" in line:
63
+ count += 1
64
+ items = line.split('`')
65
+ if count % 2 == 1:
66
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
67
+ else:
68
+ lines[i] = f'<br></code></pre>'
69
+ else:
70
+ if i > 0:
71
+ if count % 2 == 1:
72
+ line = line.replace("`", "\`")
73
+ line = line.replace("<", "&lt;")
74
+ line = line.replace(">", "&gt;")
75
+ line = line.replace(" ", "&nbsp;")
76
+ line = line.replace("*", "&ast;")
77
+ line = line.replace("_", "&lowbar;")
78
+ line = line.replace("-", "&#45;")
79
+ line = line.replace(".", "&#46;")
80
+ line = line.replace("!", "&#33;")
81
+ line = line.replace("(", "&#40;")
82
+ line = line.replace(")", "&#41;")
83
+ line = line.replace("$", "&#36;")
84
+ lines[i] = "<br>"+line
85
+ text = "".join(lines)
86
+ return text
87
+
88
+
89
+ def predict(input, chatbot, max_length, top_p, temperature, history, past_key_values):
90
+ chatbot.append((parse_text(input), ""))
91
+ for response, history, past_key_values in model.stream_chat(tokenizer, input, history, past_key_values=past_key_values,
92
+ return_past_key_values=True,
93
+ max_length=max_length, top_p=top_p,
94
+ temperature=temperature):
95
+ chatbot[-1] = (parse_text(input), parse_text(response))
96
+
97
+ yield chatbot, history, past_key_values
98
+
99
+
100
+ def trans_api(input, max_length=4096, top_p=0.8, temperature=0.2):
101
+ if max_length < 100:
102
+ max_length = 4096
103
+ if top_p < 0.1:
104
+ top_p = 0.8
105
+ if temperature <= 0:
106
+ temperature = 0.01
107
+ try:
108
+ res, _ = model.chat(
109
+ tokenizer,
110
+ input,
111
+ history=[],
112
+ past_key_values=None,
113
+ max_length=max_length,
114
+ top_p=top_p,
115
+ temperature=temperature,
116
+ )
117
+ # logger.debug(f"{res=} \n{_=}")
118
+ except Exception as exc:
119
+ logger.error(f"{exc=}")
120
+ res = str(exc)
121
+
122
+ return res
123
+
124
+
125
+ def reset_user_input():
126
+ return gr.update(value='')
127
+
128
+
129
+ def reset_state():
130
+ return [], [], None
131
+
132
+
133
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
134
+ gr.HTML("""<h1 align="center">ChatGLM2-6B-int4</h1>""")
135
+ with gr.Accordion("Info", open=False):
136
+ _ = """
137
+ A query takes from 30 seconds to a few tens of seconds, dependent on the number of words/characters
138
+ the question and answer contain.
139
+
140
+ * Low temperature: responses will be more deterministic and focused; High temperature: responses more creative.
141
+
142
+ * Suggested temperatures -- translation: up to 0.3; chatting: > 0.4
143
+
144
+ * Top P controls dynamic vocabulary selection based on context.
145
+
146
+ For a table of example values for different scenarios, refer to [this](https://community.openai.com/t/cheat-sheet-mastering-temperature-and-top-p-in-chatgpt-api-a-few-tips-and-tricks-on-controlling-the-creativity-deterministic-output-of-prompt-responses/172683)
147
+
148
+ If the instance is not on a GPU (T4), it will be very slow. You can try to run the colab notebook [chatglm2-6b-4bit colab notebook](https://colab.research.google.com/drive/1WkF7kOjVCcBBatDHjaGkuJHnPdMWNtbW?usp=sharing) for a spin.
149
+
150
+ The T4 GPU is sponsored by a community GPU grant from Huggingface. Thanks a lot!
151
+ """
152
+ gr.Markdown(dedent(_))
153
+ chatbot = gr.Chatbot()
154
+ with gr.Row():
155
+ with gr.Column(scale=4):
156
+ with gr.Column(scale=12):
157
+ user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
158
+ container=False)
159
+ with gr.Column(min_width=32, scale=1):
160
+ submitBtn = gr.Button("Submit", variant="primary")
161
+ with gr.Column(scale=1):
162
+ emptyBtn = gr.Button("Clear History")
163
+ max_length = gr.Slider(0, 32768, value=8192/2, step=1.0, label="Maximum length", interactive=True)
164
+ top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
165
+ temperature = gr.Slider(0.01, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
166
+
167
+ history = gr.State([])
168
+ past_key_values = gr.State(None)
169
+
170
+ user_input.submit(predict, [user_input, chatbot, max_length, top_p, temperature, history, past_key_values],
171
+ [chatbot, history, past_key_values], show_progress=True)
172
+ submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history, past_key_values],
173
+ [chatbot, history, past_key_values], show_progress=True, api_name="predict")
174
+ submitBtn.click(reset_user_input, [], [user_input])
175
+
176
+ emptyBtn.click(reset_state, outputs=[chatbot, history, past_key_values], show_progress=True)
177
+
178
+ with gr.Accordion("For Translation API", open=False):
179
+ input_text = gr.Text()
180
+ tr_btn = gr.Button("Go", variant="primary")
181
+ out_text = gr.Text()
182
+ tr_btn.click(trans_api, [input_text, max_length, top_p, temperature], out_text, show_progress=True, api_name="tr")
183
+ input_text.submit(trans_api, [input_text, max_length, top_p, temperature], out_text, show_progress=True, api_name="tr")
184
+
185
+ # demo.queue().launch(share=False, inbrowser=True)
186
+ # demo.queue().launch(share=True, inbrowser=True, debug=True)
187
+
188
+ demo.queue().launch(debug=True)
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ protobuf
2
+ transformers==4.30.2
3
+ cpm_kernels
4
+ torch>=2.0
5
+ # gradio
6
+ mdtex2html
7
+ sentencepiece
8
+ accelerate
9
+ loguru