Spaces:
Build error
Build error
import os | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
import requests | |
import streamlit as st | |
from PIL import Image | |
from utils import load_model | |
def split_image(im, num_rows=3, num_cols=3): | |
im = np.array(im) | |
row_size = im.shape[0] // num_rows | |
col_size = im.shape[1] // num_cols | |
tiles = [ | |
im[row : row + row_size, col : col + col_size] | |
for row in range(0, num_rows * row_size, row_size) | |
for col in range(0, num_cols * col_size, col_size) | |
] | |
return tiles | |
def app(model_name): | |
model, processor = load_model(f"koclip/{model_name}") | |
st.title("Patch-based Relevance Ranking") | |
st.markdown( | |
""" | |
Given a piece of text, the CLIP model finds the part of an image that best explains the text. | |
To try it out, you can | |
1. Upload an image | |
2. Explain a part of the image in text | |
which will yield the most relevant image tile from a grid of the image. You can specify how | |
granular you want to be with your search by specifying the number of rows and columns that | |
make up the image grid. | |
--- | |
""" | |
) | |
query1 = st.text_input( | |
"Enter a URL to an image...", | |
value="https://img.sbs.co.kr/newimg/news/20200823/201463830_1280.jpg", | |
) | |
query2 = st.file_uploader("or upload an image...", type=["jpg", "jpeg", "png"]) | |
captions = st.text_input( | |
"Enter a prompt to query the image.", | |
value="이건 서울의 경복궁 사진이다.", | |
) | |
col1, col2 = st.beta_columns(2) | |
with col1: | |
num_rows = st.slider( | |
"Number of rows", min_value=1, max_value=5, value=3, step=1 | |
) | |
with col2: | |
num_cols = st.slider( | |
"Number of columns", min_value=1, max_value=5, value=3, step=1 | |
) | |
if st.button("질문 (Query)"): | |
if not any([query1, query2]): | |
st.error("Please upload an image or paste an image URL.") | |
else: | |
st.markdown("""---""") | |
with st.spinner("Computing..."): | |
image_data = ( | |
query2 | |
if query2 is not None | |
else requests.get(query1, stream=True).raw | |
) | |
image = Image.open(image_data) | |
st.image(image) | |
images = split_image(image, num_rows, num_cols) | |
inputs = processor( | |
text=captions, images=images, return_tensors="jax", padding=True | |
) | |
inputs["pixel_values"] = jnp.transpose( | |
inputs["pixel_values"], axes=[0, 2, 3, 1] | |
) | |
outputs = model(**inputs) | |
probs = jax.nn.softmax(outputs.logits_per_image, axis=0) | |
for idx, prob in sorted( | |
enumerate(probs), key=lambda x: x[1], reverse=True | |
): | |
st.text(f"Score: {prob[0]:.3f}") | |
st.image(images[idx]) | |