dejanseo commited on
Commit
99d1428
·
verified ·
1 Parent(s): 91aa528

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +166 -0
app.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import AlbertTokenizer, AlbertForSequenceClassification, AlbertModel
4
+ import numpy as np
5
+ import pandas as pd
6
+ import os
7
+ from torch.nn.functional import softmax
8
+ import torch.nn as nn
9
+
10
+ # Paths
11
+ LEVEL_DIRS = {
12
+ 1: 'level1',
13
+ 2: 'level2',
14
+ 3: 'level3',
15
+ 4: 'level4',
16
+ 5: 'level5',
17
+ 6: 'level6',
18
+ 7: 'level7'
19
+ }
20
+ MAPPING_FILE = 'mapping.csv'
21
+ MODEL_NAME = 'albert/albert-base-v2' # Define the base model name
22
+
23
+ # Load mapping
24
+ mapping_df = pd.read_csv(MAPPING_FILE)
25
+
26
+ def get_label_text(level, predicted_id):
27
+ level_map = {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6}
28
+ level_num = level_map.get(level)
29
+ if level_num is not None:
30
+ row = mapping_df[(mapping_df['level'] == level_num) & (mapping_df['id'] == predicted_id)]
31
+ return row['text'].iloc[0] if not row.empty else "Description not found"
32
+ return "Invalid Level"
33
+
34
+ def predict_level(level, text, parent_prediction_id=None, checkpoint_path=None):
35
+ level_dir = LEVEL_DIRS[level]
36
+ tokenizer = AlbertTokenizer.from_pretrained(checkpoint_path)
37
+ label_map = np.load(os.path.join(level_dir, 'label_map.npy'), allow_pickle=True).item()
38
+ num_labels = len(label_map)
39
+
40
+ if level == 1:
41
+ model = AlbertForSequenceClassification.from_pretrained(checkpoint_path)
42
+ else:
43
+ parent_level_dir = LEVEL_DIRS[level - 1]
44
+ parent_label_map = np.load(os.path.join(parent_level_dir, 'label_map.npy'), allow_pickle=True).item()
45
+ num_parent_labels = len(parent_label_map)
46
+
47
+ class TaxonomyClassifier(nn.Module):
48
+ def __init__(self, base_model_name, num_parent_labels, num_labels):
49
+ super().__init__()
50
+ self.albert = AlbertModel.from_pretrained(base_model_name)
51
+ self.dropout = nn.Dropout(0.1)
52
+ self.classifier = nn.Linear(self.albert.config.hidden_size + num_parent_labels, num_labels)
53
+
54
+ def forward(self, input_ids, attention_mask, parent_ids):
55
+ outputs = self.albert(input_ids, attention_mask=attention_mask)
56
+ pooled_output = outputs.pooler_output
57
+ pooled_output = self.dropout(pooled_output)
58
+ combined_features = torch.cat((pooled_output, parent_ids), dim=1)
59
+ logits = self.classifier(combined_features)
60
+ return logits
61
+
62
+ model = TaxonomyClassifier(MODEL_NAME, num_parent_labels, num_labels)
63
+ model.load_state_dict(torch.load(os.path.join(checkpoint_path, 'model.safetensors'), map_location=torch.device('cpu')))
64
+
65
+ model.eval()
66
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
67
+
68
+ if level > 1:
69
+ parent_label_map_current = np.load(os.path.join(LEVEL_DIRS[level - 1], 'label_map.npy'), allow_pickle=True).item()
70
+ num_parent_labels_current = len(parent_label_map_current)
71
+ parent_one_hot = torch.zeros(num_parent_labels_current)
72
+ if parent_prediction_id != 0:
73
+ parent_index = parent_label_map_current.get(parent_prediction_id)
74
+ if parent_index is not None:
75
+ parent_one_hot[parent_index] = 1.0
76
+ with torch.no_grad():
77
+ outputs = model(inputs.input_ids, attention_mask=inputs.attention_mask, parent_ids=parent_one_hot.unsqueeze(0))
78
+ else:
79
+ with torch.no_grad():
80
+ outputs = model(**inputs)
81
+
82
+ probabilities = softmax(outputs.logits if level == 1 else outputs, dim=-1)[0]
83
+ top3_prob, top3_indices = torch.topk(probabilities, 3)
84
+ index_to_label = {v: k for k, v in label_map.items()}
85
+ results = []
86
+ for prob, index in zip(top3_prob, top3_indices):
87
+ predicted_label_id = index_to_label[index.item()]
88
+ results.append((predicted_label_id, prob.item()))
89
+ return results
90
+
91
+ st.title("Taxonomy Model Inference")
92
+
93
+ input_text = st.text_area("Enter text to classify", "Experience the magic of music with the Clavinova CLP-800 series. This versatile range of digital pianos is designed to delight everyone, from budding musicians to seasoned pianists. Each model combines state-of-the-art technology with the realistic touch and tone of world-renowned grand pianos, enhanced by GrandTouch keyboard action and Virtual Resonance Modeling. With seamless Bluetooth® connectivity, built-in lessons, and elegant design, the CLP-800 series offers the perfect blend of tradition and innovation. Elevate your musical journey with the warmth and sophistication of the Yamaha Clavinova, our finest series of digital pianos.")
94
+
95
+ softmax_threshold = st.slider("Softmax Threshold", min_value=0.0, max_value=1.0, value=0.5, step=0.05)
96
+
97
+ # Checkpoint Selection
98
+ available_levels = []
99
+ level_checkpoints = {}
100
+ for level in LEVEL_DIRS:
101
+ level_dir = LEVEL_DIRS[level]
102
+ if os.path.exists(level_dir):
103
+ options = [d for d in os.listdir(level_dir) if os.path.isdir(os.path.join(level_dir, d))]
104
+ options = [d for d in options if 'step' in d or d == 'model']
105
+ options.sort(key=lambda x: (('step' not in x), int(x.split('step')[-1]) if 'step' in x else -1))
106
+ level_checkpoints[level] = [os.path.join(level_dir, opt) for opt in options]
107
+ if level_checkpoints[level]:
108
+ available_levels.append(level)
109
+ else:
110
+ level_checkpoints[level] = []
111
+
112
+ selected_checkpoints = {}
113
+ for level in available_levels:
114
+ selected_checkpoints[level] = st.selectbox(f"Select Level {level} Checkpoint", options=level_checkpoints[level])
115
+
116
+ if st.button("Run Inference"):
117
+ if input_text:
118
+ all_level_results = {}
119
+ current_prediction_id = None
120
+ last_level = 0
121
+
122
+ for level in sorted(available_levels):
123
+ if selected_checkpoints[level]:
124
+ checkpoint_path = selected_checkpoints[level]
125
+ if level == 1:
126
+ level_results = predict_level(level, input_text, checkpoint_path=checkpoint_path)
127
+ else:
128
+ if current_prediction_id == 0:
129
+ st.info(f"Taxonomy terminated at Level {last_level} with ID 0.")
130
+ break
131
+ level_results = predict_level(level, input_text, parent_prediction_id=current_prediction_id, checkpoint_path=checkpoint_path)
132
+
133
+ if level_results[0][1] < softmax_threshold:
134
+ st.info(f"Inference stopped at Level {level} due to softmax probability ({level_results[0][1]:.3f}) being below the threshold.")
135
+ break
136
+
137
+ all_level_results[level] = level_results
138
+ current_prediction_id = level_results[0][0]
139
+ last_level = level
140
+ else:
141
+ st.warning(f"Skipping Level {level} as no checkpoint is selected.")
142
+ break
143
+
144
+ data = []
145
+ for level in sorted(all_level_results.keys()):
146
+ results = all_level_results[level]
147
+ data.append({
148
+ 'level': level,
149
+ 'text': get_label_text(level - 1, results[0][0]),
150
+ 'softmax': f"{results[0][1]:.3f}",
151
+ 'runner_up_1_id': results[1][0],
152
+ 'runner_up_1_text': get_label_text(level - 1, results[1][0]),
153
+ 'runner_up_1_softmax': f"{results[1][1]:.3f}",
154
+ 'runner_up_2_id': results[2][0],
155
+ 'runner_up_2_text': get_label_text(level - 1, results[2][0]),
156
+ 'runner_up_2_softmax': f"{results[2][1]:.3f}",
157
+ })
158
+
159
+ if data:
160
+ df = pd.DataFrame(data)
161
+ st.dataframe(df)
162
+ else:
163
+ st.info("No predictions made or inference stopped.")
164
+
165
+ else:
166
+ st.warning("Please enter text for classification.")