Spaces:
Runtime error
Runtime error
import gradio as gr | |
import tensorflow as tf | |
from transformers import AutoTokenizer | |
# Load the model | |
from huggingface_hub import hf_hub_download | |
model_path = hf_hub_download(repo_id="anand6572r/my-keras-model11", filename="trained_model.keras") | |
model = tf.keras.models.load_model(model_path) | |
# Load a Hugging Face tokenizer (use a compatible model's tokenizer) | |
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") | |
# Define preprocessing function | |
def preprocess_input(query): | |
""" | |
Tokenizes and formats the input text for the model using Hugging Face's tokenizer. | |
""" | |
inputs = tokenizer( | |
query, | |
padding="max_length", | |
max_length=100, | |
truncation=True, | |
return_tensors="tf", | |
) | |
return inputs["input_ids"] | |
# Define postprocessing function | |
def postprocess_output(prediction): | |
""" | |
Converts model predictions into a user-friendly response. | |
""" | |
if prediction[0] > 0.5: # Adjust threshold if necessary | |
return "Yes, this relates to a Wikipedia article." | |
else: | |
return "No, this does not relate to a Wikipedia article." | |
# Define prediction function | |
def predict(query): | |
""" | |
Predicts whether the query relates to a Wikipedia article. | |
""" | |
# Preprocess the input query | |
input_data = preprocess_input(query) | |
# Get prediction from the model | |
prediction = model.predict(input_data) | |
# Postprocess the prediction | |
response = postprocess_output(prediction[0]) | |
return response | |
# Gradio interface | |
interface = gr.Interface( | |
fn=predict, | |
inputs=gr.Textbox(lines=2, placeholder="Enter your query here..."), | |
outputs=gr.Textbox(label="Response"), | |
title="Wikipedia Article Query Predictor", | |
description="This model predicts whether a query relates to a Wikipedia article.", | |
) | |
# Launch the Gradio app | |
if __name__ == "__main__": | |
interface.launch() | |