ONNX model - a fine tuned version of DistilBERT which can be used to classify text as one of:
- neutral, offensive_language, harmful_behaviour, hate_speech
The model was trained using the csfy tool and the dataset seanius/toxic-or-neutral-text-labelled
The base model is required (distilbert-base-uncased)
For an example of how to run the model, see below - or see the csfy tool.
The output is a number indicating the class - it is decoded via the label_mapping.json file.
Usage
# Loading the label mappings
import json
def load_label_mappings():
with open("./label_mapping.json", encoding="utf-8") as f:
data = json.load(f)
return data['labels']
label_mappings = load_label_mappings()
# Loading the model
import onnxruntime as ort
from transformers import DistilBertTokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
ort_session = ort.InferenceSession("./toxic-or-neutral-text-labelled.onnx")
# Predicting label for given text
def predict_via_onnx(text, ort_session, tokenizer, label_mappings):
model_expected_input_shape = ort_session.get_inputs()[0].shape
print("Model expects input shape:", model_expected_input_shape)
inputs = tokenizer(text, return_tensors="np", padding="max_length", truncation=True, max_length=model_expected_input_shape[1])
print("input shape", inputs['input_ids'].shape)
input_ids = inputs['input_ids']
if input_ids.ndim == 1:
input_ids = input_ids[np.newaxis, :]
ort_inputs = {ort_session.get_inputs()[0].name: input_ids}
ort_inputs['input_ids'] = ort_inputs['input_ids'].astype(np.int64)
ort_outputs = ort_session.run(None, ort_inputs)
predictions = np.argmax(ort_outputs, axis=-1)
predicted_label = label_mappings[predictions.item()]
return predicted_label
predicted_label = predict_via_onnx("How do I get to the beach?", ort_session, tokenizer, label_mappings)
print(predicted_label)
license: mit
Model tree for seanius/toxic-or-neutral-text-labelled
Base model
distilbert/distilbert-base-uncased