Upload app (5).py
Browse files- app (5).py +110 -0
app (5).py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from transformers import AutoModelForQuestionAnswering, AutoTokenizer
|
3 |
+
import torch
|
4 |
+
import logging
|
5 |
+
|
6 |
+
# Configure logging
|
7 |
+
logging.basicConfig(
|
8 |
+
level=logging.INFO,
|
9 |
+
format='%(asctime)s - %(levelname)s - %(message)s'
|
10 |
+
)
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
class QuestionAnsweringSystem:
|
14 |
+
def __init__(self, model_name):
|
15 |
+
"""Initialize the QA system with the specified model."""
|
16 |
+
logger.info(f"Loading model and tokenizer from {model_name}")
|
17 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
18 |
+
self.model = AutoModelForQuestionAnswering.from_pretrained(model_name)
|
19 |
+
self.max_length = 384
|
20 |
+
|
21 |
+
def answer_question(self, context, question):
|
22 |
+
"""Process the question and context to generate an answer."""
|
23 |
+
try:
|
24 |
+
# Tokenize input text
|
25 |
+
inputs = self.tokenizer(
|
26 |
+
question,
|
27 |
+
context,
|
28 |
+
max_length=self.max_length,
|
29 |
+
truncation="only_second",
|
30 |
+
padding="max_length",
|
31 |
+
return_tensors="pt"
|
32 |
+
)
|
33 |
+
|
34 |
+
# Generate model predictions
|
35 |
+
with torch.no_grad():
|
36 |
+
outputs = self.model(**inputs)
|
37 |
+
|
38 |
+
# Extract answer span
|
39 |
+
answer_start = torch.argmax(outputs.start_logits)
|
40 |
+
answer_end = torch.argmax(outputs.end_logits)
|
41 |
+
|
42 |
+
# Decode tokens to get the answer
|
43 |
+
tokens = inputs.input_ids[0][answer_start:answer_end + 1]
|
44 |
+
answer = self.tokenizer.decode(tokens, skip_special_tokens=True)
|
45 |
+
|
46 |
+
# Calculate confidence score
|
47 |
+
confidence_start = torch.softmax(outputs.start_logits, dim=1).max().item()
|
48 |
+
confidence_end = torch.softmax(outputs.end_logits, dim=1).max().item()
|
49 |
+
confidence = (confidence_start + confidence_end) / 2
|
50 |
+
|
51 |
+
return answer, float(confidence)
|
52 |
+
|
53 |
+
except Exception as e:
|
54 |
+
logger.error(f"Error in answer generation: {str(e)}")
|
55 |
+
return "An error occurred while processing your question.", 0.0
|
56 |
+
|
57 |
+
# Initialize the QA system
|
58 |
+
qa_system = QuestionAnsweringSystem("aman-augurs/bert-fine-tuned-qa3e") # Replace with your model name
|
59 |
+
|
60 |
+
def process_query(context, question):
|
61 |
+
"""Handle the user query and return formatted results."""
|
62 |
+
if not context or not question:
|
63 |
+
return "Please provide both a context and a question."
|
64 |
+
|
65 |
+
try:
|
66 |
+
answer, confidence = qa_system.answer_question(context, question)
|
67 |
+
|
68 |
+
if confidence < 0.1:
|
69 |
+
return "I'm not confident enough to provide an answer based on the given context."
|
70 |
+
|
71 |
+
response = f"Answer: {answer}\nConfidence: {confidence:.2%}"
|
72 |
+
return response
|
73 |
+
|
74 |
+
except Exception as e:
|
75 |
+
logger.error(f"Error processing query: {str(e)}")
|
76 |
+
return "An error occurred while processing your request."
|
77 |
+
|
78 |
+
# Create the Gradio interface
|
79 |
+
def create_interface():
|
80 |
+
"""Create and configure the Gradio interface."""
|
81 |
+
return gr.Interface(
|
82 |
+
fn=process_query,
|
83 |
+
inputs=[
|
84 |
+
gr.Textbox(
|
85 |
+
label="Context",
|
86 |
+
placeholder="Enter the context passage here...",
|
87 |
+
lines=10
|
88 |
+
),
|
89 |
+
gr.Textbox(
|
90 |
+
label="Question",
|
91 |
+
placeholder="Enter your question here..."
|
92 |
+
)
|
93 |
+
],
|
94 |
+
outputs=gr.Textbox(label="Response"),
|
95 |
+
title="Question Answering System",
|
96 |
+
description="""This application uses a fine-tuned BERT model to answer questions based on provided context.
|
97 |
+
Enter a passage of text as context and ask a specific question about it.""",
|
98 |
+
examples=[
|
99 |
+
["The Golden Gate Bridge is a suspension bridge spanning the Golden Gate strait, the one-mile-wide strait connecting San Francisco Bay and the Pacific Ocean. The structure links the U.S. city of San Francisco, California to Marin County, carrying both U.S. Route 101 and California State Route 1 across the strait.",
|
100 |
+
"How wide is the Golden Gate strait?"],
|
101 |
+
["Python is a high-level, interpreted programming language. Python's design philosophy emphasizes code readability with the use of significant indentation. Its language constructs and object-oriented approach aim to help programmers write clear, logical code for small and large-scale projects.",
|
102 |
+
"What is Python's design philosophy?"]
|
103 |
+
],
|
104 |
+
theme=gr.themes.Base()
|
105 |
+
)
|
106 |
+
|
107 |
+
# Create app.py for Hugging Face Spaces
|
108 |
+
if __name__ == "__main__":
|
109 |
+
interface = create_interface()
|
110 |
+
interface.launch()
|