|
|
|
import os |
|
|
|
import gradio as gr |
|
import torch |
|
from PIL import Image |
|
from tqdm import tqdm |
|
from trclip.trclip import Trclip |
|
from trclip.visualizer import image_retrieval_visualize, text_retrieval_visualize |
|
|
|
print(f'gr version : {gr.__version__}') |
|
import pickle |
|
import random |
|
|
|
import numpy as np |
|
|
|
|
|
model_name = 'trclip-vitl14-e10' |
|
if not os.path.exists(model_name): |
|
os.system(f'git clone https://huggingface.co./yusufani/{model_name} --progress') |
|
|
|
if not os.path.exists('TrCaption-trclip-vitl14-e10'): |
|
os.system(f'git clone https://huggingface.co./datasets/yusufani/TrCaption-trclip-vitl14-e10/ --progress') |
|
os.chdir('TrCaption-trclip-vitl14-e10') |
|
os.system(f'git lfs install') |
|
os.system(f' git lfs fetch') |
|
os.system(f' git lfs pull') |
|
os.chdir('..') |
|
|
|
|
|
|
|
|
|
def load_image_embeddings(load_batch=True): |
|
path = os.path.join('TrCaption-trclip-vitl14-e10', 'image_embeddings') |
|
bs = 100_000 |
|
if load_batch: |
|
for i in tqdm(range(0, 3_100_000, bs), desc='Loading TrCaption Image embeddings'): |
|
with open(os.path.join(path, f'image_em_{i}.pkl'), 'rb') as f: |
|
yield pickle.load(f) |
|
return |
|
|
|
else: |
|
embeddings = [] |
|
for i in tqdm(range(0, 3_100_000, bs), desc='Loading TrCaption Image embeddings'): |
|
with open(os.path.join(path, f'image_em_{i}.pkl'), 'rb') as f: |
|
embeddings.append(pickle.load(f)) |
|
return torch.cat(embeddings, dim=0) |
|
|
|
|
|
def load_text_embeddings(load_batch=True): |
|
path = os.path.join('TrCaption-trclip-vitl14-e10', 'text_embeddings') |
|
bs = 100_000 |
|
if load_batch: |
|
for i in tqdm(range(0, 3_600_000, bs), desc='Loading TrCaption text embeddings'): |
|
with open(os.path.join(path, f'text_em_{i}.pkl'), 'rb') as f: |
|
yield pickle.load(f) |
|
return |
|
else: |
|
embeddings = [] |
|
for i in tqdm(range(0, 3_600_000, bs), desc='Loading TrCaption text embeddings'): |
|
with open(os.path.join(path, f'text_em_{i}.pkl'), 'rb') as f: |
|
embeddings.append(pickle.load(f)) |
|
return torch.cat(embeddings, dim=0) |
|
|
|
|
|
def load_metadata(): |
|
path = os.path.join('TrCaption-trclip-vitl14-e10', 'metadata.pkl') |
|
with open(path, 'rb') as f: |
|
metadata = pickle.load(f) |
|
trcap_texts = metadata['texts'] |
|
trcap_urls = metadata['image_urls'] |
|
return trcap_texts, trcap_urls |
|
|
|
|
|
def load_spesific_tensor(index, type, bs=100_000): |
|
part = index // bs |
|
idx = index % bs |
|
with open(os.path.join('TrCaption-trclip-vitl14-e10', f'{type}_embeddings', f'{type}_em_{part * bs}.pkl'), 'rb') as f: |
|
embeddings = pickle.load(f) |
|
return embeddings[idx] |
|
|
|
|
|
|
|
trcap_texts, trcap_urls = load_metadata() |
|
|
|
print(f'INFO : Model loading') |
|
model_path = os.path.join(model_name, 'pytorch_model.bin') |
|
trclip = Trclip(model_path, clip_model='ViT-L/14', device='cpu') |
|
|
|
|
|
|
|
import datetime |
|
|
|
|
|
def run_im(im1, use_trcap_images, text1, use_trcap_texts): |
|
print(f'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")} INFO : Image retrieval starting') |
|
f_texts_embeddings = None |
|
ims = None |
|
if use_trcap_images: |
|
print('INFO : TRCaption images used') |
|
im_paths = trcap_urls |
|
else: |
|
print('INFO : Own images used') |
|
|
|
im_paths = [i.name for i in im1] |
|
ims = [Image.open(i) for i in im_paths] |
|
if use_trcap_texts: |
|
print(f'INFO : TRCaption texts used') |
|
random_indexes = random.sample(range(len(trcap_texts)), 2) |
|
f_texts_embeddings = [] |
|
for i in random_indexes: |
|
f_texts_embeddings.append(load_spesific_tensor(i, 'text')) |
|
f_texts_embeddings = torch.stack(f_texts_embeddings) |
|
texts = [trcap_texts[i] for i in random_indexes] |
|
|
|
else: |
|
print(f'INFO : Own texts used') |
|
texts = [i.strip() for i in text1.split('\n')[:2] if i.strip() != ''] |
|
|
|
if use_trcap_images: |
|
per_mode_probs = [] |
|
f_texts_embeddings = f_texts_embeddings if use_trcap_texts else trclip.get_text_features(texts) |
|
for f_image_embeddings in tqdm(load_image_embeddings(load_batch=True), desc='Running image retrieval'): |
|
batch_probs = trclip.get_results( |
|
text_features=f_texts_embeddings, image_features=f_image_embeddings, mode='per_text', return_probs=True) |
|
per_mode_probs.append(batch_probs) |
|
per_mode_probs = torch.cat(per_mode_probs, dim=1) |
|
per_mode_probs = per_mode_probs.softmax(dim=-1).cpu().detach().numpy() |
|
per_mode_indices = [np.argsort(prob)[::-1] for prob in per_mode_probs] |
|
|
|
else: |
|
per_mode_indices, per_mode_probs = trclip.get_results(texts=texts, images=ims, text_features=f_texts_embeddings, mode='per_text') |
|
|
|
print(f'per_mode_indices = {per_mode_indices}\n,per_mode_probs = {per_mode_probs} ') |
|
print(f'im_paths = {im_paths}') |
|
return image_retrieval_visualize(per_mode_indices, per_mode_probs, texts, im_paths, |
|
n_figure_in_column=2, |
|
n_images_in_figure=4, n_figure_in_row=1, save_fig=False, |
|
show=False, |
|
break_on_index=-1) |
|
|
|
|
|
def run_text(im1, use_trcap_images, text1, use_trcap_texts): |
|
print(f'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")} INFO : Text retrieval starting') |
|
f_image_embeddings = None |
|
ims = None |
|
if use_trcap_images: |
|
print('INFO : TRCaption images used') |
|
random_indexes = random.sample(range(len(trcap_urls)), 2) |
|
f_image_embeddings = [] |
|
for i in random_indexes: |
|
f_image_embeddings.append(load_spesific_tensor(i, 'image')) |
|
f_image_embeddings = torch.stack(f_image_embeddings) |
|
print(f'f_image_embeddings = {f_image_embeddings}') |
|
|
|
im_paths = [trcap_urls[i] for i in random_indexes] |
|
print(f'im_paths = {im_paths}') |
|
|
|
else: |
|
print('INFO : Own images used') |
|
|
|
im_paths = [i.name for i in im1[:2]] |
|
ims = [Image.open(i) for i in im_paths] |
|
|
|
if use_trcap_texts: |
|
texts = trcap_texts |
|
else: |
|
texts = [i.strip() for i in text1.split('\n')[:2] if i.strip() != ''] |
|
|
|
if use_trcap_texts: |
|
f_image_embeddings = f_image_embeddings if use_trcap_images else trclip.get_image_features(ims) |
|
per_mode_probs = [] |
|
for f_texts_embeddings in tqdm(load_text_embeddings(load_batch=True), desc='Running text retrieval'): |
|
batch_probs = trclip.get_results( |
|
text_features=f_texts_embeddings, image_features=f_image_embeddings, mode='per_image', return_probs=True) |
|
per_mode_probs.append(batch_probs) |
|
per_mode_probs = torch.cat(per_mode_probs, dim=1) |
|
per_mode_probs = per_mode_probs.softmax(dim=-1).cpu().detach().numpy() |
|
per_mode_indices = [np.argsort(prob)[::-1] for prob in per_mode_probs] |
|
|
|
else: |
|
per_mode_indices, per_mode_probs = trclip.get_results(texts=texts, images=ims, image_features=f_image_embeddings, mode='per_image') |
|
print(per_mode_indices) |
|
print(per_mode_probs) |
|
return text_retrieval_visualize(per_mode_indices, per_mode_probs, im_paths, texts, |
|
n_figure_in_column=4, |
|
n_texts_in_figure=4 if len(texts) > 4 else len(texts), |
|
n_figure_in_row=2, |
|
save_fig=False, |
|
show=False, |
|
break_on_index=-1, |
|
) |
|
|
|
|
|
def change_textbox(choice): |
|
if choice == "Use Own Images": |
|
|
|
return gr.Image.update(visible=True) |
|
else: |
|
return gr.Image.update(visible=False) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.HTML(""" |
|
<div style="text-align: center; max-width: 650px; margin: 0 auto;"> |
|
<div |
|
style=" |
|
display: inline-flex; |
|
align-items: center; |
|
gap: 0.8rem; |
|
font-size: 1.75rem; |
|
" |
|
> |
|
<svg |
|
width="0.65em" |
|
height="0.65em" |
|
viewBox="0 0 115 115" |
|
fill="none" |
|
xmlns="http://www.w3.org/2000/svg" |
|
> |
|
<rect width="23" height="23" fill="white"></rect> |
|
<rect y="69" width="23" height="23" fill="white"></rect> |
|
<rect x="23" width="23" height="23" fill="#AEAEAE"></rect> |
|
<rect x="23" y="69" width="23" height="23" fill="#AEAEAE"></rect> |
|
<rect x="46" width="23" height="23" fill="white"></rect> |
|
<rect x="46" y="69" width="23" height="23" fill="white"></rect> |
|
<rect x="69" width="23" height="23" fill="black"></rect> |
|
<rect x="69" y="69" width="23" height="23" fill="black"></rect> |
|
<rect x="92" width="23" height="23" fill="#D9D9D9"></rect> |
|
<rect x="92" y="69" width="23" height="23" fill="#AEAEAE"></rect> |
|
<rect x="115" y="46" width="23" height="23" fill="white"></rect> |
|
<rect x="115" y="115" width="23" height="23" fill="white"></rect> |
|
<rect x="115" y="69" width="23" height="23" fill="#D9D9D9"></rect> |
|
<rect x="92" y="46" width="23" height="23" fill="#AEAEAE"></rect> |
|
<rect x="92" y="115" width="23" height="23" fill="#AEAEAE"></rect> |
|
<rect x="92" y="69" width="23" height="23" fill="white"></rect> |
|
<rect x="69" y="46" width="23" height="23" fill="white"></rect> |
|
<rect x="69" y="115" width="23" height="23" fill="white"></rect> |
|
<rect x="69" y="69" width="23" height="23" fill="#D9D9D9"></rect> |
|
<rect x="46" y="46" width="23" height="23" fill="black"></rect> |
|
<rect x="46" y="115" width="23" height="23" fill="black"></rect> |
|
<rect x="46" y="69" width="23" height="23" fill="black"></rect> |
|
<rect x="23" y="46" width="23" height="23" fill="#D9D9D9"></rect> |
|
<rect x="23" y="115" width="23" height="23" fill="#AEAEAE"></rect> |
|
<rect x="23" y="69" width="23" height="23" fill="black"></rect> |
|
</svg> |
|
<h1 style="font-weight: 1500; margin-bottom: 7px;"> |
|
Trclip Demo |
|
<a |
|
href="https://github.com/yusufani/TrCLIP" |
|
style="text-decoration: underline;" |
|
target="_blank" |
|
></a |
|
Github Trclip: |
|
</h1> |
|
</div> |
|
<p style="margin-bottom: 10px; font-size: 94%"> |
|
Trclip is Turkish port of real clip. In this space you can try your images or/and texts. |
|
<br>Also you can use pre calculated TrCaption embeddings. |
|
<br>Number of texts = 3533312 |
|
<br>Number of images = 3070976 |
|
<br> |
|
Some images are not available in the internet because I downloaded and calculated TrCaption embeddings long time ago. Don't be suprise if you encounter with Image not found :D |
|
|
|
<div style="text-align: center;font-size: 100%"> |
|
<p><strong><span style="background-color: #000000; color: #ffffff;"><a style="background-color: #000000; color: #ffffff;" href="https://github.com/yusufani/TrCLIP">A GitHub Repository</a> </span>--- <span style="background-color: #000000;"><span style="color: #ffffff;">Paper( Not available yet ) </span></span></strong></p> |
|
</div> |
|
</p> |
|
|
|
</div> |
|
<div style="text-align: center; margin: 0 auto;"> |
|
<p style="margin-bottom: 10px; font-size: 75%" ><em>Huggingface Space containers has 16 gb ram. TrCaption embeddings are totaly 20 gb. </em><em>I did a lot of writing and reading to files to make this space workable. That's why<span style="background-color: #ff6600; color: #ffffff;"> <strong>it's running much slower if you're using TrCaption Embeddig</strong>s</span>.</em></p> |
|
<div class="sc-jSFjdj sc-iCoGMd jcTaHb kMthTr"> |
|
<div class="sc-iqAclL xfxEN"> |
|
<div class="sc-bdnxRM fJdnBK sc-crzoAE DykGo"> |
|
<div class="sc-gtsrHT gfuSqG"> </div> |
|
</div> |
|
</div> |
|
</div> |
|
</div> |
|
""") |
|
|
|
with gr.Tabs(): |
|
with gr.TabItem("Upload a Images"): |
|
im_input = gr.components.File(label="Image input", optional=True, file_count='multiple') |
|
is_trcap_ims = gr.Checkbox(label="Use TRCaption Images\n[Note: Random 2 sample selected in text retrieval mode]",default=True) |
|
|
|
with gr.Tabs(): |
|
with gr.TabItem("Input a text (Seperated by new line Max 2 for Image retrieval)"): |
|
text_input = gr.components.Textbox(label="Text input", optional=True , placeholder = "kedi\nköpek\nGemi\nKahvesini içmekte olan bir adam\n Kahvesini içmekte olan bir kadın\nAraba") |
|
is_trcap_texts = gr.Checkbox(label="Use TrCaption Captions \n[Note: Random 2 sample selected in image retrieval mode]",default=True) |
|
|
|
im_ret_but = gr.Button("Image Retrieval") |
|
text_ret_but = gr.Button("Text Retrieval") |
|
|
|
im_out = gr.components.Image() |
|
|
|
im_ret_but.click(run_im, inputs=[im_input, is_trcap_ims, text_input, is_trcap_texts], outputs=im_out) |
|
text_ret_but.click(run_text, inputs=[im_input, is_trcap_ims, text_input, is_trcap_texts], outputs=im_out) |
|
|
|
demo.launch() |
|
|
|
|
|
|