Vinay15 commited on
Commit
a8781ff
·
verified ·
1 Parent(s): 0a8a545

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -27
app.py CHANGED
@@ -1,39 +1,32 @@
1
  import gradio as gr
2
  from transformers import AutoModel, AutoTokenizer
3
  from PIL import Image
4
- import tempfile
5
 
6
- # Load the tokenizer and model without GPU settings
7
  tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
8
- model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
9
- model = model.eval() # Set to evaluation mode
 
10
 
11
- # Define the OCR function with error handling
12
- def perform_ocr(image):
13
- try:
14
- # Convert PIL image to RGB format (if necessary)
15
- if image.mode != "RGB":
16
- image = image.convert("RGB")
17
-
18
- # Save the image to a temporary file
19
- with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_file:
20
- image.save(temp_file.name)
21
 
22
- # Perform OCR using the model
23
- res = model.chat(tokenizer, temp_file.name, ocr_type='ocr') # Pass the file path
24
-
25
- return res
26
- except Exception as e:
27
- return f"An error occurred: {str(e)}"
28
 
29
- # Define the Gradio interface
30
  interface = gr.Interface(
31
- fn=perform_ocr,
32
- inputs=gr.Image(type="pil", label="Upload Image"),
33
- outputs=gr.Textbox(label="Extracted Text"),
34
- title="OCR and Document Search Web Application",
35
- description="Upload an image to extract text using the GOT-OCR2_0 model."
36
  )
37
 
38
- # Launch the Gradio app
39
  interface.launch()
 
1
  import gradio as gr
2
  from transformers import AutoModel, AutoTokenizer
3
  from PIL import Image
4
+ import torch
5
 
6
+ # Load the OCR model and tokenizer
7
  tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
8
+ model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True,
9
+ low_cpu_mem_usage=True,
10
+ pad_token_id=tokenizer.eos_token_id).eval()
11
 
12
+ # Check if GPU is available and use it, else use CPU
13
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
+ model = model.to(device)
 
 
 
 
 
 
 
15
 
16
+ # Function to perform OCR on the image
17
+ def perform_ocr(image):
18
+ # Perform OCR using the model
19
+ result = model.chat(tokenizer, image, ocr_type='ocr')
20
+ return result
 
21
 
22
+ # Create the Gradio interface
23
  interface = gr.Interface(
24
+ fn=perform_ocr,
25
+ inputs=gr.inputs.Image(type="pil"), # Accepts an image input
26
+ outputs="text", # Outputs extracted text
27
+ title="OCR Web App",
28
+ description="Upload an image to extract text using the GOT-OCR2.0 model."
29
  )
30
 
31
+ # Launch the app
32
  interface.launch()