Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
import tensorflow as tf | |
import nltk | |
import spacy | |
import re | |
from nltk.corpus import stopwords | |
from nltk.tokenize import word_tokenize | |
from tensorflow.keras.preprocessing.sequence import pad_sequences | |
import requests | |
import pickle | |
# Download necessary resources | |
import spacy.cli | |
spacy.cli.download("en_core_web_sm") | |
nltk.download('punkt_tab') | |
nltk.download('stopwords') | |
stop_words = set(stopwords.words('english')) | |
nlp = spacy.load('en_core_web_sm') | |
# Download the model file from Hugging Face | |
model_url = "https://huggingface.co./Zmorell/HIPA_2/resolve/main/saved_keras_model.keras" | |
local_model_path = "saved_keras_model.keras" | |
response = requests.get(model_url) | |
with open(local_model_path, 'wb') as f: | |
f.write(response.content) | |
print(f"Model downloaded to {local_model_path}") | |
# Load the downloaded model | |
model = tf.keras.models.load_model(local_model_path) | |
print(f"Model loaded from {local_model_path}") | |
# Load the tokenizer | |
tokenizer_file_path = "tokenizer.pickle" | |
with open(tokenizer_file_path, 'rb') as handle: | |
tokenizer = pickle.load(handle) | |
print("Tokenizer loaded from tokenizer.pickle") | |
def preprocess_text(text): | |
text = re.sub(r'[^a-zA-Z0-9\s]', '', text) | |
tokens = word_tokenize(text.lower()) | |
tokens = [word for word in tokens if word not in stop_words] | |
doc = nlp(' '.join(tokens)) | |
lemmas = [token.lemma_ for token in doc] | |
return ' '.join(lemmas) | |
def predict(text): | |
try: | |
print(f"Input text: {text}") | |
inputs = preprocess_text(text) | |
print(f"Preprocessed text: {inputs}") | |
inputs = tokenizer.texts_to_sequences([inputs]) | |
print(f"Tokenized text: {inputs}") | |
inputs = pad_sequences(inputs, maxlen=750, padding='post') | |
print(f"Padded text: {inputs}") | |
outputs = model.predict(inputs) | |
print(f"Model outputs: {outputs}") | |
# Interpret the output as a prediction | |
prediction = outputs[0][0] | |
if prediction >= 0.5: | |
result = f"True = {prediction:.2f}" | |
else: | |
result = f"False = {prediction:.2f}" | |
return result | |
except Exception as e: | |
print(f"Error during prediction: {e}") | |
return f"Error during prediction: {e}" | |
ui_css = """ | |
#body { | |
height: 700px; | |
width: 500px; | |
background-color: rgb(108, 207, 239); | |
border-radius: 15px; | |
} | |
#hipaa-image { | |
width: 75px; | |
} | |
#input-box { | |
width: 480px; | |
border: 2px solid black; | |
margin-left: 8px; | |
margin-right: 8px; | |
overflow-y: scroll; | |
height: 150px; | |
max-height: 150px; | |
} | |
#output-elems { | |
width: 480px; | |
border: 2px solid black; | |
margin-left: 8px; | |
margin-right: 8px; | |
padding: 1em; | |
} | |
#submit-button, #clear-button { | |
color: white; | |
height: 45px; | |
width: 60px; | |
margin: 10px; | |
border-radius: 5px; | |
border: 5px solid black; | |
} | |
#submit-button { | |
background-color: red; | |
} | |
#clear-button { | |
background-color: grey; | |
} | |
#addinfo { | |
font-size: 16; | |
justify-self: center; | |
} | |
""" | |
# Set up the Gradio interface | |
with gr.Blocks(css=ui_css) as demo: | |
with gr.Column(elem_id="body"): | |
with gr.Row(elem_id="header"): | |
with gr.Row(elem_id="hipaa-image"): | |
gr.Image(value="hipaa-e1638383751916.png") | |
with gr.Row(): | |
gr.Markdown("Enter text below to determine if it is a HIPAA violation. Smaller inputs may be less accurate.", elem_id="addinfo") | |
with gr.Row(elem_id="interactives"): | |
inputs=gr.Textbox(label="Enter Input Text Here", elem_id="input-box", lines=5) | |
with gr.Row(elem_id="output-elems"): | |
gr.Markdown("This text is a violation: ") | |
outputs=gr.Textbox(label="", elem_id="output-box", interactive=False) | |
with gr.Row(): | |
submit_button = gr.Button("Submit", elem_id="submit-button") | |
clear_button = gr.Button("Clear", elem_id="clear-button") | |
submit_button.click(predict, inputs=inputs, outputs=outputs) | |
clear_button.click(lambda: ("", ""), inputs=None, outputs=[inputs, outputs]) | |
demo.launch() |