Spaces:
Sleeping
Sleeping
File size: 3,668 Bytes
ae5819c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import gradio as gr
from langchain.prompts import PromptTemplate
from langchain_huggingface import HuggingFaceEndpoint
from langchain_core.output_parsers import JsonOutputParser
import time
# Initialize the LLM and other components
llm = HuggingFaceEndpoint(
repo_id="mistralai/Mistral-7B-Instruct-v0.3",
task="text-generation",
max_new_tokens=128,
temperature=0.7,
do_sample=False,
)
# Provide the family labels directly in the prompt
family_labels = [
"Batteries and generators and kinetic power transmission",
"Building and facility maintenance and repair services",
"Business administration services",
"Communications Devices and Accessories",
"Components for information technology or broadcasting or telecommunications",
"Computer Equipment and Accessories",
"Consumer electronics",
"Data Voice or Multimedia Network Equipment or Platforms and Accessories",
"Domestic appliances",
"Electrical equipment and components and supplies",
"Electrical wire and cable and harness",
"Electronic hardware and component parts and accessories",
"Electronic manufacturing machinery and equipment and accessories",
"General agreements and contracts",
"Heating and ventilation and air circulation",
"Heavy construction machinery and equipment",
"Industrial process machinery and equipment and supplies",
"Management advisory services",
"Marketing and distribution",
"Office and desk accessories",
"Office machines and their supplies and accessories",
"Office supply",
"Power generation",
"Power sources",
"Printing and publishing equipment",
"Software",
"Structural components and basic shapes"
]
# Modify the prompt to focus on selecting a UNSPSC family label from the given list
template_classify = '''
You are a classifier bot that assigns a UNSPSC family label to the given text.
Your task is to classify the text into one of the following UNSPSC family labels:
{family_labels}
Provide only the family label in your answer. If unsure, label as "Unknown".
Convert it to JSON format using 'Answer' as the key and return it.
Your final response MUST contain only the response, no other text.
Example:
{{"Answer":["Family Label"]}}
What is the UNSPSC family label for the following text?:
<text>
{TEXT}
</text>
'''
json_output_parser = JsonOutputParser()
# Define the classify_text function
def classify_text(text):
global llm
start = time.time()
# Join the family labels into a string for the prompt
family_labels_str = "\n".join(family_labels)
prompt_classify = PromptTemplate(
template=template_classify,
input_variables=["TEXT", "family_labels"]
)
formatted_prompt = prompt_classify.format(TEXT=text, family_labels=family_labels_str)
classify = llm.invoke(formatted_prompt)
parsed_output = json_output_parser.parse(classify)
end = time.time()
duration = end - start
return parsed_output["Answer"][0], duration
# Create the Gradio interface
def create_gradio_interface():
with gr.Blocks() as iface:
text_input = gr.Textbox(label="Text")
output_text = gr.Textbox(label="Detected UNSPSC Family")
time_taken = gr.Textbox(label="Time Taken (seconds)")
submit_btn = gr.Button("Classify UNSPSC Family")
def on_submit(text):
classification, duration = classify_text(text)
return classification, f"Time taken: {duration:.2f} seconds"
submit_btn.click(fn=on_submit, inputs=text_input, outputs=[output_text, time_taken])
iface.launch()
if __name__ == "__main__":
create_gradio_interface()
|