O S I H commited on
Commit
8ac3910
1 Parent(s): 2600d7f

upload files

Browse files
Files changed (2) hide show
  1. app.py +115 -0
  2. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import (
4
+ AutoModelForCausalLM,
5
+ AutoTokenizer,
6
+ TextIteratorStreamer,
7
+ )
8
+ import os
9
+ from threading import Thread
10
+ import spaces
11
+ import time
12
+ import subprocess
13
+
14
+ subprocess.run(
15
+ "pip install flash-attn --no-build-isolation",
16
+ env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
17
+ shell=True,
18
+ )
19
+
20
+
21
+ model = AutoModelForCausalLM.from_pretrained(
22
+ "microsoft/Phi-3-small-128k-instruct",
23
+ torch_dtype="auto",
24
+ trust_remote_code=True,
25
+ )
26
+ tok = AutoTokenizer.from_pretrained("microsoft/Phi-3-small-128k-instruct",trust_remote_code=True,)
27
+ terminators = [
28
+ tok.eos_token_id,
29
+ ]
30
+
31
+ if torch.cuda.is_available():
32
+ device = torch.device("cuda")
33
+ print(f"Using GPU: {torch.cuda.get_device_name(device)}")
34
+ else:
35
+ device = torch.device("cpu")
36
+ print("Using CPU")
37
+
38
+ model = model.to(device)
39
+ # Dispatch Errors
40
+
41
+
42
+ @spaces.GPU(duration=60)
43
+ def chat(message, history,system_prompt, temperature, do_sample, max_tokens, top_k, repetition_penalty, top_p):
44
+ chat = [
45
+ {"role": "assistant", "content": system_prompt}
46
+ ]
47
+ for item in history:
48
+ chat.append({"role": "user", "content": item[0]})
49
+ if item[1] is not None:
50
+ chat.append({"role": "assistant", "content": item[1]})
51
+ chat.append({"role": "user", "content": message})
52
+ messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
53
+ model_inputs = tok([messages], return_tensors="pt").to(device)
54
+ streamer = TextIteratorStreamer(
55
+ tok, timeout=20.0, skip_prompt=True, skip_special_tokens=True
56
+ )
57
+ generate_kwargs = dict(
58
+ model_inputs,
59
+ streamer=streamer,
60
+ max_new_tokens=max_tokens,
61
+ do_sample=True,
62
+ temperature=temperature,
63
+ eos_token_id=terminators,
64
+ top_k=top_k,
65
+ repetition_penalty=repetition_penalty,
66
+ top_p=top_p
67
+ )
68
+
69
+ if temperature == 0:
70
+ generate_kwargs["do_sample"] = False
71
+
72
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
73
+ t.start()
74
+
75
+ partial_text = ""
76
+ for new_text in streamer:
77
+ partial_text += new_text
78
+ yield partial_text
79
+
80
+ yield partial_text
81
+
82
+
83
+ demo = gr.ChatInterface(
84
+ fn=chat,
85
+ examples=[["Write me a poem about Machine Learning."],
86
+ ["write fibonacci sequence in python"],
87
+ ["who won the world cup in 2018?"],
88
+ ["when was the first computer invented?"],
89
+ ],
90
+ additional_inputs_accordion=gr.Accordion(
91
+ label="⚙️ Parameters", open=False, render=False
92
+ ),
93
+ additional_inputs=[
94
+ gr.Textbox("Perform the task to the best of your ability.", label="System prompt"),
95
+ gr.Slider(
96
+ minimum=0, maximum=1, step=0.1, value=0.9, label="Temperature", render=False
97
+ ),
98
+ gr.Checkbox(label="Sampling", value=True),
99
+ gr.Slider(
100
+ minimum=128,
101
+ maximum=4096,
102
+ step=1,
103
+ value=512,
104
+ label="Max new tokens",
105
+ render=False,
106
+ ),
107
+ gr.Slider(1, 80, 40, label="Top K sampling"),
108
+ gr.Slider(0, 2, 1.1, label="Repetition penalty"),
109
+ gr.Slider(0, 1, 0.95, label="Top P sampling"),
110
+ ],
111
+ stop_btn="Stop Generation",
112
+ title="Chat With Phi-3-small-128k-instruct",
113
+ description="[microsoft/Phi-3-small-128k-instruct](https://huggingface.co/microsoft/Phi-3-small-128k-instruct)",
114
+ )
115
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ tiktoken
2
+ gradio
3
+ spaces
4
+ torch==2.2.0
5
+ git+https://github.com/huggingface/transformers/
6
+ optimum
7
+ accelerate
8
+ bitsandbytes
9
+ pytest