akhaliq HF staff commited on
Commit
be1ceb8
·
1 Parent(s): c072566

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -49
app.py CHANGED
@@ -34,7 +34,6 @@ from data import get_dataset
34
  import torchvision.transforms as transforms
35
 
36
  import gradio as gr
37
- import streamlit as st
38
 
39
  model_name = "convnext_xlarge_in22k"
40
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
@@ -332,57 +331,42 @@ def load_model():
332
  """
333
  lseg_model, lseg_transform = load_model()
334
 
335
- # to be revised
336
- uploaded_file = gr.inputs.Image(type='pil')
337
- input_labels = st.text_input("Input labels", value="dog, grass, other")
338
- gr.outputs.Label(type="confidences",num_top_classes=5)
339
- st.write("The labels are", input_labels)
340
 
341
- image = Image.open(uploaded_file)
342
- pimage = lseg_transform(np.array(image)).unsqueeze(0)
343
 
344
- labels = []
345
- for label in input_labels.split(","):
346
- labels.append(label.strip())
347
-
348
- with torch.no_grad():
349
- outputs = lseg_model.parallel_forward(pimage, labels)
350
 
351
- predicts = [
352
- torch.max(output, 1)[1].cpu().numpy()
353
- for output in outputs
354
- ]
 
 
 
 
 
 
 
355
 
356
- image = pimage[0].permute(1,2,0)
357
- image = image * 0.5 + 0.5
358
- image = Image.fromarray(np.uint8(255*image)).convert("RGBA")
359
-
360
- pred = predicts[0]
361
- new_palette = get_new_pallete(len(labels))
362
- mask, patches = get_new_mask_pallete(pred, new_palette, out_label_flag=True, labels=labels)
363
- seg = mask.convert("RGBA")
364
-
365
- fig = plt.figure()
366
- plt.subplot(121)
367
- plt.imshow(image)
368
- plt.axis('off')
369
-
370
- plt.subplot(122)
371
- plt.imshow(seg)
372
- plt.legend(handles=patches, loc='upper right', bbox_to_anchor=(1.3, 1), prop={'size': 5})
373
- plt.axis('off')
374
-
375
- plt.tight_layout()
376
-
377
- #st.image([image,seg], width=700, caption=["Input image", "Segmentation"])
378
- st.pyplot(fig)
379
-
380
- title = "LSeg"
381
-
382
- description = "Gradio demo for LSeg for semantic segmentation. To use it, simply upload your image, or click one of the examples to load them, then add any label set"
383
-
384
- article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2201.03546' target='_blank'>Language-driven Semantic Segmentation</a> | <a href='hhttps://github.com/isl-org/lang-seg' target='_blank'>Github Repo</a></p>"
385
 
386
- examples = ['test.jpeg']
387
 
388
- gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, analytics_enabled=False, examples=examples).launch(enable_queue=True)
 
34
  import torchvision.transforms as transforms
35
 
36
  import gradio as gr
 
37
 
38
  model_name = "convnext_xlarge_in22k"
39
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
331
  """
332
  lseg_model, lseg_transform = load_model()
333
 
334
+ def inference(image,text):
335
+ input_labels = text
 
 
 
336
 
337
+ pimage = lseg_transform(np.array(image)).unsqueeze(0)
 
338
 
339
+ labels = []
340
+ for label in input_labels.split(","):
341
+ labels.append(label.strip())
 
 
 
342
 
343
+ with torch.no_grad():
344
+ outputs = lseg_model.parallel_forward(pimage, labels)
345
+
346
+ predicts = [
347
+ torch.max(output, 1)[1].cpu().numpy()
348
+ for output in outputs
349
+ ]
350
+
351
+ image = pimage[0].permute(1,2,0)
352
+ image = image * 0.5 + 0.5
353
+ image = Image.fromarray(np.uint8(255*image)).convert("RGBA")
354
 
355
+ pred = predicts[0]
356
+ new_palette = get_new_pallete(len(labels))
357
+ mask, patches = get_new_mask_pallete(pred, new_palette, out_label_flag=True, labels=labels)
358
+ seg = mask.convert("RGBA")
359
+
360
+ fig = plt.figure()
361
+ plt.subplot(121)
362
+ plt.axis('off')
363
+
364
+ plt.subplot(122)
365
+ plt.legend(handles=patches, loc='upper right', bbox_to_anchor=(1.3, 1), prop={'size': 5})
366
+ plt.axis('off')
367
+
368
+ plt.tight_layout()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
 
370
+ return plt
371
 
372
+ gr.Interface(inference,["image","text"],"plot").launch(debug=True)