Spaces:
Build error
Build error
import torch | |
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification | |
def load_model(model_path, device): | |
model = DistilBertForSequenceClassification.from_pretrained(model_path) | |
model.to(device) | |
model.eval() | |
return model | |
def run_inference(model, tokenizer, label_decoder, device, user_input): | |
model.eval() # Set the model to evaluation mode | |
# user_input = input("Enter a text for prediction: ") | |
# Tokenize user input | |
input_ids = tokenizer.encode(user_input, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = model(input_ids) | |
predicted_label = torch.argmax(outputs.logits, dim=1).tolist() | |
# Extracting the text and predicted outcome | |
input_text = tokenizer.decode(input_ids[0], skip_special_tokens=True) | |
predicted_outcome = label_decoder[predicted_label[0]] | |
# Display the results | |
print(f"Text: {input_text}") | |
print(f"Predicted Outcome: {predicted_outcome}") | |
print() | |
return predicted_outcome # Add a new line for better readability | |
# Example usage | |
model_path = "/home/lwasinam/AI_Projects/hate_speech_detection/model6" # Replace with the actual path to your model | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") # Replace with your desired tokenizer | |
# Load model | |
model = load_model(model_path, device) | |
label_decoder = {0: "Not Hate", 1: "Hate",} | |
# Assuming you have label_decoder defined | |
import streamlit as st | |
st.title("Hate Speech Detection") | |
user_input = st.text_input("Enter your text:") | |
if user_input: | |
result = run_inference(model, tokenizer, label_decoder, device, user_input) | |
st.write("Inference Result:", result) |