Amrrs ryparmar commited on
Commit
2296a5c
β€’
0 Parent(s):

Duplicate from ryparmar/fashion-aggregator

Browse files

Co-authored-by: Martin Rypar <[email protected]>

Files changed (4) hide show
  1. .gitattributes +34 -0
  2. README.md +13 -0
  3. app.py +216 -0
  4. requirements.txt +4 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Fashion Aggregator
3
+ emoji: πŸ‘•
4
+ colorFrom: purple
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 3.9
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: ryparmar/fashion-aggregator
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Provide a text query describing what you are looking for and get back out images with links!"""
2
+ import argparse
3
+ import logging
4
+ import os
5
+ import wandb
6
+ import gradio as gr
7
+
8
+ import zipfile
9
+ import pickle
10
+ from pathlib import Path
11
+ from typing import List, Any, Dict
12
+ from PIL import Image
13
+ from pathlib import Path
14
+
15
+ from transformers import AutoTokenizer
16
+ from sentence_transformers import SentenceTransformer, util
17
+ from multilingual_clip import pt_multilingual_clip
18
+ import torch
19
+
20
+ from pathlib import Path
21
+ from typing import Callable, Dict, List, Tuple
22
+ from PIL.Image import Image
23
+
24
+ print(__file__)
25
+
26
+ os.environ["CUDA_VISIBLE_DEVICES"] = "" # do not use GPU
27
+
28
+ logging.basicConfig(level=logging.INFO)
29
+ DEFAULT_APPLICATION_NAME = "fashion-aggregator"
30
+
31
+ APP_DIR = Path(__file__).resolve().parent # what is the directory for this application?
32
+ FAVICON = APP_DIR / "t-shirt_1f455.png" # path to a small image for display in browser tab and social media
33
+ README = APP_DIR / "README.md" # path to an app readme file in HTML/markdown
34
+
35
+ DEFAULT_PORT = 11700
36
+
37
+ EMBEDDINGS_DIR = "artifacts/img-embeddings"
38
+ EMBEDDINGS_FILE = os.path.join(EMBEDDINGS_DIR, "embeddings.pkl")
39
+ RAW_PHOTOS_DIR = "artifacts/raw-photos"
40
+
41
+ # Download image embeddings and raw photos
42
+ wandb.login(key="4b5a23a662b20fdd61f2aeb5032cf56fdce278a4") # os.getenv('wandb')
43
+ api = wandb.Api()
44
+ artifact_embeddings = api.artifact("ryparmar/fashion-aggregator/unimoda-images:v1")
45
+ artifact_embeddings.download(EMBEDDINGS_DIR)
46
+ artifact_raw_photos = api.artifact("ryparmar/fashion-aggregator/unimoda-raw-images:v1")
47
+ artifact_raw_photos.download("artifacts")
48
+
49
+ with zipfile.ZipFile("artifacts/unimoda.zip", 'r') as zip_ref:
50
+ zip_ref.extractall(RAW_PHOTOS_DIR)
51
+
52
+
53
+ class TextEncoder:
54
+ """Encodes the given text"""
55
+
56
+ def __init__(self, model_path="M-CLIP/XLM-Roberta-Large-Vit-B-32"):
57
+ self.model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(model_path)
58
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
59
+
60
+ @torch.no_grad()
61
+ def encode(self, query: str) -> torch.Tensor:
62
+ """Predict/infer text embedding for a given query."""
63
+ query_emb = self.model.forward([query], self.tokenizer)
64
+ return query_emb
65
+
66
+
67
+ class ImageEnoder:
68
+ """Encodes the given image"""
69
+
70
+ def __init__(self, model_path="clip-ViT-B-32"):
71
+ self.model = SentenceTransformer(model_path)
72
+
73
+ @torch.no_grad()
74
+ def encode(self, image: Image) -> torch.Tensor:
75
+ """Predict/infer text embedding for a given query."""
76
+ image_emb = self.model.encode([image], convert_to_tensor=True, show_progress_bar=False)
77
+ return image_emb
78
+
79
+
80
+ class Retriever:
81
+ """Retrieves relevant images for a given text embedding."""
82
+
83
+ def __init__(self, image_embeddings_path=None):
84
+ self.text_encoder = TextEncoder()
85
+ self.image_encoder = ImageEnoder()
86
+
87
+ with open(image_embeddings_path, "rb") as file:
88
+ self.image_names, self.image_embeddings = pickle.load(file)
89
+ self.image_names = [
90
+ img_name.replace("fashion-aggregator/fashion_aggregator/data/photos/", "")
91
+ for img_name in self.image_names
92
+ ]
93
+ print("Images:", len(self.image_names))
94
+
95
+ @torch.no_grad()
96
+ def predict(self, text_query: str, k: int = 10) -> List[Any]:
97
+ """Return top-k relevant items for a given embedding"""
98
+ query_emb = self.text_encoder.encode(text_query)
99
+ relevant_images = util.semantic_search(query_emb, self.image_embeddings, top_k=k)[0]
100
+ return relevant_images
101
+
102
+ @torch.no_grad()
103
+ def search_images(self, text_query: str, k: int = 6) -> Dict[str, List[Any]]:
104
+ """Return top-k relevant images for a given embedding"""
105
+ images = self.predict(text_query, k)
106
+ paths_and_scores = {"path": [], "score": []}
107
+ for img in images:
108
+ paths_and_scores["path"].append(os.path.join(RAW_PHOTOS_DIR, self.image_names[img["corpus_id"]]))
109
+ paths_and_scores["score"].append(img["score"])
110
+ return paths_and_scores
111
+
112
+
113
+ def main(args):
114
+ predictor = PredictorBackend(url=args.model_url)
115
+ frontend = make_frontend(predictor.run, flagging=args.flagging, gantry=args.gantry, app_name=args.application)
116
+ frontend.launch(
117
+ # server_name="0.0.0.0", # make server accessible, binding all interfaces # noqa: S104
118
+ # server_port=args.port, # set a port to bind to, failing if unavailable
119
+ # share=False, # should we create a (temporary) public link on https://gradio.app?
120
+ # favicon_path=FAVICON, # what icon should we display in the address bar?
121
+ )
122
+
123
+
124
+ def make_frontend(
125
+ fn: Callable[[Image], str], flagging: bool = False, gantry: bool = False, app_name: str = "fashion-aggregator"
126
+ ):
127
+ """Creates a gradio.Interface frontend for text to image search function."""
128
+
129
+ allow_flagging = "never"
130
+
131
+ # build a basic browser interface to a Python function
132
+ frontend = gr.Interface(
133
+ fn=fn, # which Python function are we interacting with?
134
+ outputs=gr.Gallery(label="Relevant Items"),
135
+ # what input widgets does it need? we configure an image widget
136
+ inputs=gr.components.Textbox(label="Item Description"),
137
+ title="πŸ“ Text2Image πŸ‘•", # what should we display at the top of the page?
138
+ thumbnail=FAVICON, # what should we display when the link is shared, e.g. on social media?
139
+ description=__doc__, # what should we display just above the interface?
140
+ cache_examples=False, # should we cache those inputs for faster inference? slows down start
141
+ allow_flagging=allow_flagging, # should we show users the option to "flag" outputs?
142
+ flagging_options=["incorrect", "offensive", "other"], # what options do users have for feedback?
143
+ )
144
+ return frontend
145
+
146
+
147
+ class PredictorBackend:
148
+ """Interface to a backend that serves predictions.
149
+
150
+ To communicate with a backend accessible via a URL, provide the url kwarg.
151
+
152
+ Otherwise, runs a predictor locally.
153
+ """
154
+
155
+ def __init__(self, url=None):
156
+ if url is not None:
157
+ self.url = url
158
+ self._predict = self._predict_from_endpoint
159
+ else:
160
+ model = Retriever(image_embeddings_path=EMBEDDINGS_FILE)
161
+ self._predict = model.predict
162
+ self._search_images = model.search_images
163
+
164
+ def run(self, text: str):
165
+ pred, metrics = self._predict_with_metrics(text)
166
+ self._log_inference(pred, metrics)
167
+ return pred
168
+
169
+ def _predict_with_metrics(self, text: str) -> Tuple[List[str], Dict[str, float]]:
170
+ paths_and_scores = self._search_images(text)
171
+ metrics = {"mean_score": sum(paths_and_scores["score"]) / len(paths_and_scores["score"])}
172
+ return paths_and_scores["path"], metrics
173
+
174
+ def _log_inference(self, pred, metrics):
175
+ for key, value in metrics.items():
176
+ logging.info(f"METRIC {key} {value}")
177
+ logging.info(f"PRED >begin\n{pred}\nPRED >end")
178
+
179
+
180
+ def _make_parser():
181
+ parser = argparse.ArgumentParser(description=__doc__)
182
+ parser.add_argument(
183
+ "--model_url",
184
+ default=None,
185
+ type=str,
186
+ help="Identifies a URL to which to send image data. Data is base64-encoded, converted to a utf-8 string, and then set via a POST request as JSON with the key 'image'. Default is None, which instead sends the data to a model running locally.",
187
+ )
188
+ parser.add_argument(
189
+ "--port",
190
+ default=DEFAULT_PORT,
191
+ type=int,
192
+ help=f"Port on which to expose this server. Default is {DEFAULT_PORT}.",
193
+ )
194
+ parser.add_argument(
195
+ "--flagging",
196
+ action="store_true",
197
+ help="Pass this flag to allow users to 'flag' model behavior and provide feedback.",
198
+ )
199
+ parser.add_argument(
200
+ "--gantry",
201
+ action="store_true",
202
+ help="Pass --flagging and this flag to log user feedback to Gantry. Requires GANTRY_API_KEY to be defined as an environment variable.",
203
+ )
204
+ parser.add_argument(
205
+ "--application",
206
+ default=DEFAULT_APPLICATION_NAME,
207
+ type=str,
208
+ help=f"Name of the Gantry application to which feedback should be logged, if --gantry and --flagging are passed. Default is {DEFAULT_APPLICATION_NAME}.",
209
+ )
210
+ return parser
211
+
212
+
213
+ if __name__ == "__main__":
214
+ parser = _make_parser()
215
+ args = parser.parse_args()
216
+ main(args)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ sentence-transformers==2.2.2
2
+ clip @ git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
3
+ multilingual-clip==1.0.10
4
+ wandb