Sentdex commited on
Commit
076e8b6
β€’
1 Parent(s): ae2aac3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -0
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import transformers
3
+ from torch import bfloat16
4
+ # from dotenv import load_dotenv # if you wanted to adapt this for a repo that uses auth
5
+ from threading import Thread
6
+
7
+
8
+ #HF_AUTH = os.getenv('HF_AUTH')
9
+ model_id = "stabilityai/StableBeluga-7B"
10
+
11
+ bnb_config = transformers.BitsAndBytesConfig(
12
+ load_in_4bit=True,
13
+ bnb_4bit_quant_type='nf4',
14
+ bnb_4bit_use_double_quant=True,
15
+ bnb_4bit_compute_dtype=bfloat16
16
+ )
17
+ model_config = transformers.AutoConfig.from_pretrained(
18
+ model_id,
19
+ #use_auth_token=HF_AUTH
20
+ )
21
+
22
+ model = transformers.AutoModelForCausalLM.from_pretrained(
23
+ model_id,
24
+ trust_remote_code=True,
25
+ config=model_config,
26
+ quantization_config=bnb_config,
27
+ device_map='auto',
28
+ #use_auth_token=HF_AUTH
29
+ )
30
+
31
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
32
+ model_id,
33
+ #use_auth_token=HF_AUTH
34
+ )
35
+
36
+
37
+ def prompt_build(system_prompt, user_inp, hist):
38
+ prompt = f"""### System:\n{system_prompt}\n\n"""
39
+
40
+ for pair in hist:
41
+ prompt += f"""### User:\n{pair[0]}\n\n### Assistant:\n{pair[1]}\n\n"""
42
+
43
+ prompt += f"""### User:\n{user_inp}\n\n### Assistant:"""
44
+ return prompt
45
+
46
+ def chat(user_input, history, system_prompt):
47
+
48
+ prompt = prompt_build(system_prompt, user_input, history)
49
+ model_inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
50
+
51
+ streamer = transformers.TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
52
+
53
+ generate_kwargs = dict(
54
+ model_inputs,
55
+ streamer=streamer,
56
+ max_new_tokens=2048,
57
+ do_sample=True,
58
+ top_p=0.95,
59
+ temperature=0.8,
60
+ top_k=50
61
+ )
62
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
63
+ t.start()
64
+
65
+ model_output = ""
66
+ for new_text in streamer:
67
+ model_output += new_text
68
+ yield model_output
69
+ return model_output
70
+
71
+
72
+ with gr.Blocks() as demo:
73
+ system_prompt = gr.Textbox("You are helpful AI.", label="System Prompt")
74
+ chatbot = gr.ChatInterface(fn=chat, additional_inputs=[system_prompt])
75
+
76
+ demo.queue().launch()