Mohammed-Altaf commited on
Commit
5b47eb8
1 Parent(s): b52e32b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -15
app.py CHANGED
@@ -2,31 +2,43 @@ import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  model_id = "Mohammed-Altaf/medical_chatbot-8bit"
5
- model = AutoModelForCausalLM.from_pretrained(model_id)
6
  tokenizer = AutoTokenizer.from_pretrained(model_id)
7
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  def generate_text(input_text):
10
  input_ids = tokenizer.encode(input_text, return_tensors="pt")
11
 
12
  output = model.generate(
13
- input_ids,
14
- max_length=200,
15
  )
16
 
17
  output_text = tokenizer.decode(output[0], skip_special_tokens=True)
18
- print(output_text)
19
 
20
- # Remove Prompt Echo from Generated Text
21
- cleaned_output_text = output_text.replace(input_text, "")
22
- return cleaned_output_text
23
 
24
 
25
- text_generation_interface = gr.Interface(
26
- fn=generate_text,
27
- inputs=[
28
- gr.inputs.Textbox(label="Input Text"),
29
- ],
30
- outputs=gr.inputs.Textbox(label="Generated Text"),
31
- title="Medical ChatBot",
32
- ).launch()
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  model_id = "Mohammed-Altaf/medical_chatbot-8bit"
5
+ model = AutoModelForCausalLM.from_pretrained(model_id,ignore_mismatched_sizes=True)
6
  tokenizer = AutoTokenizer.from_pretrained(model_id)
7
 
8
 
9
+ def get_clean_response(response):
10
+ if type(response) == list:
11
+ response = response[0].split("\n")
12
+ else:
13
+ response = response.split("\n")
14
+
15
+ ans = ''
16
+ cnt = 0 # to verify if we have seen Human before
17
+ for answer in response:
18
+ if answer.startswith("[|Human|]"): cnt += 1
19
+
20
+ elif answer.startswith('[|AI|]'):
21
+ answer = answer.split(' ')
22
+ ans += ' '.join(char for char in answer[1:])
23
+ ans += '\n'
24
+
25
+ elif cnt:
26
+ ans += answer + '\n'
27
+ return ans
28
+
29
+
30
  def generate_text(input_text):
31
  input_ids = tokenizer.encode(input_text, return_tensors="pt")
32
 
33
  output = model.generate(
34
+ **input_ids,
35
+ max_length=100,
36
  )
37
 
38
  output_text = tokenizer.decode(output[0], skip_special_tokens=True)
 
39
 
40
+ return get_clean_response(output_text)
 
 
41
 
42
 
43
+ iface = gr.Interface(fn = generate_text, inputs = 'text', outputs = ['text'], title ='Medical ChatBot')
44
+ iface.launch()