|
import gradio as gr |
|
from paddleocr import PaddleOCR |
|
from groq import Groq |
|
from openai import OpenAI |
|
import os |
|
import json |
|
|
|
|
|
|
|
|
|
print("Loading PaddleOCR model...") |
|
|
|
|
|
AVAILABLE_LANGUAGES = { |
|
'English': 'en', |
|
'Chinese Simplified': 'ch', |
|
'French': 'fr', |
|
'German': 'german', |
|
'Korean': 'korean', |
|
'Japanese': 'japan', |
|
'Italian': 'it', |
|
'Spanish': 'es', |
|
'Portuguese': 'pt', |
|
'Russian': 'ru', |
|
'Arabic': 'ar', |
|
'Hindi': 'hi', |
|
'Vietnamese': 'vi', |
|
'Thai': 'th' |
|
} |
|
|
|
|
|
PROVIDERS = ["None", "Groq", "OpenAI"] |
|
|
|
|
|
ocr_models = {} |
|
|
|
def get_ocr_model(lang_code): |
|
if lang_code not in ocr_models: |
|
ocr_models[lang_code] = PaddleOCR( |
|
use_angle_cls=True, |
|
lang=lang_code, |
|
show_log=False, |
|
enable_mkldnn=True |
|
) |
|
return ocr_models[lang_code] |
|
|
|
|
|
|
|
|
|
def format_with_groq(text: str, api_key: str) -> str: |
|
client = Groq(api_key=api_key) |
|
completion = client.chat.completions.create( |
|
model="llama3-8b-8192", |
|
messages=[ |
|
{ |
|
"role": "system", |
|
"content": ( |
|
"You are a receipt data extraction expert. Extract and format the receipt data into a clear JSON structure.\n" |
|
"Look for these key pieces of information:\n" |
|
"1. Restaurant/store name\n" |
|
"2. Restaurant Address /store address\n" |
|
"3. Date and time\n" |
|
"4. Individual items with quantities and prices\n" |
|
"5. Table number if present\n" |
|
"6. Server name if present\n" |
|
"7. Payment details\n" |
|
"8. Receipt/order number\n" |
|
"Format numbers as actual numbers, not strings." |
|
) |
|
}, |
|
{ |
|
"role": "user", |
|
"content": f"Convert this receipt text to structured data:\n\n{text}" |
|
} |
|
], |
|
temperature=0.1, |
|
max_tokens=1024, |
|
top_p=1, |
|
stream=True |
|
) |
|
|
|
formatted_text = "" |
|
for chunk in completion: |
|
content = getattr(chunk.choices[0].delta, "content", None) |
|
if content: |
|
formatted_text += content |
|
|
|
return formatted_text.strip() |
|
|
|
def refine_json_with_groq(initial_text: str, api_key: str) -> str: |
|
client = Groq(api_key=api_key) |
|
completion = client.chat.completions.create( |
|
model="llama3-8b-8192", |
|
messages=[ |
|
{ |
|
"role": "system", |
|
"content": ( |
|
"Convert the receipt data into this exact JSON format:\n" |
|
"{\n" |
|
" 'restaurant_name': string,\n" |
|
" 'date': string,\n" |
|
" 'time': string,\n" |
|
" 'table_number': string or number,\n" |
|
" 'server_name': string,\n" |
|
" 'payment_method': string,\n" |
|
" 'items': [{'name': string, 'quantity': number, 'price': number}],\n" |
|
" 'subtotal': number,\n" |
|
" 'tax': number,\n" |
|
" 'tip': number or null,\n" |
|
" 'total': number,\n" |
|
" 'receipt_number': string or null\n" |
|
"}\n" |
|
"Rules:\n" |
|
"1. Use ONLY double quotes for JSON compliance\n" |
|
"2. All numbers must be actual numbers, not strings\n" |
|
"3. Return ONLY the JSON, no explanations\n" |
|
"4. Ensure math is correct" |
|
) |
|
}, |
|
{ |
|
"role": "user", |
|
"content": f"Format this receipt data as valid JSON:\n\n{initial_text}" |
|
} |
|
], |
|
temperature=0.1, |
|
max_tokens=1024, |
|
top_p=1, |
|
stream=True |
|
) |
|
|
|
refined_text = "" |
|
for chunk in completion: |
|
content = getattr(chunk.choices[0].delta, "content", None) |
|
if content: |
|
refined_text += content |
|
|
|
try: |
|
|
|
json_start = refined_text.find('{') |
|
json_end = refined_text.rfind('}') + 1 |
|
if json_start >= 0 and json_end > 0: |
|
refined_text = refined_text[json_start:json_end] |
|
|
|
|
|
parsed_json = json.loads(refined_text) |
|
return json.dumps(parsed_json, indent=2) |
|
except json.JSONDecodeError: |
|
return refined_text |
|
|
|
|
|
|
|
|
|
def process_with_openai(text: str, api_key: str) -> dict: |
|
client = OpenAI(api_key=api_key) |
|
try: |
|
completion = client.chat.completions.create( |
|
model="gpt-3.5-turbo", |
|
messages=[ |
|
{ |
|
"role": "system", |
|
"content": ( |
|
"Convert the receipt data into this exact JSON format:\n" |
|
"{\n" |
|
" 'restaurant_name': string,\n" |
|
" 'restaurant_address': string,\n" |
|
" 'date': string,\n" |
|
" 'time': string,\n" |
|
" 'table_number': string or number,\n" |
|
" 'server_name': string,\n" |
|
" 'payment_method': string,\n" |
|
" 'items': [{'name': string, 'quantity': number, 'price': number}],\n" |
|
" 'subtotal': number,\n" |
|
" 'tax': number,\n" |
|
" 'tip': number or null,\n" |
|
" 'total': number,\n" |
|
" 'receipt_number': string or null\n" |
|
"}\n" |
|
"Rules:\n" |
|
"1. Use ONLY double quotes for JSON compliance\n" |
|
"2. All numbers must be actual numbers, not strings\n" |
|
"3. Return ONLY the JSON, no explanations" |
|
) |
|
}, |
|
{ |
|
"role": "user", |
|
"content": f"Convert this receipt text to JSON:\n\n{text}" |
|
} |
|
], |
|
temperature=0.1 |
|
) |
|
return completion.choices[0].message.content |
|
except Exception as e: |
|
return json.dumps({"error": str(e)}) |
|
|
|
|
|
|
|
|
|
def process_receipt(image, selected_language, provider="None", api_key=""): |
|
try: |
|
os.makedirs("temp", exist_ok=True) |
|
|
|
image_path = os.path.join("temp", "temp_image.jpg") |
|
image.save(image_path) |
|
|
|
|
|
lang_code = AVAILABLE_LANGUAGES[selected_language] |
|
ocr_model = get_ocr_model(lang_code) |
|
result = ocr_model.ocr(image_path, cls=True) |
|
|
|
|
|
extracted_text = "\n".join([line[1][0] for page in result for line in page]) |
|
|
|
|
|
if not api_key or provider == "None": |
|
return { |
|
"raw_ocr_text": extracted_text, |
|
"note": "Provide API key and select a provider for structured JSON output" |
|
} |
|
|
|
try: |
|
if provider == "Groq": |
|
|
|
initial_text = format_with_groq(extracted_text, api_key) |
|
final_json = refine_json_with_groq(initial_text, api_key) |
|
return json.loads(final_json) |
|
|
|
elif provider == "OpenAI": |
|
|
|
result = process_with_openai(extracted_text, api_key) |
|
return json.loads(result) |
|
|
|
except json.JSONDecodeError: |
|
return { |
|
"error": "Failed to parse response", |
|
"raw_ocr_text": extracted_text |
|
} |
|
|
|
except Exception as e: |
|
return { |
|
"error": str(e), |
|
"type": "processing_error" |
|
} |
|
finally: |
|
if os.path.exists(image_path): |
|
try: |
|
os.remove(image_path) |
|
except: |
|
pass |
|
|
|
|
|
|
|
|
|
css = """ |
|
.gradio-container {max-width: 1100px !important} |
|
""" |
|
|
|
with gr.Blocks(css=css) as demo: |
|
gr.Markdown("# Multi-Language Receipt OCR") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
image_input = gr.Image( |
|
type="pil", |
|
label="Upload Receipt Image", |
|
height=400 |
|
) |
|
language_dropdown = gr.Dropdown( |
|
choices=list(AVAILABLE_LANGUAGES.keys()), |
|
value="English", |
|
label="Select Language", |
|
info="Choose the primary language of the receipt" |
|
) |
|
|
|
with gr.Row(): |
|
provider_dropdown = gr.Dropdown( |
|
choices=PROVIDERS, |
|
value="None", |
|
label="Select LLM Provider", |
|
info="Choose provider for JSON formatting" |
|
) |
|
api_key_input = gr.Textbox( |
|
label="API Key", |
|
placeholder="Enter your API key", |
|
type="password", |
|
info="Required for JSON formatting" |
|
) |
|
|
|
submit_button = gr.Button("Process Receipt", variant="primary") |
|
|
|
with gr.Column(scale=1): |
|
json_output = gr.JSON( |
|
label="Extracted Receipt Data", |
|
height=500 |
|
) |
|
|
|
gr.Markdown(""" |
|
### Usage Instructions |
|
1. Upload a clear image of your receipt |
|
2. Select the receipt's primary language |
|
3. (Optional) Choose a provider and enter API key for JSON formatting |
|
4. Click 'Process Receipt' |
|
|
|
### Notes |
|
- Without an API key, you'll receive raw OCR text |
|
- For best results, ensure receipt image is clear and well-lit |
|
- Supported languages include English, Chinese, French, German, and more |
|
""") |
|
|
|
submit_button.click( |
|
fn=process_receipt, |
|
inputs=[ |
|
image_input, |
|
language_dropdown, |
|
provider_dropdown, |
|
api_key_input |
|
], |
|
outputs=[json_output], |
|
) |
|
|
|
|
|
gr.close_all() |
|
|
|
|
|
demo.queue(max_size=10) |
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
show_api=False, |
|
share=False |
|
) |