suryadev1 Shan41 commited on
Commit
7884735
·
verified ·
1 Parent(s): 31e9520

Updated app.py (#3)

Browse files

- Updated app.py (25c56052a34f58dc325846186d90936b739492c3)


Co-authored-by: Sudarshan Balaji <[email protected]>

Files changed (1) hide show
  1. app.py +830 -562
app.py CHANGED
@@ -1,563 +1,831 @@
1
- import gradio as gr
2
- from huggingface_hub import hf_hub_download
3
- import pickle
4
- from gradio import Progress
5
- import numpy as np
6
- import subprocess
7
- import shutil
8
- import matplotlib.pyplot as plt
9
- from sklearn.metrics import roc_curve, auc
10
- import pandas as pd
11
- from sklearn.metrics import roc_auc_score
12
- from matplotlib.figure import Figure
13
- # Define the function to process the input file and model selection
14
-
15
- def process_file(model_name,inc_slider,progress=Progress(track_tqdm=True)):
16
- # progress = gr.Progress(track_tqdm=True)
17
-
18
- progress(0, desc="Starting the processing")
19
- # with open(file.name, 'r') as f:
20
- # content = f.read()
21
- # saved_test_dataset = "train.txt"
22
- # saved_test_label = "train_label.txt"
23
- # saved_train_info="train_info.txt"
24
- # Save the uploaded file content to a specified location
25
- # shutil.copyfile(file.name, saved_test_dataset)
26
- # shutil.copyfile(label.name, saved_test_label)
27
- # shutil.copyfile(info.name, saved_train_info)
28
- parent_location="ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/"
29
- test_info_location=parent_location+"fullTest/test_info.txt"
30
- test_location=parent_location+"fullTest/test.txt"
31
- if(model_name=="ASTRA-FT-HGR"):
32
- finetune_task="highGRschool10"
33
- # test_info_location=parent_location+"fullTest/test_info.txt"
34
- # test_location=parent_location+"fullTest/test.txt"
35
- elif(model_name== "ASTRA-FT-LGR" ):
36
- finetune_task="lowGRschoolAll"
37
- # test_info_location=parent_location+"lowGRschoolAll/test_info.txt"
38
- # test_location=parent_location+"lowGRschoolAll/test.txt"
39
- elif(model_name=="ASTRA-FT-FULL"):
40
- # test_info_location=parent_location+"fullTest/test_info.txt"
41
- # test_location=parent_location+"fullTest/test.txt"
42
- finetune_task="fullTest"
43
- else:
44
- finetune_task=None
45
- # Load the test_info file and the graduation rate file
46
- test_info = pd.read_csv(test_info_location, sep=',', header=None, engine='python')
47
- grad_rate_data = pd.DataFrame(pd.read_pickle('school_grduation_rate.pkl'),columns=['school_number','grad_rate']) # Load the grad_rate data
48
-
49
- # Step 1: Extract unique school numbers from test_info
50
- unique_schools = test_info[0].unique()
51
-
52
- # Step 2: Filter the grad_rate_data using the unique school numbers
53
- schools = grad_rate_data[grad_rate_data['school_number'].isin(unique_schools)]
54
-
55
- # Define a threshold for high and low graduation rates (adjust as needed)
56
- grad_rate_threshold = 0.9
57
-
58
- # Step 4: Divide schools into high and low graduation rate groups
59
- high_grad_schools = schools[schools['grad_rate'] >= grad_rate_threshold]['school_number'].unique()
60
- low_grad_schools = schools[schools['grad_rate'] < grad_rate_threshold]['school_number'].unique()
61
-
62
- # Step 5: Sample percentage of schools from each group
63
- high_sample = pd.Series(high_grad_schools).sample(frac=inc_slider/100, random_state=1).tolist()
64
- low_sample = pd.Series(low_grad_schools).sample(frac=inc_slider/100, random_state=1).tolist()
65
-
66
- # Step 6: Combine the sampled schools
67
- random_schools = high_sample + low_sample
68
-
69
- # Step 7: Get indices for the sampled schools
70
- indices = test_info[test_info[0].isin(random_schools)].index.tolist()
71
- high_indices = test_info[(test_info[0].isin(high_sample))].index.tolist()
72
- low_indices = test_info[(test_info[0].isin(low_sample))].index.tolist()
73
-
74
- # Load the test file and select rows based on indices
75
- test = pd.read_csv(test_location, sep=',', header=None, engine='python')
76
- selected_rows_df2 = test.loc[indices]
77
-
78
- # Save the selected rows to a file
79
- selected_rows_df2.to_csv('selected_rows.txt', sep='\t', index=False, header=False, quoting=3, escapechar=' ')
80
-
81
- graduation_groups = [
82
- 'high' if idx in high_indices else 'low' for idx in selected_rows_df2.index
83
- ]
84
- # Group data by opt_task1 and opt_task2 based on test_info[6]
85
- opt_task_groups = ['opt_task1' if test_info.loc[idx, 6] == 0 else 'opt_task2' for idx in selected_rows_df2.index]
86
-
87
- with open("roc_data2.pkl", 'rb') as file:
88
- data = pickle.load(file)
89
- t_label=data[0]
90
- p_label=data[1]
91
- # Step 1: Align graduation_group, t_label, and p_label
92
- aligned_labels = list(zip(graduation_groups, t_label, p_label))
93
- opt_task_aligned = list(zip(opt_task_groups, t_label, p_label))
94
- # Step 2: Separate the labels for high and low groups
95
- high_t_labels = [t for grad, t, p in aligned_labels if grad == 'high']
96
- low_t_labels = [t for grad, t, p in aligned_labels if grad == 'low']
97
-
98
- high_p_labels = [p for grad, t, p in aligned_labels if grad == 'high']
99
- low_p_labels = [p for grad, t, p in aligned_labels if grad == 'low']
100
-
101
- opt_task1_t_labels = [t for task, t, p in opt_task_aligned if task == 'opt_task1']
102
- opt_task1_p_labels = [p for task, t, p in opt_task_aligned if task == 'opt_task1']
103
-
104
- opt_task2_t_labels = [t for task, t, p in opt_task_aligned if task == 'opt_task2']
105
- opt_task2_p_labels = [p for task, t, p in opt_task_aligned if task == 'opt_task2']
106
-
107
- high_roc_auc = roc_auc_score(high_t_labels, high_p_labels) if len(set(high_t_labels)) > 1 else None
108
- low_roc_auc = roc_auc_score(low_t_labels, low_p_labels) if len(set(low_t_labels)) > 1 else None
109
-
110
- opt_task1_roc_auc = roc_auc_score(opt_task1_t_labels, opt_task1_p_labels) if len(set(opt_task1_t_labels)) > 1 else None
111
- opt_task2_roc_auc = roc_auc_score(opt_task2_t_labels, opt_task2_p_labels) if len(set(opt_task2_t_labels)) > 1 else None
112
-
113
- # For demonstration purposes, we'll just return the content with the selected model name
114
-
115
- # print(checkpoint)
116
- progress(0.1, desc="Files created and saved")
117
- # if (inc_val<5):
118
- # model_name="highGRschool10"
119
- # elif(inc_val>=5 & inc_val<10):
120
- # model_name="highGRschool10"
121
- # else:
122
- # model_name="highGRschool10"
123
- # Function to analyze each row
124
- def analyze_row(row):
125
- # Split the row into fields
126
- fields = row.split("\t")
127
-
128
- # Define tasks for OptionalTask_1, OptionalTask_2, and FinalAnswer
129
- optional_task_1_subtasks = ["DenominatorFactor", "NumeratorFactor", "EquationAnswer"]
130
- optional_task_2_subtasks = [
131
- "FirstRow2:1", "FirstRow2:2", "FirstRow1:1", "FirstRow1:2",
132
- "SecondRow", "ThirdRow"
133
- ]
134
-
135
- # Helper function to evaluate task attempts
136
- def evaluate_tasks(fields, tasks):
137
- task_status = {}
138
- for task in tasks:
139
- relevant_attempts = [f for f in fields if task in f]
140
- if any("OK" in attempt for attempt in relevant_attempts):
141
- task_status[task] = "Attempted (Successful)"
142
- elif any("ERROR" in attempt for attempt in relevant_attempts):
143
- task_status[task] = "Attempted (Error)"
144
- elif any("JIT" in attempt for attempt in relevant_attempts):
145
- task_status[task] = "Attempted (JIT)"
146
- else:
147
- task_status[task] = "Unattempted"
148
- return task_status
149
-
150
- # Evaluate tasks for each category
151
- optional_task_1_status = evaluate_tasks(fields, optional_task_1_subtasks)
152
- optional_task_2_status = evaluate_tasks(fields, optional_task_2_subtasks)
153
-
154
- # Check if tasks have any successful attempt
155
- opt1_done = any(status == "Attempted (Successful)" for status in optional_task_1_status.values())
156
- opt2_done = any(status == "Attempted (Successful)" for status in optional_task_2_status.values())
157
-
158
- return opt1_done, opt2_done
159
-
160
- # Read data from test_info.txt
161
- # Read data from test_info.txt
162
- with open(test_info_location, "r") as file:
163
- data = file.readlines()
164
-
165
- # Assuming test_info[7] is a list with ideal tasks for each instance
166
- ideal_tasks = test_info[6] # A list where each element is either 1 or 2
167
-
168
- # Initialize counters
169
- task_counts = {
170
- 1: {"only_opt1": 0, "only_opt2": 0, "both": 0,"none":0},
171
- 2: {"only_opt1": 0, "only_opt2": 0, "both": 0,"none":0}
172
- }
173
-
174
- # Analyze rows
175
- for i, row in enumerate(data):
176
- row = row.strip()
177
- if not row:
178
- continue
179
-
180
- ideal_task = ideal_tasks[i] # Get the ideal task for the current row
181
- opt1_done, opt2_done = analyze_row(row)
182
-
183
- if ideal_task == 0:
184
- if opt1_done and not opt2_done:
185
- task_counts[1]["only_opt1"] += 1
186
- elif not opt1_done and opt2_done:
187
- task_counts[1]["only_opt2"] += 1
188
- elif opt1_done and opt2_done:
189
- task_counts[1]["both"] += 1
190
- else:
191
- task_counts[1]["none"] +=1
192
- elif ideal_task == 1:
193
- if opt1_done and not opt2_done:
194
- task_counts[2]["only_opt1"] += 1
195
- elif not opt1_done and opt2_done:
196
- task_counts[2]["only_opt2"] += 1
197
- elif opt1_done and opt2_done:
198
- task_counts[2]["both"] += 1
199
- else:
200
- task_counts[2]["none"] +=1
201
-
202
- # Create a string output for results
203
- # output_summary = "Task Analysis Summary:\n"
204
- # output_summary += "-----------------------\n"
205
-
206
- # for ideal_task, counts in task_counts.items():
207
- # output_summary += f"Ideal Task = OptionalTask_{ideal_task}:\n"
208
- # output_summary += f" Only OptionalTask_1 done: {counts['only_opt1']}\n"
209
- # output_summary += f" Only OptionalTask_2 done: {counts['only_opt2']}\n"
210
- # output_summary += f" Both done: {counts['both']}\n"
211
-
212
- # Generate pie chart for Task 1
213
- task1_labels = list(task_counts[1].keys())
214
- task1_values = list(task_counts[1].values())
215
-
216
- fig_task1 = Figure()
217
- ax1 = fig_task1.add_subplot(1, 1, 1)
218
- ax1.pie(task1_values, labels=task1_labels, autopct='%1.1f%%', startangle=90)
219
- ax1.set_title('Ideal Task 1 Distribution')
220
-
221
- # Generate pie chart for Task 2
222
- task2_labels = list(task_counts[2].keys())
223
- task2_values = list(task_counts[2].values())
224
-
225
- fig_task2 = Figure()
226
- ax2 = fig_task2.add_subplot(1, 1, 1)
227
- ax2.pie(task2_values, labels=task2_labels, autopct='%1.1f%%', startangle=90)
228
- ax2.set_title('Ideal Task 2 Distribution')
229
-
230
- # print(output_summary)
231
-
232
- progress(0.2, desc="analysis done!! Executing models")
233
- print("finetuned task: ",finetune_task)
234
- # subprocess.run([
235
- # "python", "new_test_saved_finetuned_model.py",
236
- # "-workspace_name", "ratio_proportion_change3_2223/sch_largest_100-coded",
237
- # "-finetune_task", finetune_task,
238
- # "-test_dataset_path","../../../../selected_rows.txt",
239
- # # "-test_label_path","../../../../train_label.txt",
240
- # "-finetuned_bert_classifier_checkpoint",
241
- # "ratio_proportion_change3_2223/sch_largest_100-coded/output/highGRschool10/bert_fine_tuned.model.ep42",
242
- # "-e",str(1),
243
- # "-b",str(1000)
244
- # ])
245
- progress(0.6,desc="Model execution completed")
246
- result = {}
247
- with open("result.txt", 'r') as file:
248
- for line in file:
249
- key, value = line.strip().split(': ', 1)
250
- # print(type(key))
251
- if key=='epoch':
252
- result[key]=value
253
- else:
254
- result[key]=float(value)
255
- result["ROC score of HGR"]=high_roc_auc
256
- result["ROC score of LGR"]=low_roc_auc
257
- # Create a plot
258
- with open("roc_data.pkl", "rb") as f:
259
- fpr, tpr, _ = pickle.load(f)
260
- # print(fpr,tpr)
261
- roc_auc = auc(fpr, tpr)
262
-
263
-
264
- # Create a matplotlib figure
265
- fig = Figure()
266
- ax = fig.add_subplot(1, 1, 1)
267
- ax.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
268
- ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
269
- ax.set(xlabel='False Positive Rate', ylabel='True Positive Rate', title=f'Receiver Operating Curve (ROC)')
270
- ax.legend(loc="lower right")
271
- ax.grid()
272
-
273
- # Save plot to a file
274
- plot_path = "plot.png"
275
- fig.savefig(plot_path)
276
- plt.close(fig)
277
-
278
-
279
-
280
-
281
- progress(1.0)
282
- # Prepare text output
283
- text_output = f"Model: {model_name}\nResult:\n{result}"
284
- # Prepare text output with HTML formatting
285
- text_output = f"""
286
- Model: {model_name}\n
287
- -----------------\n
288
- Time Taken: {result['time_taken_from_start']:.2f} seconds\n
289
- Total Schools in test: {len(unique_schools):.4f}\n
290
- Total number of instances having Schools with HGR : {len(high_sample):.4f}\n
291
- Total number of instances having Schools with LGR: {len(low_sample):.4f}\n
292
-
293
- ROC score of HGR: {high_roc_auc}\n
294
- ROC score of LGR: {low_roc_auc}\n
295
-
296
- ROC score of opt1: {opt_task1_roc_auc}\n
297
- ROC score of opt2: {opt_task2_roc_auc}\n
298
- -----------------\n
299
- """
300
- return text_output,fig,fig_task1,fig_task2
301
-
302
- # List of models for the dropdown menu
303
-
304
- # models = ["ASTRA-FT-HGR", "ASTRA-FT-LGR", "ASTRA-FT-FULL"]
305
- models = ["ASTRA-FT-HGR", "ASTRA-FT-FULL"]
306
- content = """
307
- <h1 style="color: white;">ASTRA: An AI Model for Analyzing Math Strategies</h1>
308
-
309
- <h3 style="color: white;">
310
- <a href="https://drive.google.com/file/d/1lbEpg8Se1ugTtkjreD8eXIg7qrplhWan/view" style="color: #1E90FF; text-decoration: none;">Link To Paper</a> |
311
- <a href="https://github.com/Syudu41/ASTRA---Gates-Project" style="color: #1E90FF; text-decoration: none;">GitHub</a> |
312
- <a href="#" style="color: #1E90FF; text-decoration: none;">Project Page</a>
313
- </h3>
314
-
315
- <p style="color: white;">Welcome to a demo of ASTRA. ASTRA is a collaborative research project between researchers at the
316
- <a href="https://www.memphis.edu" style="color: #1E90FF; text-decoration: none;">University of Memphis</a> and
317
- <a href="https://www.carnegielearning.com" style="color: #1E90FF; text-decoration: none;">Carnegie Learning</a>
318
- to utilize AI to improve our understanding of math learning strategies.</p>
319
-
320
- <p style="color: white;">This demo has been developed with a pre-trained model (based on an architecture similar to BERT)
321
- that learns math strategies using data collected from hundreds of schools in the U.S. who have used
322
- Carnegie Learning's MATHia (formerly known as Cognitive Tutor), the flagship Intelligent Tutor
323
- that is part of a core, blended math curriculum.</p>
324
-
325
- <p style="color: white;">For this demo, we have used data from a specific domain (teaching ratio and proportions) within
326
- 7th grade math. The fine-tuning based on the pre-trained models learns to predict which strategies
327
- lead to correct vs. incorrect solutions.</p>
328
-
329
- <p style="color: white;">To use the demo, please follow these steps:</p>
330
-
331
- <ol style="color: white;">
332
- <li style="color: white;">Select a fine-tuned model:
333
- <ul style="color: white;">
334
- <li style="color: white;">ASTRA-FT-HGR: Fine-tuned with a small sample of data from schools that have a high graduation rate.</li>
335
- <li style="color: white;">ASTRA-FT-Full: Fine-tuned with a small sample of data from a mix of schools that have high/low graduation rates.</li>
336
- </ul>
337
- </li>
338
- <li style="color: white;">Select a percentage of schools to analyze (selecting a large percentage may take a long time).</li>
339
- <li style="color: white;">View Results:
340
- <ul>
341
- <li style="color: white;">The results from the fine-tuned model are displayed on the dashboard.</li>
342
- <li style="color: white;">The results are shown separately for schools that have high and low graduation rates.</li>
343
- </ul>
344
- </li>
345
- </ol>
346
- """
347
- # CSS styling for white text
348
- # Create the Gradio interface
349
- with gr.Blocks(css="""
350
- body {
351
- background-color: #1e1e1e!important;
352
- font-family: 'Arial', sans-serif;
353
- color: #f5f5f5!important;;
354
- }
355
-
356
- .gradio-container {
357
- max-width: 850px!important;
358
- margin: 0 auto!important;;
359
- padding: 20px!important;;
360
- background-color: #292929!important;
361
- border-radius: 10px;
362
- box-shadow: 0 4px 20px rgba(0, 0, 0, 0.2);
363
- }
364
- .gradio-container-4-44-0 .prose h1 {
365
- font-size: var(--text-xxl);
366
- color: #ffffff!important;
367
- }
368
- #title {
369
- color: white!important;
370
- font-size: 2.3em;
371
- font-weight: bold;
372
- text-align: center!important;
373
- margin-bottom: 20px;
374
- }
375
- .description {
376
- text-align: center;
377
- font-size: 1.1em;
378
- color: #bfbfbf;
379
- margin-bottom: 30px;
380
- }
381
- .file-box {
382
- max-width: 180px;
383
- padding: 5px;
384
- background-color: #444!important;
385
- border: 1px solid #666!important;
386
- border-radius: 6px;
387
- height: 80px!important;;
388
- margin: 0 auto!important;;
389
- text-align: center;
390
- color: transparent;
391
- }
392
- .file-box span {
393
- color: #f5f5f5!important;
394
- font-size: 1em;
395
- line-height: 45px; /* Vertically center text */
396
- }
397
- .dropdown-menu {
398
- max-width: 220px;
399
- margin: 0 auto!important;
400
- background-color: #444!important;
401
- color:#444!important;
402
- border-radius: 6px;
403
- padding: 8px;
404
- font-size: 1.1em;
405
- border: 1px solid #666;
406
- }
407
- .button {
408
- background-color: #4CAF50!important;
409
- color: white!important;
410
- font-size: 1.1em;
411
- padding: 10px 25px;
412
- border-radius: 6px;
413
- cursor: pointer;
414
- transition: background-color 0.2s ease-in-out;
415
- }
416
- .button:hover {
417
- background-color: #45a049!important;
418
- }
419
- .output-text {
420
- background-color: #333!important;
421
- padding: 12px;
422
- border-radius: 8px;
423
- border: 1px solid #666;
424
- font-size: 1.1em;
425
- }
426
- .footer {
427
- text-align: center;
428
- margin-top: 50px;
429
- font-size: 0.9em;
430
- color: #b0b0b0;
431
- }
432
- .svelte-12ioyct .wrap {
433
- display: none !important;
434
- }
435
- .file-label-text {
436
- display: none !important;
437
- }
438
-
439
- div.svelte-sfqy0y {
440
- display: flex;
441
- flex-direction: inherit;
442
- flex-wrap: wrap;
443
- gap: var(--form-gap-width);
444
- box-shadow: var(--block-shadow);
445
- border: var(--block-border-width) solid var(--border-color-primary);
446
- border-radius: var(--block-radius);
447
- background: #1f2937!important;
448
- overflow-y: hidden;
449
- }
450
-
451
- .block.svelte-12cmxck {
452
- position: relative;
453
- margin: 0;
454
- box-shadow: var(--block-shadow);
455
- border-width: var(--block-border-width);
456
- border-color: var(--block-border-color);
457
- border-radius: var(--block-radius);
458
- background: #1f2937!important;
459
- width: 100%;
460
- line-height: var(--line-sm);
461
- }
462
-
463
- .svelte-12ioyct .wrap {
464
- display: none !important;
465
- }
466
- .file-label-text {
467
- display: none !important;
468
- }
469
- input[aria-label="file upload"] {
470
- display: none !important;
471
- }
472
-
473
- gradio-app .gradio-container.gradio-container-4-44-0 .contain .file-box span {
474
- font-size: 1em;
475
- line-height: 45px;
476
- color: #1f2937 !important;
477
- }
478
- .wrap.svelte-12ioyct {
479
- display: flex;
480
- flex-direction: column;
481
- justify-content: center;
482
- align-items: center;
483
- min-height: var(--size-60);
484
- color: #1f2937 !important;
485
- line-height: var(--line-md);
486
- height: 100%;
487
- padding-top: var(--size-3);
488
- text-align: center;
489
- margin: auto var(--spacing-lg);
490
- }
491
- span.svelte-1gfkn6j:not(.has-info) {
492
- margin-bottom: var(--spacing-lg);
493
- color: white!important;
494
- }
495
- label.float.svelte-1b6s6s {
496
- position: relative!important;
497
- top: var(--block-label-margin);
498
- left: var(--block-label-margin);
499
- }
500
- label.svelte-1b6s6s {
501
- display: inline-flex;
502
- align-items: center;
503
- z-index: var(--layer-2);
504
- box-shadow: var(--block-label-shadow);
505
- border: var(--block-label-border-width) solid var(--border-color-primary);
506
- border-top: none;
507
- border-left: none;
508
- border-radius: var(--block-label-radius);
509
- background: rgb(120 151 180)!important;
510
- padding: var(--block-label-padding);
511
- pointer-events: none;
512
- color: #1f2937!important;
513
- font-weight: var(--block-label-text-weight);
514
- font-size: var(--block-label-text-size);
515
- line-height: var(--line-sm);
516
- }
517
- .file.svelte-18wv37q.svelte-18wv37q {
518
- display: block!important;
519
- width: var(--size-full);
520
- }
521
-
522
- tbody.svelte-18wv37q>tr.svelte-18wv37q:nth-child(odd) {
523
- background: ##7897b4!important;
524
- color: white;
525
- background: #aca7b2;
526
- }
527
-
528
- .gradio-container-4-31-4 .prose h1, .gradio-container-4-31-4 .prose h2, .gradio-container-4-31-4 .prose h3, .gradio-container-4-31-4 .prose h4, .gradio-container-4-31-4 .prose h5 {
529
-
530
- color: white;
531
- }
532
- """) as demo:
533
-
534
- gr.Markdown("<h1 id='title'>ASTRA</h1>", elem_id="title")
535
- gr.Markdown(content)
536
-
537
- with gr.Row():
538
- # file_input = gr.File(label="Upload a test file", file_types=['.txt'], elem_classes="file-box")
539
- # label_input = gr.File(label="Upload test labels", file_types=['.txt'], elem_classes="file-box")
540
-
541
- # info_input = gr.File(label="Upload test info", file_types=['.txt'], elem_classes="file-box")
542
-
543
- model_dropdown = gr.Dropdown(choices=models, label="Select Fine-tuned Model", elem_classes="dropdown-menu")
544
-
545
-
546
- increment_slider = gr.Slider(minimum=1, maximum=100, step=1, label="Schools Percentage", value=1)
547
- gr.Markdown("<p class='description'>Dashboard</p>")
548
- with gr.Row():
549
- output_text = gr.Textbox(label="")
550
- # output_image = gr.Image(label="ROC")
551
- plot_output = gr.Plot(label="roc")
552
- with gr.Row():
553
- opt1_pie = gr.Plot(label="opt1")
554
- opt2_pie = gr.Plot(label="opt2")
555
- # output_summary = gr.Textbox(label="Summary")
556
-
557
- btn = gr.Button("Submit")
558
-
559
- btn.click(fn=process_file, inputs=[model_dropdown,increment_slider], outputs=[output_text,plot_output,opt1_pie,opt2_pie])
560
-
561
-
562
- # Launch the app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
563
  demo.launch()
 
1
+ import gradio as gr
2
+ from huggingface_hub import hf_hub_download
3
+ import pickle
4
+ from gradio import Progress
5
+ import numpy as np
6
+ import subprocess
7
+ import shutil
8
+ import matplotlib.pyplot as plt
9
+ from sklearn.metrics import roc_curve, auc
10
+ import pandas as pd
11
+ import plotly.graph_objects as go
12
+ from sklearn.metrics import roc_auc_score
13
+ from matplotlib.figure import Figure
14
+ # Define the function to process the input file and model selection
15
+
16
+ def process_file(model_name,inc_slider,progress=Progress(track_tqdm=True)):
17
+ # progress = gr.Progress(track_tqdm=True)
18
+
19
+ progress(0, desc="Starting the processing")
20
+ # with open(file.name, 'r') as f:
21
+ # content = f.read()
22
+ # saved_test_dataset = "train.txt"
23
+ # saved_test_label = "train_label.txt"
24
+ # saved_train_info="train_info.txt"
25
+ # Save the uploaded file content to a specified location
26
+ # shutil.copyfile(file.name, saved_test_dataset)
27
+ # shutil.copyfile(label.name, saved_test_label)
28
+ # shutil.copyfile(info.name, saved_train_info)
29
+ parent_location="ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/"
30
+ test_info_location=parent_location+"fullTest/test_info.txt"
31
+ test_location=parent_location+"fullTest/test.txt"
32
+ if(model_name=="ASTRA-FT-HGR"):
33
+ finetune_task="highGRschool10"
34
+ # test_info_location=parent_location+"fullTest/test_info.txt"
35
+ # test_location=parent_location+"fullTest/test.txt"
36
+ elif(model_name== "ASTRA-FT-LGR" ):
37
+ finetune_task="lowGRschoolAll"
38
+ # test_info_location=parent_location+"lowGRschoolAll/test_info.txt"
39
+ # test_location=parent_location+"lowGRschoolAll/test.txt"
40
+ elif(model_name=="ASTRA-FT-FULL"):
41
+ # test_info_location=parent_location+"fullTest/test_info.txt"
42
+ # test_location=parent_location+"fullTest/test.txt"
43
+ finetune_task="fullTest"
44
+ else:
45
+ finetune_task=None
46
+ # Load the test_info file and the graduation rate file
47
+ test_info = pd.read_csv(test_info_location, sep=',', header=None, engine='python')
48
+ grad_rate_data = pd.DataFrame(pd.read_pickle('school_grduation_rate.pkl'),columns=['school_number','grad_rate']) # Load the grad_rate data
49
+
50
+ # Step 1: Extract unique school numbers from test_info
51
+ unique_schools = test_info[0].unique()
52
+
53
+ # Step 2: Filter the grad_rate_data using the unique school numbers
54
+ schools = grad_rate_data[grad_rate_data['school_number'].isin(unique_schools)]
55
+
56
+ # Define a threshold for high and low graduation rates (adjust as needed)
57
+ grad_rate_threshold = 0.9
58
+
59
+ # Step 4: Divide schools into high and low graduation rate groups
60
+ high_grad_schools = schools[schools['grad_rate'] >= grad_rate_threshold]['school_number'].unique()
61
+ low_grad_schools = schools[schools['grad_rate'] < grad_rate_threshold]['school_number'].unique()
62
+
63
+ # Step 5: Sample percentage of schools from each group
64
+ high_sample = pd.Series(high_grad_schools).sample(frac=inc_slider/100, random_state=1).tolist()
65
+ low_sample = pd.Series(low_grad_schools).sample(frac=inc_slider/100, random_state=1).tolist()
66
+
67
+ # Step 6: Combine the sampled schools
68
+ random_schools = high_sample + low_sample
69
+
70
+ # Step 7: Get indices for the sampled schools
71
+ indices = test_info[test_info[0].isin(random_schools)].index.tolist()
72
+ high_indices = test_info[(test_info[0].isin(high_sample))].index.tolist()
73
+ low_indices = test_info[(test_info[0].isin(low_sample))].index.tolist()
74
+
75
+ # Load the test file and select rows based on indices
76
+ test = pd.read_csv(test_location, sep=',', header=None, engine='python')
77
+ selected_rows_df2 = test.loc[indices]
78
+
79
+ # Save the selected rows to a file
80
+ selected_rows_df2.to_csv('selected_rows.txt', sep='\t', index=False, header=False, quoting=3, escapechar=' ')
81
+
82
+ graduation_groups = [
83
+ 'high' if idx in high_indices else 'low' for idx in selected_rows_df2.index
84
+ ]
85
+ # Group data by opt_task1 and opt_task2 based on test_info[6]
86
+ opt_task_groups = ['opt_task1' if test_info.loc[idx, 6] == 0 else 'opt_task2' for idx in selected_rows_df2.index]
87
+
88
+ with open("roc_data2.pkl", 'rb') as file:
89
+ data = pickle.load(file)
90
+ t_label=data[0]
91
+ p_label=data[1]
92
+ # Step 1: Align graduation_group, t_label, and p_label
93
+ aligned_labels = list(zip(graduation_groups, t_label, p_label))
94
+ opt_task_aligned = list(zip(opt_task_groups, t_label, p_label))
95
+ # Step 2: Separate the labels for high and low groups
96
+ high_t_labels = [t for grad, t, p in aligned_labels if grad == 'high']
97
+ low_t_labels = [t for grad, t, p in aligned_labels if grad == 'low']
98
+
99
+ high_p_labels = [p for grad, t, p in aligned_labels if grad == 'high']
100
+ low_p_labels = [p for grad, t, p in aligned_labels if grad == 'low']
101
+
102
+ opt_task1_t_labels = [t for task, t, p in opt_task_aligned if task == 'opt_task1']
103
+ opt_task1_p_labels = [p for task, t, p in opt_task_aligned if task == 'opt_task1']
104
+
105
+ opt_task2_t_labels = [t for task, t, p in opt_task_aligned if task == 'opt_task2']
106
+ opt_task2_p_labels = [p for task, t, p in opt_task_aligned if task == 'opt_task2']
107
+
108
+ high_roc_auc = roc_auc_score(high_t_labels, high_p_labels) if len(set(high_t_labels)) > 1 else None
109
+ low_roc_auc = roc_auc_score(low_t_labels, low_p_labels) if len(set(low_t_labels)) > 1 else None
110
+
111
+ opt_task1_roc_auc = roc_auc_score(opt_task1_t_labels, opt_task1_p_labels) if len(set(opt_task1_t_labels)) > 1 else None
112
+ opt_task2_roc_auc = roc_auc_score(opt_task2_t_labels, opt_task2_p_labels) if len(set(opt_task2_t_labels)) > 1 else None
113
+
114
+ # For demonstration purposes, we'll just return the content with the selected model name
115
+
116
+ # print(checkpoint)
117
+ progress(0.1, desc="Files created and saved")
118
+ # if (inc_val<5):
119
+ # model_name="highGRschool10"
120
+ # elif(inc_val>=5 & inc_val<10):
121
+ # model_name="highGRschool10"
122
+ # else:
123
+ # model_name="highGRschool10"
124
+ # Function to analyze each row
125
+ def analyze_row(row):
126
+ # Split the row into fields
127
+ fields = row.split("\t")
128
+
129
+ # Define tasks for OptionalTask_1, OptionalTask_2, and FinalAnswer
130
+ optional_task_1_subtasks = ["DenominatorFactor", "NumeratorFactor", "EquationAnswer"]
131
+ optional_task_2_subtasks = [
132
+ "FirstRow2:1", "FirstRow2:2", "FirstRow1:1", "FirstRow1:2",
133
+ "SecondRow", "ThirdRow"
134
+ ]
135
+
136
+ # Helper function to evaluate task attempts
137
+ def evaluate_tasks(fields, tasks):
138
+ task_status = {}
139
+ for task in tasks:
140
+ relevant_attempts = [f for f in fields if task in f]
141
+ if any("OK" in attempt for attempt in relevant_attempts):
142
+ task_status[task] = "Attempted (Successful)"
143
+ elif any("ERROR" in attempt for attempt in relevant_attempts):
144
+ task_status[task] = "Attempted (Error)"
145
+ elif any("JIT" in attempt for attempt in relevant_attempts):
146
+ task_status[task] = "Attempted (JIT)"
147
+ else:
148
+ task_status[task] = "Unattempted"
149
+ return task_status
150
+
151
+ # Evaluate tasks for each category
152
+ optional_task_1_status = evaluate_tasks(fields, optional_task_1_subtasks)
153
+ optional_task_2_status = evaluate_tasks(fields, optional_task_2_subtasks)
154
+
155
+ # Check if tasks have any successful attempt
156
+ opt1_done = any(status == "Attempted (Successful)" for status in optional_task_1_status.values())
157
+ opt2_done = any(status == "Attempted (Successful)" for status in optional_task_2_status.values())
158
+
159
+ return opt1_done, opt2_done
160
+
161
+ # Read data from test_info.txt
162
+ with open(test_info_location, "r") as file:
163
+ data = file.readlines()
164
+
165
+ # Assuming test_info[7] is a list with ideal tasks for each instance
166
+ ideal_tasks = test_info[6] # A list where each element is either 1 or 2
167
+
168
+ # Initialize counters
169
+ task_counts = {
170
+ 1: {"ER": 0, "ME": 0, "both": 0,"none":0},
171
+ 2: {"ER": 0, "ME": 0, "both": 0,"none":0}
172
+ }
173
+
174
+ # Analyze rows
175
+ for i, row in enumerate(data):
176
+ row = row.strip()
177
+ if not row:
178
+ continue
179
+
180
+ ideal_task = ideal_tasks[i] # Get the ideal task for the current row
181
+ opt1_done, opt2_done = analyze_row(row)
182
+
183
+ if ideal_task == 0:
184
+ if opt1_done and not opt2_done:
185
+ task_counts[1]["ER"] += 1
186
+ elif not opt1_done and opt2_done:
187
+ task_counts[1]["ME"] += 1
188
+ elif opt1_done and opt2_done:
189
+ task_counts[1]["both"] += 1
190
+ else:
191
+ task_counts[1]["none"] +=1
192
+ elif ideal_task == 1:
193
+ if opt1_done and not opt2_done:
194
+ task_counts[2]["ER"] += 1
195
+ elif not opt1_done and opt2_done:
196
+ task_counts[2]["ME"] += 1
197
+ elif opt1_done and opt2_done:
198
+ task_counts[2]["both"] += 1
199
+ else:
200
+ task_counts[2]["none"] +=1
201
+
202
+ # Create a string output for results
203
+ # output_summary = "Task Analysis Summary:\n"
204
+ # output_summary += "-----------------------\n"
205
+
206
+ # for ideal_task, counts in task_counts.items():
207
+ # output_summary += f"Ideal Task = OptionalTask_{ideal_task}:\n"
208
+ # output_summary += f" Only OptionalTask_1 done: {counts['ER']}\n"
209
+ # output_summary += f" Only OptionalTask_2 done: {counts['ME']}\n"
210
+ # output_summary += f" Both done: {counts['both']}\n"
211
+
212
+ # colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']
213
+ colors = ["#FF6F61", "#6B5B95", "#88B04B", "#F7CAC9"]
214
+
215
+ # Generate pie chart for Task 1
216
+ task1_labels = list(task_counts[1].keys())
217
+ task1_values = list(task_counts[1].values())
218
+
219
+ # fig_task1 = Figure()
220
+ # ax1 = fig_task1.add_subplot(1, 1, 1)
221
+ # ax1.pie(task1_values, labels=task1_labels, autopct='%1.1f%%', startangle=90)
222
+ # ax1.set_title('Ideal Task 1 Distribution')
223
+
224
+ fig_task1 = go.Figure(data=[go.Pie(
225
+ labels=task1_labels,
226
+ values=task1_values,
227
+ textinfo='percent+label',
228
+ textposition='auto',
229
+ marker=dict(colors=colors),
230
+ sort=False
231
+
232
+ )])
233
+
234
+ fig_task1.update_layout(
235
+ title='Problem Type: ER',
236
+ title_x=0.5,
237
+ font=dict(
238
+ family="sans-serif",
239
+ size=12,
240
+ color="black"
241
+ ),
242
+ )
243
+
244
+ fig_task1.update_layout(
245
+ legend=dict(
246
+ font=dict(
247
+ family="sans-serif",
248
+ size=12,
249
+ color="black"
250
+ ),
251
+ )
252
+ )
253
+
254
+
255
+
256
+ # fig.show()
257
+
258
+ # Generate pie chart for Task 2
259
+ task2_labels = list(task_counts[2].keys())
260
+ task2_values = list(task_counts[2].values())
261
+
262
+ fig_task2 = go.Figure(data=[go.Pie(
263
+ labels=task2_labels,
264
+ values=task2_values,
265
+ textinfo='percent+label',
266
+ textposition='auto',
267
+ marker=dict(colors=colors),
268
+ sort=False
269
+ # pull=[0, 0.2, 0, 0] # for pulling part of pie chart out (depends on position)
270
+
271
+ )])
272
+
273
+ fig_task2.update_layout(
274
+ title='Problem Type: ME',
275
+ title_x=0.5,
276
+ font=dict(
277
+ family="sans-serif",
278
+ size=12,
279
+ color="black"
280
+ ),
281
+ )
282
+
283
+ fig_task2.update_layout(
284
+ legend=dict(
285
+ font=dict(
286
+ family="sans-serif",
287
+ size=12,
288
+ color="black"
289
+ ),
290
+ )
291
+ )
292
+
293
+
294
+ # fig_task2 = Figure()
295
+ # ax2 = fig_task2.add_subplot(1, 1, 1)
296
+ # ax2.pie(task2_values, labels=task2_labels, autopct='%1.1f%%', startangle=90)
297
+ # ax2.set_title('Ideal Task 2 Distribution')
298
+
299
+ # print(output_summary)
300
+
301
+ progress(0.2, desc="analysis done!! Executing models")
302
+ print("finetuned task: ",finetune_task)
303
+ # subprocess.run([
304
+ # "python", "new_test_saved_finetuned_model.py",
305
+ # "-workspace_name", "ratio_proportion_change3_2223/sch_largest_100-coded",
306
+ # "-finetune_task", finetune_task,
307
+ # "-test_dataset_path","../../../../selected_rows.txt",
308
+ # # "-test_label_path","../../../../train_label.txt",
309
+ # "-finetuned_bert_classifier_checkpoint",
310
+ # "ratio_proportion_change3_2223/sch_largest_100-coded/output/highGRschool10/bert_fine_tuned.model.ep42",
311
+ # "-e",str(1),
312
+ # "-b",str(1000)
313
+ # ])
314
+ progress(0.6,desc="Model execution completed")
315
+ result = {}
316
+ with open("result.txt", 'r') as file:
317
+ for line in file:
318
+ key, value = line.strip().split(': ', 1)
319
+ # print(type(key))
320
+ if key=='epoch':
321
+ result[key]=value
322
+ else:
323
+ result[key]=float(value)
324
+ result["ROC score of HGR"]=high_roc_auc
325
+ result["ROC score of LGR"]=low_roc_auc
326
+ # Create a plot
327
+ with open("roc_data.pkl", "rb") as f:
328
+ fpr, tpr, _ = pickle.load(f)
329
+ # print(fpr,tpr)
330
+ roc_auc = auc(fpr, tpr)
331
+
332
+
333
+ # Create a matplotlib figure
334
+ # fig = Figure()
335
+ # ax = fig.add_subplot(1, 1, 1)
336
+ # ax.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
337
+ # ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
338
+ # ax.set(xlabel='False Positive Rate', ylabel='True Positive Rate', title=f'Receiver Operating Curve (ROC)')
339
+ # ax.legend(loc="lower right")
340
+ # ax.grid()
341
+
342
+ fig = go.Figure()
343
+ # Create and style traces
344
+ fig.add_trace(go.Line(x = list(fpr), y = list(tpr), name=f'ROC curve (area = {roc_auc:.2f})',
345
+ line=dict(color='royalblue', width=3,
346
+ ) # dash options include 'dash', 'dot', and 'dashdot'
347
+ ))
348
+ fig.add_trace(go.Line(x = [0,1], y = [0,1], showlegend = False,
349
+ line=dict(color='firebrick', width=2,
350
+ dash='dash',) # dash options include 'dash', 'dot', and 'dashdot'
351
+ ))
352
+
353
+ # Edit the layout
354
+ fig.update_layout(
355
+ showlegend = True,
356
+ title_x=0.5,
357
+ title=dict(
358
+ text='Receiver Operating Curve (ROC)'
359
+ ),
360
+ xaxis=dict(
361
+ title=dict(
362
+ text='False Positive Rate'
363
+ )
364
+ ),
365
+ yaxis=dict(
366
+ title=dict(
367
+ text='False Negative Rate'
368
+ )
369
+ ),
370
+ font=dict(
371
+ family="sans-serif",
372
+ color="black"
373
+ ),
374
+
375
+ )
376
+ fig.update_layout(
377
+ legend=dict(
378
+ x=0.75,
379
+ y=0,
380
+ traceorder="normal",
381
+ font=dict(
382
+ family="sans-serif",
383
+ size=12,
384
+ color="black"
385
+ ),
386
+ )
387
+ )
388
+
389
+
390
+
391
+
392
+
393
+
394
+ # Save plot to a file
395
+ # plot_path = "plot.png"
396
+ # fig.savefig(plot_path)
397
+ # plt.close(fig)
398
+
399
+
400
+
401
+
402
+ progress(1.0)
403
+ # Prepare text output
404
+ text_output = f"Model: {model_name}\nResult:\n{result}"
405
+ # Prepare text output with HTML formatting
406
+ text_output = f"""
407
+ ---------------------------
408
+ Model: {model_name}
409
+ ---------------------------\n
410
+ Time Taken: {result['time_taken_from_start']:.2f} seconds
411
+ Total Schools in test: {len(unique_schools):.4f}
412
+ Total number of instances having Schools with HGR : {len(high_sample):.4f}
413
+ Total number of instances having Schools with LGR: {len(low_sample):.4f}
414
+
415
+ ROC score of HGR: {high_roc_auc:.4f}
416
+ ROC score of LGR: {low_roc_auc:.4f}
417
+
418
+ ROC-AUC for problems of type ER: {opt_task1_roc_auc:.4f}
419
+ ROC-AUC for problems of type ME: {opt_task2_roc_auc:.4f}
420
+ """
421
+ return text_output,fig,fig_task1,fig_task2
422
+
423
+ # List of models for the dropdown menu
424
+
425
+ # models = ["ASTRA-FT-HGR", "ASTRA-FT-LGR", "ASTRA-FT-FULL"]
426
+ models = ["ASTRA-FT-HGR", "ASTRA-FT-FULL"]
427
+ content = """
428
+ <h1 style="color: black;">A S T R A</h1>
429
+ <h2 style="color: black;">An AI Model for Analyzing Math Strategies</h2>
430
+
431
+ <h3 style="color: white; text-align: center">
432
+ <a href="https://drive.google.com/file/d/1lbEpg8Se1ugTtkjreD8eXIg7qrplhWan/view" style="color: gr.themes.colors.red; text-decoration: none;">Link To Paper</a> |
433
+ <a href="https://github.com/Syudu41/ASTRA---Gates-Project" style="color: #1E90FF; text-decoration: none;">GitHub</a> |
434
+ <a href="https://sites.google.com/view/astra-research/home" style="color: #1E90FF; text-decoration: none;">Project Page</a>
435
+ </h3>
436
+
437
+ <p style="color: white;">Welcome to a demo of ASTRA. ASTRA is a collaborative research project between researchers at the
438
+ <a href="https://sites.google.com/site/dvngopal/" style="color: #1E90FF; text-decoration: none;">University of Memphis</a> and
439
+ <a href="https://www.carnegielearning.com" style="color: #1E90FF; text-decoration: none;">Carnegie Learning</a>
440
+ to utilize AI to improve our understanding of math learning strategies.</p>
441
+
442
+ <p style="color: white;">This demo has been developed with a pre-trained model (based on an architecture similar to BERT ) that learns math strategies using data
443
+ collected from hundreds of schools in the U.S. who have used Carnegie Learning’s MATHia (formerly known as Cognitive Tutor), the flagship Intelligent Tutor that is part of a core, blended math curriculum.
444
+ For this demo, we have used data from a specific domain (teaching ratio and proportions) within 7th grade math. The fine-tuning based on the pre-trained model learns to predict which strategies lead to correct vs incorrect solutions.
445
+ </p>
446
+
447
+ <p style="color: white;">In this math domain, students were given word problems related to ratio and proportions. Further, the students
448
+ were given a choice of optional tasks to work on in parallel to the main problem to demonstrate their thinking (metacognition).
449
+ The optional tasks are designed based on solving problems using Equivalent Ratios (ER) and solving using Means and Extremes/cross-multiplication (ME).
450
+ When the equivalent ratios are easy to compute (integral values), ER is much more efficient compared to ME and switching between the tasks appropriately demonstrates cognitive flexibility.
451
+ </p>
452
+
453
+ <p style="color: white;">To use the demo, please follow these steps:</p>
454
+
455
+ <ol style="color: white;">
456
+ <li style="color: white;">Select a fine-tuned model:
457
+ <ul style="color: white;">
458
+ <li style="color: white;">ASTRA-FT-HGR: Fine-tuned with a small sample of data from schools that have a high graduation rate.</li>
459
+ <li style="color: white;">ASTRA-FT-Full: Fine-tuned with a small sample of data from a mix of schools that have high/low graduation rates.</li>
460
+ </ul>
461
+ </li>
462
+ <li style="color: white;">Select a percentage of schools to analyze (selecting a large percentage may take a long time). Note that the selected percentage is applied to both High Graduation Rate (HGR) schools and Low Graduation Rate (LGR schools).
463
+ </li>
464
+ <li style="color: white;">The results from the fine-tuned model are displayed in the dashboard:
465
+ <ul>
466
+ <li style="color: white;">The model accuracy is computed using the ROC-AUC metric.
467
+ </li>
468
+ <li style="color: white;">The results are shown for HGR, LGR schools and for different problem types (ER/ME).
469
+ </li>
470
+ <li style="color: white;">The distribution over how students utilized the optional tasks (whether they utilized ER/ME, used both of them or none of them) is shown for each problem type.
471
+ </li>
472
+ </ul>
473
+ </li>
474
+ </ol>
475
+ """
476
+ # CSS styling for white text
477
+ # Create the Gradio interface
478
+ available_themes = {
479
+ "default": gr.themes.Default(),
480
+ "soft": gr.themes.Soft(),
481
+ "monochrome": gr.themes.Monochrome(),
482
+ "glass": gr.themes.Glass(),
483
+ "base": gr.themes.Base(),
484
+ }
485
+
486
+ # Comprehensive CSS for all HTML elements
487
+ custom_css = '''
488
+ /* Import Fira Sans font */
489
+ @import url('https://fonts.googleapis.com/css2?family=Fira+Sans:wght@400;500;600;700&family=Inter:wght@400;500;600;700&display=swap');
490
+ @import url('https://fonts.googleapis.com/css2?family=Libre+Caslon+Text:ital,wght@0,400;0,700;1,400&family=Spectral+SC:wght@600&display=swap');
491
+ /* Container modifications for centering */
492
+ .gradio-container {
493
+ color: var(--block-label-text-color) !important;
494
+ max-width: 1000px !important;
495
+ margin: 0 auto !important;
496
+ padding: 2rem !important;
497
+ font-family: Arial, sans-serif !important;
498
+ }
499
+
500
+ /* Main title (ASTRA) */
501
+ #title {
502
+ text-align: center !important;
503
+ margin: 1rem auto !important; /* Reduced margin */
504
+ font-size: 2.5em !important;
505
+ font-weight: 600 !important;
506
+ font-family: "Spectral SC", 'Fira Sans', sans-serif !important;
507
+ padding-bottom: 0 !important; /* Remove bottom padding */
508
+ }
509
+
510
+ /* Subtitle (An AI Model...) */
511
+ h1 {
512
+ text-align: center !important;
513
+ font-size: 30pt !important;
514
+ font-weight: 600 !important;
515
+ font-family: "Spectral SC", 'Fira Sans', sans-serif !important;
516
+ margin-top: 0.5em !important; /* Reduced top margin */
517
+ margin-bottom: 0.3em !important;
518
+ }
519
+
520
+ h2 {
521
+ text-align: center !important;
522
+ font-size: 22pt !important;
523
+ font-weight: 600 !important;
524
+ font-family: "Spectral SC",'Fira Sans', sans-serif !important;
525
+ margin-top: 0.2em !important; /* Reduced top margin */
526
+ margin-bottom: 0.3em !important;
527
+ }
528
+
529
+ /* Links container styling */
530
+ .links-container {
531
+ text-align: center !important;
532
+ margin: 1em auto !important;
533
+ font-family: 'Inter' ,'Fira Sans', sans-serif !important;
534
+ }
535
+
536
+ /* Links */
537
+ a {
538
+ color: #2563eb !important;
539
+ text-decoration: none !important;
540
+ font-family:'Inter' , 'Fira Sans', sans-serif !important;
541
+ }
542
+
543
+ a:hover {
544
+ text-decoration: underline !important;
545
+ opacity: 0.8;
546
+ }
547
+
548
+ /* Regular text */
549
+ p, li, .description, .markdown-text {
550
+ font-family: 'Inter', Arial, sans-serif !important;
551
+ color: black !important;
552
+ font-size: 11pt;
553
+ line-height: 1.6;
554
+ font-weight: 500 !important;
555
+ color: var(--block-label-text-color) !important;
556
+ }
557
+
558
+ /* Other headings */
559
+ h3, h4, h5 {
560
+ font-family: 'Fira Sans', sans-serif !important;
561
+ color: var(--block-label-text-color) !important;
562
+ margin-top: 1.5em;
563
+ margin-bottom: 0.75em;
564
+ }
565
+
566
+
567
+ h3 { font-size: 1.5em; font-weight: 600; }
568
+ h4 { font-size: 1.25em; font-weight: 500; }
569
+ h5 { font-size: 1.1em; font-weight: 500; }
570
+
571
+ /* Form elements */
572
+ .select-wrap select, .wrap select,
573
+ input, textarea {
574
+ font-family: 'Inter' ,Arial, sans-serif !important;
575
+ color: var(--block-label-text-color) !important;
576
+ }
577
+
578
+ /* Lists */
579
+ ul, ol {
580
+ margin-left: 0 !important;
581
+ margin-bottom: 1.25em;
582
+ padding-left: 2em;
583
+ }
584
+
585
+ li {
586
+ margin-bottom: 0.75em;
587
+ }
588
+
589
+ /* Form container */
590
+ .form-container {
591
+ max-width: 1000px !important;
592
+ margin: 0 auto !important;
593
+ padding: 1rem !important;
594
+ }
595
+
596
+ /* Dashboard */
597
+ .dashboard {
598
+ margin-top: 2rem !important;
599
+ padding: 1rem !important;
600
+ border-radius: 8px !important;
601
+ }
602
+
603
+ /* Slider styling */
604
+ .gradio-slider-row {
605
+ display: flex;
606
+ align-items: center;
607
+ justify-content: space-between;
608
+ margin: 1.5em 0;
609
+ max-width: 100% !important;
610
+ }
611
+
612
+ .gradio-slider {
613
+ flex-grow: 1;
614
+ margin-right: 15px;
615
+ }
616
+
617
+ .slider-percentage {
618
+ font-family: 'Inter', Arial, sans-serif !important;
619
+ flex-shrink: 0;
620
+ min-width: 60px;
621
+ font-size: 1em;
622
+ font-weight: bold;
623
+ text-align: center;
624
+ background-color: #f0f8ff;
625
+ border: 1px solid #004080;
626
+ border-radius: 5px;
627
+ padding: 5px 10px;
628
+ }
629
+
630
+ .progress-bar-wrap.progress-bar-wrap.progress-bar-wrap
631
+ {
632
+ border-radius: var(--input-radius);
633
+ height: 1.25rem;
634
+ margin-top: 1rem;
635
+ overflow: hidden;
636
+ width: 70%;
637
+ font-family: 'Inter', Arial, sans-serif !important;
638
+ }
639
+
640
+ /* Add these new styles after your existing CSS */
641
+
642
+ /* Card-like appearance for the dashboard */
643
+ .dashboard {
644
+ background: #ffffff !important;
645
+ box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06) !important;
646
+ border-radius: 12px !important;
647
+ padding: 2rem !important;
648
+ margin-top: 2.5rem !important;
649
+ }
650
+
651
+ /* Enhance ROC graph container */
652
+ #roc {
653
+ background: #ffffff !important;
654
+ padding: 1.5rem !important;
655
+ border-radius: 8px !important;
656
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05) !important;
657
+ margin: 1.5rem 0 !important;
658
+ }
659
+
660
+ /* Style the dropdown select */
661
+ select {
662
+ background-color: #ffffff !important;
663
+ border: 1px solid #e2e8f0 !important;
664
+ border-radius: 8px !important;
665
+ padding: 0.5rem 1rem !important;
666
+ transition: all 0.2s ease-in-out !important;
667
+ box-shadow: 0 1px 2px rgba(0, 0, 0, 0.05) !important;
668
+ }
669
+
670
+ select:hover {
671
+ border-color: #cbd5e1 !important;
672
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1) !important;
673
+ }
674
+
675
+ /* Enhance slider appearance */
676
+ .progress-bar-wrap {
677
+ background: #f8fafc !important;
678
+ border: 1px solid #e2e8f0 !important;
679
+ box-shadow: inset 0 2px 4px rgba(0, 0, 0, 0.05) !important;
680
+ }
681
+
682
+ /* Style metrics in dashboard */
683
+ .dashboard p {
684
+ padding: 0.5rem 0 !important;
685
+ border-bottom: 1px solid #f1f5f9 !important;
686
+ }
687
+
688
+ /* Add spacing between sections */
689
+ .dashboard > div {
690
+ margin-bottom: 1.5rem !important;
691
+ }
692
+
693
+ /* Style the ROC curve title */
694
+ .dashboard h4 {
695
+ color: #1e293b !important;
696
+ font-weight: 600 !important;
697
+ margin-bottom: 1rem !important;
698
+ padding-bottom: 0.5rem !important;
699
+ border-bottom: 2px solid #e2e8f0 !important;
700
+ }
701
+
702
+ /* Enhance link appearances */
703
+ a {
704
+ position: relative !important;
705
+ padding-bottom: 2px !important;
706
+ transition: all 0.2s ease-in-out !important;
707
+ }
708
+
709
+ a:after {
710
+ content: '' !important;
711
+ position: absolute !important;
712
+ width: 0 !important;
713
+ height: 1px !important;
714
+ bottom: 0 !important;
715
+ left: 0 !important;
716
+ background-color: #2563eb !important;
717
+ transition: width 0.3s ease-in-out !important;
718
+ }
719
+
720
+ a:hover:after {
721
+ width: 100% !important;
722
+ }
723
+
724
+ /* Add subtle dividers between sections */
725
+ .form-container > div {
726
+ padding-bottom: 1.5rem !important;
727
+ margin-bottom: 1.5rem !important;
728
+ border-bottom: 1px solid #f1f5f9 !important;
729
+ }
730
+
731
+ /* Style model selection section */
732
+ .select-wrap {
733
+ background: #ffffff !important;
734
+ padding: 1.5rem !important;
735
+ border-radius: 8px !important;
736
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05) !important;
737
+ margin-bottom: 2rem !important;
738
+ }
739
+
740
+ /* Style the metrics display */
741
+ .dashboard span {
742
+ font-family: 'Inter', sans-serif !important;
743
+ font-weight: 500 !important;
744
+ color: #334155 !important;
745
+ }
746
+
747
+ /* Add subtle animation to interactive elements */
748
+ button, select, .slider-percentage {
749
+ transition: all 0.2s ease-in-out !important;
750
+ }
751
+
752
+ /* Style the ROC curve container */
753
+ .plot-container {
754
+ background: #ffffff !important;
755
+ border-radius: 8px !important;
756
+ padding: 1rem !important;
757
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05) !important;
758
+ }
759
+
760
+ /* Add container styles for opt1 and opt2 sections */
761
+ #opt1, #opt2 {
762
+ background: #ffffff !important;
763
+ border-radius: 8px !important;
764
+ padding: 1.5rem !important;
765
+ margin-top: 1.5rem !important;
766
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05) !important;
767
+ }
768
+
769
+ /* Style the distribution titles */
770
+ .distribution-title {
771
+ font-family: 'Inter', sans-serif !important;
772
+ font-weight: 600 !important;
773
+ color: #1e293b !important;
774
+ margin-bottom: 1rem !important;
775
+ text-align: center !important;
776
+ }
777
+
778
+ '''
779
+
780
+ with gr.Blocks(theme='gstaff/sketch', css=custom_css) as demo:
781
+
782
+ # gr.Markdown("<h1 id='title'>ASTRA</h1>", elem_id="title")
783
+ gr.Markdown(content)
784
+
785
+ with gr.Row():
786
+ # file_input = gr.File(label="Upload a test file", file_types=['.txt'], elem_classes="file-box")
787
+ # label_input = gr.File(label="Upload test labels", file_types=['.txt'], elem_classes="file-box")
788
+
789
+ # info_input = gr.File(label="Upload test info", file_types=['.txt'], elem_classes="file-box")
790
+ model_dropdown = gr.Dropdown(
791
+ choices=models,
792
+ label="Select Fine-tuned Model",
793
+ elem_classes="dropdown-menu"
794
+ )
795
+ increment_slider = gr.Slider(
796
+ minimum=1,
797
+ maximum=100,
798
+ step=1,
799
+ label="Schools Percentage",
800
+ value=1,
801
+ elem_id="increment-slider",
802
+ elem_classes="gradio-slider"
803
+ )
804
+
805
+ with gr.Row():
806
+ btn = gr.Button("Submit")
807
+
808
+ gr.Markdown("<p class='description'>Dashboard</p>")
809
+
810
+ with gr.Row():
811
+ output_text = gr.Textbox(label="")
812
+ # output_image = gr.Image(label="ROC")
813
+ with gr.Row():
814
+ plot_output = gr.Plot(label="ROC")
815
+
816
+ with gr.Row():
817
+ opt1_pie = gr.Plot(label="ER")
818
+ opt2_pie = gr.Plot(label="ME")
819
+ # output_summary = gr.Textbox(label="Summary")
820
+
821
+
822
+
823
+ btn.click(
824
+ fn=process_file,
825
+ inputs=[model_dropdown,increment_slider],
826
+ outputs=[output_text,plot_output,opt1_pie,opt2_pie]
827
+ )
828
+
829
+
830
+ # Launch the app
831
  demo.launch()