dejanseo commited on
Commit
bf7efe3
·
verified ·
1 Parent(s): e5ccff9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -105
app.py CHANGED
@@ -1,105 +1,105 @@
1
- import streamlit as st
2
- import trafilatura
3
- import numpy as np
4
- import pandas as pd
5
- from tflite_runtime.interpreter import Interpreter
6
- import requests
7
-
8
- # File paths
9
- MODEL_PATH = "./model.tflite"
10
- VOCAB_PATH = "./vocab.txt"
11
- LABELS_PATH = "./taxonomy_v2.csv"
12
-
13
- @st.cache_resource
14
- def load_vocab():
15
- with open(VOCAB_PATH, 'r') as f:
16
- vocab = [line.strip() for line in f]
17
- return vocab
18
-
19
- @st.cache_resource
20
- def load_labels():
21
- # Load labels from the CSV file
22
- taxonomy = pd.read_csv(LABELS_PATH)
23
- taxonomy["ID"] = taxonomy["ID"].astype(int)
24
- labels_dict = taxonomy.set_index("ID")["Topic"].to_dict()
25
- return labels_dict
26
-
27
- @st.cache_resource
28
- def load_model():
29
- try:
30
- # Use TensorFlow Lite Interpreter
31
- interpreter = Interpreter(model_path=MODEL_PATH)
32
- interpreter.allocate_tensors()
33
- input_details = interpreter.get_input_details()
34
- output_details = interpreter.get_output_details()
35
- return interpreter, input_details, output_details
36
- except Exception as e:
37
- st.error(f"Failed to load the model: {e}")
38
- raise
39
-
40
- def preprocess_text(text, vocab, max_length=128):
41
- # Tokenize the text using the provided vocabulary
42
- words = text.split()[:max_length] # Split and truncate
43
- token_ids = [vocab.index(word) if word in vocab else vocab.index("[UNK]") for word in words]
44
- token_ids = np.array(token_ids + [0] * (max_length - len(token_ids)), dtype=np.int32) # Pad to max length
45
- attention_mask = np.array([1 if i < len(words) else 0 for i in range(max_length)], dtype=np.int32)
46
- token_type_ids = np.zeros_like(attention_mask, dtype=np.int32)
47
- return token_ids[np.newaxis, :], attention_mask[np.newaxis, :], token_type_ids[np.newaxis, :]
48
-
49
- def classify_text(interpreter, input_details, output_details, input_word_ids, input_mask, input_type_ids):
50
- interpreter.set_tensor(input_details[0]["index"], input_word_ids)
51
- interpreter.set_tensor(input_details[1]["index"], input_mask)
52
- interpreter.set_tensor(input_details[2]["index"], input_type_ids)
53
- interpreter.invoke()
54
- output = interpreter.get_tensor(output_details[0]["index"])
55
- return output[0]
56
-
57
- def fetch_url_content(url):
58
- headers = {
59
- "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/116.0.0.0 Safari/537.36",
60
- "Accept-Language": "en-US,en;q=0.9",
61
- "Accept-Encoding": "gzip, deflate, br",
62
- }
63
- try:
64
- response = requests.get(url, headers=headers, cookies={}, timeout=10)
65
- if response.status_code == 200:
66
- return response.text
67
- else:
68
- st.error(f"Failed to fetch content. Status code: {response.status_code}")
69
- return None
70
- except Exception as e:
71
- st.error(f"Error fetching content: {e}")
72
- return None
73
-
74
- # Streamlit app
75
- st.title("Topic Classification from URL")
76
-
77
- url = st.text_input("Enter a URL:", "")
78
- if url:
79
- st.write("Extracting content from the URL...")
80
- raw_content = fetch_url_content(url)
81
- if raw_content:
82
- content = trafilatura.extract(raw_content)
83
- if content:
84
- st.write("Content extracted successfully!")
85
- st.write(content[:500]) # Display a snippet of the content
86
-
87
- # Load resources
88
- vocab = load_vocab()
89
- labels_dict = load_labels()
90
- interpreter, input_details, output_details = load_model()
91
-
92
- # Preprocess content and classify
93
- input_word_ids, input_mask, input_type_ids = preprocess_text(content, vocab)
94
- predictions = classify_text(interpreter, input_details, output_details, input_word_ids, input_mask, input_type_ids)
95
-
96
- # Display classification
97
- st.write("Topic Classification:")
98
- sorted_indices = np.argsort(predictions)[::-1][:5] # Top 5 topics
99
- for idx in sorted_indices:
100
- topic = labels_dict.get(idx, "Unknown Topic")
101
- st.write(f"ID: {idx} - Topic: {topic} - Score: {predictions[idx]:.4f}")
102
- else:
103
- st.error("Unable to extract content from the fetched HTML.")
104
- else:
105
- st.error("Failed to fetch the URL.")
 
