wjbmattingly commited on
Commit
cc1b97c
1 Parent(s): cd822b1

fixed cuda bug for pixel values

Browse files
Files changed (1) hide show
  1. app.py +5 -1
app.py CHANGED
@@ -95,15 +95,19 @@ def visualize_lines(image, lines):
95
  @spaces.GPU
96
  def transcribe_lines(line_images, model_name):
97
  processor, model = load_model(model_name)
98
-
99
  transcriptions = []
100
  for line_image in line_images:
101
  # Process the line image
102
  pixel_values = processor(images=line_image, return_tensors="pt").pixel_values
103
 
 
 
104
  # Generate (no beam search)
105
  generated_ids = model.generate(pixel_values)
106
 
 
 
107
  # Decode
108
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
109
  transcriptions.append(generated_text)
 
95
  @spaces.GPU
96
  def transcribe_lines(line_images, model_name):
97
  processor, model = load_model(model_name)
98
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
99
  transcriptions = []
100
  for line_image in line_images:
101
  # Process the line image
102
  pixel_values = processor(images=line_image, return_tensors="pt").pixel_values
103
 
104
+ pixel_values = pixel_values.to(device)
105
+
106
  # Generate (no beam search)
107
  generated_ids = model.generate(pixel_values)
108
 
109
+
110
+
111
  # Decode
112
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
113
  transcriptions.append(generated_text)