newmodel / app.py
anand6572r's picture
update app.py
49a1478 verified
raw
history blame
1.9 kB
import gradio as gr
import tensorflow as tf
from transformers import AutoTokenizer
# Load the model
from huggingface_hub import hf_hub_download
model_path = hf_hub_download(repo_id="anand6572r/my-keras-model11", filename="trained_model.keras")
model = tf.keras.models.load_model(model_path)
# Load a Hugging Face tokenizer (use a compatible model's tokenizer)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# Define preprocessing function
def preprocess_input(query):
"""
Tokenizes and formats the input text for the model using Hugging Face's tokenizer.
"""
inputs = tokenizer(
query,
padding="max_length",
max_length=100,
truncation=True,
return_tensors="tf",
)
return inputs["input_ids"]
# Define postprocessing function
def postprocess_output(prediction):
"""
Converts model predictions into a user-friendly response.
"""
if prediction[0] > 0.5: # Adjust threshold if necessary
return "Yes, this relates to a Wikipedia article."
else:
return "No, this does not relate to a Wikipedia article."
# Define prediction function
def predict(query):
"""
Predicts whether the query relates to a Wikipedia article.
"""
# Preprocess the input query
input_data = preprocess_input(query)
# Get prediction from the model
prediction = model.predict(input_data)
# Postprocess the prediction
response = postprocess_output(prediction[0])
return response
# Gradio interface
interface = gr.Interface(
fn=predict,
inputs=gr.Textbox(lines=2, placeholder="Enter your query here..."),
outputs=gr.Textbox(label="Response"),
title="Wikipedia Article Query Predictor",
description="This model predicts whether a query relates to a Wikipedia article.",
)
# Launch the Gradio app
if __name__ == "__main__":
interface.launch()