import os |
from transformers import DebertaV2Tokenizer, DebertaV2ForSequenceClassification |
import streamlit as st |
import torch |
import torch.nn.functional as F |
import matplotlib.pyplot as plt |
HF_TOKEN = os.getenv("HF_TOKEN") |
model_path = "dejanseo/DEJAN-Taxonomy-Classifier" |
tokenizer = DebertaV2Tokenizer.from_pretrained(model_path, use_auth_token=HF_TOKEN) |
model = DebertaV2ForSequenceClassification.from_pretrained(model_path, use_auth_token=HF_TOKEN) |
1: 0, 8: 1, 111: 2, 141: 3, 166: 4, 222: 5, 412: 6, 436: 7, |
469: 8, 536: 9, 537: 10, 632: 11, 772: 12, 783: 13, 888: 14, |
922: 15, 988: 16, 1239: 17, 2092: 18, 5181: 19, 5605: 20 |
} |
1: "Animals & Pet Supplies", |
8: "Arts & Entertainment", |
111: "Business & Industrial", |
141: "Cameras & Optics", |
166: "Apparel & Accessories", |
222: "Electronics", |
412: "Food, Beverages & Tobacco", |
436: "Furniture", |
469: "Health & Beauty", |
536: "Home & Garden", |
537: "Baby & Toddler", |
632: "Hardware", |
772: "Mature", |
783: "Media", |
888: "Vehicles & Parts", |
922: "Office Supplies", |
988: "Sporting Goods", |
1239: "Toys & Games", |
2092: "Software", |
5181: "Luggage & Bags", |
5605: "Religious & Ceremonial" |
} |
INDEX_TO_CATEGORY = {v: f"[{k}] {CATEGORY_NAMES[k]}" for k, v in LABEL_MAPPING.items()} |
st.title("Google Taxonomy Classifier by DEJAN") |
st.write("Enter text in the input box, and the model will classify it into one of the 21 top level categories. This demo showcases early model capability while the full 5000+ label model is undergoing extensive training.") |
st.write("Works for product descriptions, search queries, articles, social media posts and broadly web text of any style. Suitable for classification pipelines of millions of queries.") |
input_text = st.text_area("Enter text for classification:") |
def classify_text(text): |
if not text.strip(): |
return None |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) |
with torch.no_grad(): |
outputs = model(**inputs) |
logits = outputs.logits |
probabilities = F.softmax(logits, dim=-1).squeeze().tolist() |
return probabilities |
if st.button("Classify"): |
if input_text.strip(): |
st.write("Processing...") |
probabilities = classify_text(input_text) |
if probabilities: |
mapped_probs = {INDEX_TO_CATEGORY[idx]: prob for idx, prob in enumerate(probabilities)} |
sorted_categories = sorted(mapped_probs.items(), key=lambda x: x[1], reverse=True) |
categories = [item[0] for item in sorted_categories] |
values = [item[1] for item in sorted_categories] |
fig, ax = plt.subplots(figsize=(10, 6)) |
ax.barh(categories, values) |
ax.set_xlabel("Probability") |
ax.set_ylabel("Category") |
ax.set_title("Classification Probabilities") |
ax.invert_yaxis() |
ax.set_xlim(0, 1) |
st.pyplot(fig) |
st.divider() |
st.markdown(""" |
Interested in using this in an automated pipeline for bulk link prediction? |
Please [book an appointment](https://dejanmarketing.com/conference/) to discuss your needs. |
""") |
else: |
st.error("Could not classify the text. Please try again.") |
else: |
st.warning("Please enter some text for classification.") |