Spaces:
Sleeping
Sleeping
import gradio as gr | |
from langchain.prompts import PromptTemplate | |
from langchain_huggingface import HuggingFaceEndpoint | |
from langchain_core.output_parsers import JsonOutputParser | |
import time | |
# Initialize the LLM and other components | |
llm = HuggingFaceEndpoint( | |
repo_id="mistralai/Mistral-7B-Instruct-v0.3", | |
task="text-generation", | |
max_new_tokens=128, | |
temperature=0.7, | |
do_sample=False, | |
) | |
# Provide the family labels directly in the prompt | |
family_labels = [ | |
"Batteries and generators and kinetic power transmission", | |
"Building and facility maintenance and repair services", | |
"Business administration services", | |
"Communications Devices and Accessories", | |
"Components for information technology or broadcasting or telecommunications", | |
"Computer Equipment and Accessories", | |
"Consumer electronics", | |
"Data Voice or Multimedia Network Equipment or Platforms and Accessories", | |
"Domestic appliances", | |
"Electrical equipment and components and supplies", | |
"Electrical wire and cable and harness", | |
"Electronic hardware and component parts and accessories", | |
"Electronic manufacturing machinery and equipment and accessories", | |
"General agreements and contracts", | |
"Heating and ventilation and air circulation", | |
"Heavy construction machinery and equipment", | |
"Industrial process machinery and equipment and supplies", | |
"Management advisory services", | |
"Marketing and distribution", | |
"Office and desk accessories", | |
"Office machines and their supplies and accessories", | |
"Office supply", | |
"Power generation", | |
"Power sources", | |
"Printing and publishing equipment", | |
"Software", | |
"Structural components and basic shapes" | |
] | |
# Modify the prompt to focus on selecting a UNSPSC family label from the given list | |
template_classify = ''' | |
You are a classifier bot that assigns a UNSPSC family label to the given text. | |
Your task is to classify the text into one of the following UNSPSC family labels: | |
{family_labels} | |
Provide only the family label in your answer. If unsure, label as "Unknown". | |
Convert it to JSON format using 'Answer' as the key and return it. | |
Your final response MUST contain only the response, no other text. | |
Example: | |
{{"Answer":["Family Label"]}} | |
What is the UNSPSC family label for the following text?: | |
<text> | |
{TEXT} | |
</text> | |
''' | |
json_output_parser = JsonOutputParser() | |
# Define the classify_text function | |
def classify_text(text): | |
global llm | |
start = time.time() | |
# Join the family labels into a string for the prompt | |
family_labels_str = "\n".join(family_labels) | |
prompt_classify = PromptTemplate( | |
template=template_classify, | |
input_variables=["TEXT", "family_labels"] | |
) | |
formatted_prompt = prompt_classify.format(TEXT=text, family_labels=family_labels_str) | |
classify = llm.invoke(formatted_prompt) | |
parsed_output = json_output_parser.parse(classify) | |
end = time.time() | |
duration = end - start | |
return parsed_output["Answer"][0], duration | |
# Create the Gradio interface | |
def create_gradio_interface(): | |
with gr.Blocks() as iface: | |
text_input = gr.Textbox(label="Text") | |
output_text = gr.Textbox(label="Detected UNSPSC Family") | |
time_taken = gr.Textbox(label="Time Taken (seconds)") | |
submit_btn = gr.Button("Classify UNSPSC Family") | |
def on_submit(text): | |
classification, duration = classify_text(text) | |
return classification, f"Time taken: {duration:.2f} seconds" | |
submit_btn.click(fn=on_submit, inputs=text_input, outputs=[output_text, time_taken]) | |
iface.launch() | |
if __name__ == "__main__": | |
create_gradio_interface() | |