jeo053 commited on
Commit
218916e
ยท
verified ยท
1 Parent(s): e3ec907

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +337 -0
app.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ import torch
6
+ import cv2
7
+ import pandas as pd
8
+ import os
9
+
10
+ from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
11
+ from transformers import AutoImageProcessor, DetrForObjectDetection
12
+ # segmentation
13
+ processor_seg = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes")
14
+ model_seg = AutoModelForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")
15
+ #object detection
16
+ processor_obj = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
17
+ model_obj = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
18
+
19
+
20
+ def center_image(image_path,width=700):
21
+ st.markdown(
22
+ f'<style>img {{ display: block; margin-left: auto; margin-right: auto; }} </style>',
23
+ unsafe_allow_html=True
24
+ )
25
+ st.image(image_path,width = width)
26
+
27
+
28
+ ### INTRO ###
29
+ st.header('๐Ÿ‘š ์˜ค๋Š˜ ๋ญ์ž…์ง€?! ๐Ÿ‘•')
30
+ st.markdown('๐Ÿ’ฌ : ๐Ÿšจ **์„ค๋งˆ ๋„ˆ ์ง€๊ธˆ.. ๊ทธ๋ ‡๊ฒŒ ์ž…๊ณ  ๋‚˜๊ฐ€๊ฒŒ?** ๐Ÿšจ')
31
+ st.markdown(' **ํŒจ์…˜์„ผ์Šค๊ฐ€ 2% ๋ถ€์กฑํ•œ ๋‹น์‹ ์„ ์œ„ํ•ด ์ค€๋น„ํ–ˆ์Šต๋‹ˆ๋‹ค!** ์‚ฌ์ง„ ์ด๋ฏธ์ง€๋งŒ ์ž…๋ ฅํ•˜๋ฉด, ์š”์ฆ˜ ํŠธ๋ Œ๋””ํ•œ ์Šคํƒ€์ผ๊ณผ ์—ฌ๋Ÿฌ๋ถ„์˜ TPO๋ฅผ ๊ณ ๋ คํ•˜์—ฌ ์ฝ”๋””๋ฅผ ์ถ”์ฒœํ•ด๋“œ๋ฆฝ๋‹ˆ๋‹ค. ๋ฌด์‹ ์‚ฌ์™€ ์˜จ๋”๋ฃฉ์˜ ํŒจ์…”๋‹ˆ์Šคํƒ€๋“ค์˜ ์ฝ”๋””๋ฅผ ์ง€๊ธˆ ๋ฐ”๋กœ ์ฐธ๊ณ ํ•ด๋ณด์„ธ์š”! ')
32
+ center_image('./intro_img/fashionista.jpg')
33
+
34
+ st.markdown('--------------------------------------------------------------------------------------')
35
+ st.subheader('PROCESS')
36
+ center_image('./intro_img/process.png')
37
+ st.markdown('--------------------------------------------------------------------------------------')
38
+
39
+
40
+ ## INPUT ###
41
+ st.subheader(' โœ… ์˜๋ฅ˜ ์ด๋ฏธ์ง€ ์—…๋กœ๋“œ ')
42
+ input_image = st.file_uploader(" **์˜๋ฅ˜ ์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•˜์„ธ์š”. (๋ฐฐ๊ฒฝ์ด ๊น”๋”ํ•œ ์‚ฌ์ง„์ด๋ผ๋ฉด ๋” ์ข‹์Šต๋‹ˆ๋‹ค!)** ", type=['png', 'jpg', 'jpeg'])
43
+ if not input_image :
44
+ con = st.container()
45
+ st.stop()
46
+ center_image(input_image,400)
47
+ st.markdown('--------------------------------------------------------------------------------------')
48
+
49
+ st.subheader(' โœ… ์—…๋กœ๋“œํ•œ ์˜๋ฅ˜ ์ด๋ฏธ์ง€ ์นดํ…Œ๊ณ ๋ฆฌ ์„ ํƒ ')
50
+ input_cat = st.radio(
51
+ "**๊ท€ํ•˜๊ฐ€ ์—…๋กœ๋“œํ•œ ์˜๋ฅ˜ ์ด๋ฏธ์ง€์˜ ์นดํ…Œ๊ณ ๋ฆฌ๋ฅผ ๊ณจ๋ผ์ฃผ์„ธ์š”.**",
52
+ ['top๐Ÿ‘•', 'bottom๐Ÿ‘–', 'shoes๐Ÿ‘ž', 'hat๐Ÿงข', 'sunglasses๐Ÿ•ถ๏ธ', 'scarf๐Ÿงฃ', 'bag๐Ÿ‘œ'],
53
+ index=None,
54
+ horizontal = True)
55
+
56
+
57
+ if not input_cat :
58
+ con = st.container()
59
+ st.stop()
60
+ input_cat = input_cat[:-1]
61
+ st.write('You selected:', input_cat)
62
+ st.markdown('--------------------------------------------------------------------------------------')
63
+
64
+ st.subheader(' โœ… ์ถ”์ฒœ๋ฐ›๊ณ  ์‹ถ์€ ์˜๋ฅ˜ ์นดํ…Œ๊ณ ๋ฆฌ ์„ ํƒ ')
65
+ output_cat = st.radio(
66
+ '**์ถ”์ฒœ๋ฐ›๊ณ  ์‹ถ์€ ์˜๋ฅ˜ ์นดํ…Œ๊ณ ๋ฆฌ๋ฅผ ์„ ํƒํ•ด์ฃผ์„ธ์š”.**',
67
+ ['top๐Ÿ‘•', 'bottom๐Ÿ‘–', 'shoes๐Ÿ‘ž', 'hat๐Ÿงข', 'sunglasses๐Ÿ•ถ๏ธ', 'scarf๐Ÿงฃ', 'bag๐Ÿ‘œ'],
68
+ index=None,
69
+ horizontal = True)
70
+
71
+ if not output_cat :
72
+ con = st.container()
73
+ st.write('๐Ÿšซ ์ฃผ์˜: ์—…๋กœ๋“œํ•œ ์˜๋ฅ˜ ์นดํ…Œ๊ณ ๋ฆฌ์™€ ๋‹ค๋ฅธ ์นดํ…Œ๊ณ ๋ฆฌ๋ฅผ ์„ ํƒํ•ด์ฃผ์„ธ์š”.')
74
+ st.stop()
75
+ output_cat = output_cat[:-1]
76
+ st.write('You selected:', output_cat)
77
+ st.write(' ')
78
+ st.markdown('--------------------------------------------------------------------------------------')
79
+
80
+
81
+ st.subheader(' โœ… ์ƒํ™ฉ ์นดํ…Œ๊ณ ๋ฆฌ ์„ ํƒ ')
82
+ situation = st.radio(
83
+ "**์ƒํ™ฉ ์นดํ…Œ๊ณ ๋ฆฌ๋ฅผ ์„ ํƒํ•ด์ฃผ์„ธ์š”.**",
84
+ ['์—ฌํ–‰๐ŸŒŠ', '์นดํŽ˜โ˜•๏ธ', '์ „์‹œํšŒ๐Ÿ–ผ๏ธ', '์บ ํผ์Šค๐Ÿซ & ์ถœ๊ทผ๐Ÿ’ผ', '๊ธ‰์ถ”์œ„๐Ÿคง', '์šด๋™๐Ÿ’ช'],
85
+ captions = ['(๋ฐ”๋‹ค,์—ฌํ–‰)','(์นดํŽ˜, ๋ฐ์ผ๋ฆฌ)','(๋ฐ์ดํŠธ, ๊ฒฐํ˜ผ์‹)','','',''],
86
+ index=None,
87
+ horizontal = True)
88
+
89
+ # ์„ ํƒ๋œ ์ƒํ™ฉ ์นดํ…Œ๊ณ ๋ฆฌ๋ฅผ ์˜์–ด๋กœ ๋ณ€ํ™˜ํ•ด์„œ ๋ณ€์ˆ˜ ์ €์žฅ
90
+ situation_mapping = {
91
+ '์—ฌํ–‰๐ŸŒŠ': 'travel',
92
+ '์นดํŽ˜โ˜•๏ธ': 'cafe',
93
+ '์ „์‹œํšŒ๐Ÿ–ผ๏ธ': 'exhibit',
94
+ '์บ ํผ์Šค๐Ÿซ & ์ถœ๊ทผ๐Ÿ’ผ': 'campus_work',
95
+ '๊ธ‰์ถ”์œ„๐Ÿคง': 'cold',
96
+ '์šด๋™๐Ÿ’ช': 'exercise'}
97
+
98
+ if not situation:
99
+ con = st.container()
100
+ st.stop()
101
+ situation= situation_mapping[situation]
102
+ st.write('You selected:', situation)
103
+
104
+ ## ๋ณ€์ˆ˜ ๋ช…
105
+ # input_img
106
+ # input_cat : ์ž…์€ ์˜ท ์นดํ…Œ๊ณ ๋ฆฌ
107
+ # output_cat : ์ถ”์ฒœ ๋ฐ›์„ ์นดํ…Œ๊ณ ๋ฆฌ
108
+ # situation : ์ƒํ™ฉ
109
+
110
+ st.markdown('--------------------------------------------------------------------------------------')
111
+
112
+
113
+ ### ์ž…๋ ฅ๋ฐ›์€ ์ด๋ฏธ์ง€ segmentation & detection & vector๋ณ€ํ™˜ ###
114
+ image = Image.open(input_image)
115
+
116
+ # object detection & cropping ํ•จ์ˆ˜
117
+ def cropping(images,st = 1,
118
+ fi = 0.0,
119
+ step = -0.05):
120
+ image_1 = Image.fromarray(images)
121
+ inputs = processor_obj(images=image_1, return_tensors="pt")
122
+ outputs = model_obj(**inputs)
123
+ for tre in np.arange(st,fi,step):
124
+ try:
125
+ # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
126
+ target_sizes = torch.tensor([image_1.size[::-1]])
127
+ results = processor_obj.post_process_object_detection(outputs, threshold=tre, target_sizes=target_sizes)[0]
128
+
129
+ img = None
130
+ for idx, (score, label, box) in enumerate(zip(results["scores"], results["labels"], results["boxes"])):
131
+ box = [round(i, 2) for i in box.tolist()]
132
+ xmin, ymin, xmax, ymax = box
133
+ img = image_1.crop((xmin, ymin, xmax, ymax))
134
+
135
+ poss = np.array(img).sum().sum()
136
+ return img
137
+ break
138
+ except:
139
+ continue
140
+ return images
141
+
142
+ # vector ๋ณ€ํ™˜ ํ•จ์ˆ˜
143
+ default_path = './'
144
+ def image_to_vector(image,resize_size=(256,256)): # ์ด๋ฏธ์ง€ size ๋ณ€ํ™˜ resize(256,256)
145
+ #image = Image.fromarray(image)
146
+ #image = image.resize(resize_size)
147
+ image = Image.fromarray(np.copy(image))
148
+ image = image.resize(resize_size)
149
+ image_array = np.array(image, dtype=np.float32)
150
+ image_vector = image_array.flatten()
151
+ return image_vector
152
+
153
+ # ์ „์ฒด ํ†ตํ•ฉ ํ•จ์ˆ˜
154
+ def final_image(image):
155
+ if len(np.array(image).shape) == 2:
156
+ image = Image.fromarray(image).convert('RGB')
157
+ # segmentation
158
+ inputs = processor_seg(images=image, return_tensors="pt")
159
+ outputs = model_seg(**inputs)
160
+ logits = outputs.logits.cpu()
161
+ upsampled_logits = nn.functional.interpolate(
162
+ logits,
163
+ size=image.size[::-1],
164
+ mode="bilinear",
165
+ align_corners=False,
166
+ )
167
+ pred_seg = upsampled_logits.argmax(dim=1)[0]
168
+ segments = torch.unique(pred_seg)
169
+ default_path = './'
170
+
171
+ for i in segments:
172
+ if int(i) == 0:
173
+ continue
174
+ if int(i) == 1:
175
+ cloth = 'hat'
176
+ cloths = 'hat'
177
+ mask = pred_seg == i
178
+ image = np.array(image)
179
+ mask_np = (mask * 255).numpy().astype(np.uint8)
180
+ result = cv2.bitwise_and(image.astype(np.uint8), image.astype(np.uint8), mask=mask_np)
181
+ img = cropping(result)
182
+ img_vector = image_to_vector(img)
183
+ elif int(i) == 3:
184
+ cloth= 'sunglasses'
185
+ cloths= 'sunglasses'
186
+ mask = pred_seg == i
187
+ image = np.array(image)
188
+ mask_np = (mask * 255).numpy().astype(np.uint8)
189
+ result = cv2.bitwise_and(image.astype(np.uint8), image.astype(np.uint8), mask=mask_np)
190
+ img = cropping(result)
191
+ img_vector = image_to_vector(img)
192
+ elif int(i) == 4:
193
+ cloth = 'top'
194
+ cloths = 'top'
195
+ mask = pred_seg == i
196
+ image = np.array(image)
197
+ mask_np = (mask * 255).numpy().astype(np.uint8)
198
+ result = cv2.bitwise_and(image.astype(np.uint8), image.astype(np.uint8), mask=mask_np)
199
+ img = cropping(result)
200
+ img_vector = image_to_vector(img)
201
+ elif int(i) in [5,6,7]:
202
+ cloth= ['pants','skirt','dress']
203
+ cloths= 'bottom'
204
+ mask = (pred_seg == torch.tensor(5)) | (pred_seg == torch.tensor(6)) | (pred_seg == torch.tensor(7))
205
+ image = np.array(image)
206
+ mask_np = (mask * 255).numpy().astype(np.uint8)
207
+ result = cv2.bitwise_and(image.astype(np.uint8), image.astype(np.uint8), mask=mask_np)
208
+ img = cropping(result)
209
+ img_vector = image_to_vector(img)
210
+ elif int(i) == 8:
211
+ cloth = 'belt'
212
+ cloths = 'belt'
213
+ mask = pred_seg == torch.tensor(8)
214
+ image = np.array(image)
215
+ mask_np = (mask * 255).numpy().astype(np.uint8)
216
+ result = cv2.bitwise_and(image.astype(np.uint8), image.astype(np.uint8), mask=mask_np)
217
+ img = cropping(result)
218
+ img_vector = image_to_vector(img)
219
+ elif (int(i) == 9):
220
+ cloth = 'shoes'
221
+ cloths = 'shoes'
222
+ mask = (pred_seg == torch.tensor(9)) | (pred_seg == torch.tensor(10))
223
+ image = np.array(image)
224
+ mask_np = (mask * 255).numpy().astype(np.uint8)
225
+ result = cv2.bitwise_and(image.astype(np.uint8), image.astype(np.uint8), mask=mask_np)
226
+ img = cropping(result)
227
+ img_vector = image_to_vector(img)
228
+ elif int(i) == 16:
229
+ cloth = 'bag'
230
+ cloths = 'bag'
231
+ mask = pred_seg == torch.tensor(16)
232
+ image = np.array(image)
233
+ mask_np = (mask * 255).numpy().astype(np.uint8)
234
+ result = cv2.bitwise_and(image.astype(np.uint8), image.astype(np.uint8), mask=mask_np)
235
+ img = cropping(result)
236
+ img_vector = image_to_vector(img)
237
+ elif int(i) == 17:
238
+ cloth = 'scarf'
239
+ cloths = 'scarf'
240
+ mask = pred_seg == torch.tensor(17)
241
+ image = np.array(image)
242
+ mask_np = (mask * 255).numpy().astype(np.uint8)
243
+ result = cv2.bitwise_and(image.astype(np.uint8), image.astype(np.uint8), mask=mask_np)
244
+ img = cropping(result)
245
+ img_vector = image_to_vector(img)
246
+ return img_vector
247
+
248
+ # ์ž…๋ ฅ๋ฐ›์€ ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ ์™„๋ฃŒ
249
+ input_img = final_image(image)
250
+
251
+ ### ์œ ์‚ฌ๋„ ๋ถ„์„ ###
252
+ # ํ•˜๋‚˜๋Š” ์ด๋ฏธ๏ฟฝ๏ฟฝ๏ฟฝ, ๋‹ค๋ฅธ ํ•˜๋‚˜๋Š” ๊ฒฝ๋กœ๋กœ ๋ฐ›๋Š” ๊ฒฝ์šฐ
253
+ def cosine_similarity(vec1, vec2_path):
254
+ vec2 = np.loadtxt(vec2_path)
255
+ dot_product = np.dot(vec1, vec2)
256
+ norm_vec1 = np.linalg.norm(vec1)
257
+ norm_vec2 = np.linalg.norm(vec2)
258
+ similarity = dot_product / (norm_vec1 * norm_vec2)
259
+ return similarity
260
+
261
+ # ๋‘˜ ๋‹ค ๊ฒฝ๋กœ๋กœ ๋ฐ›๋Š” ๊ฒฝ์šฐ
262
+ def cosine_similarity_2(vec1_path, vec2_path):
263
+ vec1 = np.loadtxt(vec1_path)
264
+ vec2 = np.loadtxt(vec2_path)
265
+ dot_product = np.dot(vec1, vec2)
266
+ norm_vec1 = np.linalg.norm(vec1)
267
+ norm_vec2 = np.linalg.norm(vec2)
268
+ similarity = dot_product / (norm_vec1 * norm_vec2)
269
+ return similarity
270
+
271
+ with st.spinner('Wait for it...'):
272
+ # ์ž…๋ ฅ๋ฐ›์€ ์ด๋ฏธ์ง€ & ๋™์ผ ์นดํ…Œ๊ณ ๋ฆฌ ํด๋”์— ์ €์žฅ๋œ ์Šคํƒ€์ผ ์ด๋ฏธ์ง€
273
+ sim_list = []
274
+ file_path = './style/' + situation + '/' + input_cat + '/' # ex) './cafe/top/'
275
+ cloths = os.listdir('./style/' + situation + '/' + input_cat + '/')
276
+ for cloth in cloths:
277
+ sim_list.append(cosine_similarity(input_img, file_path + cloth))
278
+ max_idx = np.argmax(sim_list)
279
+
280
+ # target_image ์ •์˜
281
+ target_image = './style/' + situation + '/' + output_cat + '/' + cloths[max_idx]
282
+ # ์œ ์‚ฌ๋„ ๋ถ„์„ ์™„๋ฃŒ๋œ ์Šคํƒ€์ผseg ์ด๋ฏธ์ง€์™€ product_seg ์œ ์‚ฌ๋„๋ถ„์„
283
+ sim_list = []
284
+ file_path = './product/' + output_cat + '/'
285
+ cloths = os.listdir('./product/' + output_cat + '/')
286
+ for cloth in cloths:
287
+ sim_list.append(cosine_similarity_2(target_image, file_path + cloth))
288
+ max_idx = np.argmax(sim_list)
289
+ output_name = cloths[max_idx]
290
+ ## ์˜ˆ์‹œ ์ถœ๋ ฅ๊ฐ’: 'bottom_1883.txt'
291
+
292
+ # name ๋กœ๋“œ
293
+ acc_name = pd.read_csv('acc_name.csv')
294
+ bottom_name =pd.read_csv('bottom_name.csv')
295
+ outer_name =pd.read_csv('outer_name.csv')
296
+ shoes_name =pd.read_csv('shoes_name.csv')
297
+ top_name =pd.read_csv('top_name.csv')
298
+
299
+ #์ƒํ’ˆ ๋ฐ์ดํ„ฐ ๋กœ๋“œ
300
+ outer = pd.read_csv('outer.csv')
301
+ top = pd.read_csv('top.csv')
302
+ bottom = pd.read_csv('bottom.csv')
303
+ shoes = pd.read_csv('shoes.csv')
304
+ acc = pd.read_csv('acc.csv')
305
+
306
+ if output_cat == 'bottom':
307
+ df = bottom.copy()
308
+ df_name = bottom_name.copy()
309
+ elif output_cat == 'top':
310
+ df = top.copy()
311
+ df_name = top_name.copy()
312
+ elif output_cat == 'shoes':
313
+ df = shoes.copy()
314
+ df_name = shoes_name.copy()
315
+ elif (output_cat == 'hat') or (output_cat == 'sunglasses') or (output_cat == 'scarf') or (output_cat == 'bag') or (output_cat == 'belt'):
316
+ df = acc.copy()
317
+ df_name = acc_name.copy()
318
+
319
+ output_name = output_name.split('.')[0]
320
+ file_name = df_name[df_name['index']==output_name].iloc[0,1] #3049906_16754112975667_500.jpg
321
+ final = df[df['id'] == file_name]
322
+
323
+ name = final['name'].values[0].split('\n')[-1] # ์ƒํ’ˆ๋ช…
324
+ price = final['price'].values[0] # ์ƒํ’ˆ๊ฐ€๊ฒฉ
325
+
326
+ image_path = './product/img/'
327
+
328
+ st.subheader('OUTPUT')
329
+
330
+ img = Image.open(image_path+output_cat+'/'+file_name)
331
+
332
+ col1, col2, col3 = st.columns(3)
333
+ with col1:
334
+ st.image(img,width=400)
335
+ with col3:
336
+ st.caption('์ƒํ’๋ช… : ' + name)
337
+ st.caption('๊ฐ€๊ฒฉ : ' + price)