vilarin commited on
Commit
389126e
·
verified ·
1 Parent(s): 2a1804b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -16
app.py CHANGED
@@ -52,13 +52,13 @@ footer {
52
  }
53
  """
54
 
55
- MODEL_ID = "Qwen/Qwen2-1.5B-Instruct"
56
 
57
  model = AutoModelForCausalLM.from_pretrained(
58
  MODEL_ID,
59
  torch_dtype=torch.float16,
60
  device_map="auto"
61
- )
62
 
63
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
64
 
@@ -130,27 +130,24 @@ async def gen_show(script):
130
 
131
  @spaces.GPU
132
  def generator(messages):
133
- text = tokenizer.apply_chat_template(
134
- messages,
135
- tokenize=False,
136
- add_generation_prompt=True
 
137
  )
138
- model_inputs = tokenizer([text], return_tensors="pt").to(0)
139
 
140
- generated_ids = model.generate(
141
- model_inputs.input_ids,
 
142
  max_new_tokens=4096,
143
  temperature=0.5,
144
  repetition_penalty=1.2,
145
  )
146
 
147
- generated_ids = [
148
- output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
149
- ]
150
-
151
- response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
152
- print(response)
153
- return response
154
 
155
  def extract_content(text):
156
  """Extracts the JSON content from the given text."""
 
52
  }
53
  """
54
 
55
+ MODEL_ID = "01-ai/Yi-1.5-6B-Chat"
56
 
57
  model = AutoModelForCausalLM.from_pretrained(
58
  MODEL_ID,
59
  torch_dtype=torch.float16,
60
  device_map="auto"
61
+ ).eval()
62
 
63
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
64
 
 
130
 
131
  @spaces.GPU
132
  def generator(messages):
133
+ input_ids = tokenizer.apply_chat_template(
134
+ conversation=messages,
135
+ add_generation_prompt=True,
136
+ tokenize=True,
137
+ return_tensors='pt'
138
  )
 
139
 
140
+ output_ids = model.generate(
141
+ input_ids.to('cuda'),
142
+ eos_token_id=tokenizer.eos_token_id,
143
  max_new_tokens=4096,
144
  temperature=0.5,
145
  repetition_penalty=1.2,
146
  )
147
 
148
+ results = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True)
149
+ print(results)
150
+ return results
 
 
 
 
151
 
152
  def extract_content(text):
153
  """Extracts the JSON content from the given text."""