Zmorell commited on
Commit
8f1c4bc
·
verified ·
1 Parent(s): 1eee92e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -14
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import gradio as gr
2
- from tensorflow import keras
3
  import pandas as pd
4
  import tensorflow as tf
5
  import nltk
@@ -10,6 +9,7 @@ from nltk.tokenize import word_tokenize
10
  from tensorflow.keras.preprocessing.text import Tokenizer
11
  from tensorflow.keras.preprocessing.sequence import pad_sequences
12
 
 
13
  import spacy.cli
14
  spacy.cli.download("en_core_web_sm")
15
  nltk.download('punkt_tab')
@@ -17,17 +17,20 @@ nltk.download('stopwords')
17
  stop_words = set(stopwords.words('english'))
18
  nlp = spacy.load('en_core_web_sm')
19
 
20
- # Available backend options are: "jax", "torch", "tensorflow".
21
- import os
22
- os.environ["KERAS_BACKEND"] = "jax"
 
23
 
24
- # Ensure the necessary libraries are correctly imported
25
- import keras
 
26
 
27
- # Load the model from the Hugging Face repository
28
- model_path = "https://huggingface.co/Zmorell/HIPA_2/resolve/main/saved_keras_model.keras"
29
- model = tf.keras.models.load_model(model_path)
30
- print(f"Model loaded from {model_path}")
 
31
 
32
  def preprocess_text(text):
33
  text = re.sub(r'[^a-zA-Z0-9\s]', '', text) # Only remove non-alphanumeric characters except spaces
@@ -44,9 +47,11 @@ def preprocess_text(text):
44
  def predict(text):
45
  inputs = preprocess_text(text)
46
  # Ensure the input shape matches what the model expects
47
- inputs = tf.convert_to_tensor([inputs])
48
- outputs = model(inputs)
49
- return "This text is a violation = " + str(outputs[0][0].numpy())
 
50
 
 
51
  demo = gr.Interface(fn=predict, inputs="text", outputs="text")
52
- demo.launch()
 
1
  import gradio as gr
 
2
  import pandas as pd
3
  import tensorflow as tf
4
  import nltk
 
9
  from tensorflow.keras.preprocessing.text import Tokenizer
10
  from tensorflow.keras.preprocessing.sequence import pad_sequences
11
 
12
+ # Download and load necessary resources
13
  import spacy.cli
14
  spacy.cli.download("en_core_web_sm")
15
  nltk.download('punkt_tab')
 
17
  stop_words = set(stopwords.words('english'))
18
  nlp = spacy.load('en_core_web_sm')
19
 
20
+ # Download the model file from Hugging Face
21
+ import requests
22
+ model_url = "https://huggingface.co/Zmorell/HIPA_2/resolve/main/saved_keras_model.keras"
23
+ local_model_path = "saved_keras_model.keras"
24
 
25
+ response = requests.get(model_url)
26
+ with open(local_model_path, 'wb') as f:
27
+ f.write(response.content)
28
 
29
+ print(f"Model downloaded to {local_model_path}")
30
+
31
+ # Load the downloaded model
32
+ model = tf.keras.models.load_model(local_model_path)
33
+ print(f"Model loaded from {local_model_path}")
34
 
35
  def preprocess_text(text):
36
  text = re.sub(r'[^a-zA-Z0-9\s]', '', text) # Only remove non-alphanumeric characters except spaces
 
47
  def predict(text):
48
  inputs = preprocess_text(text)
49
  # Ensure the input shape matches what the model expects
50
+ inputs = tokenizer.texts_to_sequences([inputs])
51
+ inputs = pad_sequences(inputs, maxlen=1000, padding='post')
52
+ outputs = model.predict(inputs)
53
+ return f"This text is a violation = {outputs[0][0]:.2f}"
54
 
55
+ # Set up the Gradio interface
56
  demo = gr.Interface(fn=predict, inputs="text", outputs="text")
57
+ demo.launch()