Spaces:
Running
Running
File size: 5,330 Bytes
8eae3a5 96e35ac 8eae3a5 96e35ac 8eae3a5 96e35ac 8eae3a5 b5ba0b7 8eae3a5 b5ba0b7 96e35ac b5ba0b7 8eae3a5 96e35ac 8eae3a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import streamlit as st
import os
import torch
from byaldi import RAGMultiModalModel
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
# Check for CUDA availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Caching the model loading
@st.cache_resource
def load_rag_model():
return RAGMultiModalModel.from_pretrained("vidore/colpali")
@st.cache_resource
def load_qwen_model():
return Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct",
trust_remote_code=True,
torch_dtype=torch.bfloat16
).to(device).eval()
@st.cache_resource
def load_processor():
return AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True)
# Load models
RAG = load_rag_model()
model = load_qwen_model()
processor = load_processor()
st.title("Multimodal RAG App")
st.warning("⚠️ Disclaimer: This app is currently running on CPU, which may result in slow processing times. For optimal performance, download and run the app locally on a machine with GPU support.")
# Add download link
st.markdown("[📥 Download the app code](https://huggingface.co./spaces/clayton07/colpali-qwen2-ocr/blob/main/app.py)")
# Initialize session state
if 'index_created' not in st.session_state:
st.session_state.index_created = False
if 'processed_images' not in st.session_state:
st.session_state.processed_images = set()
# File uploader
image_source = st.radio("Choose image source:", ("Upload an image", "Use example image"))
if image_source == "Upload an image":
uploaded_file = st.file_uploader("Choose an image file", type=["png", "jpg", "jpeg"])
else:
# Use a pre-defined example image
example_image_path = "hindi-qp.jpg"
uploaded_file = example_image_path
if uploaded_file is not None:
# If using the example image, no need to save it
if image_source == "Upload an image":
with open("temp_image.png", "wb") as f:
f.write(uploaded_file.getvalue())
image_path = "temp_image.png"
else:
image_path = uploaded_file
# Check if this image has been processed before
if image_path not in st.session_state.processed_images:
with st.spinner('Processing image...'):
if not st.session_state.index_created:
# Initialize the index for the first image
RAG.index(
input_path=image_path,
index_name="temp_index",
store_collection_with_index=False,
overwrite=True
)
st.session_state.index_created = True
st.success('Index created successfully!')
else:
# Add to the existing index for subsequent images
RAG.add_to_index(
input_item=image_path,
store_collection_with_index=False
)
st.success('Image added to index successfully!')
# Mark this image as processed
st.session_state.processed_images.add(image_path)
st.image(image_path, caption="Uploaded Image", use_column_width=True)
# Text query input
text_query = st.text_input("Enter your query about the image:")
max_new_tokens = st.slider("Max new tokens for response", min_value=100, max_value=1000, value=100, step=10)
if text_query:
with st.spinner(
f'Processing your query... This may take a while due to CPU processing. Generating up to {max_new_tokens} tokens.'):
# Perform RAG search
results = RAG.search(text_query, k=2)
# Process with Qwen2VL model
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": image_path,
},
{"type": "text", "text": text_query},
],
}
]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(device)
generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
# Display results
st.subheader("Results:")
st.write(output_text[0])
# Clean up temporary file
if image_source == "Upload an image" and os.path.exists("temp_image.png"):
os.remove("temp_image.png")
else:
st.write("Please upload an image to get started.") |