Vinay15 commited on
Commit
c920662
·
verified ·
1 Parent(s): 9fca578

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -17
app.py CHANGED
@@ -1,33 +1,38 @@
1
  import torch
2
  from transformers import AutoModel, AutoTokenizer
3
  from PIL import Image
 
 
4
 
5
- # Load the OCR model and tokenizer with low memory usage in mind
6
  tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
7
  model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True,
8
  low_cpu_mem_usage=True,
9
  pad_token_id=tokenizer.eos_token_id).eval()
10
 
11
- # Ensure we are using CPU
12
- device = torch.device('cpu')
13
  model = model.to(device)
14
 
15
  # Function to perform OCR on the image
16
- def perform_ocr(image_path):
17
- # Open the image file
18
- image = Image.open(image_path)
 
 
19
 
20
- # Use torch.no_grad() to avoid unnecessary memory usage
21
- with torch.no_grad():
22
- # Perform OCR using the model
23
- result = model.chat(tokenizer, image_path, ocr_type='ocr')
24
-
25
- # Return the extracted text
26
  return result
27
 
28
- # Example usage with an image file path
29
- image_path = "/content/id.jpg"
30
- extracted_text = perform_ocr(image_path)
 
 
 
 
 
31
 
32
- # Output the extracted text
33
- print("Extracted Text:", extracted_text)
 
1
  import torch
2
  from transformers import AutoModel, AutoTokenizer
3
  from PIL import Image
4
+ import gradio as gr
5
+ import tempfile
6
 
7
+ # Load the OCR model and tokenizer
8
  tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
9
  model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True,
10
  low_cpu_mem_usage=True,
11
  pad_token_id=tokenizer.eos_token_id).eval()
12
 
13
+ # Check if GPU is available and use it, else use CPU
14
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
  model = model.to(device)
16
 
17
  # Function to perform OCR on the image
18
+ def perform_ocr(image):
19
+ # Save the image to a temporary file
20
+ with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_file:
21
+ image.save(temp_file.name) # Save the image
22
+ temp_image_path = temp_file.name # Get the file path for the saved image
23
 
24
+ # Perform OCR using the model
25
+ result = model.chat(tokenizer, temp_image_path, ocr_type='ocr')
 
 
 
 
26
  return result
27
 
28
+ # Create the Gradio interface using the new syntax
29
+ interface = gr.Interface(
30
+ fn=perform_ocr,
31
+ inputs=gr.Image(type="pil"), # Updated to gr.Image
32
+ outputs=gr.Textbox(), # Updated to gr.Textbox
33
+ title="OCR Web App",
34
+ description="Upload an image to extract text using the GOT-OCR2.0 model."
35
+ )
36
 
37
+ # Launch the app
38
+ interface.launch()