vilarin commited on
Commit
2a1804b
·
verified ·
1 Parent(s): 20cf7c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -13
app.py CHANGED
@@ -52,13 +52,13 @@ footer {
52
  }
53
  """
54
 
55
- MODEL_ID = "01-ai/Yi-1.5-6B-Chat"
56
 
57
  model = AutoModelForCausalLM.from_pretrained(
58
  MODEL_ID,
59
  torch_dtype=torch.float16,
60
  device_map="auto"
61
- ).eval()
62
 
63
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
64
 
@@ -130,24 +130,27 @@ async def gen_show(script):
130
 
131
  @spaces.GPU
132
  def generator(messages):
133
- input_ids = tokenizer.apply_chat_template(
134
- conversation=messages,
135
- add_generation_prompt=True,
136
- tokenize=True,
137
- return_tensors='pt'
138
  )
 
139
 
140
- output_ids = model.generate(
141
- input_ids.to('cuda'),
142
- eos_token_id=tokenizer.eos_token_id,
143
  max_new_tokens=4096,
144
  temperature=0.5,
145
  repetition_penalty=1.2,
146
  )
147
 
148
- results = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True)
149
- print(results)
150
- return results
 
 
 
 
151
 
152
  def extract_content(text):
153
  """Extracts the JSON content from the given text."""
 
52
  }
53
  """
54
 
55
+ MODEL_ID = "Qwen/Qwen2-1.5B-Instruct"
56
 
57
  model = AutoModelForCausalLM.from_pretrained(
58
  MODEL_ID,
59
  torch_dtype=torch.float16,
60
  device_map="auto"
61
+ )
62
 
63
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
64
 
 
130
 
131
  @spaces.GPU
132
  def generator(messages):
133
+ text = tokenizer.apply_chat_template(
134
+ messages,
135
+ tokenize=False,
136
+ add_generation_prompt=True
 
137
  )
138
+ model_inputs = tokenizer([text], return_tensors="pt").to(0)
139
 
140
+ generated_ids = model.generate(
141
+ model_inputs.input_ids,
 
142
  max_new_tokens=4096,
143
  temperature=0.5,
144
  repetition_penalty=1.2,
145
  )
146
 
147
+ generated_ids = [
148
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
149
+ ]
150
+
151
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
152
+ print(response)
153
+ return response
154
 
155
  def extract_content(text):
156
  """Extracts the JSON content from the given text."""