TOPSInfosol's picture
Create clip_base.py
ebc8ff7 verified
import os
import torch
import requests
import zipfile
from PIL import Image
from io import BytesIO
import tensorflow as tf
import pandas as pd
from transformers import CLIPProcessor, TFCLIPModel
class TfliteConverter:
def __init__(self, model, image_data, title_list, unique_title_list):
self.model = model
self.image_data = image_data
self.title_list = title_list
self.unique_title_list = unique_title_list
def build_model(self):
concrete_func = self.serving_fn.get_concrete_function()
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS
]
tflite_model = converter.convert()
model_path = os.path.join(os.getcwd(), 'clip_tflite_model.tflite')
json_path = os.path.join(os.getcwd(), 'categories.json')
zip_path = os.path.join(os.getcwd(), 'model_package.zip')
with open(model_path, 'wb') as f:
f.write(tflite_model)
categories_df = pd.DataFrame({
'id': range(len(self.unique_title_list)),
'title': self.unique_title_list
})
categories_df.to_json(json_path, orient='records', indent=2)
# Create ZIP file containing both files
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
zipf.write(model_path, os.path.basename(model_path))
zipf.write(json_path, os.path.basename(json_path))
# # Clean up temporary files
# os.remove(model_path)
# os.remove(json_path)
return zip_path
@tf.function(input_signature=[
tf.TensorSpec(shape=[1, 224, 224, 3], dtype=tf.float32, name='input_images')
])
def serving_fn(self, input_images):
x = tf.transpose(input_images, (0, 3, 1, 2))
with torch.no_grad():
img_embeddings = self.model.get_image_features(pixel_values=x)
labels_embeddings = [item['embeddings'] for item in self.image_data]
labels_embeddings = tf.stack(labels_embeddings, axis=0)
similarities = tf.reduce_sum(tf.multiply(img_embeddings, labels_embeddings), axis=-1)
norm_img = tf.norm(img_embeddings, axis=-1)
norm_labels = tf.norm(labels_embeddings, axis=-1)
cosine_similarity = similarities / (norm_img * norm_labels)
name_to_score = {}
names = self.title_list
static_names = self.unique_title_list
for i in range(len(names)):
name = names[i]
score = cosine_similarity[i]
if name not in name_to_score or score > name_to_score[name]:
name_to_score[name] = score
result = [name_to_score.get(name, 0) for name in static_names]
return tf.convert_to_tensor(result, dtype=tf.float32)
class OpenAiClipModel:
def __init__(self, payload):
self.payload = payload
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
self.model = TFCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
self.df = pd.DataFrame([(key, value) for key, values in self.payload.items() for value in values], columns=['image_category', 'image_path'])
self.labels_embeddings = []
self.image_data = []
self.title_list = self.df["image_category"].tolist()
self.unique_title_list = list(set(self.title_list))
def generate_text_embedding(self, text):
inputs = self.processor(text=[text], return_tensors="pt", padding=True)
with torch.no_grad():
text_embedding = self.model.get_text_features(**inputs)
return text_embedding
def generate_image_embeddings(self):
for index, row in self.df.iterrows():
image_info = row.to_dict()
image_path = image_info.get("image_path")
if os.path.exists(image_path):
try:
image = Image.open(image_path)
inputs = self.processor(images=image, return_tensors="tf")
outputs = self.model.get_image_features(**inputs)
image_embedding = outputs.numpy().flatten()
image_info['embeddings'] = image_embedding.tolist()
self.labels_embeddings.append(outputs)
except Exception as e:
image_info['embeddings'] = None
else:
image_info['embeddings'] = None
self.image_data.append(image_info)
def build_model(self):
self.generate_image_embeddings()
self.title_list = self.df["image_category"].tolist()
tflite_client = TfliteConverter(
model=self.model,
image_data=self.image_data,
title_list=self.title_list,
unique_title_list=self.unique_title_list
)
model_file = tflite_client.build_model()
return model_file