Spaces:
Paused
Paused
File size: 3,728 Bytes
851061b c627930 9e51870 851061b 9e51870 fe967b9 9e51870 fe967b9 9e51870 fe967b9 9e51870 fe967b9 9e51870 851061b 12e3bb4 9e51870 5bfc54d f0671cc 5bfc54d 851061b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "ai4bharat/Airavata"
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
def create_prompt_with_chat_format(messages, bos="<s>", eos="</s>", add_bos=True):
formatted_text = ""
for message in messages:
if message["role"] == "system":
formatted_text += "<|system|>\n" + message["content"] + "\n"
elif message["role"] == "user":
formatted_text += "<|user|>\n" + message["content"] + "\n"
elif message["role"] == "assistant":
formatted_text += "<|assistant|>\n" + message["content"].strip() + eos + "\n"
else:
raise ValueError(
"Tulu chat template only supports 'system', 'user' and 'assistant' roles. Invalid role: {}.".format(
message["role"]
)
)
formatted_text += "<|assistant|>\n"
formatted_text = bos + formatted_text if add_bos else formatted_text
return formatted_text
def inference(input_prompt, model, tokenizer):
input_prompt = create_prompt_with_chat_format([{"role": "user", "content": input_prompt}], add_bos=False)
encodings = tokenizer(input_prompt, padding=True, return_tensors="pt")
encodings = encodings.to(device)
with torch.inference_mode(): # Add missing import statement for torch.inference_mode()
outputs = model.generate(encodings.input_ids, do_sample=False, max_new_tokens=250)
output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
input_prompt = tokenizer.decode(tokenizer.encode(input_prompt), skip_special_tokens=True)
output_text = output_text[len(input_prompt):]
return output_text
def chat_interface(message,history):
outputs = inference(message, model, tokenizer)
return outputs
chat_interface = gr.ChatInterface(chat_interface,
title="CAMAI - Centralized Actionable Multimodal Agri Assistant on Edge Intelligence for Farmers ",
theme='adam-haile/DSTheme',
examples = ['दिल्ली में घूमने के लिए शीर्ष पांच सर्वोत्तम स्थान',
'भारत में शीर्ष पांच प्रमुख फसलें कौन सी हैं?',
'धान की फसल में जीवाणु रोग से कैसे बचें?',
'Kya aap Close Up karte hain?',
"Beej chayan mein kaun-kaun si baatein dhyan mein rakhni chahiye, kripaya vistar se batayein.",
"Crop rotation aur intercropping mein kya benefits hain, agricultural experts se jankari leni hai.",
"Hydroponic farming kya hai, aur ismein kis tarah ki challenges aati hain, is par experts se discuss karna hai.",
"हाइड्रोपोनिक फार्मिंग क्या है, और इसमें कौन-कौन से challenges occur होते हैं, इस पर experts से discuss करना है।"
]
)
chat_interface.launch()
|