xjlulu's picture
"~"
1404f4e
raw
history blame
1.42 kB
# -*- coding: utf-8 -*-
from transformers import pipeline
import gradio as gr
from gradio.components import Textbox
import random
import csv
question_answering = pipeline("question-answering", model="xjlulu/ntu_adl_span_selection_roberta", framework="pt")
examples = []
with open('sample_data.csv', mode='r') as csv_file:
csv_reader = csv.reader(csv_file)
for row in csv_reader:
examples.append(list(row))
def random_sample():
random_number = random.randint(0, len(examples) - 1)
return examples[random_number]
def generate_answer(question, context):
result = question_answering(question=question, context=context)
return result['answer']
description="Answer questions based on a given context paragraph"
with gr.Blocks(theme=gr.themes.Soft(), title="Question Answering") as demo:
gr.Markdown(description)
with gr.Row():
Q_input = gr.Textbox(lines=3, label="Question")
A_output = Textbox(lines=3, label="Answer")
with gr.Row():
random_button = gr.Button("Random")
generate_button = gr.Button("Generate")
C_input = gr.Textbox(lines=8, label="Context paragraph")
random_button.click(random_sample, inputs=None, outputs=[Q_input, C_input])
generate_button.click(generate_answer, inputs=[Q_input, C_input], outputs=A_output)
demo.launch()