Vinay15 commited on
Commit
dec293d
·
verified ·
1 Parent(s): b94f0a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -32
app.py CHANGED
@@ -2,44 +2,37 @@ import torch
2
  from transformers import AutoModel, AutoTokenizer
3
  from PIL import Image
4
  import gradio as gr
5
- import os
6
 
7
- # Specify the revision ID you want to pin to
8
- revision_id = "your_revision_id_here" # Replace with the actual revision ID
9
-
10
- # Load the OCR model and tokenizer with pinned revision
11
- tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', revision=revision_id, trust_remote_code=True)
12
- model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', revision=revision_id, trust_remote_code=True,
13
  low_cpu_mem_usage=True,
14
  pad_token_id=tokenizer.eos_token_id).eval()
15
 
16
- # Move model to CPU
17
- device = torch.device('cpu')
18
  model = model.to(device)
19
 
20
- # Function to perform OCR on an image file
21
- def perform_ocr(image_file_path):
22
- # Open the image using PIL
23
- image = Image.open(image_file_path)
24
-
25
- # Save the image temporarily
26
- temp_image_path = "temp_image.png"
27
- image.save(temp_image_path)
28
-
29
- # Use torch.no_grad() to avoid unnecessary memory usage
30
- with torch.no_grad():
31
- # Perform OCR using the model on CPU (pass the file path of the saved image)
32
- result = model.chat(tokenizer, temp_image_path, ocr_type='ocr')
33
 
34
- # Clean up the temporary image file
35
- os.remove(temp_image_path)
36
-
37
- # Return the extracted text
38
  return result
39
 
40
- # Gradio interface for file upload and OCR
41
- iface = gr.Interface(fn=perform_ocr, inputs="file", outputs="text",
42
- title="OCR Application", description="Upload an image to extract text.")
43
-
44
- # Launch the Gradio app
45
- iface.launch()
 
 
 
 
 
 
2
  from transformers import AutoModel, AutoTokenizer
3
  from PIL import Image
4
  import gradio as gr
5
+ import tempfile
6
 
7
+ # Load the OCR model and tokenizer
8
+ tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
9
+ model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True,
 
 
 
10
  low_cpu_mem_usage=True,
11
  pad_token_id=tokenizer.eos_token_id).eval()
12
 
13
+ # Check if GPU is available and use it, else use CPU
14
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
  model = model.to(device)
16
 
17
+ # Function to perform OCR on the image
18
+ def perform_ocr(image):
19
+ # Save the image to a temporary file
20
+ with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_file:
21
+ image.save(temp_file.name) # Save the image
22
+ temp_image_path = temp_file.name # Get the file path for the saved image
 
 
 
 
 
 
 
23
 
24
+ # Perform OCR using the model
25
+ result = model.chat(tokenizer, temp_image_path, ocr_type='ocr')
 
 
26
  return result
27
 
28
+ # Create the Gradio interface using the new syntax
29
+ interface = gr.Interface(
30
+ fn=perform_ocr,
31
+ inputs=gr.Image(type="pil"), # Updated to gr.Image
32
+ outputs=gr.Textbox(), # Updated to gr.Textbox
33
+ title="OCR Web App",
34
+ description="Upload an image to extract text using the GOT-OCR2.0 model."
35
+ )
36
+
37
+ # Launch the app
38
+ interface.launch()