File size: 8,525 Bytes
2063d73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import os
import gradio as gr
import torch
from FlagEmbedding.visual.modeling import Visualized_BGE
from torchvision import transforms
from PIL import Image
from torch.utils.data import DataLoader
from tqdm import tqdm
from pdf2image import convert_from_path
import numpy as np
import torch.nn.functional as F
import io

# Initialize the Visualized-BGE model
def load_bge_model(model_name: str, model_weight_path: str):
    model = Visualized_BGE(model_name_bge=model_name, model_weight=model_weight_path)
    model.eval()
    return model

# Load the BGE model (ensure you have downloaded the weights and provide the correct path)
model_name = "BAAI/bge-base-en-v1.5"  # or "BAAI/bge-m3" for multilingual
model_weight_path ="./Visualized_base_en_v1.5.pth"
model = load_bge_model(model_name, model_weight_path)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model.to(device)

# Function to encode images
import tempfile
import os

def encode_image(image_input):
    """
    Encodes an image for retrieval.

    Args:
        image_input: Can be a file path (str), a NumPy array, or a PIL Image.

    Returns:
        torch.Tensor: The image embedding.
    """
    delete_temp_file = False  
    if isinstance(image_input, str):
        image_path = image_input
    else:
        with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file:
            if isinstance(image_input, np.ndarray):
                image = Image.fromarray(image_input)
            elif isinstance(image_input, Image.Image):
                image = image_input
            else:
                raise ValueError("Unsupported image input type for image encoding.")

            image.save(tmp_file.name)
            image_path = tmp_file.name
            delete_temp_file = True  # Mark that we need to delete this temp file

    try:
        with torch.no_grad():
            embed = model.encode(image=image_path)
        embed = embed.squeeze(0)
    finally:
        if delete_temp_file:
            # Remove the temporary file
            os.remove(image_path)

    return embed.cpu()


# Function to encode text
def encode_text(text):
    with torch.no_grad():
        embed = model.encode(text=text)  # Assuming encode returns [1, D]
    embed = embed.squeeze(0)  # Remove the batch dimension if present
    return embed.cpu()

# Function to index uploaded files (PDFs or images)
def index_files(files, embeddings_state, metadata_state):
    print("Indexing files...")
    embeddings = []
    metadata = []

    for file in files:
        if file.name.lower().endswith('.pdf'):
            images = convert_from_path(file.name, thread_count=4)
            for idx, img in enumerate(images):
                img_path = f"{file.name}_page_{idx}.png"
                img.save(img_path)
                embed = encode_image(img_path)
                print(f"Embedding shape after encoding image: {embed.shape}")  # Should be [768]
                embeddings.append(embed)
                metadata.append({"type": "image", "path": img_path, "info": f"Page {idx}"})
        elif file.name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
            img_path = file.name
            embed = encode_image(img_path)
            print(f"Embedding shape after encoding image: {embed.shape}")  # Should be [768]
            embeddings.append(embed)
            metadata.append({"type": "image", "path": img_path, "info": "Uploaded Image"})
        else:
            raise gr.Error("Unsupported file type. Please upload PDFs or image files.")

    embeddings = torch.stack(embeddings).to(device)  # Should result in shape [N, 768]
    print(f"Stacked embeddings shape: {embeddings.shape}")
    embeddings_state = embeddings
    metadata_state = metadata
    return f"Indexed {len(embeddings)} items.", embeddings_state, metadata_state

