File size: 3,181 Bytes
300388f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import gradio as gr

# Function to compute evaluation metrics (dummy implementation)
def compute_metrics(gt_spans, pred_spans):
    # Dummy implementation of a metric computation
    # Replace this with actual metric computation logic
    tp = len(set(gt_spans) & set(pred_spans))
    fp = len(set(pred_spans) - set(gt_spans))
    fn = len(set(gt_spans) - set(pred_spans))
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
    
    return {"precision": precision, "recall": recall, "f1_score": f1_score}

def create_app():
    with gr.Blocks() as demo:
        # Input components
        text_input = gr.Textbox(label="Input Text")
        highlight_input = gr.Textbox(label="Highlight Text and Press Add")
        
        gt_spans_state = gr.State([])
        pred_spans_state = gr.State([])

        # Buttons for ground truth and prediction
        add_gt_button = gr.Button("Add to Ground Truth")
        add_pred_button = gr.Button("Add to Predictions")

        # Outputs for highlighted spans
        gt_output = gr.HighlightedText(label="Ground Truth Spans")
        pred_output = gr.HighlightedText(label="Predicted Spans")

        # Compute metrics button and its output
        compute_button = gr.Button("Compute Metrics")
        metrics_output = gr.JSON(label="Metrics")

        # Function to update spans
        def update_spans(text, span, gt_spans, pred_spans, is_gt):
            start_idx = text.find(span)
            end_idx = start_idx + len(span)
            new_span = (start_idx, end_idx)
            if is_gt:
                gt_spans.append(new_span)
                gt_spans = list(set(gt_spans))
            else:
                pred_spans.append(new_span)
                pred_spans = list(set(pred_spans))
            return gt_spans, pred_spans, highlight_spans(text, gt_spans), highlight_spans(text, pred_spans)
        
        # Function to highlight spans
        def highlight_spans(text, spans):
            span_dict = {}
            for span in spans:
                span_dict[(span[0], span[1])] = "highlight"
            return span_dict

        # Event handlers for buttons
        add_gt_button.click(fn=update_spans, inputs=[text_input, highlight_input, gt_spans_state, pred_spans_state, gr.State(True)], outputs=[gt_spans_state, pred_spans_state, gt_output, pred_output])
        add_pred_button.click(fn=update_spans, inputs=[text_input, highlight_input, gt_spans_state, pred_spans_state, gr.State(False)], outputs=[gt_spans_state, pred_spans_state, gt_output, pred_output])

        # Function to compute metrics
        def on_compute_metrics(gt_spans, pred_spans):
            metrics = compute_metrics(gt_spans, pred_spans)
            return metrics
        
        compute_button.click(fn=on_compute_metrics, inputs=[gt_spans_state, pred_spans_state], outputs=metrics_output)

        # Layout arrangement
        text_input.change(fn=lambda x: x, inputs=text_input, outputs=[gt_output, pred_output])

    return demo

# Run the app
demo = create_app()
demo.launch()