yainage90 commited on
Commit
7329ef7
·
verified ·
1 Parent(s): 011d19a

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +70 -3
README.md CHANGED
@@ -4,6 +4,73 @@ tags:
4
  - pytorch_model_hub_mixin
5
  ---
6
 
7
- This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
8
- - Library: [More Information Needed]
9
- - Docs: [More Information Needed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  - pytorch_model_hub_mixin
5
  ---
6
 
7
+ This is fashion image feature extractor model.
8
+
9
+ # 1. Model Architecture
10
+
11
+ I used [microsoft/swin-base-patch4-window7-224](https://huggingface.co/microsoft/swin-base-patch4-window7-224) for base image encoder model. Just added a 128 size fully connected layer to lower embedding size.
12
+ The dataset used anchor (product areas detected from posts) - positive (product thumbnail) image pairs. Within each batch, all samples except one's own positive were used as negative samples, training to minimize the distance between anchor-positive pairs while maximizing the distance between anchor-negative pairs. This method is known as contrastive learning, which is the training method used by OpenAI's CLIP model.
13
+ Initially, anchor - positive - negative pairs were explicitly constructed in a 1:1:1 ratio using triplet loss, but training with in-batch negative sampling and contrastive loss showed much better performance as it allowed learning from more negative samples.
14
+
15
+ <img src="image_encoder.png" width="500" alt="image_encoder">
16
+
17
+ <img src="contrastive_learning.png" width="500" alt="contrastive_learning">
18
+
19
+
20
+ # 2. Training dataset
21
+
22
+ User posting images from onthelook and kream were crawled and preprocessed. First, raw data of image-product thumbnail combinations from posts were collected. Then, object detection was performed on posting images, and category classification was performed on product thumbnails to pair images of the same category together. For thumbnail category classification, a trained category classifier was used. Finally, about 290,000 anchor-positive image pairs were created for 6 categories: tops, bottoms, outer, shoes, bags, and hats.
23
+ Finally, approximately 290,000 anchor-positive image pairs were created for 6 categories: tops, bottoms, outer, shoes, bags, and hats.
24
+
25
+
26
+ You can find object-detection model -> [https://huggingface.co/yainage90/fashion-object-detection](https://huggingface.co/yainage90/fashion-object-detection)
27
+
28
+ You can find details of model in this github repo -> [fashion-visual-search](https://github.com/yainage90/fashion-visual-search)
29
+
30
+
31
+ ```python
32
+ from PIL import Image
33
+ import torch
34
+ import torch.nn as nn
35
+ import torch.nn.functional as F
36
+ import torchvision.transforms as v2
37
+ from transformers import AutoImageProcessor, SwinModel, SwinConfig
38
+ from huggingface_hub import PyTorchModelHubMixin
39
+
40
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
41
+
42
+ ckpt = "yainage90/fashion-image-feature-extractor"
43
+ encoder_config = SwinConfig.from_pretrained(ckpt)
44
+ encoder_image_processor = AutoImageProcessor.from_pretrained(ckpt)
45
+
46
+ class ImageEncoder(nn.Module, PyTorchModelHubMixin):
47
+ def __init__(self):
48
+ super(ImageEncoder, self).__init__()
49
+ self.swin = SwinModel(config=encoder_config)
50
+ self.embedding_layer = nn.Linear(encoder_config.hidden_size, 128)
51
+
52
+ def forward(self, image_tensor):
53
+ features = self.swin(image_tensor).pooler_output
54
+ embeddings = self.embedding_layer(features)
55
+ embeddings = F.normalize(embeddings, p=2, dim=1)
56
+
57
+ return embeddings
58
+
59
+ encoder = ImageEncoder().from_pretrained('yainage90/fashion-image-feature-extractor').to(device)
60
+
61
+ transform = v2.Compose([
62
+ v2.Resize((encoder_config.image_size, encoder_config.image_size)),
63
+ v2.ToTensor(),
64
+ v2.Normalize(mean=encoder_image_processor.image_mean, std=encoder_image_processor.image_std),
65
+ ])
66
+
67
+ image = Image.open('<path/to/image>').convert('RGB')
68
+ image = transform(image)
69
+ with torch.no_grad():
70
+ embedding = encoder(image.unsqueeze(0).to(device)).cpu().numpy()
71
+ ```
72
+
73
+ <img src="detection_image1.png" width="500" alt="detection_image1">
74
+ <img src="result_image1.png" width="700" alt="result_image1">
75
+ <img src="detection_image2.png" width="500" alt="detection_image2">
76
+ <img src="result_image2.png" width="700" alt="result_image2">