def search(query_text, query_image, k, embeddings_state, metadata_state):
    embeddings = embeddings_state
    metadata = metadata_state

    if embeddings is None or embeddings.size(0) == 0:
        return "No embeddings indexed. Please upload and index files first.", []

    query_emb = None

    if query_text and query_image:
        gr.warning("Please provide either a text query or an image query, not both. Using text query by default.")
        # text_emb = encode_text(query_text)  # [D]
        # image_emb = encode_image(query_image)  # [D]
        # query_emb = (text_emb + image_emb) / 2  # [D]
        # print("Combined text and image embeddings for query.")
        query_emb = encode_text(query_text)  # [D]
    if query_text:
        query_emb = encode_text(query_text)  # [D]
        print("Encoded text query.")
    elif query_image is not None :
        print(query_image)
        query_emb = encode_image(query_image)  # [D]
        print("Encoded image query.")
    else:
        return "Please provide at least a text query or an image query.", []

    # Ensure query_emb has shape [1, D]
    if query_emb.dim() == 1:
        query_emb = query_emb.unsqueeze(0)  # [1, D]

    # Normalize embeddings for cosine similarity
    query_emb = F.normalize(query_emb.to(device), p=2, dim=1)  # [1, D]
    indexed_emb = F.normalize(embeddings.to(device), p=2, dim=1)  # [N, D]

    print(f"Query embedding shape: {query_emb.shape}")  # Should be [1, 768]
    print(f"Indexed embeddings shape: {indexed_emb.shape}")  # Should be [N, 768]

    # Compute cosine similarities
    similarities = torch.matmul(query_emb, indexed_emb.T).squeeze(0)  # [N]
    print(f"Similarities shape: {similarities.shape}")

    # Get top-k results
    topk = torch.topk(similarities, k)
    topk_indices = topk.indices.cpu().numpy()
    topk_scores = topk.values.cpu().numpy()

    print(f"Top-{k} indices: {topk_indices}")
    print(f"Top-{k} scores: {topk_scores}")

    results = []
    for idx, score in zip(topk_indices, topk_scores):
        item = metadata[idx]
        if item["type"] == "image":
            # Load image from path
            img = Image.open(item["path"]).convert("RGB")
            results.append((img, f"Score: {score:.4f} | {item['info']}"))
        else:
            # Handle text data if applicable
            results.append((item["data"], f"Score: {score:.4f} | {item['info']}"))

    return results

# Gradio Interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# Visualized-BGE: Multimodal Retrieval Demo πŸŽ‰")
    gr.Markdown("""
    Upload PDF or image files to index them. Then, perform searches using text, images, or both to retrieve the most relevant items.

    **Note:** Ensure that you have indexed the files before performing a search.
    """)

    # Initialize state variables
    embeddings_state = gr.State(None)
    metadata_state = gr.State(None)

    with gr.Row():
        with gr.Column(scale=2):
            gr.Markdown("## 1️⃣ Upload and Index Files")
            file_input = gr.File(file_types=["pdf", "png", "jpg", "jpeg", "bmp", "gif"], file_count="multiple", label="Upload Files")
            index_button = gr.Button("πŸ”„ Index Files")
            index_status = gr.Textbox("No files indexed yet.", label="Indexing Status")

        with gr.Column(scale=3):
            gr.Markdown("## 2️⃣ Perform Search")
            with gr.Row():
                query_text = gr.Textbox(placeholder="Enter your text query here...", label="Text Query")
                query_image = gr.Image(label="Image Query (Optional)")
            k = gr.Slider(minimum=1, maximum=20, step=1, label="Number of Results", value=5)
            search_button = gr.Button("πŸ” Search")
            output_gallery = gr.Gallery(label="Retrieved Results", show_label=True, columns=2)

    # Define button actions
    index_button.click(
        index_files,
        inputs=[file_input, embeddings_state, metadata_state],
        outputs=[index_status, embeddings_state, metadata_state]
    )
    search_button.click(
        search,
        inputs=[query_text, query_image, k, embeddings_state, metadata_state],
        outputs=output_gallery
    )

    gr.Markdown("""
    ---
    ## About
    This demo uses the **Visualized-BGE** model for efficient multimodal retrieval tasks. Upload your documents or images, index them, and perform searches using text, images, or a combination of both.

    **References:**
    - [Visualized-BGE Paper](https://arxiv.org/abs/2406.04292)
    - [FlagEmbedding GitHub](https://github.com/FlagOpen/FlagEmbedding)
    """)

if __name__ == "__main__":
    demo.launch(debug=True, share=True)