Spaces:
Runtime error
Runtime error
from curses import A_ATTRIBUTES | |
import numpy | |
import torch | |
from pip import main | |
from sentence_transformers import SentenceTransformer, util | |
# predefined shape text | |
upper_length_text = [ | |
'sleeveless', 'without sleeves', 'sleeves have been cut off', 'tank top', | |
'tank shirt', 'muscle shirt', 'short-sleeve', 'short sleeves', | |
'with short sleeves', 'medium-sleeve', 'medium sleeves', | |
'with medium sleeves', 'sleeves reach elbow', 'long-sleeve', | |
'long sleeves', 'with long sleeves' | |
] | |
upper_length_attr = { | |
'sleeveless': 0, | |
'without sleeves': 0, | |
'sleeves have been cut off': 0, | |
'tank top': 0, | |
'tank shirt': 0, | |
'muscle shirt': 0, | |
'short-sleeve': 1, | |
'with short sleeves': 1, | |
'short sleeves': 1, | |
'medium-sleeve': 2, | |
'with medium sleeves': 2, | |
'medium sleeves': 2, | |
'sleeves reach elbow': 2, | |
'long-sleeve': 3, | |
'long sleeves': 3, | |
'with long sleeves': 3 | |
} | |
lower_length_text = [ | |
'three-point', 'medium', 'short', 'covering knee', 'cropped', | |
'three-quarter', 'long', 'slack', 'of long length' | |
] | |
lower_length_attr = { | |
'three-point': 0, | |
'medium': 1, | |
'covering knee': 1, | |
'short': 1, | |
'cropped': 2, | |
'three-quarter': 2, | |
'long': 3, | |
'slack': 3, | |
'of long length': 3 | |
} | |
socks_length_text = [ | |
'socks', 'stocking', 'pantyhose', 'leggings', 'sheer hosiery' | |
] | |
socks_length_attr = { | |
'socks': 0, | |
'stocking': 1, | |
'pantyhose': 1, | |
'leggings': 1, | |
'sheer hosiery': 1 | |
} | |
hat_text = ['hat', 'cap', 'chapeau'] | |
eyeglasses_text = ['sunglasses'] | |
belt_text = ['belt', 'with a dress tied around the waist'] | |
outer_shape_text = [ | |
'with outer clothing open', 'with outer clothing unzipped', | |
'covering inner clothes', 'with outer clothing zipped' | |
] | |
outer_shape_attr = { | |
'with outer clothing open': 0, | |
'with outer clothing unzipped': 0, | |
'covering inner clothes': 1, | |
'with outer clothing zipped': 1 | |
} | |
upper_types = [ | |
'T-shirt', 'shirt', 'sweater', 'hoodie', 'tops', 'blouse', 'Basic Tee' | |
] | |
outer_types = [ | |
'jacket', 'outer clothing', 'coat', 'overcoat', 'blazer', 'outerwear', | |
'duffle', 'cardigan' | |
] | |
skirt_types = ['skirt'] | |
dress_types = ['dress'] | |
pant_types = ['jeans', 'pants', 'trousers'] | |
rompers_types = ['rompers', 'bodysuit', 'jumpsuit'] | |
attr_names_list = [ | |
'gender', 'hair length', '0 upper clothing length', | |
'1 lower clothing length', '2 socks', '3 hat', '4 eyeglasses', '5 belt', | |
'6 opening of outer clothing', '7 upper clothes', '8 outer clothing', | |
'9 skirt', '10 dress', '11 pants', '12 rompers' | |
] | |
def generate_shape_attributes(user_shape_texts): | |
model = SentenceTransformer('all-MiniLM-L6-v2') | |
parsed_texts = user_shape_texts.split(',') | |
text_num = len(parsed_texts) | |
human_attr = [0, 0] | |
attr = [1, 3, 0, 0, 0, 3, 1, 1, 0, 0, 0, 0, 0] | |
changed = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] | |
for text_id, text in enumerate(parsed_texts): | |
user_embeddings = model.encode(text) | |
if ('man' in text) and (text_id == 0): | |
human_attr[0] = 0 | |
human_attr[1] = 0 | |
if ('woman' in text or 'lady' in text) and (text_id == 0): | |
human_attr[0] = 1 | |
human_attr[1] = 2 | |
if (not changed[0]) and (text_id == 1): | |
# upper length | |
predefined_embeddings = model.encode(upper_length_text) | |
similarities = util.dot_score(user_embeddings, | |
predefined_embeddings) | |
arg_idx = torch.argmax(similarities).item() | |
attr[0] = upper_length_attr[upper_length_text[arg_idx]] | |
changed[0] = 1 | |
if (not changed[1]) and ((text_num == 2 and text_id == 1) or | |
(text_num > 2 and text_id == 2)): | |
# lower length | |
predefined_embeddings = model.encode(lower_length_text) | |
similarities = util.dot_score(user_embeddings, | |
predefined_embeddings) | |
arg_idx = torch.argmax(similarities).item() | |
attr[1] = lower_length_attr[lower_length_text[arg_idx]] | |
changed[1] = 1 | |
if (not changed[2]) and (text_id > 2): | |
# socks length | |
predefined_embeddings = model.encode(socks_length_text) | |
similarities = util.dot_score(user_embeddings, | |
predefined_embeddings) | |
arg_idx = torch.argmax(similarities).item() | |
if similarities[0][arg_idx] > 0.7: | |
attr[2] = arg_idx + 1 | |
changed[2] = 1 | |
if (not changed[3]) and (text_id > 2): | |
# hat | |
predefined_embeddings = model.encode(hat_text) | |
similarities = util.dot_score(user_embeddings, | |
predefined_embeddings) | |
if similarities[0][0] > 0.7: | |
attr[3] = 1 | |
changed[3] = 1 | |
if (not changed[4]) and (text_id > 2): | |
# glasses | |
predefined_embeddings = model.encode(eyeglasses_text) | |
similarities = util.dot_score(user_embeddings, | |
predefined_embeddings) | |
arg_idx = torch.argmax(similarities).item() | |
if similarities[0][arg_idx] > 0.7: | |
attr[4] = arg_idx + 1 | |
changed[4] = 1 | |
if (not changed[5]) and (text_id > 2): | |
# belt | |
predefined_embeddings = model.encode(belt_text) | |
similarities = util.dot_score(user_embeddings, | |
predefined_embeddings) | |
arg_idx = torch.argmax(similarities).item() | |
if similarities[0][arg_idx] > 0.7: | |
attr[5] = arg_idx + 1 | |
changed[5] = 1 | |
if (not changed[6]) and (text_id == 3): | |
# outer coverage | |
predefined_embeddings = model.encode(outer_shape_text) | |
similarities = util.dot_score(user_embeddings, | |
predefined_embeddings) | |
arg_idx = torch.argmax(similarities).item() | |
if similarities[0][arg_idx] > 0.7: | |
attr[6] = arg_idx | |
changed[6] = 1 | |
if (not changed[10]) and (text_num == 2 and text_id == 1): | |
# dress_types | |
predefined_embeddings = model.encode(dress_types) | |
similarities = util.dot_score(user_embeddings, | |
predefined_embeddings) | |
similarity_skirt = util.dot_score(user_embeddings, | |
model.encode(skirt_types)) | |
if similarities[0][0] > 0.5 and similarities[0][ | |
0] > similarity_skirt[0][0]: | |
attr[10] = 1 | |
attr[7] = 0 | |
attr[8] = 0 | |
attr[9] = 0 | |
attr[11] = 0 | |
attr[12] = 0 | |
changed[0] = 1 | |
changed[10] = 1 | |
changed[7] = 1 | |
changed[8] = 1 | |
changed[9] = 1 | |
changed[11] = 1 | |
changed[12] = 1 | |
if (not changed[12]) and (text_num == 2 and text_id == 1): | |
# rompers_types | |
predefined_embeddings = model.encode(rompers_types) | |
similarities = util.dot_score(user_embeddings, | |
predefined_embeddings) | |
max_similarity = torch.max(similarities).item() | |
if max_similarity > 0.6: | |
attr[12] = 1 | |
attr[7] = 0 | |
attr[8] = 0 | |
attr[9] = 0 | |
attr[10] = 0 | |
attr[11] = 0 | |
changed[12] = 1 | |
changed[7] = 1 | |
changed[8] = 1 | |
changed[9] = 1 | |
changed[10] = 1 | |
changed[11] = 1 | |
if (not changed[7]) and (text_num > 2 and text_id == 1): | |
# upper_types | |
predefined_embeddings = model.encode(upper_types) | |
similarities = util.dot_score(user_embeddings, | |
predefined_embeddings) | |
max_similarity = torch.max(similarities).item() | |
if max_similarity > 0.6: | |
attr[7] = 1 | |
changed[7] = 1 | |
if (not changed[8]) and (text_id == 3): | |
# outer_types | |
predefined_embeddings = model.encode(outer_types) | |
similarities = util.dot_score(user_embeddings, | |
predefined_embeddings) | |
arg_idx = torch.argmax(similarities).item() | |
if similarities[0][arg_idx] > 0.7: | |
attr[6] = outer_shape_attr[outer_shape_text[arg_idx]] | |
attr[8] = 1 | |
changed[8] = 1 | |
if (not changed[9]) and (text_num > 2 and text_id == 2): | |
# skirt_types | |
predefined_embeddings = model.encode(skirt_types) | |
similarity_skirt = util.dot_score(user_embeddings, | |
predefined_embeddings) | |
similarity_dress = util.dot_score(user_embeddings, | |
model.encode(dress_types)) | |
if similarity_skirt[0][0] > 0.7 and similarity_skirt[0][ | |
0] > similarity_dress[0][0]: | |
attr[9] = 1 | |
attr[10] = 0 | |
changed[9] = 1 | |
changed[10] = 1 | |
if (not changed[11]) and (text_num > 2 and text_id == 2): | |
# pant_types | |
predefined_embeddings = model.encode(pant_types) | |
similarities = util.dot_score(user_embeddings, | |
predefined_embeddings) | |
max_similarity = torch.max(similarities).item() | |
if max_similarity > 0.6: | |
attr[11] = 1 | |
attr[9] = 0 | |
attr[10] = 0 | |
attr[12] = 0 | |
changed[11] = 1 | |
changed[9] = 1 | |
changed[10] = 1 | |
changed[12] = 1 | |
return human_attr + attr | |
def generate_texture_attributes(user_text): | |
parsed_texts = user_text.split(',') | |
attr = [] | |
for text in parsed_texts: | |
if ('pure color' in text) or ('solid color' in text): | |
attr.append(4) | |
elif ('spline' in text) or ('stripe' in text): | |
attr.append(3) | |
elif ('plaid' in text) or ('lattice' in text): | |
attr.append(5) | |
elif 'floral' in text: | |
attr.append(1) | |
elif 'denim' in text: | |
attr.append(0) | |
else: | |
attr.append(17) | |
if len(attr) == 1: | |
attr.append(attr[0]) | |
attr.append(17) | |
if len(attr) == 2: | |
attr.append(17) | |
return attr | |
if __name__ == "__main__": | |
user_request = input('Enter your request: ') | |
while user_request != '\\q': | |
attr = generate_shape_attributes(user_request) | |
print(attr) | |
for attr_name, attr_value in zip(attr_names_list, attr): | |
print(attr_name, attr_value) | |
user_request = input('Enter your request: ') | |