akhaliq HF staff commited on
Commit
00b1c44
1 Parent(s): 511b41a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -20
app.py CHANGED
@@ -14,19 +14,13 @@ from matplotlib import pyplot as plt
14
  from torchvision import transforms
15
  from diffusers import DiffusionPipeline
16
 
17
- from share_btn import community_icon_html, loading_icon_html, share_js
18
-
19
- auth_token = os.environ.get("API_TOKEN") or True
20
-
21
- device = "cuda" if torch.cuda.is_available() else "cpu"
22
 
23
- pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", dtype=torch.float16, revision="fp16", use_auth_token=auth_token).to(device)
24
-
25
- transform = transforms.Compose([
26
- transforms.ToTensor(),
27
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
28
- transforms.Resize((512, 512)),
29
- ])
30
 
31
  def read_content(file_path: str) -> str:
32
  """read the content of target file
@@ -36,11 +30,11 @@ def read_content(file_path: str) -> str:
36
 
37
  return content
38
 
39
- def predict(dict, prompt=""):
40
- init_image = dict["image"].convert("RGB").resize((512, 512))
41
- mask = dict["mask"].convert("RGB").resize((512, 512))
42
- output = pipe(prompt = prompt, image=init_image, mask_image=mask,guidance_scale=7.5)
43
- return output.images[0], gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
44
 
45
 
46
  css = '''
@@ -89,9 +83,9 @@ with image_blocks as demo:
89
  with gr.Box():
90
  with gr.Row():
91
  with gr.Column():
92
- image = gr.Image(source='upload', tool='sketch', elem_id="image_upload", type="pil", label="Upload").style(height=400)
 
93
  with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
94
- prompt = gr.Textbox(placeholder = 'Your prompt (what you want in place of what is erased)', show_label=False, elem_id="input-text")
95
  btn = gr.Button("Inpaint!").style(
96
  margin=False,
97
  rounded=(False, True, True, False),
@@ -105,7 +99,7 @@ with image_blocks as demo:
105
  share_button = gr.Button("Share to community", elem_id="share-btn", visible=False)
106
 
107
 
108
- btn.click(fn=predict, inputs=[image, prompt], outputs=[image_out, community_icon, loading_icon, share_button])
109
  share_button.click(None, [], [], _js=share_js)
110
 
111
 
 
14
  from torchvision import transforms
15
  from diffusers import DiffusionPipeline
16
 
17
+ pipe = DiffusionPipeline.from_pretrained(
18
+ "patrickvonplaten/new_inpaint_test",
19
+ torch_dtype=torch.float16,
20
+ )
21
+ pipe = pipe.to("cuda")
22
 
23
+ from share_btn import community_icon_html, loading_icon_html, share_js
 
 
 
 
 
 
24
 
25
  def read_content(file_path: str) -> str:
26
  """read the content of target file
 
30
 
31
  return content
32
 
33
+ def predict(dict, example_image):
34
+ init_image = dict["image"].convert("RGB")
35
+ mask = dict["mask"].convert("RGB")
36
+ image = pipe(image=init_image, mask_image=mask, example_image=example_image).images[0]
37
+ return image, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
38
 
39
 
40
  css = '''
 
83
  with gr.Box():
84
  with gr.Row():
85
  with gr.Column():
86
+ image = gr.Image(source='upload', tool='sketch', elem_id="image_upload", type="pil", label="Upload")
87
+ example = gr.Image(source='upload', elem_id="image_upload", type="pil", label="Upload")
88
  with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
 
89
  btn = gr.Button("Inpaint!").style(
90
  margin=False,
91
  rounded=(False, True, True, False),
 
99
  share_button = gr.Button("Share to community", elem_id="share-btn", visible=False)
100
 
101
 
102
+ btn.click(fn=predict, inputs=[image, example], outputs=[image_out, community_icon, loading_icon, share_button])
103
  share_button.click(None, [], [], _js=share_js)
104
 
105