rombodawg commited on
Commit
f1125eb
·
verified ·
1 Parent(s): 473d783

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -8
app.py CHANGED
@@ -37,13 +37,14 @@ h3 {
37
 
38
  device = "cuda" # for GPU usage or "cpu" for CPU usage
39
 
40
- tokenizer = AutoTokenizer.from_pretrained(MODEL)
41
  model = AutoModelForCausalLM.from_pretrained(
42
  MODEL,
43
  torch_dtype=torch.bfloat16,
44
  device_map="auto",
45
  trust_remote_code=True,
46
- ignore_mismatched_sizes=True)
 
47
 
48
  def format_chat(system_prompt, history, message):
49
  formatted_chat = f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
@@ -60,10 +61,8 @@ def stream_chat(
60
  system_prompt: str,
61
  temperature: float = 0.3,
62
  max_new_tokens: int = 256,
63
- top_p: float = 1.0
64
- ,
65
  top_k: int = 20,
66
-
67
  repetition_penalty: float = 1.2,
68
  ):
69
  print(f'message: {message}')
@@ -72,8 +71,7 @@ def stream_chat(
72
  formatted_prompt = format_chat(system_prompt, history, message)
73
  inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
74
 
75
- streamer = TextIteratorStreamer(tokenizer, timeout=5000.0
76
- , skip_prompt=True, skip_special_tokens=True)
77
 
78
  generate_kwargs = dict(
79
  input_ids=inputs.input_ids,
@@ -167,4 +165,4 @@ with gr.Blocks(css=CSS, theme="soft") as demo:
167
  )
168
 
169
  if __name__ == "__main__":
170
- demo.launch()
 
37
 
38
  device = "cuda" # for GPU usage or "cpu" for CPU usage
39
 
40
+ tokenizer = AutoTokenizer.from_pretrained(MODEL, use_fast=False, force_download=True)
41
  model = AutoModelForCausalLM.from_pretrained(
42
  MODEL,
43
  torch_dtype=torch.bfloat16,
44
  device_map="auto",
45
  trust_remote_code=True,
46
+ ignore_mismatched_sizes=True,
47
+ force_download=True)
48
 
49
  def format_chat(system_prompt, history, message):
50
  formatted_chat = f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
 
61
  system_prompt: str,
62
  temperature: float = 0.3,
63
  max_new_tokens: int = 256,
64
+ top_p: float = 1.0,
 
65
  top_k: int = 20,
 
66
  repetition_penalty: float = 1.2,
67
  ):
68
  print(f'message: {message}')
 
71
  formatted_prompt = format_chat(system_prompt, history, message)
72
  inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
73
 
74
+ streamer = TextIteratorStreamer(tokenizer, timeout=5000.0, skip_prompt=True, skip_special_tokens=True)
 
75
 
76
  generate_kwargs = dict(
77
  input_ids=inputs.input_ids,
 
165
  )
166
 
167
  if __name__ == "__main__":
168
+ demo.launch()