|
import os |
|
import torch |
|
import requests |
|
import zipfile |
|
from PIL import Image |
|
from io import BytesIO |
|
import tensorflow as tf |
|
import pandas as pd |
|
from transformers import CLIPProcessor, TFCLIPModel |
|
|
|
|
|
class TfliteConverter: |
|
|
|
def __init__(self, model, image_data, title_list, unique_title_list): |
|
self.model = model |
|
self.image_data = image_data |
|
self.title_list = title_list |
|
self.unique_title_list = unique_title_list |
|
|
|
def build_model(self): |
|
concrete_func = self.serving_fn.get_concrete_function() |
|
|
|
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) |
|
|
|
converter.target_spec.supported_ops = [ |
|
tf.lite.OpsSet.TFLITE_BUILTINS, |
|
tf.lite.OpsSet.SELECT_TF_OPS |
|
] |
|
|
|
tflite_model = converter.convert() |
|
|
|
model_path = os.path.join(os.getcwd(), 'clip_tflite_model.tflite') |
|
json_path = os.path.join(os.getcwd(), 'categories.json') |
|
zip_path = os.path.join(os.getcwd(), 'model_package.zip') |
|
|
|
with open(model_path, 'wb') as f: |
|
f.write(tflite_model) |
|
|
|
categories_df = pd.DataFrame({ |
|
'id': range(len(self.unique_title_list)), |
|
'title': self.unique_title_list |
|
}) |
|
categories_df.to_json(json_path, orient='records', indent=2) |
|
|
|
|
|
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: |
|
zipf.write(model_path, os.path.basename(model_path)) |
|
zipf.write(json_path, os.path.basename(json_path)) |
|
|
|
|
|
|
|
|
|
|
|
return zip_path |
|
|
|
@tf.function(input_signature=[ |
|
tf.TensorSpec(shape=[1, 224, 224, 3], dtype=tf.float32, name='input_images') |
|
]) |
|
def serving_fn(self, input_images): |
|
x = tf.transpose(input_images, (0, 3, 1, 2)) |
|
|
|
with torch.no_grad(): |
|
img_embeddings = self.model.get_image_features(pixel_values=x) |
|
|
|
labels_embeddings = [item['embeddings'] for item in self.image_data] |
|
|
|
labels_embeddings = tf.stack(labels_embeddings, axis=0) |
|
|
|
similarities = tf.reduce_sum(tf.multiply(img_embeddings, labels_embeddings), axis=-1) |
|
norm_img = tf.norm(img_embeddings, axis=-1) |
|
norm_labels = tf.norm(labels_embeddings, axis=-1) |
|
|
|
cosine_similarity = similarities / (norm_img * norm_labels) |
|
|
|
name_to_score = {} |
|
|
|
names = self.title_list |
|
static_names = self.unique_title_list |
|
|
|
for i in range(len(names)): |
|
name = names[i] |
|
score = cosine_similarity[i] |
|
|
|
if name not in name_to_score or score > name_to_score[name]: |
|
name_to_score[name] = score |
|
|
|
result = [name_to_score.get(name, 0) for name in static_names] |
|
return tf.convert_to_tensor(result, dtype=tf.float32) |
|
|
|
|
|
class OpenAiClipModel: |
|
|
|
def __init__(self, payload): |
|
self.payload = payload |
|
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
self.model = TFCLIPModel.from_pretrained("openai/clip-vit-base-patch32") |
|
self.df = pd.DataFrame([(key, value) for key, values in self.payload.items() for value in values], columns=['image_category', 'image_path']) |
|
self.labels_embeddings = [] |
|
self.image_data = [] |
|
self.title_list = self.df["image_category"].tolist() |
|
self.unique_title_list = list(set(self.title_list)) |
|
|
|
|
|
def generate_text_embedding(self, text): |
|
inputs = self.processor(text=[text], return_tensors="pt", padding=True) |
|
with torch.no_grad(): |
|
text_embedding = self.model.get_text_features(**inputs) |
|
return text_embedding |
|
|
|
def generate_image_embeddings(self): |
|
for index, row in self.df.iterrows(): |
|
image_info = row.to_dict() |
|
image_path = image_info.get("image_path") |
|
|
|
if os.path.exists(image_path): |
|
try: |
|
image = Image.open(image_path) |
|
inputs = self.processor(images=image, return_tensors="tf") |
|
outputs = self.model.get_image_features(**inputs) |
|
image_embedding = outputs.numpy().flatten() |
|
image_info['embeddings'] = image_embedding.tolist() |
|
self.labels_embeddings.append(outputs) |
|
except Exception as e: |
|
image_info['embeddings'] = None |
|
else: |
|
image_info['embeddings'] = None |
|
|
|
self.image_data.append(image_info) |
|
|
|
|
|
def build_model(self): |
|
self.generate_image_embeddings() |
|
self.title_list = self.df["image_category"].tolist() |
|
|
|
tflite_client = TfliteConverter( |
|
model=self.model, |
|
image_data=self.image_data, |
|
title_list=self.title_list, |
|
unique_title_list=self.unique_title_list |
|
) |
|
model_file = tflite_client.build_model() |
|
return model_file |
|
|
|
|
|
|
|
|