|
import os |
|
from transformers import DebertaV2Tokenizer, DebertaV2ForSequenceClassification |
|
import streamlit as st |
|
import torch |
|
import torch.nn.functional as F |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
model_path = "dejanseo/DEJAN-Taxonomy-Classifier" |
|
|
|
|
|
tokenizer = DebertaV2Tokenizer.from_pretrained(model_path, use_auth_token=HF_TOKEN) |
|
model = DebertaV2ForSequenceClassification.from_pretrained(model_path, use_auth_token=HF_TOKEN) |
|
|
|
|
|
LABEL_MAPPING = { |
|
1: 0, 8: 1, 111: 2, 141: 3, 166: 4, 222: 5, 412: 6, 436: 7, |
|
469: 8, 536: 9, 537: 10, 632: 11, 772: 12, 783: 13, 888: 14, |
|
922: 15, 988: 16, 1239: 17, 2092: 18, 5181: 19, 5605: 20 |
|
} |
|
|
|
CATEGORY_NAMES = { |
|
1: "Animals & Pet Supplies", |
|
8: "Arts & Entertainment", |
|
111: "Apparel & Accessories", |
|
141: "Baby & Toddler", |
|
166: "Books & Magazines", |
|
222: "Business & Industrial", |
|
412: "Cameras & Optics", |
|
436: "Cars & Vehicles", |
|
469: "Computers & Electronics", |
|
536: "Food & Beverages", |
|
537: "Furniture", |
|
632: "Hardware", |
|
772: "Health & Beauty", |
|
783: "Home & Garden", |
|
888: "Luggage & Bags", |
|
922: "Media", |
|
988: "Sporting Goods", |
|
1239: "Software", |
|
2092: "Sports & Outdoors", |
|
5181: "Toys & Games", |
|
5605: "Travel & Tourism" |
|
} |
|
|
|
|
|
INDEX_TO_CATEGORY = {v: f"[{k}] {CATEGORY_NAMES[k]}" for k, v in LABEL_MAPPING.items()} |
|
|
|
|
|
st.title("Google Taxonomy Classifier by DEJAN") |
|
st.write("Enter text in the input box, and the model will classify it into one of the 21 top level categories. This demo showcases early model capability while the full 5000+ label model is undergoing extensive training.") |
|
|
|
|
|
input_text = st.text_area("Enter text for classification:") |
|
|
|
|
|
def classify_text(text): |
|
if not text.strip(): |
|
return None |
|
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
|
|
probabilities = F.softmax(logits, dim=-1).squeeze().tolist() |
|
return probabilities |
|
|
|
|
|
if st.button("Classify"): |
|
if input_text.strip(): |
|
st.write("Processing...") |
|
|
|
probabilities = classify_text(input_text) |
|
if probabilities: |
|
|
|
mapped_probs = {INDEX_TO_CATEGORY[idx]: prob for idx, prob in enumerate(probabilities)} |
|
|
|
sorted_categories = sorted(mapped_probs.items(), key=lambda x: x[1], reverse=True) |
|
categories = [item[0] for item in sorted_categories] |
|
values = [item[1] for item in sorted_categories] |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(10, 6)) |
|
ax.barh(categories, values) |
|
ax.set_xlabel("Probability") |
|
ax.set_ylabel("Category") |
|
ax.set_title("Classification Probabilities") |
|
ax.invert_yaxis() |
|
ax.set_xlim(0, 1) |
|
st.pyplot(fig) |
|
else: |
|
st.error("Could not classify the text. Please try again.") |
|
else: |
|
st.warning("Please enter some text for classification.") |
|
|