Spaces:
Build error
Build error
import streamlit as st | |
from PIL import Image | |
import torch.nn as nn | |
import numpy as np | |
import torch | |
import cv2 | |
import pandas as pd | |
import os | |
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation | |
from transformers import AutoImageProcessor, DetrForObjectDetection | |
# segmentation | |
processor_seg = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes") | |
model_seg = AutoModelForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes") | |
#object detection | |
processor_obj = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50") | |
model_obj = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") | |
def center_image(image_path,width=700): | |
st.markdown( | |
f'<style>img {{ display: block; margin-left: auto; margin-right: auto; }} </style>', | |
unsafe_allow_html=True | |
) | |
st.image(image_path,width = width) | |
### INTRO ### | |
st.header('๐ ์ค๋ ๋ญ์ ์ง?! ๐') | |
st.markdown('๐ฌ : ๐จ **์ค๋ง ๋ ์ง๊ธ.. ๊ทธ๋ ๊ฒ ์ ๊ณ ๋๊ฐ๊ฒ?** ๐จ') | |
st.markdown(' **ํจ์ ์ผ์ค๊ฐ 2% ๋ถ์กฑํ ๋น์ ์ ์ํด ์ค๋นํ์ต๋๋ค!** ์ฌ์ง ์ด๋ฏธ์ง๋ง ์ ๋ ฅํ๋ฉด, ์์ฆ ํธ๋ ๋ํ ์คํ์ผ๊ณผ ์ฌ๋ฌ๋ถ์ TPO๋ฅผ ๊ณ ๋ คํ์ฌ ์ฝ๋๋ฅผ ์ถ์ฒํด๋๋ฆฝ๋๋ค. ๋ฌด์ ์ฌ์ ์จ๋๋ฃฉ์ ํจ์ ๋์คํ๋ค์ ์ฝ๋๋ฅผ ์ง๊ธ ๋ฐ๋ก ์ฐธ๊ณ ํด๋ณด์ธ์! ') | |
center_image('./intro_img/fashionista.jpg') | |
st.markdown('--------------------------------------------------------------------------------------') | |
st.subheader('PROCESS') | |
center_image('./intro_img/process.png') | |
st.markdown('--------------------------------------------------------------------------------------') | |
## INPUT ### | |
st.subheader(' โ ์๋ฅ ์ด๋ฏธ์ง ์ ๋ก๋ ') | |
input_image = st.file_uploader(" **์๋ฅ ์ด๋ฏธ์ง๋ฅผ ์ ๋ก๋ํ์ธ์. (๋ฐฐ๊ฒฝ์ด ๊น๋ํ ์ฌ์ง์ด๋ผ๋ฉด ๋ ์ข์ต๋๋ค!)** ", type=['png', 'jpg', 'jpeg']) | |
if not input_image : | |
con = st.container() | |
st.stop() | |
center_image(input_image,400) | |
st.markdown('--------------------------------------------------------------------------------------') | |
st.subheader(' โ ์ ๋ก๋ํ ์๋ฅ ์ด๋ฏธ์ง ์นดํ ๊ณ ๋ฆฌ ์ ํ ') | |
input_cat = st.radio( | |
"**๊ทํ๊ฐ ์ ๋ก๋ํ ์๋ฅ ์ด๋ฏธ์ง์ ์นดํ ๊ณ ๋ฆฌ๋ฅผ ๊ณจ๋ผ์ฃผ์ธ์.**", | |
['top๐', 'bottom๐', 'shoes๐', 'hat๐งข', 'sunglasses๐ถ๏ธ', 'scarf๐งฃ', 'bag๐'], | |
index=None, | |
horizontal = True) | |
if not input_cat : | |
con = st.container() | |
st.stop() | |
input_cat = input_cat[:-1] | |
st.write('You selected:', input_cat) | |
st.markdown('--------------------------------------------------------------------------------------') | |
st.subheader(' โ ์ถ์ฒ๋ฐ๊ณ ์ถ์ ์๋ฅ ์นดํ ๊ณ ๋ฆฌ ์ ํ ') | |
output_cat = st.radio( | |
'**์ถ์ฒ๋ฐ๊ณ ์ถ์ ์๋ฅ ์นดํ ๊ณ ๋ฆฌ๋ฅผ ์ ํํด์ฃผ์ธ์.**', | |
['top๐', 'bottom๐', 'shoes๐', 'hat๐งข', 'sunglasses๐ถ๏ธ', 'scarf๐งฃ', 'bag๐'], | |
index=None, | |
horizontal = True) | |
if not output_cat : | |
con = st.container() | |
st.write('๐ซ ์ฃผ์: ์ ๋ก๋ํ ์๋ฅ ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ์นดํ ๊ณ ๋ฆฌ๋ฅผ ์ ํํด์ฃผ์ธ์.') | |
st.stop() | |
output_cat = output_cat[:-1] | |
st.write('You selected:', output_cat) | |
st.write(' ') | |
st.markdown('--------------------------------------------------------------------------------------') | |
st.subheader(' โ ์ํฉ ์นดํ ๊ณ ๋ฆฌ ์ ํ ') | |
situation = st.radio( | |
"**์ํฉ ์นดํ ๊ณ ๋ฆฌ๋ฅผ ์ ํํด์ฃผ์ธ์.**", | |
['์ฌํ๐', '์นดํโ๏ธ', '์ ์ํ๐ผ๏ธ', '์บ ํผ์ค๐ซ & ์ถ๊ทผ๐ผ', '๊ธ์ถ์๐คง', '์ด๋๐ช'], | |
captions = ['(๋ฐ๋ค,์ฌํ)','(์นดํ, ๋ฐ์ผ๋ฆฌ)','(๋ฐ์ดํธ, ๊ฒฐํผ์)','','',''], | |
index=None, | |
horizontal = True) | |
# ์ ํ๋ ์ํฉ ์นดํ ๊ณ ๋ฆฌ๋ฅผ ์์ด๋ก ๋ณํํด์ ๋ณ์ ์ ์ฅ | |
situation_mapping = { | |
'์ฌํ๐': 'travel', | |
'์นดํโ๏ธ': 'cafe', | |
'์ ์ํ๐ผ๏ธ': 'exhibit', | |
'์บ ํผ์ค๐ซ & ์ถ๊ทผ๐ผ': 'campus_work', | |
'๊ธ์ถ์๐คง': 'cold', | |
'์ด๋๐ช': 'exercise'} | |
if not situation: | |
con = st.container() | |
st.stop() | |
situation= situation_mapping[situation] | |
st.write('You selected:', situation) | |
## ๋ณ์ ๋ช | |
# input_img | |
# input_cat : ์ ์ ์ท ์นดํ ๊ณ ๋ฆฌ | |
# output_cat : ์ถ์ฒ ๋ฐ์ ์นดํ ๊ณ ๋ฆฌ | |
# situation : ์ํฉ | |
st.markdown('--------------------------------------------------------------------------------------') | |
### ์ ๋ ฅ๋ฐ์ ์ด๋ฏธ์ง segmentation & detection & vector๋ณํ ### | |
image = Image.open(input_image) | |
# object detection & cropping ํจ์ | |
def cropping(images,st = 1, | |
fi = 0.0, | |
step = -0.05): | |
image_1 = Image.fromarray(images) | |
inputs = processor_obj(images=image_1, return_tensors="pt") | |
outputs = model_obj(**inputs) | |
for tre in np.arange(st,fi,step): | |
try: | |
# convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax) | |
target_sizes = torch.tensor([image_1.size[::-1]]) | |
results = processor_obj.post_process_object_detection(outputs, threshold=tre, target_sizes=target_sizes)[0] | |
img = None | |
for idx, (score, label, box) in enumerate(zip(results["scores"], results["labels"], results["boxes"])): | |
box = [round(i, 2) for i in box.tolist()] | |
xmin, ymin, xmax, ymax = box | |
img = image_1.crop((xmin, ymin, xmax, ymax)) | |
poss = np.array(img).sum().sum() | |
return img | |
break | |
except: | |
continue | |
return images | |
# vector ๋ณํ ํจ์ | |
default_path = './' | |
def image_to_vector(image,resize_size=(256,256)): # ์ด๋ฏธ์ง size ๋ณํ resize(256,256) | |
#image = Image.fromarray(image) | |
#image = image.resize(resize_size) | |
image = Image.fromarray(np.copy(image)) | |
image = image.resize(resize_size) | |
image_array = np.array(image, dtype=np.float32) | |
image_vector = image_array.flatten() | |
return image_vector | |
# ์ ์ฒด ํตํฉ ํจ์ | |
def final_image(image): | |
if len(np.array(image).shape) == 2: | |
image = Image.fromarray(image).convert('RGB') | |
# segmentation | |
inputs = processor_seg(images=image, return_tensors="pt") | |
outputs = model_seg(**inputs) | |
logits = outputs.logits.cpu() | |
upsampled_logits = nn.functional.interpolate( | |
logits, | |
size=image.size[::-1], | |
mode="bilinear", | |
align_corners=False, | |
) | |
pred_seg = upsampled_logits.argmax(dim=1)[0] | |
segments = torch.unique(pred_seg) | |
default_path = './' | |
for i in segments: | |
if int(i) == 0: | |
continue | |
if int(i) == 1: | |
cloth = 'hat' | |
cloths = 'hat' | |
mask = pred_seg == i | |
image = np.array(image) | |
mask_np = (mask * 255).numpy().astype(np.uint8) | |
result = cv2.bitwise_and(image.astype(np.uint8), image.astype(np.uint8), mask=mask_np) | |
img = cropping(result) | |
img_vector = image_to_vector(img) | |
elif int(i) == 3: | |
cloth= 'sunglasses' | |
cloths= 'sunglasses' | |
mask = pred_seg == i | |
image = np.array(image) | |
mask_np = (mask * 255).numpy().astype(np.uint8) | |
result = cv2.bitwise_and(image.astype(np.uint8), image.astype(np.uint8), mask=mask_np) | |
img = cropping(result) | |
img_vector = image_to_vector(img) | |
elif int(i) == 4: | |
cloth = 'top' | |
cloths = 'top' | |
mask = pred_seg == i | |
image = np.array(image) | |
mask_np = (mask * 255).numpy().astype(np.uint8) | |
result = cv2.bitwise_and(image.astype(np.uint8), image.astype(np.uint8), mask=mask_np) | |
img = cropping(result) | |
img_vector = image_to_vector(img) | |
elif int(i) in [5,6,7]: | |
cloth= ['pants','skirt','dress'] | |
cloths= 'bottom' | |
mask = (pred_seg == torch.tensor(5)) | (pred_seg == torch.tensor(6)) | (pred_seg == torch.tensor(7)) | |
image = np.array(image) | |
mask_np = (mask * 255).numpy().astype(np.uint8) | |
result = cv2.bitwise_and(image.astype(np.uint8), image.astype(np.uint8), mask=mask_np) | |
img = cropping(result) | |
img_vector = image_to_vector(img) | |
elif int(i) == 8: | |
cloth = 'belt' | |
cloths = 'belt' | |
mask = pred_seg == torch.tensor(8) | |
image = np.array(image) | |
mask_np = (mask * 255).numpy().astype(np.uint8) | |
result = cv2.bitwise_and(image.astype(np.uint8), image.astype(np.uint8), mask=mask_np) | |
img = cropping(result) | |
img_vector = image_to_vector(img) | |
elif (int(i) == 9): | |
cloth = 'shoes' | |
cloths = 'shoes' | |
mask = (pred_seg == torch.tensor(9)) | (pred_seg == torch.tensor(10)) | |
image = np.array(image) | |
mask_np = (mask * 255).numpy().astype(np.uint8) | |
result = cv2.bitwise_and(image.astype(np.uint8), image.astype(np.uint8), mask=mask_np) | |
img = cropping(result) | |
img_vector = image_to_vector(img) | |
elif int(i) == 16: | |
cloth = 'bag' | |
cloths = 'bag' | |
mask = pred_seg == torch.tensor(16) | |
image = np.array(image) | |
mask_np = (mask * 255).numpy().astype(np.uint8) | |
result = cv2.bitwise_and(image.astype(np.uint8), image.astype(np.uint8), mask=mask_np) | |
img = cropping(result) | |
img_vector = image_to_vector(img) | |
elif int(i) == 17: | |
cloth = 'scarf' | |
cloths = 'scarf' | |
mask = pred_seg == torch.tensor(17) | |
image = np.array(image) | |
mask_np = (mask * 255).numpy().astype(np.uint8) | |
result = cv2.bitwise_and(image.astype(np.uint8), image.astype(np.uint8), mask=mask_np) | |
img = cropping(result) | |
img_vector = image_to_vector(img) | |
return img_vector | |
# ์ ๋ ฅ๋ฐ์ ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ ์๋ฃ | |
input_img = final_image(image) | |
### ์ ์ฌ๋ ๋ถ์ ### | |
# ํ๋๋ ์ด๋ฏธ์ง, ๋ค๋ฅธ ํ๋๋ ๊ฒฝ๋ก๋ก ๋ฐ๋ ๊ฒฝ์ฐ | |
def cosine_similarity(vec1, vec2_path): | |
vec2 = np.loadtxt(vec2_path) | |
dot_product = np.dot(vec1, vec2) | |
norm_vec1 = np.linalg.norm(vec1) | |
norm_vec2 = np.linalg.norm(vec2) | |
similarity = dot_product / (norm_vec1 * norm_vec2) | |
return similarity | |
# ๋ ๋ค ๊ฒฝ๋ก๋ก ๋ฐ๋ ๊ฒฝ์ฐ | |
def cosine_similarity_2(vec1_path, vec2_path): | |
vec1 = np.loadtxt(vec1_path) | |
vec2 = np.loadtxt(vec2_path) | |
dot_product = np.dot(vec1, vec2) | |
norm_vec1 = np.linalg.norm(vec1) | |
norm_vec2 = np.linalg.norm(vec2) | |
similarity = dot_product / (norm_vec1 * norm_vec2) | |
return similarity | |
with st.spinner('Wait for it...'): | |
# ์ ๋ ฅ๋ฐ์ ์ด๋ฏธ์ง & ๋์ผ ์นดํ ๊ณ ๋ฆฌ ํด๋์ ์ ์ฅ๋ ์คํ์ผ ์ด๋ฏธ์ง | |
sim_list = [] | |
file_path = './style/' + situation + '/' + input_cat + '/' # ex) './cafe/top/' | |
cloths = os.listdir('./style/' + situation + '/' + input_cat + '/') | |
for cloth in cloths: | |
sim_list.append(cosine_similarity(input_img, file_path + cloth)) | |
max_idx = np.argmax(sim_list) | |
# target_image ์ ์ | |
target_image = './style/' + situation + '/' + output_cat + '/' + cloths[max_idx] | |
# ์ ์ฌ๋ ๋ถ์ ์๋ฃ๋ ์คํ์ผseg ์ด๋ฏธ์ง์ product_seg ์ ์ฌ๋๋ถ์ | |
sim_list = [] | |
file_path = './product/' + output_cat + '/' | |
cloths = os.listdir('./product/' + output_cat + '/') | |
for cloth in cloths: | |
sim_list.append(cosine_similarity_2(target_image, file_path + cloth)) | |
max_idx = np.argmax(sim_list) | |
output_name = cloths[max_idx] | |
## ์์ ์ถ๋ ฅ๊ฐ: 'bottom_1883.txt' | |
# name ๋ก๋ | |
acc_name = pd.read_csv('acc_name.csv') | |
bottom_name =pd.read_csv('bottom_name.csv') | |
outer_name =pd.read_csv('outer_name.csv') | |
shoes_name =pd.read_csv('shoes_name.csv') | |
top_name =pd.read_csv('top_name.csv') | |
#์ํ ๋ฐ์ดํฐ ๋ก๋ | |
outer = pd.read_csv('outer.csv') | |
top = pd.read_csv('top.csv') | |
bottom = pd.read_csv('bottom.csv') | |
shoes = pd.read_csv('shoes.csv') | |
acc = pd.read_csv('acc.csv') | |
if output_cat == 'bottom': | |
df = bottom.copy() | |
df_name = bottom_name.copy() | |
elif output_cat == 'top': | |
df = top.copy() | |
df_name = top_name.copy() | |
elif output_cat == 'shoes': | |
df = shoes.copy() | |
df_name = shoes_name.copy() | |
elif (output_cat == 'hat') or (output_cat == 'sunglasses') or (output_cat == 'scarf') or (output_cat == 'bag') or (output_cat == 'belt'): | |
df = acc.copy() | |
df_name = acc_name.copy() | |
output_name = output_name.split('.')[0] | |
file_name = df_name[df_name['index']==output_name].iloc[0,1] #3049906_16754112975667_500.jpg | |
final = df[df['id'] == file_name] | |
name = final['name'].values[0].split('\n')[-1] # ์ํ๋ช | |
price = final['price'].values[0] # ์ํ๊ฐ๊ฒฉ | |
image_path = './product/img/' | |
st.subheader('OUTPUT') | |
img = Image.open(image_path+output_cat+'/'+file_name) | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.image(img,width=400) | |
with col3: | |
st.caption('์ํ๋ช : ' + name) | |
st.caption('๊ฐ๊ฒฉ : ' + price) | |