1
+ import streamlit as st
2
+ import trafilatura
3
+ import numpy as np
4
+ import pandas as pd
5
+ from tensorflow.lite.python.interpreter import Interpreter
6
+ import requests
7
+
8
+ # File paths
9
+ MODEL_PATH = "./model.tflite"
10
+ VOCAB_PATH = "./vocab.txt"
11
+ LABELS_PATH = "./taxonomy_v2.csv"
12
+
13
+ @st.cache_resource
14
+ def load_vocab():
15
+ with open(VOCAB_PATH, 'r') as f:
16
+ vocab = [line.strip() for line in f]
17
+ return vocab
18
+
19
+ @st.cache_resource
20
+ def load_labels():
21
+ # Load labels from the CSV file
22
+ taxonomy = pd.read_csv(LABELS_PATH)
23
+ taxonomy["ID"] = taxonomy["ID"].astype(int)
24
+ labels_dict = taxonomy.set_index("ID")["Topic"].to_dict()
25
+ return labels_dict
26
+
27
+ @st.cache_resource
28
+ def load_model():
29
+ try:
30
+ # Use TensorFlow Lite Interpreter
31
+ interpreter = Interpreter(model_path=MODEL_PATH)
32
+ interpreter.allocate_tensors()
33
+ input_details = interpreter.get_input_details()
34
+ output_details = interpreter.get_output_details()
35
+ return interpreter, input_details, output_details
36
+ except Exception as e:
37
+ st.error(f"Failed to load the model: {e}")
38
+ raise
39
+
40
+ def preprocess_text(text, vocab, max_length=128):
41
+ # Tokenize the text using the provided vocabulary
42
+ words = text.split()[:max_length] # Split and truncate
43
+ token_ids = [vocab.index(word) if word in vocab else vocab.index("[UNK]") for word in words]
44
+ token_ids = np.array(token_ids + [0] * (max_length - len(token_ids)), dtype=np.int32) # Pad to max length
45
+ attention_mask = np.array([1 if i < len(words) else 0 for i in range(max_length)], dtype=np.int32)
46
+ token_type_ids = np.zeros_like(attention_mask, dtype=np.int32)
47
+ return token_ids[np.newaxis, :], attention_mask[np.newaxis, :], token_type_ids[np.newaxis, :]
48
+
49
+ def classify_text(interpreter, input_details, output_details, input_word_ids, input_mask, input_type_ids):
50
+ interpreter.set_tensor(input_details[0]["index"], input_word_ids)
51
+ interpreter.set_tensor(input_details[1]["index"], input_mask)
52
+ interpreter.set_tensor(input_details[2]["index"], input_type_ids)
53
+ interpreter.invoke()
54
+ output = interpreter.get_tensor(output_details[0]["index"])
55
+ return output[0]
56
+
57
+ def fetch_url_content(url):
58
+ headers = {
59
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/116.0.0.0 Safari/537.36",
60
+ "Accept-Language": "en-US,en;q=0.9",
61
+ "Accept-Encoding": "gzip, deflate, br",
62
+ }
63
+ try:
64
+ response = requests.get(url, headers=headers, cookies={}, timeout=10)
65
+ if response.status_code == 200:
66
+ return response.text
67
+ else:
68
+ st.error(f"Failed to fetch content. Status code: {response.status_code}")
69
+ return None
70
+ except Exception as e:
71
+ st.error(f"Error fetching content: {e}")
72
+ return None
73
+
74
+ # Streamlit app
75
+ st.title("Topic Classification from URL")
76
+
77
+ url = st.text_input("Enter a URL:", "")
78
+ if url:
79
+ st.write("Extracting content from the URL...")
80
+ raw_content = fetch_url_content(url)
81
+ if raw_content:
82
+ content = trafilatura.extract(raw_content)
83
+ if content:
84
+ st.write("Content extracted successfully!")
85
+ st.write(content[:500]) # Display a snippet of the content
86
+
87
+ # Load resources
88
+ vocab = load_vocab()
89
+ labels_dict = load_labels()
90
+ interpreter, input_details, output_details = load_model()
91
+
92
+ # Preprocess content and classify
93
+ input_word_ids, input_mask, input_type_ids = preprocess_text(content, vocab)
94
+ predictions = classify_text(interpreter, input_details, output_details, input_word_ids, input_mask, input_type_ids)
95
+
96
+ # Display classification
97
+ st.write("Topic Classification:")
98
+ sorted_indices = np.argsort(predictions)[::-1][:5] # Top 5 topics
99
+ for idx in sorted_indices:
100
+ topic = labels_dict.get(idx, "Unknown Topic")
101
+ st.write(f"ID: {idx} - Topic: {topic} - Score: {predictions[idx]:.4f}")
102
+ else:
103
+ st.error("Unable to extract content from the fetched HTML.")
104
+ else:
105
+ st.error("Failed to fetch the URL.")