NegiTurkey commited on
Commit
1ea5bb8
·
verified ·
1 Parent(s): 5bb0782

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -18
app.py CHANGED
@@ -1,12 +1,11 @@
1
  import gradio as gr
2
  from gradio_imageslider import ImageSlider
3
  from loadimg import load_img
4
- import spaces
5
  from transformers import AutoModelForImageSegmentation
6
  import torch
7
  from torchvision import transforms
8
- import zipfile
9
  import os
 
10
 
11
  torch.set_float32_matmul_precision(["high", "highest"][0])
12
 
@@ -22,12 +21,10 @@ transform_image = transforms.Compose(
22
  ]
23
  )
24
 
25
- @spaces.GPU
26
  def fn(image):
27
  im = load_img(image, output_type="pil")
28
  im = im.convert("RGB")
29
  image_size = im.size
30
- origin = im.copy()
31
  input_images = transform_image(im).unsqueeze(0).to("cpu")
32
 
33
  with torch.no_grad():
@@ -35,34 +32,32 @@ def fn(image):
35
  pred = preds[0].squeeze()
36
  pred_pil = transforms.ToPILImage()(pred)
37
  mask = pred_pil.resize(image_size)
 
38
  im.putalpha(mask)
39
-
40
  output_file_path = os.path.join("output_images", "output_image_single.png")
41
  im.save(output_file_path)
42
 
43
- return (im, origin)
44
 
45
- @spaces.GPU
46
  def fn_url(url):
47
  im = load_img(url, output_type="pil")
48
  im = im.convert("RGB")
49
- origin = im.copy()
50
  image_size = im.size
51
  input_images = transform_image(im).unsqueeze(0).to("cpu")
52
 
 
53
  with torch.no_grad():
54
  preds = birefnet(input_images)[-1].sigmoid().cpu()
55
  pred = preds[0].squeeze()
56
  pred_pil = transforms.ToPILImage()(pred)
57
  mask = pred_pil.resize(image_size)
 
58
  im.putalpha(mask)
59
-
60
  output_file_path = os.path.join("output_images", "output_image_url.png")
61
  im.save(output_file_path)
62
 
63
- return [im, origin]
64
 
65
- @spaces.GPU
66
  def batch_fn(images):
67
  output_paths = []
68
  for idx, image_path in enumerate(images):
@@ -76,6 +71,7 @@ def batch_fn(images):
76
  pred = preds[0].squeeze()
77
  pred_pil = transforms.ToPILImage()(pred)
78
  mask = pred_pil.resize(image_size)
 
79
  im.putalpha(mask)
80
 
81
  output_file_path = os.path.join("output_images", f"output_image_batch_{idx + 1}.png")
@@ -89,7 +85,7 @@ def batch_fn(images):
89
 
90
  return zip_file_path
91
 
92
- batch_image = gr.File(label="Upload multiple images", type="filepath", file_count="multiple") # 複数画像のアップロードを許可
93
 
94
  slider1 = ImageSlider(label="Processed Image", type="pil")
95
  slider2 = ImageSlider(label="Processed Image from URL", type="pil")
@@ -109,12 +105,7 @@ tab3 = gr.Interface(
109
  batch_fn,
110
  inputs=batch_image,
111
  outputs=gr.File(label="Download Processed Files"),
112
- api_name="batch",
113
- css="""
114
- #component-37 {
115
- display: none;
116
- }
117
- """
118
  )
119
 
120
  demo = gr.TabbedInterface(
 
1
  import gradio as gr
2
  from gradio_imageslider import ImageSlider
3
  from loadimg import load_img
 
4
  from transformers import AutoModelForImageSegmentation
5
  import torch
6
  from torchvision import transforms
 
7
  import os
8
+ import zipfile
9
 
10
  torch.set_float32_matmul_precision(["high", "highest"][0])
11
 
 
21
  ]
22
  )
23
 
 
24
  def fn(image):
25
  im = load_img(image, output_type="pil")
26
  im = im.convert("RGB")
27
  image_size = im.size
 
28
  input_images = transform_image(im).unsqueeze(0).to("cpu")
29
 
30
  with torch.no_grad():
 
32
  pred = preds[0].squeeze()
33
  pred_pil = transforms.ToPILImage()(pred)
34
  mask = pred_pil.resize(image_size)
35
+
36
  im.putalpha(mask)
 
37
  output_file_path = os.path.join("output_images", "output_image_single.png")
38
  im.save(output_file_path)
39
 
40
+ return [mask, im]
41
 
 
42
  def fn_url(url):
43
  im = load_img(url, output_type="pil")
44
  im = im.convert("RGB")
 
45
  image_size = im.size
46
  input_images = transform_image(im).unsqueeze(0).to("cpu")
47
 
48
+ # Prediction
49
  with torch.no_grad():
50
  preds = birefnet(input_images)[-1].sigmoid().cpu()
51
  pred = preds[0].squeeze()
52
  pred_pil = transforms.ToPILImage()(pred)
53
  mask = pred_pil.resize(image_size)
54
+
55
  im.putalpha(mask)
 
56
  output_file_path = os.path.join("output_images", "output_image_url.png")
57
  im.save(output_file_path)
58
 
59
+ return [mask, im]
60
 
 
61
  def batch_fn(images):
62
  output_paths = []
63
  for idx, image_path in enumerate(images):
 
71
  pred = preds[0].squeeze()
72
  pred_pil = transforms.ToPILImage()(pred)
73
  mask = pred_pil.resize(image_size)
74
+
75
  im.putalpha(mask)
76
 
77
  output_file_path = os.path.join("output_images", f"output_image_batch_{idx + 1}.png")
 
85
 
86
  return zip_file_path
87
 
88
+ batch_image = gr.File(label="Upload multiple images", type="filepath", file_count="multiple")
89
 
90
  slider1 = ImageSlider(label="Processed Image", type="pil")
91
  slider2 = ImageSlider(label="Processed Image from URL", type="pil")
 
105
  batch_fn,
106
  inputs=batch_image,
107
  outputs=gr.File(label="Download Processed Files"),
108
+ api_name="batch"
 
 
 
 
 
109
  )
110
 
111
  demo = gr.TabbedInterface(