nisten commited on
Commit
720352d
·
verified ·
1 Parent(s): adaf527

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -8
app.py CHANGED
@@ -1,13 +1,13 @@
1
  import gradio as gr
2
  import spaces
3
- from transformers import OlmoeForCausalLM, AutoTokenizer
4
  import torch
5
  import subprocess
6
  import sys
7
- import os
8
 
9
- # Force upgrade transformers to the latest version
10
- subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "transformers"])
 
 
11
 
12
  model_name = "allenai/OLMoE-1B-7B-0924-Instruct"
13
 
@@ -37,18 +37,20 @@ def generate_response(message, history, temperature, max_new_tokens):
37
  if model is None or tokenizer is None:
38
  return "Model or tokenizer not loaded properly. Please check the logs."
39
 
40
- messages = [{"role": "user", "content": message}]
 
 
41
  inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(DEVICE)
42
 
43
  with torch.no_grad():
44
  generate_ids = model.generate(
45
- **inputs,
46
  max_new_tokens=max_new_tokens,
47
  do_sample=True,
48
  temperature=temperature,
49
  eos_token_id=tokenizer.eos_token_id,
50
  )
51
- response = tokenizer.decode(generate_ids[0], skip_special_tokens=True)
52
  return response.strip()
53
 
54
  css = """
@@ -84,4 +86,4 @@ with gr.Blocks(css=css) as demo:
84
 
85
  if __name__ == "__main__":
86
  demo.queue(api_open=False)
87
- demo.launch(debug=True, show_api=False, share=True )
 
1
  import gradio as gr
2
  import spaces
 
3
  import torch
4
  import subprocess
5
  import sys
 
6
 
7
+ # Force install the specific transformers version from the GitHub PR
8
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "--force-reinstall", "--no-deps", "git+https://github.com/Muennighoff/transformers.git@olmoe"])
9
+
10
+ from transformers import OlmoeForCausalLM, AutoTokenizer
11
 
12
  model_name = "allenai/OLMoE-1B-7B-0924-Instruct"
13
 
 
37
  if model is None or tokenizer is None:
38
  return "Model or tokenizer not loaded properly. Please check the logs."
39
 
40
+ messages = [{"role": "system", "content": system_prompt},
41
+ {"role": "user", "content": message}]
42
+
43
  inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(DEVICE)
44
 
45
  with torch.no_grad():
46
  generate_ids = model.generate(
47
+ inputs,
48
  max_new_tokens=max_new_tokens,
49
  do_sample=True,
50
  temperature=temperature,
51
  eos_token_id=tokenizer.eos_token_id,
52
  )
53
+ response = tokenizer.decode(generate_ids[0, inputs.shape[1]:], skip_special_tokens=True)
54
  return response.strip()
55
 
56
  css = """
 
86
 
87
  if __name__ == "__main__":
88
  demo.queue(api_open=False)
89
+ demo.launch(debug=True, show_api=False)