gdurkin commited on
Commit
9dfb0a4
1 Parent(s): 16649f2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -0
app.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ from transformers import Mask2FormerImageProcessor, Mask2FormerForUniversalSegmentation
5
+
6
+ # Load the model from Hugging Face
7
+ model_name = "gdurkin/cdl_mask2former_hi_res_v3"
8
+ processor = Mask2FormerImageProcessor.from_pretrained(model_name)
9
+ model = Mask2FormerForUniversalSegmentation.from_pretrained(model_name)
10
+ device = torch.device('cpu')
11
+
12
+ # Define the inference function
13
+ def predict(img):
14
+ if isinstance(img, np.ndarray):
15
+ img = torch.from_numpy(img).float()
16
+ if torch.is_tensor(img):
17
+ input_tensor = img
18
+ else:
19
+ raise ValueError("Unsupported image format")
20
+
21
+ if input_tensor.ndim == 3:
22
+ input_tensor = input_tensor.unsqueeze(0)
23
+ elif input_tensor.ndim != 4:
24
+ raise ValueError("Input tensor must be 3D or 4D")
25
+
26
+ input_tensor = input_tensor.permute(0, 3, 1, 2) # Ensure the tensor is in the correct shape (N, C, H, W)
27
+
28
+ with torch.no_grad():
29
+ outputs = model(input_tensor.to(device))
30
+
31
+ target_sizes = [(input_tensor.shape[2], input_tensor.shape[3])]
32
+ predicted_segmentation_maps = processor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes)
33
+
34
+ return predicted_segmentation_maps[0].cpu().numpy()
35
+
36
+ # Create a Gradio interface
37
+ iface = gr.Interface(
38
+ fn=predict,
39
+ inputs=gr.Image(type="numpy", image_mode='RGB'),
40
+ outputs="numpy",
41
+ live=True
42
+ )
43
+
44
+ # Launch the interface
45
+ iface.launch()