TOPSInfosol commited on
Commit
ebc8ff7
·
verified ·
1 Parent(s): a3c472a

Create clip_base.py

Browse files
Files changed (1) hide show
  1. clip_base.py +145 -0
clip_base.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import requests
4
+ import zipfile
5
+ from PIL import Image
6
+ from io import BytesIO
7
+ import tensorflow as tf
8
+ import pandas as pd
9
+ from transformers import CLIPProcessor, TFCLIPModel
10
+
11
+
12
+ class TfliteConverter:
13
+
14
+ def __init__(self, model, image_data, title_list, unique_title_list):
15
+ self.model = model
16
+ self.image_data = image_data
17
+ self.title_list = title_list
18
+ self.unique_title_list = unique_title_list
19
+
20
+ def build_model(self):
21
+ concrete_func = self.serving_fn.get_concrete_function()
22
+
23
+ converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
24
+
25
+ converter.target_spec.supported_ops = [
26
+ tf.lite.OpsSet.TFLITE_BUILTINS,
27
+ tf.lite.OpsSet.SELECT_TF_OPS
28
+ ]
29
+
30
+ tflite_model = converter.convert()
31
+
32
+ model_path = os.path.join(os.getcwd(), 'clip_tflite_model.tflite')
33
+ json_path = os.path.join(os.getcwd(), 'categories.json')
34
+ zip_path = os.path.join(os.getcwd(), 'model_package.zip')
35
+
36
+ with open(model_path, 'wb') as f:
37
+ f.write(tflite_model)
38
+
39
+ categories_df = pd.DataFrame({
40
+ 'id': range(len(self.unique_title_list)),
41
+ 'title': self.unique_title_list
42
+ })
43
+ categories_df.to_json(json_path, orient='records', indent=2)
44
+
45
+ # Create ZIP file containing both files
46
+ with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
47
+ zipf.write(model_path, os.path.basename(model_path))
48
+ zipf.write(json_path, os.path.basename(json_path))
49
+
50
+ # # Clean up temporary files
51
+ # os.remove(model_path)
52
+ # os.remove(json_path)
53
+
54
+ return zip_path
55
+
56
+ @tf.function(input_signature=[
57
+ tf.TensorSpec(shape=[1, 224, 224, 3], dtype=tf.float32, name='input_images')
58
+ ])
59
+ def serving_fn(self, input_images):
60
+ x = tf.transpose(input_images, (0, 3, 1, 2))
61
+
62
+ with torch.no_grad():
63
+ img_embeddings = self.model.get_image_features(pixel_values=x)
64
+
65
+ labels_embeddings = [item['embeddings'] for item in self.image_data]
66
+
67
+ labels_embeddings = tf.stack(labels_embeddings, axis=0)
68
+
69
+ similarities = tf.reduce_sum(tf.multiply(img_embeddings, labels_embeddings), axis=-1)
70
+ norm_img = tf.norm(img_embeddings, axis=-1)
71
+ norm_labels = tf.norm(labels_embeddings, axis=-1)
72
+
73
+ cosine_similarity = similarities / (norm_img * norm_labels)
74
+
75
+ name_to_score = {}
76
+
77
+ names = self.title_list
78
+ static_names = self.unique_title_list
79
+
80
+ for i in range(len(names)):
81
+ name = names[i]
82
+ score = cosine_similarity[i]
83
+
84
+ if name not in name_to_score or score > name_to_score[name]:
85
+ name_to_score[name] = score
86
+
87
+ result = [name_to_score.get(name, 0) for name in static_names]
88
+ return tf.convert_to_tensor(result, dtype=tf.float32)
89
+
90
+
91
+ class OpenAiClipModel:
92
+
93
+ def __init__(self, payload):
94
+ self.payload = payload
95
+ self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
96
+ self.model = TFCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
97
+ self.df = pd.DataFrame([(key, value) for key, values in self.payload.items() for value in values], columns=['image_category', 'image_path'])
98
+ self.labels_embeddings = []
99
+ self.image_data = []
100
+ self.title_list = self.df["image_category"].tolist()
101
+ self.unique_title_list = list(set(self.title_list))
102
+
103
+
104
+ def generate_text_embedding(self, text):
105
+ inputs = self.processor(text=[text], return_tensors="pt", padding=True)
106
+ with torch.no_grad():
107
+ text_embedding = self.model.get_text_features(**inputs)
108
+ return text_embedding
109
+
110
+ def generate_image_embeddings(self):
111
+ for index, row in self.df.iterrows():
112
+ image_info = row.to_dict()
113
+ image_path = image_info.get("image_path")
114
+
115
+ if os.path.exists(image_path):
116
+ try:
117
+ image = Image.open(image_path)
118
+ inputs = self.processor(images=image, return_tensors="tf")
119
+ outputs = self.model.get_image_features(**inputs)
120
+ image_embedding = outputs.numpy().flatten()
121
+ image_info['embeddings'] = image_embedding.tolist()
122
+ self.labels_embeddings.append(outputs)
123
+ except Exception as e:
124
+ image_info['embeddings'] = None
125
+ else:
126
+ image_info['embeddings'] = None
127
+
128
+ self.image_data.append(image_info)
129
+
130
+
131
+ def build_model(self):
132
+ self.generate_image_embeddings()
133
+ self.title_list = self.df["image_category"].tolist()
134
+
135
+ tflite_client = TfliteConverter(
136
+ model=self.model,
137
+ image_data=self.image_data,
138
+ title_list=self.title_list,
139
+ unique_title_list=self.unique_title_list
140
+ )
141
+ model_file = tflite_client.build_model()
142
+ return model_file
143
+
144
+
145
+