import gradio as gr from langchain.prompts import PromptTemplate from langchain_huggingface import HuggingFaceEndpoint from langchain_core.output_parsers import JsonOutputParser import time # Initialize the LLM and other components llm = HuggingFaceEndpoint( repo_id="mistralai/Mistral-7B-Instruct-v0.3", task="text-generation", max_new_tokens=128, temperature=0.7, do_sample=False, ) # Provide the family labels directly in the prompt family_labels = [ "Batteries and generators and kinetic power transmission", "Building and facility maintenance and repair services", "Business administration services", "Communications Devices and Accessories", "Components for information technology or broadcasting or telecommunications", "Computer Equipment and Accessories", "Consumer electronics", "Data Voice or Multimedia Network Equipment or Platforms and Accessories", "Domestic appliances", "Electrical equipment and components and supplies", "Electrical wire and cable and harness", "Electronic hardware and component parts and accessories", "Electronic manufacturing machinery and equipment and accessories", "General agreements and contracts", "Heating and ventilation and air circulation", "Heavy construction machinery and equipment", "Industrial process machinery and equipment and supplies", "Management advisory services", "Marketing and distribution", "Office and desk accessories", "Office machines and their supplies and accessories", "Office supply", "Power generation", "Power sources", "Printing and publishing equipment", "Software", "Structural components and basic shapes" ] # Modify the prompt to focus on selecting a UNSPSC family label from the given list template_classify = ''' You are a classifier bot that assigns a UNSPSC family label to the given text. Your task is to classify the text into one of the following UNSPSC family labels: {family_labels} Provide only the family label in your answer. If unsure, label as "Unknown". Convert it to JSON format using 'Answer' as the key and return it. Your final response MUST contain only the response, no other text. Example: {{"Answer":["Family Label"]}} What is the UNSPSC family label for the following text?: {TEXT} ''' json_output_parser = JsonOutputParser() # Define the classify_text function def classify_text(text): global llm start = time.time() # Join the family labels into a string for the prompt family_labels_str = "\n".join(family_labels) prompt_classify = PromptTemplate( template=template_classify, input_variables=["TEXT", "family_labels"] ) formatted_prompt = prompt_classify.format(TEXT=text, family_labels=family_labels_str) classify = llm.invoke(formatted_prompt) parsed_output = json_output_parser.parse(classify) end = time.time() duration = end - start return parsed_output["Answer"][0], duration # Create the Gradio interface def create_gradio_interface(): with gr.Blocks() as iface: text_input = gr.Textbox(label="Text") output_text = gr.Textbox(label="Detected UNSPSC Family") time_taken = gr.Textbox(label="Time Taken (seconds)") submit_btn = gr.Button("Classify UNSPSC Family") def on_submit(text): classification, duration = classify_text(text) return classification, f"Time taken: {duration:.2f} seconds" submit_btn.click(fn=on_submit, inputs=text_input, outputs=[output_text, time_taken]) iface.launch() if __name__ == "__main__": create_gradio_interface()