SivaResearch commited on
Commit
9e51870
·
verified ·
1 Parent(s): 69fcc4b

updated with interface

Browse files
Files changed (1) hide show
  1. app.py +63 -98
app.py CHANGED
@@ -1,100 +1,65 @@
1
- import gradio as gr
 
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- # Load model and tokenizer directly
5
- tokenizer = AutoTokenizer.from_pretrained("ai4bharat/Airavata")
6
- model = AutoModelForCausalLM.from_pretrained("ai4bharat/Airavata")
7
-
8
- def chat_interface(user_input, assistant_input):
9
- # Concatenate the user and assistant inputs to simulate a chat conversation
10
- chat_history = f"{assistant_input} User: {user_input}"
11
-
12
- # Tokenize the chat history and generate the response
13
- inputs = tokenizer(chat_history, return_tensors="pt", max_length=256, truncation=True)
14
- outputs = model.generate(**inputs)
15
- response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
16
-
17
- return response, chat_history
18
-
19
- # Define Gradio Chat Interface
20
- iface = gr.ChatInterface(
21
- chat_model=chat_interface,
22
- title="GPT-2 Chat Interface",
23
- inputs=["text", "text"],
24
- outputs=["text", "text"],
25
- )
26
-
27
- # Launch Gradio Chat Interface
28
- iface.launch()
29
-
30
-
31
-
32
- # import torch
33
- # from transformers import AutoTokenizer, AutoModelForCausalLM
34
- # import gradio as gr
35
-
36
- # device = "cuda" if torch.cuda.is_available() else "cpu"
37
-
38
-
39
- # def create_prompt_with_chat_format(messages, bos="<s>", eos="</s>", add_bos=True):
40
- # formatted_text = ""
41
- # for message in messages:
42
- # if message["role"] == "system":
43
- # formatted_text += "<|system|>\n" + message["content"] + "\n"
44
- # elif message["role"] == "user":
45
- # formatted_text += "<|user|>\n" + message["content"] + "\n"
46
- # elif message["role"] == "assistant":
47
- # formatted_text += "<|assistant|>\n" + message["content"].strip() + eos + "\n"
48
- # else:
49
- # raise ValueError(
50
- # "Tulu chat template only supports 'system', 'user' and 'assistant' roles. Invalid role: {}.".format(
51
- # message["role"]
52
- # )
53
- # )
54
- # formatted_text += "<|assistant|>\n"
55
- # formatted_text = bos + formatted_text if add_bos else formatted_text
56
- # return formatted_text
57
-
58
-
59
- # def inference(input_prompts, model, tokenizer):
60
- # input_prompts = [
61
- # create_prompt_with_chat_format([{"role": "user", "content": input_prompt}], add_bos=False)
62
- # for input_prompt in input_prompts
63
- # ]
64
-
65
- # encodings = tokenizer(input_prompts, padding=True, return_tensors="pt")
66
- # encodings = encodings.to(device)
67
-
68
- # with torch.inference_mode():
69
- # outputs = model.generate(encodings.input_ids, do_sample=False, max_new_tokens=250)
70
-
71
- # output_texts = tokenizer.batch_decode(outputs.detach(), skip_special_tokens=True)
72
-
73
- # input_prompts = [
74
- # tokenizer.decode(tokenizer.encode(input_prompt), skip_special_tokens=True) for input_prompt in input_prompts
75
- # ]
76
- # output_texts = [output_text[len(input_prompt) :] for input_prompt, output_text in zip(input_prompts, output_texts)]
77
- # return output_texts
78
-
79
-
80
- # model_name = "ai4bharat/Airavata"
81
-
82
- # tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
83
- # tokenizer.pad_token = tokenizer.eos_token
84
- # model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
85
- # print(f"Loading model: {model_name}")
86
-
87
- # examples = [
88
- # ["मैं अपने समय प्रबंधन कौशल को कैसे सुधार सकता हूँ? मुझे पांच बिंदु बताएं।"],
89
- # ["मैं अपने समय प्रबंधन कौशल को कैसे सुधार सकता हूँ? मुझे पांच बिंदु बताएं और उनका वर्णन करें।"],
90
- # ]
91
-
92
- # def chat_interface(input_prompts):
93
- # outputs = inference(input_prompts, model, tokenizer)
94
- # return outputs
95
-
96
- # gr.Interface(fn=chat_interface,
97
- # inputs="text",
98
- # outputs="text",
99
- # examples=examples,
100
- # title="CAMAI ChatBot").launch()
 
1
+
2
+
3
+ import torch
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ import gradio as gr
6
+
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
+
9
+ model_name = "ai4bharat/Airavata"
10
+
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
12
+ tokenizer.pad_token = tokenizer.eos_token
13
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
14
+
15
+
16
+ def create_prompt_with_chat_format(messages, bos="<s>", eos="</s>", add_bos=True):
17
+ formatted_text = ""
18
+ for message in messages:
19
+ if message["role"] == "system":
20
+ formatted_text += "<|system|>\n" + message["content"] + "\n"
21
+ elif message["role"] == "user":
22
+ formatted_text += "<|user|>\n" + message["content"] + "\n"
23
+ elif message["role"] == "assistant":
24
+ formatted_text += "<|assistant|>\n" + message["content"].strip() + eos + "\n"
25
+ else:
26
+ raise ValueError(
27
+ "Tulu chat template only supports 'system', 'user' and 'assistant' roles. Invalid role: {}.".format(
28
+ message["role"]
29
+ )
30
+ )
31
+ formatted_text += "<|assistant|>\n"
32
+ formatted_text = bos + formatted_text if add_bos else formatted_text
33
+ return formatted_text
34
+
35
+
36
+ def inference(input_prompts, model, tokenizer):
37
+ input_prompts = [
38
+ create_prompt_with_chat_format([{"role": "user", "content": input_prompt}], add_bos=False)
39
+ for input_prompt in input_prompts
40
+ ]
41
+
42
+ encodings = tokenizer(input_prompts, padding=True, return_tensors="pt")
43
+ encodings = encodings.to(device)
44
+
45
+ with torch.inference_mode(): # Add missing import statement for torch.inference_mode()
46
+ outputs = model.generate(encodings.input_ids, do_sample=False, max_new_tokens=250)
47
+
48
+ output_texts = tokenizer.batch_decode(outputs.detach(), skip_special_tokens=True)
49
+
50
+ input_prompts = [
51
+ tokenizer.decode(tokenizer.encode(input_prompt), skip_special_tokens=True) for input_prompt in input_prompts
52
+ ]
53
+ output_texts = [output_text[len(input_prompt) :] for input_prompt, output_text in zip(input_prompts, output_texts)]
54
+ return output_texts
55
+
56
+
57
+ def chat_interface(input_prompts):
58
+ outputs = inference(input_prompts, model, tokenizer)
59
+ return outputs
60
+
61
+
62
+ inputs = gr.inputs.Textbox(lines=2, label="User Input")
63
+ outputs = gr.outputs.Textbox(label="Assistant Response")
64
 
65
+ gr.Interface(fn=chat_interface, inputs=inputs, outputs=outputs, title="Chat Interface").launch()