wjbmattingly commited on
Commit
0c1fe20
1 Parent(s): ea32023

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -28
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import gradio as gr
2
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
- import requests
4
- from PIL import Image
5
 
6
  # Dictionary of model names and their corresponding HuggingFace model IDs
7
  MODEL_OPTIONS = {
@@ -18,49 +17,57 @@ MODEL_OPTIONS = {
18
  "Medieval Print": "medieval-data/trocr-medieval-print"
19
  }
20
 
21
- # Load image examples
22
- urls = [
23
- 'https://huggingface.co/medieval-data/trocr-medieval-base/resolve/main/images/caroline-1.png'
24
- ]
25
-
26
- for idx, url in enumerate(urls):
27
- image = Image.open(requests.get(url, stream=True).raw)
28
- image.save(f"image_{idx}.png")
29
 
30
  def load_model(model_name):
31
- model_id = MODEL_OPTIONS[model_name]
32
- processor = TrOCRProcessor.from_pretrained(model_id)
33
- model = VisionEncoderDecoderModel.from_pretrained(model_id)
34
- return processor, model
 
 
 
 
 
 
 
 
 
35
 
36
  def process_image(image, model_name):
37
  processor, model = load_model(model_name)
38
 
39
- # prepare image
40
  pixel_values = processor(image, return_tensors="pt").pixel_values
41
 
42
- # generate (no beam search)
43
- generated_ids = model.generate(pixel_values)
 
44
 
45
- # decode
 
 
 
 
46
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
47
  return generated_text
48
 
49
- title = "Interactive demo: TrOCR Model Switcher"
50
- description = "Demo for the Medieval TrOCR HTR Models."
51
-
52
  iface = gr.Interface(
53
  fn=process_image,
54
  inputs=[
55
- gr.Image(type="pil"),
56
- gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), label="Select Model")
57
  ],
58
- outputs=gr.Textbox(),
59
- title=title,
60
- description=description,
61
  examples=[
62
- ["image_0.png", "Medieval Latin Caroline"]
63
  ]
64
  )
65
 
66
- iface.launch(debug=True)
 
1
  import gradio as gr
2
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
+ import torch
 
4
 
5
  # Dictionary of model names and their corresponding HuggingFace model IDs
6
  MODEL_OPTIONS = {
 
17
  "Medieval Print": "medieval-data/trocr-medieval-print"
18
  }
19
 
20
+ # Global variables to store the current model and processor
21
+ current_model = None
22
+ current_processor = None
23
+ current_model_name = None
 
 
 
 
24
 
25
  def load_model(model_name):
26
+ global current_model, current_processor, current_model_name
27
+
28
+ if model_name != current_model_name:
29
+ model_id = MODEL_OPTIONS[model_name]
30
+ current_processor = TrOCRProcessor.from_pretrained(model_id)
31
+ current_model = VisionEncoderDecoderModel.from_pretrained(model_id)
32
+ current_model_name = model_name
33
+
34
+ # Move model to GPU if available
35
+ if torch.cuda.is_available():
36
+ current_model = current_model.to('cuda')
37
+
38
+ return current_processor, current_model
39
 
40
  def process_image(image, model_name):
41
  processor, model = load_model(model_name)
42
 
43
+ # Prepare image
44
  pixel_values = processor(image, return_tensors="pt").pixel_values
45
 
46
+ # Move input to GPU if model is on GPU
47
+ if next(model.parameters()).is_cuda:
48
+ pixel_values = pixel_values.to('cuda')
49
 
50
+ # Generate (no beam search)
51
+ with torch.no_grad():
52
+ generated_ids = model.generate(pixel_values)
53
+
54
+ # Decode
55
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
56
  return generated_text
57
 
58
+ # Gradio interface
 
 
59
  iface = gr.Interface(
60
  fn=process_image,
61
  inputs=[
62
+ gr.Image(type="pil", label="Input Image"),
63
+ gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), label="Select Model", value="Medieval Base")
64
  ],
65
+ outputs=gr.Textbox(label="Transcription"),
66
+ title="Medieval TrOCR Model Switcher",
67
+ description="Upload an image of medieval text and select a model to transcribe it.",
68
  examples=[
69
+ ["https://huggingface.co/medieval-data/trocr-medieval-base/resolve/main/images/caroline-1.png", "Medieval Latin Caroline"]
70
  ]
71
  )
72
 
73
+ iface.launch()