distilbert-base-q-cat

Model Description

distilbert-base-q-cat is a lightweight, fine-tuned DistilBERT model designed for text classification, specifically focusing on categorizing questions into three distinct categories: fact, opinion, and hypothetical. The model was trained on a Quora dataset, leveraging keyword-based labeling and sentiment analysis to ensure high-quality categorization.

Features

Built on DistilBERT, ensuring faster inference and lower computational requirements compared to standard BERT.

Three Class Categories:

  • Fact: Questions seeking factual or objective information.
  • Opinion: Questions that elicit subjective views or opinions.
  • Hypothetical: Questions exploring hypothetical scenarios or speculative ideas.

Pretrained and Fine-Tuned: Utilizes DistilBERT’s pretrained weights with additional fine-tuning on labeled data.

Dataset

The model was trained using a custom dataset derived from Quora questions:

Data Preparation:

  • Labeling involved keyword-based rules for fact and hypothetical questions.

  • Sentiment analysis determined questions as opinion-based.

Dataset Size: ~50k samples, split into training, validation, and test sets.

Performance

The model achieves the following metrics on the validation set:

  • Accuracy: 93.33%
  • Precision: 93.41%
  • Recall: 93.33%
  • F1-Score: 93.32%

Installation

To use this model, install the required dependencies:

pip install transformers torch

Usage

Load Model and Tokenizer

from transformers import AutoTokenizer, AutoModelForSequenceClassification

# Load model and tokenizer
model_name = "distilbert-base-q-cat"
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3, ignore_mismatched_sizes=True)
tokenizer = AutoTokenizer.from_pretrained(model_name)

Inference Example

def predict_question(question):
    inputs = tokenizer(question, return_tensors="pt", truncation=True, padding=True)
    outputs = model(**inputs)
    logits = outputs.logits
    predicted_class = logits.argmax(dim=-1).item()

    label_map = {0: "fact", 1: "opinion", 2: "hypothetical"}
    return label_map[predicted_class]

# Example usage
question = "What is artificial intelligence?"
print(predict_question(question))
Downloads last month
16
Safetensors
Model size
67M params
Tensor type
F32
·
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.

Model tree for alwanadi17/distilbert-base-q-cat

Finetuned
(7834)
this model

Dataset used to train alwanadi17/distilbert-base-q-cat