suryadev1 commited on
Commit
0b133b0
·
verified ·
1 Parent(s): 7884735

updated with app.py to run subprocess

Browse files
Files changed (1) hide show
  1. app.py +830 -830
app.py CHANGED
@@ -1,831 +1,831 @@
1
- import gradio as gr
2
- from huggingface_hub import hf_hub_download
3
- import pickle
4
- from gradio import Progress
5
- import numpy as np
6
- import subprocess
7
- import shutil
8
- import matplotlib.pyplot as plt
9
- from sklearn.metrics import roc_curve, auc
10
- import pandas as pd
11
- import plotly.graph_objects as go
12
- from sklearn.metrics import roc_auc_score
13
- from matplotlib.figure import Figure
14
- # Define the function to process the input file and model selection
15
-
16
- def process_file(model_name,inc_slider,progress=Progress(track_tqdm=True)):
17
- # progress = gr.Progress(track_tqdm=True)
18
-
19
- progress(0, desc="Starting the processing")
20
- # with open(file.name, 'r') as f:
21
- # content = f.read()
22
- # saved_test_dataset = "train.txt"
23
- # saved_test_label = "train_label.txt"
24
- # saved_train_info="train_info.txt"
25
- # Save the uploaded file content to a specified location
26
- # shutil.copyfile(file.name, saved_test_dataset)
27
- # shutil.copyfile(label.name, saved_test_label)
28
- # shutil.copyfile(info.name, saved_train_info)
29
- parent_location="ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/"
30
- test_info_location=parent_location+"fullTest/test_info.txt"
31
- test_location=parent_location+"fullTest/test.txt"
32
- if(model_name=="ASTRA-FT-HGR"):
33
- finetune_task="highGRschool10"
34
- # test_info_location=parent_location+"fullTest/test_info.txt"
35
- # test_location=parent_location+"fullTest/test.txt"
36
- elif(model_name== "ASTRA-FT-LGR" ):
37
- finetune_task="lowGRschoolAll"
38
- # test_info_location=parent_location+"lowGRschoolAll/test_info.txt"
39
- # test_location=parent_location+"lowGRschoolAll/test.txt"
40
- elif(model_name=="ASTRA-FT-FULL"):
41
- # test_info_location=parent_location+"fullTest/test_info.txt"
42
- # test_location=parent_location+"fullTest/test.txt"
43
- finetune_task="fullTest"
44
- else:
45
- finetune_task=None
46
- # Load the test_info file and the graduation rate file
47
- test_info = pd.read_csv(test_info_location, sep=',', header=None, engine='python')
48
- grad_rate_data = pd.DataFrame(pd.read_pickle('school_grduation_rate.pkl'),columns=['school_number','grad_rate']) # Load the grad_rate data
49
-
50
- # Step 1: Extract unique school numbers from test_info
51
- unique_schools = test_info[0].unique()
52
-
53
- # Step 2: Filter the grad_rate_data using the unique school numbers
54
- schools = grad_rate_data[grad_rate_data['school_number'].isin(unique_schools)]
55
-
56
- # Define a threshold for high and low graduation rates (adjust as needed)
57
- grad_rate_threshold = 0.9
58
-
59
- # Step 4: Divide schools into high and low graduation rate groups
60
- high_grad_schools = schools[schools['grad_rate'] >= grad_rate_threshold]['school_number'].unique()
61
- low_grad_schools = schools[schools['grad_rate'] < grad_rate_threshold]['school_number'].unique()
62
-
63
- # Step 5: Sample percentage of schools from each group
64
- high_sample = pd.Series(high_grad_schools).sample(frac=inc_slider/100, random_state=1).tolist()
65
- low_sample = pd.Series(low_grad_schools).sample(frac=inc_slider/100, random_state=1).tolist()
66
-
67
- # Step 6: Combine the sampled schools
68
- random_schools = high_sample + low_sample
69
-
70
- # Step 7: Get indices for the sampled schools
71
- indices = test_info[test_info[0].isin(random_schools)].index.tolist()
72
- high_indices = test_info[(test_info[0].isin(high_sample))].index.tolist()
73
- low_indices = test_info[(test_info[0].isin(low_sample))].index.tolist()
74
-
75
- # Load the test file and select rows based on indices
76
- test = pd.read_csv(test_location, sep=',', header=None, engine='python')
77
- selected_rows_df2 = test.loc[indices]
78
-
79
- # Save the selected rows to a file
80
- selected_rows_df2.to_csv('selected_rows.txt', sep='\t', index=False, header=False, quoting=3, escapechar=' ')
81
-
82
- graduation_groups = [
83
- 'high' if idx in high_indices else 'low' for idx in selected_rows_df2.index
84
- ]
85
- # Group data by opt_task1 and opt_task2 based on test_info[6]
86
- opt_task_groups = ['opt_task1' if test_info.loc[idx, 6] == 0 else 'opt_task2' for idx in selected_rows_df2.index]
87
-
88
- with open("roc_data2.pkl", 'rb') as file:
89
- data = pickle.load(file)
90
- t_label=data[0]
91
- p_label=data[1]
92
- # Step 1: Align graduation_group, t_label, and p_label
93
- aligned_labels = list(zip(graduation_groups, t_label, p_label))
94
- opt_task_aligned = list(zip(opt_task_groups, t_label, p_label))
95
- # Step 2: Separate the labels for high and low groups
96
- high_t_labels = [t for grad, t, p in aligned_labels if grad == 'high']
97
- low_t_labels = [t for grad, t, p in aligned_labels if grad == 'low']
98
-
99
- high_p_labels = [p for grad, t, p in aligned_labels if grad == 'high']
100
- low_p_labels = [p for grad, t, p in aligned_labels if grad == 'low']
101
-
102
- opt_task1_t_labels = [t for task, t, p in opt_task_aligned if task == 'opt_task1']
103
- opt_task1_p_labels = [p for task, t, p in opt_task_aligned if task == 'opt_task1']
104
-
105
- opt_task2_t_labels = [t for task, t, p in opt_task_aligned if task == 'opt_task2']
106
- opt_task2_p_labels = [p for task, t, p in opt_task_aligned if task == 'opt_task2']
107
-
108
- high_roc_auc = roc_auc_score(high_t_labels, high_p_labels) if len(set(high_t_labels)) > 1 else None
109
- low_roc_auc = roc_auc_score(low_t_labels, low_p_labels) if len(set(low_t_labels)) > 1 else None
110
-
111
- opt_task1_roc_auc = roc_auc_score(opt_task1_t_labels, opt_task1_p_labels) if len(set(opt_task1_t_labels)) > 1 else None
112
- opt_task2_roc_auc = roc_auc_score(opt_task2_t_labels, opt_task2_p_labels) if len(set(opt_task2_t_labels)) > 1 else None
113
-
114
- # For demonstration purposes, we'll just return the content with the selected model name
115
-
116
- # print(checkpoint)
117
- progress(0.1, desc="Files created and saved")
118
- # if (inc_val<5):
119
- # model_name="highGRschool10"
120
- # elif(inc_val>=5 & inc_val<10):
121
- # model_name="highGRschool10"
122
- # else:
123
- # model_name="highGRschool10"
124
- # Function to analyze each row
125
- def analyze_row(row):
126
- # Split the row into fields
127
- fields = row.split("\t")
128
-
129
- # Define tasks for OptionalTask_1, OptionalTask_2, and FinalAnswer
130
- optional_task_1_subtasks = ["DenominatorFactor", "NumeratorFactor", "EquationAnswer"]
131
- optional_task_2_subtasks = [
132
- "FirstRow2:1", "FirstRow2:2", "FirstRow1:1", "FirstRow1:2",
133
- "SecondRow", "ThirdRow"
134
- ]
135
-
136
- # Helper function to evaluate task attempts
137
- def evaluate_tasks(fields, tasks):
138
- task_status = {}
139
- for task in tasks:
140
- relevant_attempts = [f for f in fields if task in f]
141
- if any("OK" in attempt for attempt in relevant_attempts):
142
- task_status[task] = "Attempted (Successful)"
143
- elif any("ERROR" in attempt for attempt in relevant_attempts):
144
- task_status[task] = "Attempted (Error)"
145
- elif any("JIT" in attempt for attempt in relevant_attempts):
146
- task_status[task] = "Attempted (JIT)"
147
- else:
148
- task_status[task] = "Unattempted"
149
- return task_status
150
-
151
- # Evaluate tasks for each category
152
- optional_task_1_status = evaluate_tasks(fields, optional_task_1_subtasks)
153
- optional_task_2_status = evaluate_tasks(fields, optional_task_2_subtasks)
154
-
155
- # Check if tasks have any successful attempt
156
- opt1_done = any(status == "Attempted (Successful)" for status in optional_task_1_status.values())
157
- opt2_done = any(status == "Attempted (Successful)" for status in optional_task_2_status.values())
158
-
159
- return opt1_done, opt2_done
160
-
161
- # Read data from test_info.txt
162
- with open(test_info_location, "r") as file:
163
- data = file.readlines()
164
-
165
- # Assuming test_info[7] is a list with ideal tasks for each instance
166
- ideal_tasks = test_info[6] # A list where each element is either 1 or 2
167
-
168
- # Initialize counters
169
- task_counts = {
170
- 1: {"ER": 0, "ME": 0, "both": 0,"none":0},
171
- 2: {"ER": 0, "ME": 0, "both": 0,"none":0}
172
- }
173
-
174
- # Analyze rows
175
- for i, row in enumerate(data):
176
- row = row.strip()
177
- if not row:
178
- continue
179
-
180
- ideal_task = ideal_tasks[i] # Get the ideal task for the current row
181
- opt1_done, opt2_done = analyze_row(row)
182
-
183
- if ideal_task == 0:
184
- if opt1_done and not opt2_done:
185
- task_counts[1]["ER"] += 1
186
- elif not opt1_done and opt2_done:
187
- task_counts[1]["ME"] += 1
188
- elif opt1_done and opt2_done:
189
- task_counts[1]["both"] += 1
190
- else:
191
- task_counts[1]["none"] +=1
192
- elif ideal_task == 1:
193
- if opt1_done and not opt2_done:
194
- task_counts[2]["ER"] += 1
195
- elif not opt1_done and opt2_done:
196
- task_counts[2]["ME"] += 1
197
- elif opt1_done and opt2_done:
198
- task_counts[2]["both"] += 1
199
- else:
200
- task_counts[2]["none"] +=1
201
-
202
- # Create a string output for results
203
- # output_summary = "Task Analysis Summary:\n"
204
- # output_summary += "-----------------------\n"
205
-
206
- # for ideal_task, counts in task_counts.items():
207
- # output_summary += f"Ideal Task = OptionalTask_{ideal_task}:\n"
208
- # output_summary += f" Only OptionalTask_1 done: {counts['ER']}\n"
209
- # output_summary += f" Only OptionalTask_2 done: {counts['ME']}\n"
210
- # output_summary += f" Both done: {counts['both']}\n"
211
-
212
- # colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']
213
- colors = ["#FF6F61", "#6B5B95", "#88B04B", "#F7CAC9"]
214
-
215
- # Generate pie chart for Task 1
216
- task1_labels = list(task_counts[1].keys())
217
- task1_values = list(task_counts[1].values())
218
-
219
- # fig_task1 = Figure()
220
- # ax1 = fig_task1.add_subplot(1, 1, 1)
221
- # ax1.pie(task1_values, labels=task1_labels, autopct='%1.1f%%', startangle=90)
222
- # ax1.set_title('Ideal Task 1 Distribution')
223
-
224
- fig_task1 = go.Figure(data=[go.Pie(
225
- labels=task1_labels,
226
- values=task1_values,
227
- textinfo='percent+label',
228
- textposition='auto',
229
- marker=dict(colors=colors),
230
- sort=False
231
-
232
- )])
233
-
234
- fig_task1.update_layout(
235
- title='Problem Type: ER',
236
- title_x=0.5,
237
- font=dict(
238
- family="sans-serif",
239
- size=12,
240
- color="black"
241
- ),
242
- )
243
-
244
- fig_task1.update_layout(
245
- legend=dict(
246
- font=dict(
247
- family="sans-serif",
248
- size=12,
249
- color="black"
250
- ),
251
- )
252
- )
253
-
254
-
255
-
256
- # fig.show()
257
-
258
- # Generate pie chart for Task 2
259
- task2_labels = list(task_counts[2].keys())
260
- task2_values = list(task_counts[2].values())
261
-
262
- fig_task2 = go.Figure(data=[go.Pie(
263
- labels=task2_labels,
264
- values=task2_values,
265
- textinfo='percent+label',
266
- textposition='auto',
267
- marker=dict(colors=colors),
268
- sort=False
269
- # pull=[0, 0.2, 0, 0] # for pulling part of pie chart out (depends on position)
270
-
271
- )])
272
-
273
- fig_task2.update_layout(
274
- title='Problem Type: ME',
275
- title_x=0.5,
276
- font=dict(
277
- family="sans-serif",
278
- size=12,
279
- color="black"
280
- ),
281
- )
282
-
283
- fig_task2.update_layout(
284
- legend=dict(
285
- font=dict(
286
- family="sans-serif",
287
- size=12,
288
- color="black"
289
- ),
290
- )
291
- )
292
-
293
-
294
- # fig_task2 = Figure()
295
- # ax2 = fig_task2.add_subplot(1, 1, 1)
296
- # ax2.pie(task2_values, labels=task2_labels, autopct='%1.1f%%', startangle=90)
297
- # ax2.set_title('Ideal Task 2 Distribution')
298
-
299
- # print(output_summary)
300
-
301
- progress(0.2, desc="analysis done!! Executing models")
302
- print("finetuned task: ",finetune_task)
303
- # subprocess.run([
304
- # "python", "new_test_saved_finetuned_model.py",
305
- # "-workspace_name", "ratio_proportion_change3_2223/sch_largest_100-coded",
306
- # "-finetune_task", finetune_task,
307
- # "-test_dataset_path","../../../../selected_rows.txt",
308
- # # "-test_label_path","../../../../train_label.txt",
309
- # "-finetuned_bert_classifier_checkpoint",
310
- # "ratio_proportion_change3_2223/sch_largest_100-coded/output/highGRschool10/bert_fine_tuned.model.ep42",
311
- # "-e",str(1),
312
- # "-b",str(1000)
313
- # ])
314
- progress(0.6,desc="Model execution completed")
315
- result = {}
316
- with open("result.txt", 'r') as file:
317
- for line in file:
318
- key, value = line.strip().split(': ', 1)
319
- # print(type(key))
320
- if key=='epoch':
321
- result[key]=value
322
- else:
323
- result[key]=float(value)
324
- result["ROC score of HGR"]=high_roc_auc
325
- result["ROC score of LGR"]=low_roc_auc
326
- # Create a plot
327
- with open("roc_data.pkl", "rb") as f:
328
- fpr, tpr, _ = pickle.load(f)
329
- # print(fpr,tpr)
330
- roc_auc = auc(fpr, tpr)
331
-
332
-
333
- # Create a matplotlib figure
334
- # fig = Figure()
335
- # ax = fig.add_subplot(1, 1, 1)
336
- # ax.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
337
- # ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
338
- # ax.set(xlabel='False Positive Rate', ylabel='True Positive Rate', title=f'Receiver Operating Curve (ROC)')
339
- # ax.legend(loc="lower right")
340
- # ax.grid()
341
-
342
- fig = go.Figure()
343
- # Create and style traces
344
- fig.add_trace(go.Line(x = list(fpr), y = list(tpr), name=f'ROC curve (area = {roc_auc:.2f})',
345
- line=dict(color='royalblue', width=3,
346
- ) # dash options include 'dash', 'dot', and 'dashdot'
347
- ))
348
- fig.add_trace(go.Line(x = [0,1], y = [0,1], showlegend = False,
349
- line=dict(color='firebrick', width=2,
350
- dash='dash',) # dash options include 'dash', 'dot', and 'dashdot'
351
- ))
352
-
353
- # Edit the layout
354
- fig.update_layout(
355
- showlegend = True,
356
- title_x=0.5,
357
- title=dict(
358
- text='Receiver Operating Curve (ROC)'
359
- ),
360
- xaxis=dict(
361
- title=dict(
362
- text='False Positive Rate'
363
- )
364
- ),
365
- yaxis=dict(
366
- title=dict(
367
- text='False Negative Rate'
368
- )
369
- ),
370
- font=dict(
371
- family="sans-serif",
372
- color="black"
373
- ),
374
-
375
- )
376
- fig.update_layout(
377
- legend=dict(
378
- x=0.75,
379
- y=0,
380
- traceorder="normal",
381
- font=dict(
382
- family="sans-serif",
383
- size=12,
384
- color="black"
385
- ),
386
- )
387
- )
388
-
389
-
390
-
391
-
392
-
393
-
394
- # Save plot to a file
395
- # plot_path = "plot.png"
396
- # fig.savefig(plot_path)
397
- # plt.close(fig)
398
-
399
-
400
-
401
-
402
- progress(1.0)
403
- # Prepare text output
404
- text_output = f"Model: {model_name}\nResult:\n{result}"
405
- # Prepare text output with HTML formatting
406
- text_output = f"""
407
- ---------------------------
408
- Model: {model_name}
409
- ---------------------------\n
410
- Time Taken: {result['time_taken_from_start']:.2f} seconds
411
- Total Schools in test: {len(unique_schools):.4f}
412
- Total number of instances having Schools with HGR : {len(high_sample):.4f}
413
- Total number of instances having Schools with LGR: {len(low_sample):.4f}
414
-
415
- ROC score of HGR: {high_roc_auc:.4f}
416
- ROC score of LGR: {low_roc_auc:.4f}
417
-
418
- ROC-AUC for problems of type ER: {opt_task1_roc_auc:.4f}
419
- ROC-AUC for problems of type ME: {opt_task2_roc_auc:.4f}
420
- """
421
- return text_output,fig,fig_task1,fig_task2
422
-
423
- # List of models for the dropdown menu
424
-
425
- # models = ["ASTRA-FT-HGR", "ASTRA-FT-LGR", "ASTRA-FT-FULL"]
426
- models = ["ASTRA-FT-HGR", "ASTRA-FT-FULL"]
427
- content = """
428
- <h1 style="color: black;">A S T R A</h1>
429
- <h2 style="color: black;">An AI Model for Analyzing Math Strategies</h2>
430
-
431
- <h3 style="color: white; text-align: center">
432
- <a href="https://drive.google.com/file/d/1lbEpg8Se1ugTtkjreD8eXIg7qrplhWan/view" style="color: gr.themes.colors.red; text-decoration: none;">Link To Paper</a> |
433
- <a href="https://github.com/Syudu41/ASTRA---Gates-Project" style="color: #1E90FF; text-decoration: none;">GitHub</a> |
434
- <a href="https://sites.google.com/view/astra-research/home" style="color: #1E90FF; text-decoration: none;">Project Page</a>
435
- </h3>
436
-
437
- <p style="color: white;">Welcome to a demo of ASTRA. ASTRA is a collaborative research project between researchers at the
438
- <a href="https://sites.google.com/site/dvngopal/" style="color: #1E90FF; text-decoration: none;">University of Memphis</a> and
439
- <a href="https://www.carnegielearning.com" style="color: #1E90FF; text-decoration: none;">Carnegie Learning</a>
440
- to utilize AI to improve our understanding of math learning strategies.</p>
441
-
442
- <p style="color: white;">This demo has been developed with a pre-trained model (based on an architecture similar to BERT ) that learns math strategies using data
443
- collected from hundreds of schools in the U.S. who have used Carnegie Learning’s MATHia (formerly known as Cognitive Tutor), the flagship Intelligent Tutor that is part of a core, blended math curriculum.
444
- For this demo, we have used data from a specific domain (teaching ratio and proportions) within 7th grade math. The fine-tuning based on the pre-trained model learns to predict which strategies lead to correct vs incorrect solutions.
445
- </p>
446
-
447
- <p style="color: white;">In this math domain, students were given word problems related to ratio and proportions. Further, the students
448
- were given a choice of optional tasks to work on in parallel to the main problem to demonstrate their thinking (metacognition).
449
- The optional tasks are designed based on solving problems using Equivalent Ratios (ER) and solving using Means and Extremes/cross-multiplication (ME).
450
- When the equivalent ratios are easy to compute (integral values), ER is much more efficient compared to ME and switching between the tasks appropriately demonstrates cognitive flexibility.
451
- </p>
452
-
453
- <p style="color: white;">To use the demo, please follow these steps:</p>
454
-
455
- <ol style="color: white;">
456
- <li style="color: white;">Select a fine-tuned model:
457
- <ul style="color: white;">
458
- <li style="color: white;">ASTRA-FT-HGR: Fine-tuned with a small sample of data from schools that have a high graduation rate.</li>
459
- <li style="color: white;">ASTRA-FT-Full: Fine-tuned with a small sample of data from a mix of schools that have high/low graduation rates.</li>
460
- </ul>
461
- </li>
462
- <li style="color: white;">Select a percentage of schools to analyze (selecting a large percentage may take a long time). Note that the selected percentage is applied to both High Graduation Rate (HGR) schools and Low Graduation Rate (LGR schools).
463
- </li>
464
- <li style="color: white;">The results from the fine-tuned model are displayed in the dashboard:
465
- <ul>
466
- <li style="color: white;">The model accuracy is computed using the ROC-AUC metric.
467
- </li>
468
- <li style="color: white;">The results are shown for HGR, LGR schools and for different problem types (ER/ME).
469
- </li>
470
- <li style="color: white;">The distribution over how students utilized the optional tasks (whether they utilized ER/ME, used both of them or none of them) is shown for each problem type.
471
- </li>
472
- </ul>
473
- </li>
474
- </ol>
475
- """
476
- # CSS styling for white text
477
- # Create the Gradio interface
478
- available_themes = {
479
- "default": gr.themes.Default(),
480
- "soft": gr.themes.Soft(),
481
- "monochrome": gr.themes.Monochrome(),
482
- "glass": gr.themes.Glass(),
483
- "base": gr.themes.Base(),
484
- }
485
-
486
- # Comprehensive CSS for all HTML elements
487
- custom_css = '''
488
- /* Import Fira Sans font */
489
- @import url('https://fonts.googleapis.com/css2?family=Fira+Sans:wght@400;500;600;700&family=Inter:wght@400;500;600;700&display=swap');
490
- @import url('https://fonts.googleapis.com/css2?family=Libre+Caslon+Text:ital,wght@0,400;0,700;1,400&family=Spectral+SC:wght@600&display=swap');
491
- /* Container modifications for centering */
492
- .gradio-container {
493
- color: var(--block-label-text-color) !important;
494
- max-width: 1000px !important;
495
- margin: 0 auto !important;
496
- padding: 2rem !important;
497
- font-family: Arial, sans-serif !important;
498
- }
499
-
500
- /* Main title (ASTRA) */
501
- #title {
502
- text-align: center !important;
503
- margin: 1rem auto !important; /* Reduced margin */
504
- font-size: 2.5em !important;
505
- font-weight: 600 !important;
506
- font-family: "Spectral SC", 'Fira Sans', sans-serif !important;
507
- padding-bottom: 0 !important; /* Remove bottom padding */
508
- }
509
-
510
- /* Subtitle (An AI Model...) */
511
- h1 {
512
- text-align: center !important;
513
- font-size: 30pt !important;
514
- font-weight: 600 !important;
515
- font-family: "Spectral SC", 'Fira Sans', sans-serif !important;
516
- margin-top: 0.5em !important; /* Reduced top margin */
517
- margin-bottom: 0.3em !important;
518
- }
519
-
520
- h2 {
521
- text-align: center !important;
522
- font-size: 22pt !important;
523
- font-weight: 600 !important;
524
- font-family: "Spectral SC",'Fira Sans', sans-serif !important;
525
- margin-top: 0.2em !important; /* Reduced top margin */
526
- margin-bottom: 0.3em !important;
527
- }
528
-
529
- /* Links container styling */
530
- .links-container {
531
- text-align: center !important;
532
- margin: 1em auto !important;
533
- font-family: 'Inter' ,'Fira Sans', sans-serif !important;
534
- }
535
-
536
- /* Links */
537
- a {
538
- color: #2563eb !important;
539
- text-decoration: none !important;
540
- font-family:'Inter' , 'Fira Sans', sans-serif !important;
541
- }
542
-
543
- a:hover {
544
- text-decoration: underline !important;
545
- opacity: 0.8;
546
- }
547
-
548
- /* Regular text */
549
- p, li, .description, .markdown-text {
550
- font-family: 'Inter', Arial, sans-serif !important;
551
- color: black !important;
552
- font-size: 11pt;
553
- line-height: 1.6;
554
- font-weight: 500 !important;
555
- color: var(--block-label-text-color) !important;
556
- }
557
-
558
- /* Other headings */
559
- h3, h4, h5 {
560
- font-family: 'Fira Sans', sans-serif !important;
561
- color: var(--block-label-text-color) !important;
562
- margin-top: 1.5em;
563
- margin-bottom: 0.75em;
564
- }
565
-
566
-
567
- h3 { font-size: 1.5em; font-weight: 600; }
568
- h4 { font-size: 1.25em; font-weight: 500; }
569
- h5 { font-size: 1.1em; font-weight: 500; }
570
-
571
- /* Form elements */
572
- .select-wrap select, .wrap select,
573
- input, textarea {
574
- font-family: 'Inter' ,Arial, sans-serif !important;
575
- color: var(--block-label-text-color) !important;
576
- }
577
-
578
- /* Lists */
579
- ul, ol {
580
- margin-left: 0 !important;
581
- margin-bottom: 1.25em;
582
- padding-left: 2em;
583
- }
584
-
585
- li {
586
- margin-bottom: 0.75em;
587
- }
588
-
589
- /* Form container */
590
- .form-container {
591
- max-width: 1000px !important;
592
- margin: 0 auto !important;
593
- padding: 1rem !important;
594
- }
595
-
596
- /* Dashboard */
597
- .dashboard {
598
- margin-top: 2rem !important;
599
- padding: 1rem !important;
600
- border-radius: 8px !important;
601
- }
602
-
603
- /* Slider styling */
604
- .gradio-slider-row {
605
- display: flex;
606
- align-items: center;
607
- justify-content: space-between;
608
- margin: 1.5em 0;
609
- max-width: 100% !important;
610
- }
611
-
612
- .gradio-slider {
613
- flex-grow: 1;
614
- margin-right: 15px;
615
- }
616
-
617
- .slider-percentage {
618
- font-family: 'Inter', Arial, sans-serif !important;
619
- flex-shrink: 0;
620
- min-width: 60px;
621
- font-size: 1em;
622
- font-weight: bold;
623
- text-align: center;
624
- background-color: #f0f8ff;
625
- border: 1px solid #004080;
626
- border-radius: 5px;
627
- padding: 5px 10px;
628
- }
629
-
630
- .progress-bar-wrap.progress-bar-wrap.progress-bar-wrap
631
- {
632
- border-radius: var(--input-radius);
633
- height: 1.25rem;
634
- margin-top: 1rem;
635
- overflow: hidden;
636
- width: 70%;
637
- font-family: 'Inter', Arial, sans-serif !important;
638
- }
639
-
640
- /* Add these new styles after your existing CSS */
641
-
642
- /* Card-like appearance for the dashboard */
643
- .dashboard {
644
- background: #ffffff !important;
645
- box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06) !important;
646
- border-radius: 12px !important;
647
- padding: 2rem !important;
648
- margin-top: 2.5rem !important;
649
- }
650
-
651
- /* Enhance ROC graph container */
652
- #roc {
653
- background: #ffffff !important;
654
- padding: 1.5rem !important;
655
- border-radius: 8px !important;
656
- box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05) !important;
657
- margin: 1.5rem 0 !important;
658
- }
659
-
660
- /* Style the dropdown select */
661
- select {
662
- background-color: #ffffff !important;
663
- border: 1px solid #e2e8f0 !important;
664
- border-radius: 8px !important;
665
- padding: 0.5rem 1rem !important;
666
- transition: all 0.2s ease-in-out !important;
667
- box-shadow: 0 1px 2px rgba(0, 0, 0, 0.05) !important;
668
- }
669
-
670
- select:hover {
671
- border-color: #cbd5e1 !important;
672
- box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1) !important;
673
- }
674
-
675
- /* Enhance slider appearance */
676
- .progress-bar-wrap {
677
- background: #f8fafc !important;
678
- border: 1px solid #e2e8f0 !important;
679
- box-shadow: inset 0 2px 4px rgba(0, 0, 0, 0.05) !important;
680
- }
681
-
682
- /* Style metrics in dashboard */
683
- .dashboard p {
684
- padding: 0.5rem 0 !important;
685
- border-bottom: 1px solid #f1f5f9 !important;
686
- }
687
-
688
- /* Add spacing between sections */
689
- .dashboard > div {
690
- margin-bottom: 1.5rem !important;
691
- }
692
-
693
- /* Style the ROC curve title */
694
- .dashboard h4 {
695
- color: #1e293b !important;
696
- font-weight: 600 !important;
697
- margin-bottom: 1rem !important;
698
- padding-bottom: 0.5rem !important;
699
- border-bottom: 2px solid #e2e8f0 !important;
700
- }
701
-
702
- /* Enhance link appearances */
703
- a {
704
- position: relative !important;
705
- padding-bottom: 2px !important;
706
- transition: all 0.2s ease-in-out !important;
707
- }
708
-
709
- a:after {
710
- content: '' !important;
711
- position: absolute !important;
712
- width: 0 !important;
713
- height: 1px !important;
714
- bottom: 0 !important;
715
- left: 0 !important;
716
- background-color: #2563eb !important;
717
- transition: width 0.3s ease-in-out !important;
718
- }
719
-
720
- a:hover:after {
721
- width: 100% !important;
722
- }
723
-
724
- /* Add subtle dividers between sections */
725
- .form-container > div {
726
- padding-bottom: 1.5rem !important;
727
- margin-bottom: 1.5rem !important;
728
- border-bottom: 1px solid #f1f5f9 !important;
729
- }
730
-
731
- /* Style model selection section */
732
- .select-wrap {
733
- background: #ffffff !important;
734
- padding: 1.5rem !important;
735
- border-radius: 8px !important;
736
- box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05) !important;
737
- margin-bottom: 2rem !important;
738
- }
739
-
740
- /* Style the metrics display */
741
- .dashboard span {
742
- font-family: 'Inter', sans-serif !important;
743
- font-weight: 500 !important;
744
- color: #334155 !important;
745
- }
746
-
747
- /* Add subtle animation to interactive elements */
748
- button, select, .slider-percentage {
749
- transition: all 0.2s ease-in-out !important;
750
- }
751
-
752
- /* Style the ROC curve container */
753
- .plot-container {
754
- background: #ffffff !important;
755
- border-radius: 8px !important;
756
- padding: 1rem !important;
757
- box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05) !important;
758
- }
759
-
760
- /* Add container styles for opt1 and opt2 sections */
761
- #opt1, #opt2 {
762
- background: #ffffff !important;
763
- border-radius: 8px !important;
764
- padding: 1.5rem !important;
765
- margin-top: 1.5rem !important;
766
- box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05) !important;
767
- }
768
-
769
- /* Style the distribution titles */
770
- .distribution-title {
771
- font-family: 'Inter', sans-serif !important;
772
- font-weight: 600 !important;
773
- color: #1e293b !important;
774
- margin-bottom: 1rem !important;
775
- text-align: center !important;
776
- }
777
-
778
- '''
779
-
780
- with gr.Blocks(theme='gstaff/sketch', css=custom_css) as demo:
781
-
782
- # gr.Markdown("<h1 id='title'>ASTRA</h1>", elem_id="title")
783
- gr.Markdown(content)
784
-
785
- with gr.Row():
786
- # file_input = gr.File(label="Upload a test file", file_types=['.txt'], elem_classes="file-box")
787
- # label_input = gr.File(label="Upload test labels", file_types=['.txt'], elem_classes="file-box")
788
-
789
- # info_input = gr.File(label="Upload test info", file_types=['.txt'], elem_classes="file-box")
790
- model_dropdown = gr.Dropdown(
791
- choices=models,
792
- label="Select Fine-tuned Model",
793
- elem_classes="dropdown-menu"
794
- )
795
- increment_slider = gr.Slider(
796
- minimum=1,
797
- maximum=100,
798
- step=1,
799
- label="Schools Percentage",
800
- value=1,
801
- elem_id="increment-slider",
802
- elem_classes="gradio-slider"
803
- )
804
-
805
- with gr.Row():
806
- btn = gr.Button("Submit")
807
-
808
- gr.Markdown("<p class='description'>Dashboard</p>")
809
-
810
- with gr.Row():
811
- output_text = gr.Textbox(label="")
812
- # output_image = gr.Image(label="ROC")
813
- with gr.Row():
814
- plot_output = gr.Plot(label="ROC")
815
-
816
- with gr.Row():
817
- opt1_pie = gr.Plot(label="ER")
818
- opt2_pie = gr.Plot(label="ME")
819
- # output_summary = gr.Textbox(label="Summary")
820
-
821
-
822
-
823
- btn.click(
824
- fn=process_file,
825
- inputs=[model_dropdown,increment_slider],
826
- outputs=[output_text,plot_output,opt1_pie,opt2_pie]
827
- )
828
-
829
-
830
- # Launch the app
831
  demo.launch()
 
