K00B404 commited on
Commit
0a751e6
·
verified ·
1 Parent(s): 02b7231

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -140
app.py CHANGED
@@ -9,146 +9,47 @@ from threading import Thread
9
  MODEL_LIST = ["mistralai/Mistral-Nemo-Instruct-2407"]
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
  MODEL = os.environ.get("MODEL_ID")
 
12
 
13
- TITLE = "<h1><center>Mistral-Nemo</center></h1>"
14
-
15
- PLACEHOLDER = """
16
- <center>
17
- <p>The Mistral-Nemo is a pretrained generative text model of 12B parameters trained jointly by Mistral AI and NVIDIA.</p>
18
- </center>
19
- """
20
-
21
-
22
- CSS = """
23
- .duplicate-button {
24
- margin: auto !important;
25
- color: white !important;
26
- background: black !important;
27
- border-radius: 100vh !important;
28
- }
29
- h3 {
30
- text-align: center;
31
- }
32
- """
33
-
34
- device = "cpu" # for GPU usage or "cpu" for CPU usage
35
-
36
- tokenizer = AutoTokenizer.from_pretrained(MODEL)
37
- model = AutoModelForCausalLM.from_pretrained(
38
- MODEL,
39
- torch_dtype=torch.bfloat16,
40
- device_map="auto",
41
- ignore_mismatched_sizes=True)
42
-
43
- @spaces.GPU()
44
- def stream_chat(
45
- message: str,
46
- history: list,
47
- temperature: float = 0.3,
48
- max_new_tokens: int = 1024,
49
- top_p: float = 1.0,
50
- top_k: int = 20,
51
- penalty: float = 1.2,
52
- ):
53
- print(f'message: {message}')
54
- print(f'history: {history}')
55
-
56
- conversation = []
57
- for prompt, answer in history:
58
- conversation.extend([
59
- {"role": "user", "content": prompt},
60
- {"role": "assistant", "content": answer},
61
- ])
62
-
63
- conversation.append({"role": "user", "content": message})
64
-
65
- input_text=tokenizer.apply_chat_template(conversation, tokenize=False)
66
- inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
67
- streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
68
 
69
- generate_kwargs = dict(
70
- input_ids=inputs,
71
- max_new_tokens = max_new_tokens,
72
- do_sample = False if temperature == 0 else True,
73
- top_p = top_p,
74
- top_k = top_k,
75
- temperature = temperature,
76
- streamer=streamer,
77
- repetition_penalty=penalty,
78
- pad_token_id = 10,
79
- )
80
-
81
- with torch.no_grad():
82
- thread = Thread(target=model.generate, kwargs=generate_kwargs)
83
- thread.start()
84
-
85
- buffer = ""
86
- for new_text in streamer:
87
- buffer += new_text
88
- yield buffer
89
-
90
-
91
- chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
92
-
93
- with gr.Blocks(css=CSS, theme="Nymbo/Nymbo_Theme") as demo:
94
- gr.HTML(TITLE)
95
- gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
96
- gr.ChatInterface(
97
- fn=stream_chat,
98
- chatbot=chatbot,
99
- fill_height=True,
100
- additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
101
- additional_inputs=[
102
- gr.Slider(
103
- minimum=0,
104
- maximum=1,
105
- step=0.1,
106
- value=0.3,
107
- label="Temperature",
108
- render=False,
109
- ),
110
- gr.Slider(
111
- minimum=128,
112
- maximum=8192,
113
- step=1,
114
- value=1024,
115
- label="Max new tokens",
116
- render=False,
117
- ),
118
- gr.Slider(
119
- minimum=0.0,
120
- maximum=1.0,
121
- step=0.1,
122
- value=1.0,
123
- label="top_p",
124
- render=False,
125
- ),
126
- gr.Slider(
127
- minimum=1,
128
- maximum=20,
129
- step=1,
130
- value=20,
131
- label="top_k",
132
- render=False,
133
- ),
134
- gr.Slider(
135
- minimum=0.0,
136
- maximum=2.0,
137
- step=0.1,
138
- value=1.2,
139
- label="Repetition penalty",
140
- render=False,
141
- ),
142
- ],
143
- examples=[
144
- ["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
145
- ["What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."],
146
- ["Tell me a random fun fact about the Roman Empire."],
147
- ["Show me a code snippet of a website's sticky header in CSS and JavaScript."],
148
- ],
149
- cache_examples=False,
150
- )
151
-
152
-
153
  if __name__ == "__main__":
154
- demo.launch()
 
9
  MODEL_LIST = ["mistralai/Mistral-Nemo-Instruct-2407"]
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
  MODEL = os.environ.get("MODEL_ID")
12
+ # filename: gradio_app.py
13
 
14
+ import gradio as gr
15
+ from huggingface_hub import InferenceClient
16
+
17
+ # Initialize the InferenceClient
18
+ client = InferenceClient(
19
+ "mistralai/Mistral-Nemo-Instruct-2407",
20
+ token="hf_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
21
+ )
22
+
23
+ def chat_with_model(system_prompt, user_message):
24
+ # Prepare messages for the chat completion
25
+ messages = [
26
+ {"role": "system", "content": system_prompt},
27
+ {"role": "user", "content": user_message}
28
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ # Collect the response from the model
31
+ response = ""
32
+ for message in client.chat_completion(
33
+ messages=messages,
34
+ max_tokens=500,
35
+ stream=True
36
+ ):
37
+ response += message.choices[0].delta.content
38
+
39
+ return response
40
+
41
+ # Create the Gradio interface
42
+ iface = gr.Interface(
43
+ fn=chat_with_model,
44
+ inputs=[
45
+ gr.Textbox(label="System Prompt", placeholder="Enter the system prompt here..."),
46
+ gr.Textbox(label="User Message", placeholder="Ask a question..."),
47
+ ],
48
+ outputs=gr.Textbox(label="Response"),
49
+ title="Mistral Chatbot",
50
+ description="Chat with Mistral model using your own system prompts."
51
+ )
52
+
53
+ # Launch the app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  if __name__ == "__main__":
55
+ iface.launch(show_api=True, share=False,show_error=True)