File size: 4,250 Bytes
0412fd5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f60c09b
 
 
 
 
 
 
 
 
0412fd5
f60c09b
 
 
 
0412fd5
f60c09b
 
 
 
0412fd5
 
f60c09b
0412fd5
 
 
 
a975033
 
f39e1d1
0412fd5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f39e1d1
 
 
 
 
 
 
 
0412fd5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
from transformers import DebertaV2Tokenizer, DebertaV2ForSequenceClassification
import streamlit as st
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

# Load Hugging Face token from environment
HF_TOKEN = os.getenv("HF_TOKEN")
model_path = "dejanseo/DEJAN-Taxonomy-Classifier"

# Load the model and tokenizer using the token
tokenizer = DebertaV2Tokenizer.from_pretrained(model_path, use_auth_token=HF_TOKEN)
model = DebertaV2ForSequenceClassification.from_pretrained(model_path, use_auth_token=HF_TOKEN)

# LABEL_MAPPING (from model index to numeric ID) and corresponding category names
LABEL_MAPPING = {
    1: 0, 8: 1, 111: 2, 141: 3, 166: 4, 222: 5, 412: 6, 436: 7,
    469: 8, 536: 9, 537: 10, 632: 11, 772: 12, 783: 13, 888: 14,
    922: 15, 988: 16, 1239: 17, 2092: 18, 5181: 19, 5605: 20
}

CATEGORY_NAMES = {
    1: "Animals & Pet Supplies",
    8: "Arts & Entertainment",
    111: "Business & Industrial",
    141: "Cameras & Optics",
    166: "Apparel & Accessories",
    222: "Electronics",
    412: "Food, Beverages & Tobacco",
    436: "Furniture",
    469: "Health & Beauty",
    536: "Home & Garden",
    537: "Baby & Toddler",
    632: "Hardware",
    772: "Mature",
    783: "Media",
    888: "Vehicles & Parts",
    922: "Office Supplies",
    988: "Sporting Goods",
    1239: "Toys & Games",
    2092: "Software",
    5181: "Luggage & Bags",
    5605: "Religious & Ceremonial"
}


# Reverse mapping for model output index to text label
INDEX_TO_CATEGORY = {v: f"[{k}] {CATEGORY_NAMES[k]}" for k, v in LABEL_MAPPING.items()}

# Set Streamlit app title
st.title("Google Taxonomy Classifier by DEJAN")
st.write("Enter text in the input box, and the model will classify it into one of the 21 top level categories. This demo showcases early model capability while the full 5000+ label model is undergoing extensive training.")
st.write("Works for product descriptions, search queries, articles, social media posts and broadly web text of any style. Suitable for classification pipelines of millions of queries.")

# Input text box
input_text = st.text_area("Enter text for classification:")

# Inference function
def classify_text(text):
    if not text.strip():
        return None
    # Tokenize and encode input text
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
    # Get model predictions
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
    # Convert logits to probabilities using softmax
    probabilities = F.softmax(logits, dim=-1).squeeze().tolist()
    return probabilities

# Display results when text is entered
if st.button("Classify"):
    if input_text.strip():
        st.write("Processing...")
        # Classify the input text
        probabilities = classify_text(input_text)
        if probabilities:
            # Map probabilities to categories
            mapped_probs = {INDEX_TO_CATEGORY[idx]: prob for idx, prob in enumerate(probabilities)}
            # Sort categories by probability in descending order
            sorted_categories = sorted(mapped_probs.items(), key=lambda x: x[1], reverse=True)
            categories = [item[0] for item in sorted_categories]
            values = [item[1] for item in sorted_categories]
            
            # Create horizontal bar chart
            fig, ax = plt.subplots(figsize=(10, 6))
            ax.barh(categories, values)
            ax.set_xlabel("Probability")
            ax.set_ylabel("Category")
            ax.set_title("Classification Probabilities")
            ax.invert_yaxis()  # Ensure highest probability is at the top
            ax.set_xlim(0, 1)  # Set the x-axis range to 0-1 for probabilities
            st.pyplot(fig)

            # Additional information at the end
            st.divider()
            st.markdown("""
            Interested in using this in an automated pipeline for bulk link prediction?
            Please [book an appointment](https://dejanmarketing.com/conference/) to discuss your needs.
            """)
            
        else:
            st.error("Could not classify the text. Please try again.")
    else:
        st.warning("Please enter some text for classification.")