dejanseo's picture
a975033 verified
history blame
3.75 kB
import os
from transformers import DebertaV2Tokenizer, DebertaV2ForSequenceClassification
import streamlit as st
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
# Load Hugging Face token from environment
HF_TOKEN = os.getenv("HF_TOKEN")
model_path = "dejanseo/DEJAN-Taxonomy-Classifier"
# Load the model and tokenizer using the token
tokenizer = DebertaV2Tokenizer.from_pretrained(model_path, use_auth_token=HF_TOKEN)
model = DebertaV2ForSequenceClassification.from_pretrained(model_path, use_auth_token=HF_TOKEN)
# LABEL_MAPPING (from model index to numeric ID) and corresponding category names
1: 0, 8: 1, 111: 2, 141: 3, 166: 4, 222: 5, 412: 6, 436: 7,
469: 8, 536: 9, 537: 10, 632: 11, 772: 12, 783: 13, 888: 14,
922: 15, 988: 16, 1239: 17, 2092: 18, 5181: 19, 5605: 20
1: "Animals & Pet Supplies",
8: "Arts & Entertainment",
111: "Apparel & Accessories",
141: "Baby & Toddler",
166: "Books & Magazines",
222: "Business & Industrial",
412: "Cameras & Optics",
436: "Cars & Vehicles",
469: "Computers & Electronics",
536: "Food & Beverages",
537: "Furniture",
632: "Hardware",
772: "Health & Beauty",
783: "Home & Garden",
888: "Luggage & Bags",
922: "Media",
988: "Sporting Goods",
1239: "Software",
2092: "Sports & Outdoors",
5181: "Toys & Games",
5605: "Travel & Tourism"
# Reverse mapping for model output index to text label
INDEX_TO_CATEGORY = {v: f"[{k}] {CATEGORY_NAMES[k]}" for k, v in LABEL_MAPPING.items()}
# Set Streamlit app title
st.title("Google Taxonomy Classifier by DEJAN")
st.write("Enter text in the input box, and the model will classify it into one of the 21 top level categories. This demo showcases early model capability while the full 5000+ label model is undergoing extensive training.")
# Input text box
input_text = st.text_area("Enter text for classification:")
# Inference function
def classify_text(text):
if not text.strip():
return None
# Tokenize and encode input text
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
# Get model predictions
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Convert logits to probabilities using softmax
probabilities = F.softmax(logits, dim=-1).squeeze().tolist()
return probabilities
# Display results when text is entered
if st.button("Classify"):
if input_text.strip():
# Classify the input text
probabilities = classify_text(input_text)
if probabilities:
# Map probabilities to categories
mapped_probs = {INDEX_TO_CATEGORY[idx]: prob for idx, prob in enumerate(probabilities)}
# Sort categories by probability in descending order
sorted_categories = sorted(mapped_probs.items(), key=lambda x: x[1], reverse=True)
categories = [item[0] for item in sorted_categories]
values = [item[1] for item in sorted_categories]
# Create horizontal bar chart
fig, ax = plt.subplots(figsize=(10, 6))
ax.barh(categories, values)
ax.set_title("Classification Probabilities")
ax.invert_yaxis() # Ensure highest probability is at the top
ax.set_xlim(0, 1) # Set the x-axis range to 0-1 for probabilities
st.error("Could not classify the text. Please try again.")
st.warning("Please enter some text for classification.")