jaketae commited on
Commit
8ff0261
โ€ข
1 Parent(s): cb6e7b6

feature: add image2text feature

Browse files
Files changed (1) hide show
  1. image2text.py +17 -12
image2text.py CHANGED
@@ -1,5 +1,6 @@
1
  import streamlit as st
2
  import numpy as np
 
3
  import jax.numpy as jnp
4
  from PIL import Image
5
 
@@ -24,17 +25,21 @@ def app(model_name):
24
  st.error("Please upload an image query.")
25
  else:
26
  image = Image.open(query)
27
- pixel_values = processor(
28
- text=[""], images=image, return_tensors="jax", padding=True
29
- ).pixel_values
30
- pixel_values = jnp.transpose(pixel_values, axes=[0, 2, 3, 1])
31
- vec = np.asarray(model.get_image_features(pixel_values))
32
- # ids, dists = index.knnQuery(vec, k=10)
33
- # result_files = map(lambda id: files[id], ids)
34
- # result_imgs, result_captions = [], []
35
- # for file, dist in zip(result_files, dists):
36
- # result_imgs.append(plt.imread(os.path.join(images_directory, file)))
37
- # result_captions.append("{:s} (์œ ์‚ฌ๋„: {:.3f})".format(file, 1.0 - dist))
38
-
 
 
 
 
39
 
40
 
 
1
  import streamlit as st
2
  import numpy as np
3
+ import jax
4
  import jax.numpy as jnp
5
  from PIL import Image
6
 
 
25
  st.error("Please upload an image query.")
26
  else:
27
  image = Image.open(query)
28
+ st.image(image)
29
+ # pixel_values = processor(
30
+ # text=[""], images=image, return_tensors="jax", padding=True
31
+ # ).pixel_values
32
+ # pixel_values = jnp.transpose(pixel_values, axes=[0, 2, 3, 1])
33
+ # vec = np.asarray(model.get_image_features(pixel_values))
34
+ captions = captions.split(",")
35
+ inputs = processor(text=captions, images=image, return_tensors="jax", padding=True)
36
+ inputs["pixel_values"] = jnp.transpose(
37
+ inputs["pixel_values"], axes=[0, 2, 3, 1]
38
+ )
39
+ outputs = model(**inputs)
40
+ probs = jax.nn.softmax(outputs.logits_per_image, axis=1)
41
+
42
+ for idx, prob in sorted(enumerate(*probs), key=lambda x: x[1], reverse=True):
43
+ st.text(f"Score: `{prob}`, {captions[idx]}")
44
 
45