haidlir commited on
Commit
b448ace
1 Parent(s): bdc017f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -1
app.py CHANGED
@@ -1,3 +1,46 @@
1
  import gradio as gr
 
 
 
2
 
3
- gr.load("models/haidlir/bloom-chatml-id").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList
4
+ from threading import Thread
5
 
6
+ tokenizer = AutoTokenizer.from_pretrained("haidlir/bloom-chatml-id")
7
+ model = AutoModelForCausalLM.from_pretrained("haidlir/bloom-chatml-id")
8
+
9
+ def predict(message, history):
10
+
11
+ history_chatml_format = []
12
+ for human, assistant in history:
13
+ history_chatml_format.append({"role": "user", "content": human })
14
+ history_chatml_format.append({"role": "assistant", "content":assistant})
15
+ history_chatml_format.append({"role": "user", "content": message})
16
+
17
+ model_inputs = chat_tokenizer.apply_chat_template(
18
+ history_chatml_format,
19
+ tokenize=True,
20
+ add_generation_prompt=True,
21
+ return_tensors="pt",
22
+ )
23
+
24
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
25
+ generate_kwargs = dict(
26
+ model_inputs,
27
+ streamer=streamer,
28
+ max_new_tokens=1024,
29
+ do_sample=True,
30
+ top_p=0.95,
31
+ top_k=1000,
32
+ temperature=1.0,
33
+ num_beams=1,
34
+ stopping_criteria=StoppingCriteriaList([stop])
35
+ )
36
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
37
+ t.start()
38
+
39
+ partial_message = ""
40
+ for new_token in streamer:
41
+ if new_token != '<':
42
+ partial_message += new_token
43
+ yield partial_message
44
+
45
+
46
+ gr.ChatInterface(predict).launch()