Spaces:
Sleeping
Sleeping
import gradio as gr | |
from huggingface_hub import InferenceClient | |
from datasets import load_dataset | |
import torch | |
from transformers import pipeline | |
class ContentFilter: | |
def __init__(self): | |
# Initialize toxic content detection model | |
self.toxicity_classifier = pipeline( | |
'text-classification', | |
model='unitary/toxic-bert', | |
return_all_scores=True | |
) | |
# Keyword blacklist | |
self.blacklist = [ | |
'hate', 'discriminate', 'violent', | |
'offensive', 'inappropriate', 'racist', | |
'sexist', 'homophobic', 'transphobic' | |
] | |
def filter_toxicity(self, text, toxicity_threshold=0.5): | |
""" | |
Detect toxic content using pre-trained model | |
Args: | |
text (str): Input text to check | |
toxicity_threshold (float): Threshold for filtering | |
Returns: | |
dict: Filtering results | |
""" | |
results = self.toxicity_classifier(text)[0] | |
# Convert results to dictionary | |
toxicity_scores = { | |
result['label']: result['score'] | |
for result in results | |
} | |
# Check if any toxic category exceeds threshold | |
is_toxic = any( | |
score > toxicity_threshold | |
for score in toxicity_scores.values() | |
) | |
return { | |
'is_toxic': is_toxic, | |
'toxicity_scores': toxicity_scores | |
} | |
def filter_keywords(self, text): | |
""" | |
Check text against keyword blacklist | |
Args: | |
text (str): Input text to check | |
Returns: | |
list: Matched blacklisted keywords | |
""" | |
matched_keywords = [ | |
keyword for keyword in self.blacklist | |
if keyword.lower() in text.lower() | |
] | |
return matched_keywords | |
def comprehensive_filter(self, text): | |
""" | |
Perform comprehensive content filtering | |
Args: | |
text (str): Input text to filter | |
Returns: | |
dict: Comprehensive filtering results | |
""" | |
# Toxicity model filtering | |
toxicity_results = self.filter_toxicity(text) | |
# Keyword blacklist filtering | |
blacklisted_keywords = self.filter_keywords(text) | |
# Combine results | |
return { | |
'toxicity': toxicity_results, | |
'blacklisted_keywords': blacklisted_keywords, | |
'is_safe': not toxicity_results['is_toxic'] and len(blacklisted_keywords) == 0 | |
} | |
# Initialize content filter | |
content_filter = ContentFilter() | |
# Initialize Hugging Face client | |
#client = InferenceClient("HuggingFaceH4/zephyr-7b-beta") | |
client = InferenceClient("google-t5/t5-small") | |
# Load dataset (optional) | |
dataset = load_dataset("JustKiddo/KiddosVault") | |
def respond( | |
message, | |
history: list[tuple[str, str]], | |
system_message, | |
max_tokens, | |
temperature, | |
top_p | |
): | |
# First, filter the incoming user message | |
message_filter_result = content_filter.comprehensive_filter(message) | |
# If message is not safe, return a warning | |
if not message_filter_result['is_safe']: | |
toxicity_details = message_filter_result['toxicity']['toxicity_scores'] | |
blacklisted_keywords = message_filter_result['blacklisted_keywords'] | |
warning_message = "Message flagged for inappropriate content. " | |
warning_message += "Detected issues: " | |
# Add toxicity details | |
for category, score in toxicity_details.items(): | |
if score > 0.5: | |
warning_message += f"{category} (Score: {score:.2f}), " | |
# Add blacklisted keywords | |
if blacklisted_keywords: | |
warning_message += f"Blacklisted keywords: {', '.join(blacklisted_keywords)}" | |
return warning_message | |
# Prepare messages for chat completion | |
messages = [{"role": "system", "content": system_message}] | |
for val in history: | |
if val[0]: | |
messages.append({"role": "user", "content": val[0]}) | |
if val[1]: | |
messages.append({"role": "assistant", "content": val[1]}) | |
messages.append({"role": "user", "content": message}) | |
# Generate response | |
response = "" | |
for message in client.chat_completion( | |
messages, | |
max_tokens=max_tokens, | |
stream=True, | |
temperature=temperature, | |
top_p=top_p | |
): | |
token = message.choices[0].delta.content | |
response += token | |
yield response | |
# Create Gradio interface | |
demo = gr.ChatInterface( | |
respond, | |
additional_inputs=[ | |
gr.Textbox( | |
value="You are a professional and friendly assistant.", | |
label="System message" | |
), | |
gr.Slider( | |
minimum=1, | |
maximum=6144, | |
value=6144, | |
step=1, | |
label="Max new tokens" | |
), | |
gr.Slider( | |
minimum=0.1, | |
maximum=4.0, | |
value=1, | |
step=0.1, | |
label="Temperature" | |
), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p (nucleus sampling)" | |
), | |
] | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True) |