arshaan-nazir's picture
add address
5afd2ae verified
raw
history blame
11 kB
import gradio as gr
from paddleocr import PaddleOCR
from groq import Groq
from openai import OpenAI
import os
import json
##################################
# Initialize Models
##################################
print("Loading PaddleOCR model...")
# Available languages in PaddleOCR
AVAILABLE_LANGUAGES = {
'English': 'en',
'Chinese Simplified': 'ch',
'French': 'fr',
'German': 'german',
'Korean': 'korean',
'Japanese': 'japan',
'Italian': 'it',
'Spanish': 'es',
'Portuguese': 'pt',
'Russian': 'ru',
'Arabic': 'ar',
'Hindi': 'hi',
'Vietnamese': 'vi',
'Thai': 'th'
}
# Available LLM providers
PROVIDERS = ["None", "Groq", "OpenAI"]
# Dictionary to store OCR models for different languages
ocr_models = {}
def get_ocr_model(lang_code):
if lang_code not in ocr_models:
ocr_models[lang_code] = PaddleOCR(
use_angle_cls=True,
lang=lang_code,
show_log=False,
enable_mkldnn=True # Better CPU performance
)
return ocr_models[lang_code]
##################################
# Groq Processing Functions
##################################
def format_with_groq(text: str, api_key: str) -> str:
client = Groq(api_key=api_key)
completion = client.chat.completions.create(
model="llama3-8b-8192",
messages=[
{
"role": "system",
"content": (
"You are a receipt data extraction expert. Extract and format the receipt data into a clear JSON structure.\n"
"Look for these key pieces of information:\n"
"1. Restaurant/store name\n"
"2. Restaurant Address /store address\n"
"3. Date and time\n"
"4. Individual items with quantities and prices\n"
"5. Table number if present\n"
"6. Server name if present\n"
"7. Payment details\n"
"8. Receipt/order number\n"
"Format numbers as actual numbers, not strings."
)
},
{
"role": "user",
"content": f"Convert this receipt text to structured data:\n\n{text}"
}
],
temperature=0.1,
max_tokens=1024,
top_p=1,
stream=True
)
formatted_text = ""
for chunk in completion:
content = getattr(chunk.choices[0].delta, "content", None)
if content:
formatted_text += content
return formatted_text.strip()
def refine_json_with_groq(initial_text: str, api_key: str) -> str:
client = Groq(api_key=api_key)
completion = client.chat.completions.create(
model="llama3-8b-8192",
messages=[
{
"role": "system",
"content": (
"Convert the receipt data into this exact JSON format:\n"
"{\n"
" 'restaurant_name': string,\n"
" 'date': string,\n"
" 'time': string,\n"
" 'table_number': string or number,\n"
" 'server_name': string,\n"
" 'payment_method': string,\n"
" 'items': [{'name': string, 'quantity': number, 'price': number}],\n"
" 'subtotal': number,\n"
" 'tax': number,\n"
" 'tip': number or null,\n"
" 'total': number,\n"
" 'receipt_number': string or null\n"
"}\n"
"Rules:\n"
"1. Use ONLY double quotes for JSON compliance\n"
"2. All numbers must be actual numbers, not strings\n"
"3. Return ONLY the JSON, no explanations\n"
"4. Ensure math is correct"
)
},
{
"role": "user",
"content": f"Format this receipt data as valid JSON:\n\n{initial_text}"
}
],
temperature=0.1,
max_tokens=1024,
top_p=1,
stream=True
)
refined_text = ""
for chunk in completion:
content = getattr(chunk.choices[0].delta, "content", None)
if content:
refined_text += content
try:
# Clean up any potential extra text
json_start = refined_text.find('{')
json_end = refined_text.rfind('}') + 1
if json_start >= 0 and json_end > 0:
refined_text = refined_text[json_start:json_end]
# Validate JSON and reformat
parsed_json = json.loads(refined_text)
return json.dumps(parsed_json, indent=2)
except json.JSONDecodeError:
return refined_text
##################################
# OpenAI Processing Functions
##################################
def process_with_openai(text: str, api_key: str) -> dict:
client = OpenAI(api_key=api_key)
try:
completion = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{
"role": "system",
"content": (
"Convert the receipt data into this exact JSON format:\n"
"{\n"
" 'restaurant_name': string,\n"
" 'restaurant_address': string,\n"
" 'date': string,\n"
" 'time': string,\n"
" 'table_number': string or number,\n"
" 'server_name': string,\n"
" 'payment_method': string,\n"
" 'items': [{'name': string, 'quantity': number, 'price': number}],\n"
" 'subtotal': number,\n"
" 'tax': number,\n"
" 'tip': number or null,\n"
" 'total': number,\n"
" 'receipt_number': string or null\n"
"}\n"
"Rules:\n"
"1. Use ONLY double quotes for JSON compliance\n"
"2. All numbers must be actual numbers, not strings\n"
"3. Return ONLY the JSON, no explanations"
)
},
{
"role": "user",
"content": f"Convert this receipt text to JSON:\n\n{text}"
}
],
temperature=0.1
)
return completion.choices[0].message.content
except Exception as e:
return json.dumps({"error": str(e)})
##################################
# Main Processing
##################################
def process_receipt(image, selected_language, provider="None", api_key=""):
try:
os.makedirs("temp", exist_ok=True)
image_path = os.path.join("temp", "temp_image.jpg")
image.save(image_path)
# Get OCR model and process image
lang_code = AVAILABLE_LANGUAGES[selected_language]
ocr_model = get_ocr_model(lang_code)
result = ocr_model.ocr(image_path, cls=True)
# Extract text from results
extracted_text = "\n".join([line[1][0] for page in result for line in page])
# If no provider/api key, return raw OCR
if not api_key or provider == "None":
return {
"raw_ocr_text": extracted_text,
"note": "Provide API key and select a provider for structured JSON output"
}
try:
if provider == "Groq":
# Two-step Groq processing
initial_text = format_with_groq(extracted_text, api_key)
final_json = refine_json_with_groq(initial_text, api_key)
return json.loads(final_json)
elif provider == "OpenAI":
# OpenAI processing
result = process_with_openai(extracted_text, api_key)
return json.loads(result)
except json.JSONDecodeError:
return {
"error": "Failed to parse response",
"raw_ocr_text": extracted_text
}
except Exception as e:
return {
"error": str(e),
"type": "processing_error"
}
finally:
if os.path.exists(image_path):
try:
os.remove(image_path)
except:
pass
##################################
# Gradio Interface
##################################
css = """
.gradio-container {max-width: 1100px !important}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown("# Multi-Language Receipt OCR")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(
type="pil",
label="Upload Receipt Image",
height=400
)
language_dropdown = gr.Dropdown(
choices=list(AVAILABLE_LANGUAGES.keys()),
value="English",
label="Select Language",
info="Choose the primary language of the receipt"
)
with gr.Row():
provider_dropdown = gr.Dropdown(
choices=PROVIDERS,
value="None",
label="Select LLM Provider",
info="Choose provider for JSON formatting"
)
api_key_input = gr.Textbox(
label="API Key",
placeholder="Enter your API key",
type="password",
info="Required for JSON formatting"
)
submit_button = gr.Button("Process Receipt", variant="primary")
with gr.Column(scale=1):
json_output = gr.JSON(
label="Extracted Receipt Data",
height=500
)
gr.Markdown("""
### Usage Instructions
1. Upload a clear image of your receipt
2. Select the receipt's primary language
3. (Optional) Choose a provider and enter API key for JSON formatting
4. Click 'Process Receipt'
### Notes
- Without an API key, you'll receive raw OCR text
- For best results, ensure receipt image is clear and well-lit
- Supported languages include English, Chinese, French, German, and more
""")
submit_button.click(
fn=process_receipt,
inputs=[
image_input,
language_dropdown,
provider_dropdown,
api_key_input
],
outputs=[json_output],
)
# Close any existing gradio instances
gr.close_all()
# Launch the app
demo.queue(max_size=10)
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_api=False,
share=False
)