today_cloth / app.py
jeo053's picture
Create app.py
218916e verified
raw
history blame
13.1 kB
import streamlit as st
from PIL import Image
import torch.nn as nn
import numpy as np
import torch
import cv2
import pandas as pd
import os
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
from transformers import AutoImageProcessor, DetrForObjectDetection
# segmentation
processor_seg = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes")
model_seg = AutoModelForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")
#object detection
processor_obj = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
model_obj = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
def center_image(image_path,width=700):
st.markdown(
f'<style>img {{ display: block; margin-left: auto; margin-right: auto; }} </style>',
unsafe_allow_html=True
)
st.image(image_path,width = width)
### INTRO ###
st.header('๐Ÿ‘š ์˜ค๋Š˜ ๋ญ์ž…์ง€?! ๐Ÿ‘•')
st.markdown('๐Ÿ’ฌ : ๐Ÿšจ **์„ค๋งˆ ๋„ˆ ์ง€๊ธˆ.. ๊ทธ๋ ‡๊ฒŒ ์ž…๊ณ  ๋‚˜๊ฐ€๊ฒŒ?** ๐Ÿšจ')
st.markdown(' **ํŒจ์…˜์„ผ์Šค๊ฐ€ 2% ๋ถ€์กฑํ•œ ๋‹น์‹ ์„ ์œ„ํ•ด ์ค€๋น„ํ–ˆ์Šต๋‹ˆ๋‹ค!** ์‚ฌ์ง„ ์ด๋ฏธ์ง€๋งŒ ์ž…๋ ฅํ•˜๋ฉด, ์š”์ฆ˜ ํŠธ๋ Œ๋””ํ•œ ์Šคํƒ€์ผ๊ณผ ์—ฌ๋Ÿฌ๋ถ„์˜ TPO๋ฅผ ๊ณ ๋ คํ•˜์—ฌ ์ฝ”๋””๋ฅผ ์ถ”์ฒœํ•ด๋“œ๋ฆฝ๋‹ˆ๋‹ค. ๋ฌด์‹ ์‚ฌ์™€ ์˜จ๋”๋ฃฉ์˜ ํŒจ์…”๋‹ˆ์Šคํƒ€๋“ค์˜ ์ฝ”๋””๋ฅผ ์ง€๊ธˆ ๋ฐ”๋กœ ์ฐธ๊ณ ํ•ด๋ณด์„ธ์š”! ')
center_image('./intro_img/fashionista.jpg')
st.markdown('--------------------------------------------------------------------------------------')
st.subheader('PROCESS')
center_image('./intro_img/process.png')
st.markdown('--------------------------------------------------------------------------------------')
## INPUT ###
st.subheader(' โœ… ์˜๋ฅ˜ ์ด๋ฏธ์ง€ ์—…๋กœ๋“œ ')
input_image = st.file_uploader(" **์˜๋ฅ˜ ์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•˜์„ธ์š”. (๋ฐฐ๊ฒฝ์ด ๊น”๋”ํ•œ ์‚ฌ์ง„์ด๋ผ๋ฉด ๋” ์ข‹์Šต๋‹ˆ๋‹ค!)** ", type=['png', 'jpg', 'jpeg'])
if not input_image :
con = st.container()
st.stop()
center_image(input_image,400)
st.markdown('--------------------------------------------------------------------------------------')
st.subheader(' โœ… ์—…๋กœ๋“œํ•œ ์˜๋ฅ˜ ์ด๋ฏธ์ง€ ์นดํ…Œ๊ณ ๋ฆฌ ์„ ํƒ ')
input_cat = st.radio(
"**๊ท€ํ•˜๊ฐ€ ์—…๋กœ๋“œํ•œ ์˜๋ฅ˜ ์ด๋ฏธ์ง€์˜ ์นดํ…Œ๊ณ ๋ฆฌ๋ฅผ ๊ณจ๋ผ์ฃผ์„ธ์š”.**",
['top๐Ÿ‘•', 'bottom๐Ÿ‘–', 'shoes๐Ÿ‘ž', 'hat๐Ÿงข', 'sunglasses๐Ÿ•ถ๏ธ', 'scarf๐Ÿงฃ', 'bag๐Ÿ‘œ'],
index=None,
horizontal = True)
if not input_cat :
con = st.container()
st.stop()
input_cat = input_cat[:-1]
st.write('You selected:', input_cat)
st.markdown('--------------------------------------------------------------------------------------')
st.subheader(' โœ… ์ถ”์ฒœ๋ฐ›๊ณ  ์‹ถ์€ ์˜๋ฅ˜ ์นดํ…Œ๊ณ ๋ฆฌ ์„ ํƒ ')
output_cat = st.radio(
'**์ถ”์ฒœ๋ฐ›๊ณ  ์‹ถ์€ ์˜๋ฅ˜ ์นดํ…Œ๊ณ ๋ฆฌ๋ฅผ ์„ ํƒํ•ด์ฃผ์„ธ์š”.**',
['top๐Ÿ‘•', 'bottom๐Ÿ‘–', 'shoes๐Ÿ‘ž', 'hat๐Ÿงข', 'sunglasses๐Ÿ•ถ๏ธ', 'scarf๐Ÿงฃ', 'bag๐Ÿ‘œ'],
index=None,
horizontal = True)
if not output_cat :
con = st.container()
st.write('๐Ÿšซ ์ฃผ์˜: ์—…๋กœ๋“œํ•œ ์˜๋ฅ˜ ์นดํ…Œ๊ณ ๋ฆฌ์™€ ๋‹ค๋ฅธ ์นดํ…Œ๊ณ ๋ฆฌ๋ฅผ ์„ ํƒํ•ด์ฃผ์„ธ์š”.')
st.stop()
output_cat = output_cat[:-1]
st.write('You selected:', output_cat)
st.write(' ')
st.markdown('--------------------------------------------------------------------------------------')
st.subheader(' โœ… ์ƒํ™ฉ ์นดํ…Œ๊ณ ๋ฆฌ ์„ ํƒ ')
situation = st.radio(
"**์ƒํ™ฉ ์นดํ…Œ๊ณ ๋ฆฌ๋ฅผ ์„ ํƒํ•ด์ฃผ์„ธ์š”.**",
['์—ฌํ–‰๐ŸŒŠ', '์นดํŽ˜โ˜•๏ธ', '์ „์‹œํšŒ๐Ÿ–ผ๏ธ', '์บ ํผ์Šค๐Ÿซ & ์ถœ๊ทผ๐Ÿ’ผ', '๊ธ‰์ถ”์œ„๐Ÿคง', '์šด๋™๐Ÿ’ช'],
captions = ['(๋ฐ”๋‹ค,์—ฌํ–‰)','(์นดํŽ˜, ๋ฐ์ผ๋ฆฌ)','(๋ฐ์ดํŠธ, ๊ฒฐํ˜ผ์‹)','','',''],
index=None,
horizontal = True)
# ์„ ํƒ๋œ ์ƒํ™ฉ ์นดํ…Œ๊ณ ๋ฆฌ๋ฅผ ์˜์–ด๋กœ ๋ณ€ํ™˜ํ•ด์„œ ๋ณ€์ˆ˜ ์ €์žฅ
situation_mapping = {
'์—ฌํ–‰๐ŸŒŠ': 'travel',
'์นดํŽ˜โ˜•๏ธ': 'cafe',
'์ „์‹œํšŒ๐Ÿ–ผ๏ธ': 'exhibit',
'์บ ํผ์Šค๐Ÿซ & ์ถœ๊ทผ๐Ÿ’ผ': 'campus_work',
'๊ธ‰์ถ”์œ„๐Ÿคง': 'cold',
'์šด๋™๐Ÿ’ช': 'exercise'}
if not situation:
con = st.container()
st.stop()
situation= situation_mapping[situation]
st.write('You selected:', situation)
## ๋ณ€์ˆ˜ ๋ช…
# input_img
# input_cat : ์ž…์€ ์˜ท ์นดํ…Œ๊ณ ๋ฆฌ
# output_cat : ์ถ”์ฒœ ๋ฐ›์„ ์นดํ…Œ๊ณ ๋ฆฌ
# situation : ์ƒํ™ฉ
st.markdown('--------------------------------------------------------------------------------------')
### ์ž…๋ ฅ๋ฐ›์€ ์ด๋ฏธ์ง€ segmentation & detection & vector๋ณ€ํ™˜ ###
image = Image.open(input_image)
# object detection & cropping ํ•จ์ˆ˜
def cropping(images,st = 1,
fi = 0.0,
step = -0.05):
image_1 = Image.fromarray(images)
inputs = processor_obj(images=image_1, return_tensors="pt")
outputs = model_obj(**inputs)
for tre in np.arange(st,fi,step):
try:
# convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
target_sizes = torch.tensor([image_1.size[::-1]])
results = processor_obj.post_process_object_detection(outputs, threshold=tre, target_sizes=target_sizes)[0]
img = None
for idx, (score, label, box) in enumerate(zip(results["scores"], results["labels"], results["boxes"])):
box = [round(i, 2) for i in box.tolist()]
xmin, ymin, xmax, ymax = box
img = image_1.crop((xmin, ymin, xmax, ymax))
poss = np.array(img).sum().sum()
return img
break
except:
continue
return images
# vector ๋ณ€ํ™˜ ํ•จ์ˆ˜
default_path = './'
def image_to_vector(image,resize_size=(256,256)): # ์ด๋ฏธ์ง€ size ๋ณ€ํ™˜ resize(256,256)
#image = Image.fromarray(image)
#image = image.resize(resize_size)
image = Image.fromarray(np.copy(image))
image = image.resize(resize_size)
image_array = np.array(image, dtype=np.float32)
image_vector = image_array.flatten()
return image_vector
# ์ „์ฒด ํ†ตํ•ฉ ํ•จ์ˆ˜
def final_image(image):
if len(np.array(image).shape) == 2:
image = Image.fromarray(image).convert('RGB')
# segmentation
inputs = processor_seg(images=image, return_tensors="pt")
outputs = model_seg(**inputs)
logits = outputs.logits.cpu()
upsampled_logits = nn.functional.interpolate(
logits,
size=image.size[::-1],
mode="bilinear",
align_corners=False,
)
pred_seg = upsampled_logits.argmax(dim=1)[0]
segments = torch.unique(pred_seg)
default_path = './'
for i in segments:
if int(i) == 0:
continue
if int(i) == 1:
cloth = 'hat'
cloths = 'hat'
mask = pred_seg == i
image = np.array(image)
mask_np = (mask * 255).numpy().astype(np.uint8)
result = cv2.bitwise_and(image.astype(np.uint8), image.astype(np.uint8), mask=mask_np)
img = cropping(result)
img_vector = image_to_vector(img)
elif int(i) == 3:
cloth= 'sunglasses'
cloths= 'sunglasses'
mask = pred_seg == i
image = np.array(image)
mask_np = (mask * 255).numpy().astype(np.uint8)
result = cv2.bitwise_and(image.astype(np.uint8), image.astype(np.uint8), mask=mask_np)
img = cropping(result)
img_vector = image_to_vector(img)
elif int(i) == 4:
cloth = 'top'
cloths = 'top'
mask = pred_seg == i
image = np.array(image)
mask_np = (mask * 255).numpy().astype(np.uint8)
result = cv2.bitwise_and(image.astype(np.uint8), image.astype(np.uint8), mask=mask_np)
img = cropping(result)
img_vector = image_to_vector(img)
elif int(i) in [5,6,7]:
cloth= ['pants','skirt','dress']
cloths= 'bottom'
mask = (pred_seg == torch.tensor(5)) | (pred_seg == torch.tensor(6)) | (pred_seg == torch.tensor(7))
image = np.array(image)
mask_np = (mask * 255).numpy().astype(np.uint8)
result = cv2.bitwise_and(image.astype(np.uint8), image.astype(np.uint8), mask=mask_np)
img = cropping(result)
img_vector = image_to_vector(img)
elif int(i) == 8:
cloth = 'belt'
cloths = 'belt'
mask = pred_seg == torch.tensor(8)
image = np.array(image)
mask_np = (mask * 255).numpy().astype(np.uint8)
result = cv2.bitwise_and(image.astype(np.uint8), image.astype(np.uint8), mask=mask_np)
img = cropping(result)
img_vector = image_to_vector(img)
elif (int(i) == 9):
cloth = 'shoes'
cloths = 'shoes'
mask = (pred_seg == torch.tensor(9)) | (pred_seg == torch.tensor(10))
image = np.array(image)
mask_np = (mask * 255).numpy().astype(np.uint8)
result = cv2.bitwise_and(image.astype(np.uint8), image.astype(np.uint8), mask=mask_np)
img = cropping(result)
img_vector = image_to_vector(img)
elif int(i) == 16:
cloth = 'bag'
cloths = 'bag'
mask = pred_seg == torch.tensor(16)
image = np.array(image)
mask_np = (mask * 255).numpy().astype(np.uint8)
result = cv2.bitwise_and(image.astype(np.uint8), image.astype(np.uint8), mask=mask_np)
img = cropping(result)
img_vector = image_to_vector(img)
elif int(i) == 17:
cloth = 'scarf'
cloths = 'scarf'
mask = pred_seg == torch.tensor(17)
image = np.array(image)
mask_np = (mask * 255).numpy().astype(np.uint8)
result = cv2.bitwise_and(image.astype(np.uint8), image.astype(np.uint8), mask=mask_np)
img = cropping(result)
img_vector = image_to_vector(img)
return img_vector
# ์ž…๋ ฅ๋ฐ›์€ ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ ์™„๋ฃŒ
input_img = final_image(image)
### ์œ ์‚ฌ๋„ ๋ถ„์„ ###
# ํ•˜๋‚˜๋Š” ์ด๋ฏธ์ง€, ๋‹ค๋ฅธ ํ•˜๋‚˜๋Š” ๊ฒฝ๋กœ๋กœ ๋ฐ›๋Š” ๊ฒฝ์šฐ
def cosine_similarity(vec1, vec2_path):
vec2 = np.loadtxt(vec2_path)
dot_product = np.dot(vec1, vec2)
norm_vec1 = np.linalg.norm(vec1)
norm_vec2 = np.linalg.norm(vec2)
similarity = dot_product / (norm_vec1 * norm_vec2)
return similarity
# ๋‘˜ ๋‹ค ๊ฒฝ๋กœ๋กœ ๋ฐ›๋Š” ๊ฒฝ์šฐ
def cosine_similarity_2(vec1_path, vec2_path):
vec1 = np.loadtxt(vec1_path)
vec2 = np.loadtxt(vec2_path)
dot_product = np.dot(vec1, vec2)
norm_vec1 = np.linalg.norm(vec1)
norm_vec2 = np.linalg.norm(vec2)
similarity = dot_product / (norm_vec1 * norm_vec2)
return similarity
with st.spinner('Wait for it...'):
# ์ž…๋ ฅ๋ฐ›์€ ์ด๋ฏธ์ง€ & ๋™์ผ ์นดํ…Œ๊ณ ๋ฆฌ ํด๋”์— ์ €์žฅ๋œ ์Šคํƒ€์ผ ์ด๋ฏธ์ง€
sim_list = []
file_path = './style/' + situation + '/' + input_cat + '/' # ex) './cafe/top/'
cloths = os.listdir('./style/' + situation + '/' + input_cat + '/')
for cloth in cloths:
sim_list.append(cosine_similarity(input_img, file_path + cloth))
max_idx = np.argmax(sim_list)
# target_image ์ •์˜
target_image = './style/' + situation + '/' + output_cat + '/' + cloths[max_idx]
# ์œ ์‚ฌ๋„ ๋ถ„์„ ์™„๋ฃŒ๋œ ์Šคํƒ€์ผseg ์ด๋ฏธ์ง€์™€ product_seg ์œ ์‚ฌ๋„๋ถ„์„
sim_list = []
file_path = './product/' + output_cat + '/'
cloths = os.listdir('./product/' + output_cat + '/')
for cloth in cloths:
sim_list.append(cosine_similarity_2(target_image, file_path + cloth))
max_idx = np.argmax(sim_list)
output_name = cloths[max_idx]
## ์˜ˆ์‹œ ์ถœ๋ ฅ๊ฐ’: 'bottom_1883.txt'
# name ๋กœ๋“œ
acc_name = pd.read_csv('acc_name.csv')
bottom_name =pd.read_csv('bottom_name.csv')
outer_name =pd.read_csv('outer_name.csv')
shoes_name =pd.read_csv('shoes_name.csv')
top_name =pd.read_csv('top_name.csv')
#์ƒํ’ˆ ๋ฐ์ดํ„ฐ ๋กœ๋“œ
outer = pd.read_csv('outer.csv')
top = pd.read_csv('top.csv')
bottom = pd.read_csv('bottom.csv')
shoes = pd.read_csv('shoes.csv')
acc = pd.read_csv('acc.csv')
if output_cat == 'bottom':
df = bottom.copy()
df_name = bottom_name.copy()
elif output_cat == 'top':
df = top.copy()
df_name = top_name.copy()
elif output_cat == 'shoes':
df = shoes.copy()
df_name = shoes_name.copy()
elif (output_cat == 'hat') or (output_cat == 'sunglasses') or (output_cat == 'scarf') or (output_cat == 'bag') or (output_cat == 'belt'):
df = acc.copy()
df_name = acc_name.copy()
output_name = output_name.split('.')[0]
file_name = df_name[df_name['index']==output_name].iloc[0,1] #3049906_16754112975667_500.jpg
final = df[df['id'] == file_name]
name = final['name'].values[0].split('\n')[-1] # ์ƒํ’ˆ๋ช…
price = final['price'].values[0] # ์ƒํ’ˆ๊ฐ€๊ฒฉ
image_path = './product/img/'
st.subheader('OUTPUT')
img = Image.open(image_path+output_cat+'/'+file_name)
col1, col2, col3 = st.columns(3)
with col1:
st.image(img,width=400)
with col3:
st.caption('์ƒํ’๋ช… : ' + name)
st.caption('๊ฐ€๊ฒฉ : ' + price)