vilarin commited on
Commit
d381360
·
verified ·
1 Parent(s): 22d8950

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -33
app.py CHANGED
@@ -1,15 +1,16 @@
1
  import os
2
  import time
3
- #import spaces
4
  import torch
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
  import gradio as gr
7
  from threading import Thread
8
 
9
- MODEL_LIST = ["HuggingFaceTB/SmolLM-1.7B-Instruct", "HuggingFaceTB/SmolLM-135M-Instruct", "HuggingFaceTB/SmolLM-360M-Instruct"]
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
 
11
 
12
- TITLE = "<h1><center>SmolLM-Instruct</center></h1>"
13
 
14
  PLACEHOLDER = """
15
  <center>
@@ -30,21 +31,12 @@ h3 {
30
  }
31
  """
32
 
33
- # pip install transformers
34
- from transformers import AutoModelForCausalLM, AutoTokenizer
35
 
36
- device = "cpu" # for GPU usage or "cpu" for CPU usage
 
37
 
38
- tokenizer0 = AutoTokenizer.from_pretrained(MODEL_LIST[0])
39
- model0 = AutoModelForCausalLM.from_pretrained(MODEL_LIST[0]).to(device)
40
-
41
- tokenizer1 = AutoTokenizer.from_pretrained(MODEL_LIST[1])
42
- model1 = AutoModelForCausalLM.from_pretrained(MODEL_LIST[1]).to(device)
43
-
44
- tokenizer2 = AutoTokenizer.from_pretrained(MODEL_LIST[2])
45
- model2 = AutoModelForCausalLM.from_pretrained(MODEL_LIST[2]).to(device)
46
-
47
- #@spaces.GPU()
48
  def stream_chat(
49
  message: str,
50
  history: list,
@@ -53,7 +45,6 @@ def stream_chat(
53
  top_p: float = 1.0,
54
  top_k: int = 20,
55
  penalty: float = 1.2,
56
- choice: str = "135M"
57
  ):
58
  print(f'message: {message}')
59
  print(f'history: {history}')
@@ -67,16 +58,6 @@ def stream_chat(
67
 
68
  conversation.append({"role": "user", "content": message})
69
 
70
- if choice == "1.7B":
71
- tokenizer = tokenizer0
72
- model = model0
73
- elif choice == "135M":
74
- model = model1
75
- tokenizer = tokenizer1
76
- else:
77
- model = model2
78
- tokenizer = tokenizer2
79
-
80
  input_text=tokenizer.apply_chat_template(conversation, tokenize=False)
81
  inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
82
  streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
@@ -154,12 +135,6 @@ with gr.Blocks(css=CSS, theme="soft") as demo:
154
  label="Repetition penalty",
155
  render=False,
156
  ),
157
- gr.Radio(
158
- ["135M", "360M", "1.7B"],
159
- value="135M",
160
- label="Load Model",
161
- render=False,
162
- ),
163
  ],
164
  examples=[
165
  ["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
 
1
  import os
2
  import time
3
+ import spaces
4
  import torch
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
  import gradio as gr
7
  from threading import Thread
8
 
9
+ MODEL_LIST = ["mistralai/Mistral-Nemo-Instruct-2407"]
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
+ MODEL = os.environ.get("MODEL_ID")
12
 
13
+ TITLE = "<h1><center>Mistral-Nemo</center></h1>"
14
 
15
  PLACEHOLDER = """
16
  <center>
 
31
  }
32
  """
33
 
34
+ device = "cuda" # for GPU usage or "cpu" for CPU usage
 
35
 
36
+ tokenizer = AutoTokenizer.from_pretrained(MODEL)
37
+ model = AutoModelForCausalLM.from_pretrained(MODEL).to(device)
38
 
39
+ @spaces.GPU()
 
 
 
 
 
 
 
 
 
40
  def stream_chat(
41
  message: str,
42
  history: list,
 
45
  top_p: float = 1.0,
46
  top_k: int = 20,
47
  penalty: float = 1.2,
 
48
  ):
49
  print(f'message: {message}')
50
  print(f'history: {history}')
 
58
 
59
  conversation.append({"role": "user", "content": message})
60
 
 
 
 
 
 
 
 
 
 
 
61
  input_text=tokenizer.apply_chat_template(conversation, tokenize=False)
62
  inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
63
  streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
 
135
  label="Repetition penalty",
136
  render=False,
137
  ),
 
 
 
 
 
 
138
  ],
139
  examples=[
140
  ["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],