SonnySW commited on
Commit
6ed96fd
·
1 Parent(s): ee81140

2024.07.04 complete gernerate image service

Browse files
app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 객체검출 -> 삭제 체크박스 적용본
2
+ import torch
3
+ from PIL import Image, ImageDraw
4
+ from transformers import DetrImageProcessor, DetrForObjectDetection
5
+ from diffusers import StableDiffusionInpaintPipeline
6
+ import gradio as gr
7
+
8
+ # 모델 로드
9
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
10
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
11
+ pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", torch_dtype=torch.float32)
12
+ pipe = pipe.to("cpu")
13
+
14
+ def detect_objects(image):
15
+ # 객체 검출
16
+ inputs = processor(images=image, return_tensors="pt")
17
+ outputs = model(**inputs)
18
+
19
+ # 결과 후처리```````
20
+ target_sizes = torch.tensor([image.size[::-1]])
21
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes)[0]
22
+
23
+ # 검출된 객체 정보 추출
24
+ detected_objects = []
25
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
26
+ if score > 0.9:
27
+ box = [round(i) for i in box.tolist()]
28
+ detected_objects.append({"label": model.config.id2label[label.item()], "box": box})
29
+
30
+ return detected_objects
31
+
32
+ def display_detected_objects(image):
33
+ detected_objects = detect_objects(image)
34
+ labeled_image = image.copy()
35
+ draw = ImageDraw.Draw(labeled_image)
36
+ object_labels = []
37
+ for obj in detected_objects:
38
+ box = obj["box"]
39
+ label = obj["label"]
40
+ draw.rectangle(box, outline="red", width=3)
41
+ draw.text((box[0], box[1]), label, fill="red")
42
+ object_labels.append(f"{label} at {box}")
43
+ return labeled_image, gr.update(choices=object_labels)
44
+
45
+ def inpaint_image(image, selected_objects):
46
+ detected_objects = detect_objects(image)
47
+
48
+ # 마스크 생성
49
+ mask = Image.new("L", image.size, 0)
50
+ draw = ImageDraw.Draw(mask)
51
+ for obj in detected_objects:
52
+ object_label = f"{obj['label']} at {obj['box']}"
53
+ if object_label in selected_objects:
54
+ box = obj["box"]
55
+ draw.rectangle(box, fill=255)
56
+
57
+ # Inpainting 수행
58
+ image = image.convert("RGB")
59
+ mask = mask.convert("RGB")
60
+ output = pipe(prompt="a modern interior", image=image, mask_image=mask).images[0]
61
+ # output = pipe(prompt="remove", image=image, mask_image=mask).images[0]
62
+
63
+
64
+ return output
65
+
66
+ # Gradio 인터페이스 설정
67
+ with gr.Blocks() as interface:
68
+ with gr.Row():
69
+ image_input = gr.Image(type="pil", label="Input Image")
70
+ objects_list = gr.CheckboxGroup(label="Detected Objects")
71
+
72
+ labeled_image_output = gr.Image(label="Labeled Image")
73
+ final_output = gr.Image(label="Output Image")
74
+
75
+ detect_button = gr.Button("Detect Objects")
76
+ inpaint_button = gr.Button("Remove Selected Objects")
77
+
78
+ detect_button.click(fn=display_detected_objects, inputs=image_input, outputs=[labeled_image_output, objects_list])
79
+ inpaint_button.click(fn=inpaint_image, inputs=[image_input, objects_list], outputs=final_output)
80
+
81
+ # Gradio 인터페이스 실행
82
+ interface.launch()
app.py:Zone.Identifier ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ altair==5.3.0
3
+ annotated-types==0.7.0
4
+ anyio==4.4.0
5
+ asttokens==2.4.1
6
+ attrs==23.2.0
7
+ certifi==2024.7.4
8
+ charset-normalizer==3.3.2
9
+ click==8.1.7
10
+ comm==0.2.2
11
+ contourpy==1.2.1
12
+ cycler==0.12.1
13
+ debugpy==1.8.2
14
+ decorator==5.1.1
15
+ diffusers==0.29.2
16
+ dnspython==2.6.1
17
+ email_validator==2.2.0
18
+ exceptiongroup==1.2.1
19
+ executing==2.0.1
20
+ fastapi==0.111.0
21
+ fastapi-cli==0.0.4
22
+ ffmpy==0.3.2
23
+ filelock==3.15.4
24
+ fonttools==4.53.0
25
+ fsspec==2024.6.1
26
+ gradio==4.37.2
27
+ gradio_client==1.0.2
28
+ h11==0.14.0
29
+ httpcore==1.0.5
30
+ httptools==0.6.1
31
+ httpx==0.27.0
32
+ huggingface-hub==0.23.4
33
+ idna==3.7
34
+ importlib_metadata==8.0.0
35
+ importlib_resources==6.4.0
36
+ ipykernel==6.29.5
37
+ ipython==8.26.0
38
+ jedi==0.19.1
39
+ Jinja2==3.1.4
40
+ jsonschema==4.22.0
41
+ jsonschema-specifications==2023.12.1
42
+ jupyter_client==8.6.2
43
+ jupyter_core==5.7.2
44
+ kiwisolver==1.4.5
45
+ markdown-it-py==3.0.0
46
+ MarkupSafe==2.1.5
47
+ matplotlib==3.9.0
48
+ matplotlib-inline==0.1.7
49
+ mdurl==0.1.2
50
+ mpmath==1.3.0
51
+ nest-asyncio==1.6.0
52
+ networkx==3.3
53
+ numpy==1.26.4
54
+ orjson==3.10.6
55
+ packaging==24.1
56
+ pandas==2.2.2
57
+ parso==0.8.4
58
+ pexpect==4.9.0
59
+ pillow==10.4.0
60
+ platformdirs==4.2.2
61
+ prompt_toolkit==3.0.47
62
+ psutil==6.0.0
63
+ ptyprocess==0.7.0
64
+ pure-eval==0.2.2
65
+ pydantic==2.8.2
66
+ pydantic_core==2.20.1
67
+ pydub==0.25.1
68
+ Pygments==2.18.0
69
+ pyparsing==3.1.2
70
+ python-dateutil==2.9.0.post0
71
+ python-dotenv==1.0.1
72
+ python-multipart==0.0.9
73
+ pytz==2024.1
74
+ PyYAML==6.0.1
75
+ pyzmq==26.0.3
76
+ referencing==0.35.1
77
+ regex==2024.5.15
78
+ requests==2.32.3
79
+ rich==13.7.1
80
+ rpds-py==0.18.1
81
+ ruff==0.5.0
82
+ safetensors==0.4.3
83
+ semantic-version==2.10.0
84
+ shellingham==1.5.4
85
+ six==1.16.0
86
+ sniffio==1.3.1
87
+ stack-data==0.6.3
88
+ starlette==0.37.2
89
+ sympy==1.12.1
90
+ timm==1.0.7
91
+ tokenizers==0.19.1
92
+ tomlkit==0.12.0
93
+ toolz==0.12.1
94
+ torch==2.3.1
95
+ torchvision==0.18.1
96
+ tornado==6.4.1
97
+ tqdm==4.66.4
98
+ traitlets==5.14.3
99
+ transformers==4.42.3
100
+ triton==2.3.1
101
+ typer==0.12.3
102
+ typing_extensions==4.12.2
103
+ tzdata==2024.1
104
+ ujson==5.10.0
105
+ urllib3==2.2.2
106
+ uvicorn==0.30.1
107
+ uvloop==0.19.0
108
+ watchfiles==0.22.0
109
+ wcwidth==0.2.13
110
+ websockets==11.0.3
111
+ zipp==3.19.2
requirements.txt:Zone.Identifier ADDED
File without changes