wjbmattingly commited on
Commit
3fc0241
1 Parent(s): 546d56f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -7,6 +7,7 @@ from PIL import Image, ImageDraw
7
  import os
8
  import tempfile
9
  import numpy as np
 
10
  # Dictionary of model names and their corresponding HuggingFace model IDs
11
  MODEL_OPTIONS = {
12
  "Microsoft Handwritten": "microsoft/trocr-base-handwritten",
@@ -36,12 +37,12 @@ def load_model(model_name):
36
  current_model = VisionEncoderDecoderModel.from_pretrained(model_id)
37
  current_model_name = model_name
38
 
39
- # Move model to GPU
40
- current_model = current_model.to('cuda')
 
41
 
42
  return current_processor, current_model
43
 
44
-
45
  def process_image(image, model_name):
46
  # Save the uploaded image to a temporary file
47
  with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_img:
@@ -59,6 +60,9 @@ def process_image(image, model_name):
59
 
60
  processor, model = load_model(model_name)
61
 
 
 
 
62
  # Process each line
63
  transcriptions = []
64
  for line in lines_data['lines']:
@@ -79,7 +83,7 @@ def process_image(image, model_name):
79
 
80
  # Prepare image for TrOCR
81
  pixel_values = processor(images=line_image_np, return_tensors="pt").pixel_values
82
- pixel_values = pixel_values.to('cuda')
83
 
84
  # Generate (no beam search)
85
  with torch.no_grad():
@@ -117,4 +121,4 @@ with gr.Blocks() as iface:
117
  submit_button = gr.Button("Transcribe")
118
  submit_button.click(fn=process_image, inputs=[input_image, model_dropdown], outputs=[output_image, transcription_output])
119
 
120
- iface.launch()
 
7
  import os
8
  import tempfile
9
  import numpy as np
10
+
11
  # Dictionary of model names and their corresponding HuggingFace model IDs
12
  MODEL_OPTIONS = {
13
  "Microsoft Handwritten": "microsoft/trocr-base-handwritten",
 
37
  current_model = VisionEncoderDecoderModel.from_pretrained(model_id)
38
  current_model_name = model_name
39
 
40
+ # Move model to GPU if available, else use CPU
41
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
42
+ current_model = current_model.to(device)
43
 
44
  return current_processor, current_model
45
 
 
46
  def process_image(image, model_name):
47
  # Save the uploaded image to a temporary file
48
  with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_img:
 
60
 
61
  processor, model = load_model(model_name)
62
 
63
+ # Determine device
64
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
65
+
66
  # Process each line
67
  transcriptions = []
68
  for line in lines_data['lines']:
 
83
 
84
  # Prepare image for TrOCR
85
  pixel_values = processor(images=line_image_np, return_tensors="pt").pixel_values
86
+ pixel_values = pixel_values.to(device)
87
 
88
  # Generate (no beam search)
89
  with torch.no_grad():
 
121
  submit_button = gr.Button("Transcribe")
122
  submit_button.click(fn=process_image, inputs=[input_image, model_dropdown], outputs=[output_image, transcription_output])
123
 
124
+ iface.launch(share=True)