SivaResearch commited on
Commit
cdf298d
·
verified ·
1 Parent(s): 0d58b13

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -21
app.py CHANGED
@@ -1,26 +1,89 @@
1
- import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
 
4
- tokenizer = AutoTokenizer.from_pretrained("ai4bharat/Airavata")
5
- model = AutoModelForCausalLM.from_pretrained("ai4bharat/Airavata")
6
-
7
- def generate_response(prompt):
8
- input_ids = tokenizer.encode(prompt, return_tensors="pt", max_length=50)
9
- output_ids = model.generate(input_ids, max_length=100, num_beams=5, no_repeat_ngram_size=2)
10
- response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
11
- return response
12
-
13
- iface = gr.Interface(
14
- fn=generate_response,
15
- inputs="text",
16
- outputs="text",
17
- live=True,
18
- title="Airavata LLMs Chatbot",
19
- description="Ask me anything, and I'll generate a response!",
20
- theme="light",
21
- )
22
-
23
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
 
26
 
 
1
+ import torch
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
 
4
+ device = "cuda" if torch.cuda.is_available() else "cpu"
5
+
6
+
7
+ def create_prompt_with_chat_format(messages, bos="<s>", eos="</s>", add_bos=True):
8
+ formatted_text = ""
9
+ for message in messages:
10
+ if message["role"] == "system":
11
+ formatted_text += "<|system|>\n" + message["content"] + "\n"
12
+ elif message["role"] == "user":
13
+ formatted_text += "<|user|>\n" + message["content"] + "\n"
14
+ elif message["role"] == "assistant":
15
+ formatted_text += "<|assistant|>\n" + message["content"].strip() + eos + "\n"
16
+ else:
17
+ raise ValueError(
18
+ "Tulu chat template only supports 'system', 'user' and 'assistant' roles. Invalid role: {}.".format(
19
+ message["role"]
20
+ )
21
+ )
22
+ formatted_text += "<|assistant|>\n"
23
+ formatted_text = bos + formatted_text if add_bos else formatted_text
24
+ return formatted_text
25
+
26
+
27
+ def inference(input_prompts, model, tokenizer):
28
+ input_prompts = [
29
+ create_prompt_with_chat_format([{"role": "user", "content": input_prompt}], add_bos=False)
30
+ for input_prompt in input_prompts
31
+ ]
32
+
33
+ encodings = tokenizer(input_prompts, padding=True, return_tensors="pt")
34
+ encodings = encodings.to(device)
35
+
36
+ with torch.inference_mode():
37
+ outputs = model.generate(encodings.input_ids, do_sample=False, max_new_tokens=250)
38
+
39
+ output_texts = tokenizer.batch_decode(outputs.detach(), skip_special_tokens=True)
40
+
41
+ input_prompts = [
42
+ tokenizer.decode(tokenizer.encode(input_prompt), skip_special_tokens=True) for input_prompt in input_prompts
43
+ ]
44
+ output_texts = [output_text[len(input_prompt) :] for input_prompt, output_text in zip(input_prompts, output_texts)]
45
+ return output_texts
46
+
47
+
48
+ model_name = "ai4bharat/Airavata"
49
+
50
+ tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
51
+ tokenizer.pad_token = tokenizer.eos_token
52
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
53
+
54
+ input_prompts = [
55
+ "मैं अपने समय प्रबंधन कौशल को कैसे सुधार सकता हूँ? मुझे पांच बिंदु बताएं।",
56
+ "मैं अपने समय प्रबंधन कौशल को कैसे सुधार सकता हूँ? मुझे पांच बिंदु बताएं और उनका वर्णन करें।",
57
+ ]
58
+ outputs = inference(input_prompts, model, tokenizer)
59
+ print(outputs)
60
+
61
+
62
+
63
+
64
+ # import gradio as gr
65
+ # from transformers import AutoTokenizer, AutoModelForCausalLM
66
+
67
+ # tokenizer = AutoTokenizer.from_pretrained("ai4bharat/Airavata")
68
+ # model = AutoModelForCausalLM.from_pretrained("ai4bharat/Airavata")
69
+
70
+ # def generate_response(prompt):
71
+ # input_ids = tokenizer.encode(prompt, return_tensors="pt", max_length=50)
72
+ # output_ids = model.generate(input_ids, max_length=100, num_beams=5, no_repeat_ngram_size=2)
73
+ # response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
74
+ # return response
75
+
76
+ # iface = gr.Interface(
77
+ # fn=generate_response,
78
+ # inputs="text",
79
+ # outputs="text",
80
+ # live=True,
81
+ # title="Airavata LLMs Chatbot",
82
+ # description="Ask me anything, and I'll generate a response!",
83
+ # theme="light",
84
+ # )
85
+
86
+ # iface.launch()
87
 
88
 
89