Ashishkr commited on
Commit
6bbd1a8
1 Parent(s): bdb3c76

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +21 -37
model.py CHANGED
@@ -56,33 +56,15 @@ tokenizer = transformers.AutoTokenizer.from_pretrained(
56
  # return ''.join(texts)
57
 
58
 
59
-
60
- # def get_prompt(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> str:
61
- # texts = [f'{system_prompt}\n']
62
-
63
- # for user_input, response in chat_history[:-1]:
64
- # texts.append(f'{user_input} {response}\n')
65
-
66
- # # Getting the user input and response from the last tuple in the chat history
67
- # last_user_input, last_response = chat_history[-1]
68
- # texts.append(f' input: {last_user_input} {last_response} {message} response: ')
69
-
70
- # return ''.join(texts)
71
-
72
  def get_prompt(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> str:
73
  texts = [f'{system_prompt}\n']
74
 
75
- # If chat_history is not empty, process all but the last entry
76
- if chat_history:
77
- for user_input, response in chat_history[:-1]:
78
- texts.append(f'{user_input} {response}\n')
79
 
80
- # Getting the user input and response from the last tuple in the chat history
81
- last_user_input, last_response = chat_history[-1]
82
- texts.append(f' input: {last_user_input} {last_response} {message} Response: ')
83
- else:
84
- # If chat_history is empty, just add the message with 'Response:' at the end
85
- texts.append(f' input: {message} Response: ')
86
 
87
  return ''.join(texts)
88
 
@@ -99,26 +81,28 @@ def run(message: str,
99
  max_new_tokens: int = 256,
100
  temperature: float = 0.8,
101
  top_p: float = 0.95,
102
- top_k: int = 50) -> str:
103
  prompt = get_prompt(message, chat_history, system_prompt)
104
  inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False).to(device)
105
 
106
- # Generate tokens using the model
107
- output = model.generate(
108
- input_ids=inputs['input_ids'],
109
- attention_mask=inputs['attention_mask'],
110
- max_length=max_new_tokens + inputs['input_ids'].shape[-1],
 
 
 
111
  do_sample=True,
112
  top_p=top_p,
113
  top_k=top_k,
114
  temperature=temperature,
115
- num_beams=1
116
  )
 
 
117
 
118
- # Decode the output tokens back to a string
119
- output_text = tokenizer.decode(output[0], skip_special_tokens=True)
120
-
121
- # Remove everything including and after "instruct: "
122
- output_text = output_text.split("instruct: ")[0]
123
-
124
- return output_text
 
56
  # return ''.join(texts)
57
 
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  def get_prompt(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> str:
60
  texts = [f'{system_prompt}\n']
61
 
62
+ for user_input, response in chat_history[:-1]:
63
+ texts.append(f'{user_input} {response}\n')
 
 
64
 
65
+ # Getting the user input and response from the last tuple in the chat history
66
+ last_user_input, last_response = chat_history[-1]
67
+ texts.append(f' input: {last_user_input} {last_response} {message} Response: ')
 
 
 
68
 
69
  return ''.join(texts)
70
 
 
81
  max_new_tokens: int = 256,
82
  temperature: float = 0.8,
83
  top_p: float = 0.95,
84
+ top_k: int = 50) -> Iterator[str]:
85
  prompt = get_prompt(message, chat_history, system_prompt)
86
  inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False).to(device)
87
 
88
+ streamer = TextIteratorStreamer(tokenizer,
89
+ timeout=10.,
90
+ skip_prompt=True,
91
+ skip_special_tokens=True)
92
+ generate_kwargs = dict(
93
+ inputs,
94
+ streamer=streamer,
95
+ max_new_tokens=max_new_tokens,
96
  do_sample=True,
97
  top_p=top_p,
98
  top_k=top_k,
99
  temperature=temperature,
100
+ num_beams=1,
101
  )
102
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
103
+ t.start()
104
 
105
+ outputs = []
106
+ for text in streamer:
107
+ outputs.append(text)
108
+ yield ''.join(outputs)