import streamlit as st import torch from PIL import Image import tempfile import os import time import json from transformers import Qwen2VLForConditionalGeneration, AutoProcessor from qwen_vl_utils import process_vision_info class Qwen2Wrapper: def __init__(self, model_name="Qwen/Qwen2-VL-7B-Instruct", device="cpu"): self.model = Qwen2VLForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.float32, device_map=device) self.processor = AutoProcessor.from_pretrained(model_name) self.device = device self.index = {} def index_image(self, image_path, index_name, overwrite=True): if index_name in self.index and not overwrite: raise ValueError(f"Index {index_name} already exists. Use overwrite=True to replace.") self.index[index_name] = {"image": Image.open(image_path), "extracted_text": None} def search(self, query, index_name, k=1): if index_name not in self.index: raise ValueError(f"Index {index_name} does not exist.") image = self.index[index_name]["image"] if self.index[index_name]["extracted_text"] is None: self.index[index_name]["extracted_text"] = self._extract_text(image) return [{"metadata": {"ocr_text": self.index[index_name]["extracted_text"]}}] def _extract_text(self, image): conversation_history = [] def prompt_text(query): conversation_history.append({ "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": query}, ], }) messages = conversation_history[:] text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) image_inputs, _ = process_vision_info(messages) inputs = self.processor(text=[text], images=image_inputs, padding=True, return_tensors="pt") inputs = inputs.to(self.device) generated_ids = self.model.generate(**inputs, max_new_tokens=128) generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] output_text = self.processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False) conversation_history.append({ "role": "system", "content": {"type": "text", "text": output_text[0]}, }) return output_text[0] return prompt_text("give me just the text extracted") # Function to load Qwen2 model @st.cache_resource def load_qwen2_model(): return Qwen2Wrapper(device="cpu") # Streamlit Interface st.title("OCR with Qwen2 (Byaldi-style implementation)") st.write("Upload an image for OCR processing (supports Hindi and English text)") # Image uploader image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) if image: img = Image.open(image) st.image(img, caption="Uploaded Image", use_column_width=True) # OCR Extraction st.write("Extracting text from image using Qwen2...") qwen2_model = load_qwen2_model() with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file: img.save(temp_file, format="JPEG") temp_file_path = temp_file.name # Create a unique index name unique_index_name = f"temp_index_{int(time.time())}" # Index the image qwen2_model.index_image(temp_file_path, unique_index_name) # Perform search (which triggers text extraction) ocr_results = qwen2_model.search("Extract all text from the image", unique_index_name) extracted_text = ocr_results[0]["metadata"]["ocr_text"] # Remove the temporary file os.unlink(temp_file_path) # Display results st.subheader("Qwen2 OCR Result:") st.text(extracted_text) st.json(json.dumps({"extracted_text": extracted_text}, ensure_ascii=False, indent=2)) # Keyword search st.subheader("Search in Extracted Text") keywords = st.text_input("Enter keywords to search (separate multiple keywords with commas)") if keywords: search_keywords = [k.strip() for k in keywords.split(',')] def search_text(text, keywords): words = text.split() results = [word for word in words if any(keyword.lower() in word.lower() for keyword in keywords)] return results search_results = search_text(extracted_text, search_keywords) st.write("Qwen2 Search Results:") st.write(", ".join(search_results) if search_results else "No matches found.")