File size: 3,728 Bytes
851061b
 
 
c627930
9e51870
851061b
9e51870
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe967b9
 
9e51870
fe967b9
9e51870
 
 
 
 
fe967b9
 
 
 
9e51870
fe967b9
9e51870
 
851061b
12e3bb4
9e51870
 
 
5bfc54d
 
 
f0671cc
 
 
 
 
 
 
 
 
5bfc54d
851061b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr

device = "cuda" if torch.cuda.is_available() else "cpu"

model_name = "ai4bharat/Airavata"

tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)


def create_prompt_with_chat_format(messages, bos="<s>", eos="</s>", add_bos=True):
    formatted_text = ""
    for message in messages:
        if message["role"] == "system":
            formatted_text += "<|system|>\n" + message["content"] + "\n"
        elif message["role"] == "user":
            formatted_text += "<|user|>\n" + message["content"] + "\n"
        elif message["role"] == "assistant":
            formatted_text += "<|assistant|>\n" + message["content"].strip() + eos + "\n"
        else:
            raise ValueError(
                "Tulu chat template only supports 'system', 'user' and 'assistant' roles. Invalid role: {}.".format(
                    message["role"]
                )
            )
    formatted_text += "<|assistant|>\n"
    formatted_text = bos + formatted_text if add_bos else formatted_text
    return formatted_text

def inference(input_prompt, model, tokenizer):
    input_prompt = create_prompt_with_chat_format([{"role": "user", "content": input_prompt}], add_bos=False)

    encodings = tokenizer(input_prompt, padding=True, return_tensors="pt")
    encodings = encodings.to(device)

    with torch.inference_mode():  # Add missing import statement for torch.inference_mode()
        outputs = model.generate(encodings.input_ids, do_sample=False, max_new_tokens=250)

    output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

    input_prompt = tokenizer.decode(tokenizer.encode(input_prompt), skip_special_tokens=True)
    output_text = output_text[len(input_prompt):]

    return output_text


def chat_interface(message,history):
    outputs = inference(message, model, tokenizer)
    return outputs


chat_interface = gr.ChatInterface(chat_interface,
                                  title="CAMAI - Centralized Actionable Multimodal Agri Assistant on Edge Intelligence for Farmers ",
                                  theme='adam-haile/DSTheme',
                                  examples = ['दिल्ली में घूमने के लिए शीर्ष पांच सर्वोत्तम स्थान',
                                              'भारत में शीर्ष पांच प्रमुख फसलें कौन सी हैं?',
                                              'धान की फसल में जीवाणु रोग से कैसे बचें?',
                                              'Kya aap Close Up karte hain?',
                                             "Beej chayan mein kaun-kaun si baatein dhyan mein rakhni chahiye, kripaya vistar se batayein.",
                                             "Crop rotation aur intercropping mein kya benefits hain, agricultural experts se jankari leni hai.",
                                             "Hydroponic farming kya hai, aur ismein kis tarah ki challenges aati hain, is par experts se discuss karna hai.",
                                             "हाइड्रोपोनिक फार्मिंग क्या है, और इसमें कौन-कौन से challenges occur होते हैं, इस पर experts से discuss करना है।"
                                             ]
                                 )
chat_interface.launch()