1
+ import gradio as gr
2
+ from huggingface_hub import hf_hub_download
3
+ import pickle
4
+ from gradio import Progress
5
+ import numpy as np
6
+ import subprocess
7
+ import shutil
8
+ import matplotlib.pyplot as plt
9
+ from sklearn.metrics import roc_curve, auc
10
+ import pandas as pd
11
+ import plotly.graph_objects as go
12
+ from sklearn.metrics import roc_auc_score
13
+ from matplotlib.figure import Figure
14
+ # Define the function to process the input file and model selection
15
+
16
+ def process_file(model_name,inc_slider,progress=Progress(track_tqdm=True)):
17
+ # progress = gr.Progress(track_tqdm=True)
18
+
19
+ progress(0, desc="Starting the processing")
20
+ # with open(file.name, 'r') as f:
21
+ # content = f.read()
22
+ # saved_test_dataset = "train.txt"
23
+ # saved_test_label = "train_label.txt"
24
+ # saved_train_info="train_info.txt"
25
+ # Save the uploaded file content to a specified location
26
+ # shutil.copyfile(file.name, saved_test_dataset)
27
+ # shutil.copyfile(label.name, saved_test_label)
28
+ # shutil.copyfile(info.name, saved_train_info)
29
+ parent_location="ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/"
30
+ test_info_location=parent_location+"fullTest/test_info.txt"
31
+ test_location=parent_location+"fullTest/test.txt"
32
+ if(model_name=="ASTRA-FT-HGR"):
33
+ finetune_task="highGRschool10"
34
+ # test_info_location=parent_location+"fullTest/test_info.txt"
35
+ # test_location=parent_location+"fullTest/test.txt"
36
+ elif(model_name== "ASTRA-FT-LGR" ):
37
+ finetune_task="lowGRschoolAll"
38
+ # test_info_location=parent_location+"lowGRschoolAll/test_info.txt"
39
+ # test_location=parent_location+"lowGRschoolAll/test.txt"
40
+ elif(model_name=="ASTRA-FT-FULL"):
41
+ # test_info_location=parent_location+"fullTest/test_info.txt"
42
+ # test_location=parent_location+"fullTest/test.txt"
43
+ finetune_task="fullTest"
44
+ else:
45
+ finetune_task=None
46
+ # Load the test_info file and the graduation rate file
47
+ test_info = pd.read_csv(test_info_location, sep=',', header=None, engine='python')
48
+ grad_rate_data = pd.DataFrame(pd.read_pickle('school_grduation_rate.pkl'),columns=['school_number','grad_rate']) # Load the grad_rate data
49
+
50
+ # Step 1: Extract unique school numbers from test_info
51
+ unique_schools = test_info[0].unique()
52
+
53
+ # Step 2: Filter the grad_rate_data using the unique school numbers
54
+ schools = grad_rate_data[grad_rate_data['school_number'].isin(unique_schools)]
55
+
56
+ # Define a threshold for high and low graduation rates (adjust as needed)
57
+ grad_rate_threshold = 0.9
58
+
59
+ # Step 4: Divide schools into high and low graduation rate groups
60
+ high_grad_schools = schools[schools['grad_rate'] >= grad_rate_threshold]['school_number'].unique()
61
+ low_grad_schools = schools[schools['grad_rate'] < grad_rate_threshold]['school_number'].unique()
62
+
63
+ # Step 5: Sample percentage of schools from each group
64
+ high_sample = pd.Series(high_grad_schools).sample(frac=inc_slider/100, random_state=1).tolist()
65
+ low_sample = pd.Series(low_grad_schools).sample(frac=inc_slider/100, random_state=1).tolist()
66
+
67
+ # Step 6: Combine the sampled schools
68
+ random_schools = high_sample + low_sample
69
+
70
+ # Step 7: Get indices for the sampled schools
71
+ indices = test_info[test_info[0].isin(random_schools)].index.tolist()
72
+ high_indices = test_info[(test_info[0].isin(high_sample))].index.tolist()
73
+ low_indices = test_info[(test_info[0].isin(low_sample))].index.tolist()
74
+
75
+ # Load the test file and select rows based on indices
76
+ test = pd.read_csv(test_location, sep=',', header=None, engine='python')
77
+ selected_rows_df2 = test.loc[indices]
78
+
79
+ # Save the selected rows to a file
80
+ selected_rows_df2.to_csv('selected_rows.txt', sep='\t', index=False, header=False, quoting=3, escapechar=' ')
81
+
82
+ graduation_groups = [
83
+ 'high' if idx in high_indices else 'low' for idx in selected_rows_df2.index
84
+ ]
85
+ # Group data by opt_task1 and opt_task2 based on test_info[6]
86
+ opt_task_groups = ['opt_task1' if test_info.loc[idx, 6] == 0 else 'opt_task2' for idx in selected_rows_df2.index]
87
+
88
+ with open("roc_data2.pkl", 'rb') as file:
89
+ data = pickle.load(file)
90
+ t_label=data[0]
91
+ p_label=data[1]
92
+ # Step 1: Align graduation_group, t_label, and p_label
93
+ aligned_labels = list(zip(graduation_groups, t_label, p_label))
94
+ opt_task_aligned = list(zip(opt_task_groups, t_label, p_label))
95
+ # Step 2: Separate the labels for high and low groups
96
+ high_t_labels = [t for grad, t, p in aligned_labels if grad == 'high']
97
+ low_t_labels = [t for grad, t, p in aligned_labels if grad == 'low']
98
+
99
+ high_p_labels = [p for grad, t, p in aligned_labels if grad == 'high']
100
+ low_p_labels = [p for grad, t, p in aligned_labels if grad == 'low']
101
+
102
+ opt_task1_t_labels = [t for task, t, p in opt_task_aligned if task == 'opt_task1']
103
+ opt_task1_p_labels = [p for task, t, p in opt_task_aligned if task == 'opt_task1']
104
+
105
+ opt_task2_t_labels = [t for task, t, p in opt_task_aligned if task == 'opt_task2']
106
+ opt_task2_p_labels = [p for task, t, p in opt_task_aligned if task == 'opt_task2']
107
+
108
+ high_roc_auc = roc_auc_score(high_t_labels, high_p_labels) if len(set(high_t_labels)) > 1 else None
109
+ low_roc_auc = roc_auc_score(low_t_labels, low_p_labels) if len(set(low_t_labels)) > 1 else None
110
+
111
+ opt_task1_roc_auc = roc_auc_score(opt_task1_t_labels, opt_task1_p_labels) if len(set(opt_task1_t_labels)) > 1 else None
112
+ opt_task2_roc_auc = roc_auc_score(opt_task2_t_labels, opt_task2_p_labels) if len(set(opt_task2_t_labels)) > 1 else None
113
+
114
+ # For demonstration purposes, we'll just return the content with the selected model name
115
+
116
+ # print(checkpoint)
117
+ progress(0.1, desc="Files created and saved")
118
+ # if (inc_val<5):
119
+ # model_name="highGRschool10"
120
+ # elif(inc_val>=5 & inc_val<10):
121
+ # model_name="highGRschool10"
122
+ # else:
123
+ # model_name="highGRschool10"
124
+ # Function to analyze each row
125
+ def analyze_row(row):
126
+ # Split the row into fields
127
+ fields = row.split("\t")
128
+
129
+ # Define tasks for OptionalTask_1, OptionalTask_2, and FinalAnswer
130
+ optional_task_1_subtasks = ["DenominatorFactor", "NumeratorFactor", "EquationAnswer"]
131
+ optional_task_2_subtasks = [
132
+ "FirstRow2:1", "FirstRow2:2", "FirstRow1:1", "FirstRow1:2",
133
+ "SecondRow", "ThirdRow"
134
+ ]
135
+
136
+ # Helper function to evaluate task attempts
137
+ def evaluate_tasks(fields, tasks):
138
+ task_status = {}
139
+ for task in tasks:
140
+ relevant_attempts = [f for f in fields if task in f]
141
+ if any("OK" in attempt for attempt in relevant_attempts):
142
+ task_status[task] = "Attempted (Successful)"
143
+ elif any("ERROR" in attempt for attempt in relevant_attempts):
144
+ task_status[task] = "Attempted (Error)"
145
+ elif any("JIT" in attempt for attempt in relevant_attempts):
146
+ task_status[task] = "Attempted (JIT)"
147
+ else:
148
+ task_status[task] = "Unattempted"
149
+ return task_status
150
+
151
+ # Evaluate tasks for each category
152
+ optional_task_1_status = evaluate_tasks(fields, optional_task_1_subtasks)
153
+ optional_task_2_status = evaluate_tasks(fields, optional_task_2_subtasks)
154
+
155
+ # Check if tasks have any successful attempt
156
+ opt1_done = any(status == "Attempted (Successful)" for status in optional_task_1_status.values())
157
+ opt2_done = any(status == "Attempted (Successful)" for status in optional_task_2_status.values())
158
+
159
+ return opt1_done, opt2_done
160
+
161
+ # Read data from test_info.txt
162
+ with open(test_info_location, "r") as file:
163
+ data = file.readlines()
164
+
165
+ # Assuming test_info[7] is a list with ideal tasks for each instance
166
+ ideal_tasks = test_info[6] # A list where each element is either 1 or 2
167
+
168
+ # Initialize counters
169
+ task_counts = {
170
+ 1: {"ER": 0, "ME": 0, "both": 0,"none":0},
171
+ 2: {"ER": 0, "ME": 0, "both": 0,"none":0}
172
+ }
173
+
174
+ # Analyze rows
175
+ for i, row in enumerate(data):
176
+ row = row.strip()
177
+ if not row:
178
+ continue
179
+
180
+ ideal_task = ideal_tasks[i] # Get the ideal task for the current row
181
+ opt1_done, opt2_done = analyze_row(row)
182
+
183
+ if ideal_task == 0:
184
+ if opt1_done and not opt2_done:
185
+ task_counts[1]["ER"] += 1
186
+ elif not opt1_done and opt2_done:
187
+ task_counts[1]["ME"] += 1
188
+ elif opt1_done and opt2_done:
189
+ task_counts[1]["both"] += 1
190
+ else:
191
+ task_counts[1]["none"] +=1
192
+ elif ideal_task == 1:
193
+ if opt1_done and not opt2_done:
194
+ task_counts[2]["ER"] += 1
195
+ elif not opt1_done and opt2_done:
196
+ task_counts[2]["ME"] += 1
197
+ elif opt1_done and opt2_done:
198
+ task_counts[2]["both"] += 1
199
+ else:
200
+ task_counts[2]["none"] +=1
201
+
202
+ # Create a string output for results
203
+ # output_summary = "Task Analysis Summary:\n"
204
+ # output_summary += "-----------------------\n"
205
+
206
+ # for ideal_task, counts in task_counts.items():
207
+ # output_summary += f"Ideal Task = OptionalTask_{ideal_task}:\n"
208
+ # output_summary += f" Only OptionalTask_1 done: {counts['ER']}\n"
209
+ # output_summary += f" Only OptionalTask_2 done: {counts['ME']}\n"
210
+ # output_summary += f" Both done: {counts['both']}\n"
211
+
212
+ # colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']
213
+ colors = ["#FF6F61", "#6B5B95", "#88B04B", "#F7CAC9"]
214
+
215
+ # Generate pie chart for Task 1
216
+ task1_labels = list(task_counts[1].keys())
217
+ task1_values = list(task_counts[1].values())
218
+
219
+ # fig_task1 = Figure()
220
+ # ax1 = fig_task1.add_subplot(1, 1, 1)
221
+ # ax1.pie(task1_values, labels=task1_labels, autopct='%1.1f%%', startangle=90)
222
+ # ax1.set_title('Ideal Task 1 Distribution')
223
+
224
+ fig_task1 = go.Figure(data=[go.Pie(
225
+ labels=task1_labels,
226
+ values=task1_values,
227
+ textinfo='percent+label',
228
+ textposition='auto',
229
+ marker=dict(colors=colors),
230
+ sort=False
231
+
232
+ )])
233
+
234
+ fig_task1.update_layout(
235
+ title='Problem Type: ER',
236
+ title_x=0.5,
237
+ font=dict(
238
+ family="sans-serif",
239
+ size=12,
240
+ color="black"
241
+ ),
242
+ )
243
+
244
+ fig_task1.update_layout(
245
+ legend=dict(
246
+ font=dict(
247
+ family="sans-serif",
248
+ size=12,
249
+ color="black"
250
+ ),
251
+ )
252
+ )
253
+
254
+
255
+
256
+ # fig.show()
257
+
258
+ # Generate pie chart for Task 2
259
+ task2_labels = list(task_counts[2].keys())
260
+ task2_values = list(task_counts[2].values())
261
+
262
+ fig_task2 = go.Figure(data=[go.Pie(
263
+ labels=task2_labels,
264
+ values=task2_values,
265
+ textinfo='percent+label',
266
+ textposition='auto',
267
+ marker=dict(colors=colors),
268
+ sort=False
269
+ # pull=[0, 0.2, 0, 0] # for pulling part of pie chart out (depends on position)
270
+
271
+ )])
272
+
273
+ fig_task2.update_layout(
274
+ title='Problem Type: ME',
275
+ title_x=0.5,
276
+ font=dict(
277
+ family="sans-serif",
278
+ size=12,
279
+ color="black"
280
+ ),
281
+ )
282
+
283
+ fig_task2.update_layout(
284
+ legend=dict(
285
+ font=dict(
286
+ family="sans-serif",
287
+ size=12,
288
+ color="black"
289
+ ),
290
+ )
291
+ )
292
+
293
+
294
+ # fig_task2 = Figure()
295
+ # ax2 = fig_task2.add_subplot(1, 1, 1)
296
+ # ax2.pie(task2_values, labels=task2_labels, autopct='%1.1f%%', startangle=90)
297
+ # ax2.set_title('Ideal Task 2 Distribution')
298
+
299
+ # print(output_summary)
300
+
301
+ progress(0.2, desc="analysis done!! Executing models")
302
+ print("finetuned task: ",finetune_task)
303
+ subprocess.run([
304
+ "python", "new_test_saved_finetuned_model.py",
305
+ "-workspace_name", "ratio_proportion_change3_2223/sch_largest_100-coded",
306
+ "-finetune_task", finetune_task,
307
+ "-test_dataset_path","../../../../selected_rows.txt",
308
+ # "-test_label_path","../../../../train_label.txt",
309
+ "-finetuned_bert_classifier_checkpoint",
310
+ "ratio_proportion_change3_2223/sch_largest_100-coded/output/highGRschool10/bert_fine_tuned.model.ep42",
311
+ "-e",str(1),
312
+ "-b",str(1000)
313
+ ])
314
+ progress(0.6,desc="Model execution completed")
315
+ result = {}
316
+ with open("result.txt", 'r') as file:
317
+ for line in file:
318
+ key, value = line.strip().split(': ', 1)
319
+ # print(type(key))
320
+ if key=='epoch':
321
+ result[key]=value
322
+ else:
323
+ result[key]=float(value)
324
+ result["ROC score of HGR"]=high_roc_auc
325
+ result["ROC score of LGR"]=low_roc_auc
326
+ # Create a plot
327
+ with open("roc_data.pkl", "rb") as f:
328
+ fpr, tpr, _ = pickle.load(f)
329
+ # print(fpr,tpr)
330
+ roc_auc = auc(fpr, tpr)
331
+
332
+
333
+ # Create a matplotlib figure
334
+ # fig = Figure()
335
+ # ax = fig.add_subplot(1, 1, 1)
336
+ # ax.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
337
+ # ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
338
+ # ax.set(xlabel='False Positive Rate', ylabel='True Positive Rate', title=f'Receiver Operating Curve (ROC)')
339
+ # ax.legend(loc="lower right")
340
+ # ax.grid()
341
+
342
+ fig = go.Figure()
343
+ # Create and style traces
344
+ fig.add_trace(go.Line(x = list(fpr), y = list(tpr), name=f'ROC curve (area = {roc_auc:.2f})',
345
+ line=dict(color='royalblue', width=3,
346
+ ) # dash options include 'dash', 'dot', and 'dashdot'
347
+ ))
348
+ fig.add_trace(go.Line(x = [0,1], y = [0,1], showlegend = False,
349
+ line=dict(color='firebrick', width=2,
350
+ dash='dash',) # dash options include 'dash', 'dot', and 'dashdot'
351
+ ))
352
+
353
+ # Edit the layout
354
+ fig.update_layout(
355
+ showlegend = True,
356
+ title_x=0.5,
357
+ title=dict(
358
+ text='Receiver Operating Curve (ROC)'
359
+ ),
360
+ xaxis=dict(
361
+ title=dict(
362
+ text='False Positive Rate'
363
+ )
364
+ ),
365
+ yaxis=dict(
366
+ title=dict(
367
+ text='False Negative Rate'
368
+ )
369
+ ),
370
+ font=dict(
371
+ family="sans-serif",
372
+ color="black"
373
+ ),
374
+
375
+ )
376
+ fig.update_layout(
377
+ legend=dict(
378
+ x=0.75,
379
+ y=0,
380
+ traceorder="normal",
381
+ font=dict(
382
+ family="sans-serif",
383
+ size=12,
384
+ color="black"
385
+ ),
386
+ )
387
+ )
388
+
389
+
390
+
391
+
392
+
393
+
394
+ # Save plot to a file
395
+ # plot_path = "plot.png"
396
+ # fig.savefig(plot_path)
397
+ # plt.close(fig)
398
+
399
+
400
+
401
+
402
+ progress(1.0)
403
+ # Prepare text output
404
+ text_output = f"Model: {model_name}\nResult:\n{result}"
405
+ # Prepare text output with HTML formatting
406
+ text_output = f"""
407
+ ---------------------------
408
+ Model: {model_name}
409
+ ---------------------------\n
410
+ Time Taken: {result['time_taken_from_start']:.2f} seconds
411
+ Total Schools in test: {len(unique_schools):.4f}
412
+ Total number of instances having Schools with HGR : {len(high_sample):.4f}
413
+ Total number of instances having Schools with LGR: {len(low_sample):.4f}
414
+
415
+ ROC score of HGR: {high_roc_auc:.4f}
416
+ ROC score of LGR: {low_roc_auc:.4f}
417
+
418
+ ROC-AUC for problems of type ER: {opt_task1_roc_auc:.4f}
419
+ ROC-AUC for problems of type ME: {opt_task2_roc_auc:.4f}
420
+ """
421
+ return text_output,fig,fig_task1,fig_task2
422
+
423
+ # List of models for the dropdown menu
424
+
425
+ # models = ["ASTRA-FT-HGR", "ASTRA-FT-LGR", "ASTRA-FT-FULL"]
426
+ models = ["ASTRA-FT-HGR", "ASTRA-FT-FULL"]
427
+ content = """
428
+ <h1 style="color: black;">A S T R A</h1>
429
+ <h2 style="color: black;">An AI Model for Analyzing Math Strategies</h2>
430
+
431
+ <h3 style="color: white; text-align: center">
432
+ <a href="https://drive.google.com/file/d/1lbEpg8Se1ugTtkjreD8eXIg7qrplhWan/view" style="color: gr.themes.colors.red; text-decoration: none;">Link To Paper</a> |
433
+ <a href="https://github.com/Syudu41/ASTRA---Gates-Project" style="color: #1E90FF; text-decoration: none;">GitHub</a> |
434
+ <a href="https://sites.google.com/view/astra-research/home" style="color: #1E90FF; text-decoration: none;">Project Page</a>
435
+ </h3>
436
+
437
+ <p style="color: white;">Welcome to a demo of ASTRA. ASTRA is a collaborative research project between researchers at the
438
+ <a href="https://sites.google.com/site/dvngopal/" style="color: #1E90FF; text-decoration: none;">University of Memphis</a> and
439
+ <a href="https://www.carnegielearning.com" style="color: #1E90FF; text-decoration: none;">Carnegie Learning</a>
440
+ to utilize AI to improve our understanding of math learning strategies.</p>
441
+
442
+ <p style="color: white;">This demo has been developed with a pre-trained model (based on an architecture similar to BERT ) that learns math strategies using data
443
+ collected from hundreds of schools in the U.S. who have used Carnegie Learning’s MATHia (formerly known as Cognitive Tutor), the flagship Intelligent Tutor that is part of a core, blended math curriculum.
444
+ For this demo, we have used data from a specific domain (teaching ratio and proportions) within 7th grade math. The fine-tuning based on the pre-trained model learns to predict which strategies lead to correct vs incorrect solutions.
445
+ </p>
446
+
447
+ <p style="color: white;">In this math domain, students were given word problems related to ratio and proportions. Further, the students
448
+ were given a choice of optional tasks to work on in parallel to the main problem to demonstrate their thinking (metacognition).
449
+ The optional tasks are designed based on solving problems using Equivalent Ratios (ER) and solving using Means and Extremes/cross-multiplication (ME).
450
+ When the equivalent ratios are easy to compute (integral values), ER is much more efficient compared to ME and switching between the tasks appropriately demonstrates cognitive flexibility.
451
+ </p>
452
+
453
+ <p style="color: white;">To use the demo, please follow these steps:</p>
454
+
455
+ <ol style="color: white;">
456
+ <li style="color: white;">Select a fine-tuned model:
457
+ <ul style="color: white;">
458
+ <li style="color: white;">ASTRA-FT-HGR: Fine-tuned with a small sample of data from schools that have a high graduation rate.</li>
459
+ <li style="color: white;">ASTRA-FT-Full: Fine-tuned with a small sample of data from a mix of schools that have high/low graduation rates.</li>
460
+ </ul>
461
+ </li>
462
+ <li style="color: white;">Select a percentage of schools to analyze (selecting a large percentage may take a long time). Note that the selected percentage is applied to both High Graduation Rate (HGR) schools and Low Graduation Rate (LGR schools).
463
+ </li>
464
+ <li style="color: white;">The results from the fine-tuned model are displayed in the dashboard:
465
+ <ul>
466
+ <li style="color: white;">The model accuracy is computed using the ROC-AUC metric.
467
+ </li>
468
+ <li style="color: white;">The results are shown for HGR, LGR schools and for different problem types (ER/ME).
469
+ </li>
470
+ <li style="color: white;">The distribution over how students utilized the optional tasks (whether they utilized ER/ME, used both of them or none of them) is shown for each problem type.
471
+ </li>
472
+ </ul>
473
+ </li>
474
+ </ol>
475
+ """
476
+ # CSS styling for white text
477
+ # Create the Gradio interface
478
+ available_themes = {
479
+ "default": gr.themes.Default(),
480
+ "soft": gr.themes.Soft(),
481
+ "monochrome": gr.themes.Monochrome(),
482
+ "glass": gr.themes.Glass(),
483
+ "base": gr.themes.Base(),
484
+ }
485
+
486
+ # Comprehensive CSS for all HTML elements
487
+ custom_css = '''
488
+ /* Import Fira Sans font */
489
+ @import url('https://fonts.googleapis.com/css2?family=Fira+Sans:wght@400;500;600;700&family=Inter:wght@400;500;600;700&display=swap');
490
+ @import url('https://fonts.googleapis.com/css2?family=Libre+Caslon+Text:ital,wght@0,400;0,700;1,400&family=Spectral+SC:wght@600&display=swap');
491
+ /* Container modifications for centering */
492
+ .gradio-container {
493
+ color: var(--block-label-text-color) !important;
494
+ max-width: 1000px !important;
495
+ margin: 0 auto !important;
496
+ padding: 2rem !important;
497
+ font-family: Arial, sans-serif !important;
498
+ }
499
+
500
+ /* Main title (ASTRA) */
501
+ #title {
502
+ text-align: center !important;
503
+ margin: 1rem auto !important; /* Reduced margin */
504
+ font-size: 2.5em !important;
505
+ font-weight: 600 !important;
506
+ font-family: "Spectral SC", 'Fira Sans', sans-serif !important;
507
+ padding-bottom: 0 !important; /* Remove bottom padding */
508
+ }
509
+
510
+ /* Subtitle (An AI Model...) */
511
+ h1 {
512
+ text-align: center !important;
513
+ font-size: 30pt !important;
514
+ font-weight: 600 !important;
515
+ font-family: "Spectral SC", 'Fira Sans', sans-serif !important;
516
+ margin-top: 0.5em !important; /* Reduced top margin */
517
+ margin-bottom: 0.3em !important;
518
+ }
519
+
520
+ h2 {
521
+ text-align: center !important;
522
+ font-size: 22pt !important;
523
+ font-weight: 600 !important;
524
+ font-family: "Spectral SC",'Fira Sans', sans-serif !important;
525
+ margin-top: 0.2em !important; /* Reduced top margin */
526
+ margin-bottom: 0.3em !important;
527
+ }
528
+
529
+ /* Links container styling */
530
+ .links-container {
531
+ text-align: center !important;
532
+ margin: 1em auto !important;
533
+ font-family: 'Inter' ,'Fira Sans', sans-serif !important;
534
+ }
535
+
536
+ /* Links */
537
+ a {
538
+ color: #2563eb !important;
539
+ text-decoration: none !important;
540
+ font-family:'Inter' , 'Fira Sans', sans-serif !important;
541
+ }
542
+
543
+ a:hover {
544
+ text-decoration: underline !important;
545
+ opacity: 0.8;
546
+ }
547
+
548
+ /* Regular text */
549
+ p, li, .description, .markdown-text {
550
+ font-family: 'Inter', Arial, sans-serif !important;
551
+ color: black !important;
552
+ font-size: 11pt;
553
+ line-height: 1.6;
554
+ font-weight: 500 !important;
555
+ color: var(--block-label-text-color) !important;
556
+ }
557
+
558
+ /* Other headings */
559
+ h3, h4, h5 {
560
+ font-family: 'Fira Sans', sans-serif !important;
561
+ color: var(--block-label-text-color) !important;
562
+ margin-top: 1.5em;
563
+ margin-bottom: 0.75em;
564
+ }
565
+
566
+
567
+ h3 { font-size: 1.5em; font-weight: 600; }
568
+ h4 { font-size: 1.25em; font-weight: 500; }
569
+ h5 { font-size: 1.1em; font-weight: 500; }
570
+
571
+ /* Form elements */
572
+ .select-wrap select, .wrap select,
573
+ input, textarea {
574
+ font-family: 'Inter' ,Arial, sans-serif !important;
575
+ color: var(--block-label-text-color) !important;
576
+ }
577
+
578
+ /* Lists */
579
+ ul, ol {
580
+ margin-left: 0 !important;
581
+ margin-bottom: 1.25em;
582
+ padding-left: 2em;
583
+ }
584
+
585
+ li {
586
+ margin-bottom: 0.75em;
587
+ }
588
+
589
+ /* Form container */
590
+ .form-container {
591
+ max-width: 1000px !important;
592
+ margin: 0 auto !important;
593
+ padding: 1rem !important;
594
+ }
595
+
596
+ /* Dashboard */
597
+ .dashboard {
598
+ margin-top: 2rem !important;
599
+ padding: 1rem !important;
600
+ border-radius: 8px !important;
601
+ }
602
+
603
+ /* Slider styling */
604
+ .gradio-slider-row {
605
+ display: flex;
606
+ align-items: center;
607
+ justify-content: space-between;
608
+ margin: 1.5em 0;
609
+ max-width: 100% !important;
610
+ }
611
+
612
+ .gradio-slider {
613
+ flex-grow: 1;
614
+ margin-right: 15px;
615
+ }
616
+
617
+ .slider-percentage {
618
+ font-family: 'Inter', Arial, sans-serif !important;
619
+ flex-shrink: 0;
620
+ min-width: 60px;
621
+ font-size: 1em;
622
+ font-weight: bold;
623
+ text-align: center;
624
+ background-color: #f0f8ff;
625
+ border: 1px solid #004080;
626
+ border-radius: 5px;
627
+ padding: 5px 10px;
628
+ }
629
+
630
+ .progress-bar-wrap.progress-bar-wrap.progress-bar-wrap
631
+ {
632
+ border-radius: var(--input-radius);
633
+ height: 1.25rem;
634
+ margin-top: 1rem;
635
+ overflow: hidden;
636
+ width: 70%;
637
+ font-family: 'Inter', Arial, sans-serif !important;
638
+ }
639
+
640
+ /* Add these new styles after your existing CSS */
641
+
642
+ /* Card-like appearance for the dashboard */
643
+ .dashboard {
644
+ background: #ffffff !important;
645
+ box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06) !important;
646
+ border-radius: 12px !important;
647
+ padding: 2rem !important;
648
+ margin-top: 2.5rem !important;
649
+ }
650
+
651
+ /* Enhance ROC graph container */
652
+ #roc {
653
+ background: #ffffff !important;
654
+ padding: 1.5rem !important;
655
+ border-radius: 8px !important;
656
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05) !important;
657
+ margin: 1.5rem 0 !important;
658
+ }
659
+
660
+ /* Style the dropdown select */
661
+ select {
662
+ background-color: #ffffff !important;
663
+ border: 1px solid #e2e8f0 !important;
664
+ border-radius: 8px !important;
665
+ padding: 0.5rem 1rem !important;
666
+ transition: all 0.2s ease-in-out !important;
667
+ box-shadow: 0 1px 2px rgba(0, 0, 0, 0.05) !important;
668
+ }
669
+
670
+ select:hover {
671
+ border-color: #cbd5e1 !important;
672
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1) !important;
673
+ }
674
+
675
+ /* Enhance slider appearance */
676
+ .progress-bar-wrap {
677
+ background: #f8fafc !important;
678
+ border: 1px solid #e2e8f0 !important;
679
+ box-shadow: inset 0 2px 4px rgba(0, 0, 0, 0.05) !important;
680
+ }
681
+
682
+ /* Style metrics in dashboard */
683
+ .dashboard p {
684
+ padding: 0.5rem 0 !important;
685
+ border-bottom: 1px solid #f1f5f9 !important;
686
+ }
687
+
688
+ /* Add spacing between sections */
689
+ .dashboard > div {
690
+ margin-bottom: 1.5rem !important;
691
+ }
692
+
693
+ /* Style the ROC curve title */
694
+ .dashboard h4 {
695
+ color: #1e293b !important;
696
+ font-weight: 600 !important;
697
+ margin-bottom: 1rem !important;
698
+ padding-bottom: 0.5rem !important;
699
+ border-bottom: 2px solid #e2e8f0 !important;
700
+ }
701
+
702
+ /* Enhance link appearances */
703
+ a {
704
+ position: relative !important;
705
+ padding-bottom: 2px !important;
706
+ transition: all 0.2s ease-in-out !important;
707
+ }
708
+
709
+ a:after {
710
+ content: '' !important;
711
+ position: absolute !important;
712
+ width: 0 !important;
713
+ height: 1px !important;
714
+ bottom: 0 !important;
715
+ left: 0 !important;
716
+ background-color: #2563eb !important;
717
+ transition: width 0.3s ease-in-out !important;
718
+ }
719
+
720
+ a:hover:after {
721
+ width: 100% !important;
722
+ }
723
+
724
+ /* Add subtle dividers between sections */
725
+ .form-container > div {
726
+ padding-bottom: 1.5rem !important;
727
+ margin-bottom: 1.5rem !important;
728
+ border-bottom: 1px solid #f1f5f9 !important;
729
+ }
730
+
731
+ /* Style model selection section */
732
+ .select-wrap {
733
+ background: #ffffff !important;
734
+ padding: 1.5rem !important;
735
+ border-radius: 8px !important;
736
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05) !important;
737
+ margin-bottom: 2rem !important;
738
+ }
739
+
740
+ /* Style the metrics display */
741
+ .dashboard span {
742
+ font-family: 'Inter', sans-serif !important;
743
+ font-weight: 500 !important;
744
+ color: #334155 !important;
745
+ }
746
+
747
+ /* Add subtle animation to interactive elements */
748
+ button, select, .slider-percentage {
749
+ transition: all 0.2s ease-in-out !important;
750
+ }
751
+
752
+ /* Style the ROC curve container */
753
+ .plot-container {
754
+ background: #ffffff !important;
755
+ border-radius: 8px !important;
756
+ padding: 1rem !important;
757
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05) !important;
758
+ }
759
+
760
+ /* Add container styles for opt1 and opt2 sections */
761
+ #opt1, #opt2 {
762
+ background: #ffffff !important;
763
+ border-radius: 8px !important;
764
+ padding: 1.5rem !important;
765
+ margin-top: 1.5rem !important;
766
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05) !important;
767
+ }
768
+
769
+ /* Style the distribution titles */
770
+ .distribution-title {
771
+ font-family: 'Inter', sans-serif !important;
772
+ font-weight: 600 !important;
773
+ color: #1e293b !important;
774
+ margin-bottom: 1rem !important;
775
+ text-align: center !important;
776
+ }
777
+
778
+ '''
779
+
780
+ with gr.Blocks(theme='gstaff/sketch', css=custom_css) as demo:
781
+
782
+ # gr.Markdown("<h1 id='title'>ASTRA</h1>", elem_id="title")
783
+ gr.Markdown(content)
784
+
785
+ with gr.Row():
786
+ # file_input = gr.File(label="Upload a test file", file_types=['.txt'], elem_classes="file-box")
787
+ # label_input = gr.File(label="Upload test labels", file_types=['.txt'], elem_classes="file-box")
788
+
789
+ # info_input = gr.File(label="Upload test info", file_types=['.txt'], elem_classes="file-box")
790
+ model_dropdown = gr.Dropdown(
791
+ choices=models,
792
+ label="Select Fine-tuned Model",
793
+ elem_classes="dropdown-menu"
794
+ )
795
+ increment_slider = gr.Slider(
796
+ minimum=1,
797
+ maximum=100,
798
+ step=1,
799
+ label="Schools Percentage",
800
+ value=1,
801
+ elem_id="increment-slider",
802
+ elem_classes="gradio-slider"
803
+ )
804
+
805
+ with gr.Row():
806
+ btn = gr.Button("Submit")
807
+
808
+ gr.Markdown("<p class='description'>Dashboard</p>")
809
+
810
+ with gr.Row():
811
+ output_text = gr.Textbox(label="")
812
+ # output_image = gr.Image(label="ROC")
813
+ with gr.Row():
814
+ plot_output = gr.Plot(label="ROC")
815
+
816
+ with gr.Row():
817
+ opt1_pie = gr.Plot(label="ER")
818
+ opt2_pie = gr.Plot(label="ME")
819
+ # output_summary = gr.Textbox(label="Summary")
820
+
821
+
822
+
823
+ btn.click(
824
+ fn=process_file,
825
+ inputs=[model_dropdown,increment_slider],
826
+ outputs=[output_text,plot_output,opt1_pie,opt2_pie]
827
+ )
828
+
829
+
830
+ # Launch the app
831
  demo.launch()