AmithAdiraju1694 commited on
Commit
533891e
·
verified ·
1 Parent(s): 2eeab11

Upload 10 files

Browse files
Files changed (9) hide show
  1. __init__.py +1 -0
  2. base_frame_medoid.npz +3 -0
  3. model_inference.py +168 -0
  4. packages.txt +3 -0
  5. pages.py +107 -0
  6. preprocessing.py +109 -0
  7. requirements.txt +38 -0
  8. runtime.txt +1 -0
  9. utils.py +143 -0
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, 8, -1] because the unspecified dimension size -1 can be any value and is ambiguous
base_frame_medoid.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6a14d1f517dcc8296c0f67402f44ffb071ed751749c27f4641f84e20f4e99ff1
3
+ size 717
model_inference.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+ import torch
3
+ from PIL import Image
4
+
5
+ import torch.nn as nn
6
+ import torchvision.models as models
7
+ import torch.nn.functional as F
8
+
9
+ from PIL import Image
10
+ from utils import prompt_frame_summarization, assistant_role, prompt_audio_summarization
11
+ import streamlit as st
12
+ from utils import timer
13
+
14
+ import numpy as np
15
+ import whisper
16
+ from utils import batch_generator, cosine_sim
17
+
18
+
19
+
20
+ class SiameseNetwork(nn.Module):
21
+ def __init__(self, model_name="vit_b_16"):
22
+ super(SiameseNetwork, self).__init__()
23
+
24
+ self.encoder = models.vit_b_16(weights="IMAGENET1K_V1") # Pretrained ViT
25
+ self.encoder.heads = nn.Identity() # Remove classification head
26
+
27
+ self.fc = nn.Linear(768, 128) # Reduce to 128-d embedding
28
+
29
+ def forward(self, video_frames1, video_frames2):
30
+ """
31
+ video1: (B, nf, H, W, C) # Batch of videos (50 frames each)
32
+ video2: (B, nf, H, W, C)
33
+ """
34
+ B,num_frames,H,W,C = video_frames1.shape # (Batch, Channels, H, W)
35
+
36
+ # Flatten frames into batch dimension for ViT
37
+ video_frames1 = video_frames1.permute(0,1,4,2,3).reshape(B * num_frames, C,H,W)
38
+ video_frames2 = video_frames2.reshape(0,1,4,2,3).reshape(B * num_frames, C,H,W)
39
+
40
+ # Extract frame-level embeddings
41
+ emb1 = self.encoder(video_frames1) # (B*num_frames, 768)
42
+ emb2 = self.encoder(video_frames2)
43
+
44
+ # Reshape back to (B, T, 768) and average over T
45
+ #TODO: Change this to use LSTM instead of averaging
46
+ emb1 = emb1.reshape(B, num_frames, -1).mean(dim=1) # (B, 768)
47
+ emb2 = emb2.reshape(B, num_frames, -1).mean(dim=1)
48
+
49
+ # Pass through fully connected layer
50
+ emb1 = self.fc(emb1) # (B, 128)
51
+ emb2 = self.fc(emb2)
52
+
53
+ return emb1, emb2
54
+
55
+ def inference(self, video_frames):
56
+ """
57
+ video: (B, 50, C, H, W)
58
+ """
59
+ B, num_frames, H, W, C = video_frames.shape
60
+
61
+ video_frames = video_frames.permute(0,1,4,2,3).reshape(B * num_frames, C,H,W)
62
+ emb = self.encoder(video_frames)
63
+ emb = emb.reshape(B, num_frames, -1).mean(dim=1)
64
+ emb = self.fc(emb)
65
+
66
+ return emb
67
+
68
+
69
+ @timer
70
+ def summarize_from_audio(audio_tensor):
71
+
72
+ # Transcribe the in-memory audio
73
+ result = audio_transcriber_model.transcribe(audio_tensor)
74
+ all_transcription_segments = result["text"]
75
+
76
+ summary = text_summarizer(prompt_audio_summarization + all_transcription_segments,
77
+ max_length=108,
78
+ min_length=36, do_sample=False)[0]['summary_text']
79
+
80
+
81
+ return summary
82
+
83
+
84
+ def get_important_frames_ML(frame):
85
+ """
86
+ Classifies frames using your second ML model.
87
+ """
88
+ # Implement your model's logic here
89
+ # ...
90
+ return None
91
+
92
+ def Vit_Summarize_Video(video_frames):
93
+ """
94
+ Summarizes video frames into a text sentence.
95
+ """
96
+
97
+ processor = None
98
+ messages = None
99
+ model = None
100
+ tokenizer = None
101
+
102
+ if video_frames is None or len(video_frames) == 0:
103
+ return "Error: No video frames available."
104
+
105
+
106
+ # Ensure frames are properly formatted
107
+ video_frames = [Image.fromarray(frame.astype("uint8")) for frame in video_frames]
108
+
109
+ # Ensure correct format for processor
110
+ inputs = processor(messages, images=None, videos=[video_frames])
111
+
112
+ inputs.update({
113
+ "tokenizer": tokenizer,
114
+ "max_new_tokens": 54,
115
+ "decode_text": True,
116
+ })
117
+
118
+ summary_text = model.generate(**inputs)
119
+
120
+ return summary_text
121
+
122
+ @timer
123
+ def rate_video_frames(video_frames):
124
+ """
125
+ Classifies video frames into another category.
126
+ """
127
+
128
+ tensor = torch.tensor(
129
+ np.array(video_frames),
130
+ dtype = torch.float32
131
+ ).reshape(len(video_frames)//5, 5, 224,224,3) # 20,5,224,224,3
132
+ video_frame_emb = video_rating_model.inference(tensor) # 20,128
133
+
134
+ overall_sim, count_upg = cosine_sim(emb1 = base_frame_emb,
135
+ emb2 = video_frame_emb,
136
+ threshold=0.4
137
+ )
138
+
139
+ if count_upg / (len(video_frames)//5) > 0.5:
140
+ return f"Out of {len(video_frames)} important moments of this video, {count_upg*5} moments contain under or at least PG content. Hence this video is suitable for kids & family."
141
+ else:
142
+ return f"Out of {len(video_frames)} important moments of this video, {(len(video_frames)//5 - count_upg)*5} moments contain at least PG-13 content.Hence parental guidance is strongly suggested for this video."
143
+
144
+ @st.cache_resource
145
+ def load_models():
146
+
147
+ transcriber = whisper.load_model("base")
148
+ summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
149
+
150
+ base_frame_emb = torch.tensor(
151
+ np.load('base_frame_medoid.npz')['arr'],
152
+ dtype = torch.float32
153
+ )
154
+
155
+
156
+ video_rating_model = SiameseNetwork()
157
+ # video_rating_model.load_state_dict(
158
+ # torch.load('/Users/amithadiraju/Desktop/Video_Summary_App/video_contrastive-siamese_v3.pt',
159
+ # weights_only = True
160
+ # )
161
+ # )
162
+ video_rating_model.eval()
163
+
164
+ return (
165
+ transcriber, summarizer, video_rating_model, base_frame_emb
166
+ )
167
+
168
+ audio_transcriber_model, text_summarizer, video_rating_model,base_frame_emb = load_models()
packages.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ffmpeg
2
+ libsndfile1
3
+ pkgconfig
pages.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit import session_state as sst
3
+ import time
4
+
5
+ import pandas as pd
6
+ from utils import navigate_to
7
+
8
+ from model_inference import rate_video_frames,summarize_from_audio
9
+ from utils import read_important_frames, extract_audio
10
+ import numpy as np
11
+
12
+
13
+ # Define size limits (adjust based on your system)
14
+ SMALL_VIDEO_LIMIT_MB = 35 # Files ≤ 35MB are small
15
+ LARGE_VIDEO_LIMIT_MB = 50 # Max large file upload allowed
16
+
17
+ # Convert MB to bytes
18
+ SMALL_VIDEO_LIMIT_BYTES = SMALL_VIDEO_LIMIT_MB * 1024 * 1024
19
+ LARGE_VIDEO_LIMIT_BYTES = LARGE_VIDEO_LIMIT_MB * 1024 * 1024
20
+
21
+ async def landing_page():
22
+
23
+ uploaded_file = st.file_uploader("Upload a video",
24
+ type=["mp4", "avi", "mov"])
25
+
26
+ if uploaded_file is not None:
27
+ file_size = uploaded_file.size # Get file size in bytes
28
+
29
+ # Restrict max file upload size
30
+ if file_size > LARGE_VIDEO_LIMIT_BYTES:
31
+ st.error(f"File is too large! Max allowed size is {LARGE_VIDEO_LIMIT_MB}MB. Please upload a smaller version of it.")
32
+
33
+ else:
34
+ # bytes object which can be translated to audio or video
35
+ video_bytes = uploaded_file.read()
36
+
37
+ with st.spinner("Getting most important moments from your video."):
38
+ important_frames = read_important_frames(video_bytes, 100)
39
+ st.success(f"Got important moments.")
40
+
41
+ print(f"Time taken to extract {len(important_frames)} important frames: {read_important_frames.total_time}")
42
+
43
+
44
+ with st.spinner("Getting audio transcript from your video for summary"):
45
+ audio_transcript_bytes = extract_audio(video_bytes)
46
+ st.success(f"Got audio transcript.")
47
+
48
+ print(f"Time taken to extract audio data: {extract_audio.total_time}")
49
+
50
+ # add important frames to session state and redirect to model inference page
51
+ sst["important_frames"] = important_frames
52
+
53
+ # add audio transcript to session state
54
+ sst["audio_transcript"] = audio_transcript_bytes
55
+
56
+ st.button("Summarize & Analyze Video",
57
+ on_click = navigate_to,
58
+ args = ("model_inference_page",)
59
+ )
60
+
61
+
62
+
63
+ async def model_inference_page():
64
+
65
+ df = pd.DataFrame([('Video_Text_Summary', 'Video_Rating_Scale')])
66
+ sl_df = st.table(df)
67
+
68
+ # check if audio is present and it's non-empty
69
+ if "audio_transcript" in sst:
70
+
71
+ video_summary_text = summarize_from_audio(sst["audio_transcript"])
72
+
73
+ if len(video_summary_text) > 0:
74
+ pass
75
+ else:
76
+ video_summary_text = "Sorry, we couldn't find any audio data from your video, hence couldn't generate any summary"
77
+
78
+ print("Time taken to generate text summary from audio in seconds: ", summarize_from_audio.total_time)
79
+
80
+
81
+ # check if frames are present and they are non-empty
82
+ if "important_frames" in sst:
83
+
84
+ important_frames = sst["important_frames"]
85
+ with st.spinner("Generating text summary for your video"):
86
+ video_rating_scale = rate_video_frames(important_frames)
87
+
88
+ if len(video_rating_scale) > 0:
89
+ pass
90
+ else:
91
+ video_rating_scale = "Sorry, we couldn't find any images from your video, hence couldn't generate any summary"
92
+
93
+ print("Time taken to generate video rating in seconds: ", rate_video_frames.total_time)
94
+
95
+ sl_df.add_rows(
96
+
97
+ [( video_summary_text, video_rating_scale ) ]
98
+
99
+ )
100
+
101
+ st.button("Go Home",
102
+ on_click = navigate_to,
103
+ args = ("landing_page",)
104
+ )
105
+
106
+
107
+
preprocessing.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ def pad_to_center(img,img_height, img_width, req_height, req_width):
4
+ """
5
+ Goal of this function is to increase original image size, upto the
6
+ req_height and width.
7
+
8
+ Parameters:
9
+ img -> 3D numpy array of shape Height x Width x Channels
10
+ img_height -> height of current image
11
+ img_width -> width of current image
12
+ req_height -> max height you want your image to be padded to.
13
+ req_width -> max width you want your image to be padded to.
14
+ """
15
+
16
+ # How many rows and columns needs to be added to make current image
17
+ # upto what is needed
18
+ rem_height = req_height - img_height
19
+ rem_width = req_width - img_width
20
+
21
+ # split the remaining height to be added evenly, to pad on top
22
+ # and bottom
23
+ pad_top = rem_height // 2
24
+ pad_bottom = rem_height - pad_top
25
+
26
+ # split the remaining width to be added evenly, to pad on left
27
+ # and right
28
+ pad_left = rem_width // 2
29
+ pad_right = rem_width - pad_left
30
+
31
+ # Don't pad along batch size and channels dimension, pad everything else to required height and width
32
+ # we are basically telling how many values to add on each of the 4 sides of image
33
+ return np.pad(
34
+ img,
35
+ (
36
+ (pad_top, pad_bottom),
37
+ (pad_left, pad_right),
38
+ (0,0)
39
+ ),
40
+
41
+ mode = 'reflect'
42
+ )
43
+
44
+
45
+ def crop_to_center(img, img_height, img_width, req_height, req_width, req_channel = 3):
46
+ """
47
+ Goal of this function is to reduce the original image(s) size to the required height and width,
48
+ by making sure that we only trim the edges of images and not the middle part.
49
+
50
+ Parameters:
51
+ img -> 4d numpy array batch_size/num frames x Height x width x Channels
52
+ """
53
+
54
+ # difference in height and widths of an image, divided into equal halfs
55
+ toph = (img_height - req_height)//2
56
+ leftw = (img_width - req_width)//2
57
+
58
+ # from bottom of image, how far up to go, so that top-bottom is center of image
59
+ bothen = img_height - toph
60
+
61
+ # from right of image, how far left to go, so that left-right is center of image
62
+ rightwen = img_width - leftw
63
+
64
+ cropped_image = img[toph:bothen, leftw:rightwen, :]
65
+
66
+ assert cropped_image.shape == (req_height, req_width, req_channel)
67
+ return cropped_image
68
+
69
+
70
+ def preprocess_images(img, req_height, req_width):
71
+
72
+ """
73
+ Crops the input image array to the specified height and width,
74
+ centered around the middle.
75
+
76
+ Args:
77
+ img (np.ndarray): The image to crop, represented as a NumPy array
78
+ (height, width, channels).
79
+ crop_height (int): The desired height of the cropped image.
80
+ crop_width (int): The desired width of the cropped image.
81
+
82
+ Returns:
83
+ np.ndarray: The center-cropped image.
84
+ """
85
+
86
+ image_shape_tuple = img.shape
87
+ assert len(image_shape_tuple) == 3, f"Please pass a 3D image with height, width and channels , you passed: {image_shape_tuple}"
88
+
89
+ # assuming it to be H,W,C
90
+ img_height, img_width, img_channel = image_shape_tuple
91
+
92
+ # if original image height is less than req_height or original width less than req_width
93
+ # pad them until required dimensions and return
94
+ if img_height < req_height or img_width < req_width:
95
+ return pad_to_center(img, img_height, img_width, req_height, req_width)
96
+
97
+ # if your image size is same as cropped size, nothing to crop
98
+ elif img_height == req_height and img_width == req_width:
99
+ return img
100
+
101
+ # your image height is greater than required height and width, so crop it to center frames.
102
+ else:
103
+ return crop_to_center(img,
104
+ img_height,
105
+ img_width,
106
+ req_height,
107
+ req_width,
108
+ img_channel
109
+ )
requirements.txt ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohappyeyeballs==2.4.0
2
+ aiohttp==3.10.5
3
+ aiosignal==1.3.1
4
+ async-timeout==4.0.3
5
+ av==14.1.0
6
+ ctranslate2==4.5.0
7
+ ffmpeg-python==0.2.0
8
+ httpcore==1.0.7
9
+ huggingface-hub==0.24.6
10
+ Jinja2==3.1.5
11
+ networkx==3.2.1
12
+ nltk==3.9.1
13
+ num2words==0.5.14
14
+ numba==0.60.0
15
+ numpy==1.26.3
16
+ openai-whisper==20240930
17
+ opencv-python==4.11.0.86
18
+ peft==0.12.0
19
+ pillow==10.4.0
20
+ protobuf==5.29.3
21
+ pydantic==1.10.21
22
+ PyYAML==6.0.2
23
+ safetensors==0.4.5
24
+ scipy==1.13.1
25
+ sentencepiece==0.2.0
26
+ smmap==5.0.2
27
+ sniffio==1.3.1
28
+ soundfile==0.13.1
29
+ sseclient-py==1.8.0
30
+ streamlit==1.42.0
31
+ tiktoken==0.8.0
32
+ tokenizers==0.21.0
33
+ torch==2.6.0
34
+ torchaudio==2.6.0
35
+ torchvision==0.21.0
36
+ transformers==4.48.3
37
+ typeguard==4.4.1
38
+ typing_extensions==4.12.2
runtime.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.9.6
utils.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from streamlit import session_state as sst
2
+ import time
3
+ import torch.nn.functional as F
4
+
5
+ import cv2
6
+ import av
7
+ import heapq
8
+
9
+ import numpy as np
10
+ from preprocessing import preprocess_images
11
+ import time
12
+
13
+ import io
14
+ from io import BytesIO
15
+ import torch
16
+ import soundfile as sf
17
+ import subprocess
18
+ from typing import List
19
+
20
+
21
+ prompt_frame_summarization = "These are important frames of a video file. Please generate summary such that end user gets gist of what the video is about."
22
+ prompt_audio_summarization = "This is a video transcript, tell me what is this about: "
23
+ assistant_role = "You are agent who summarizes videos from important frames, use domain specific language to generate summary: sports, cartoon, education,finance etc."
24
+
25
+ def timer(func):
26
+ def wrapper(*args, **kwargs):
27
+ start = time.time()
28
+ result = func(*args, **kwargs)
29
+ duration = time.time() - start
30
+ wrapper.total_time += duration
31
+ print(f"Execution time of {func}: {duration}")
32
+ return result
33
+
34
+ wrapper.total_time = 0
35
+ return wrapper
36
+
37
+ def navigate_to(page: str) -> None:
38
+ """
39
+ Function to set the current page in the state of streamlit. A helper for
40
+ simulating navigation in streamlit.
41
+
42
+ Parameters:
43
+ page: str, required.
44
+
45
+ Returns:
46
+ None
47
+ """
48
+
49
+ sst["page"] = page
50
+
51
+ @timer
52
+ def read_important_frames(video_bytes, top_k_frames) -> List:
53
+
54
+ # reading uploaded vidoe in memory
55
+ video_io = io.BytesIO(video_bytes)
56
+
57
+ # opening uploaded video frames
58
+ container = av.open(video_io, format='mp4')
59
+
60
+ prev_frame = None; important_frames = []
61
+
62
+
63
+ # for each frame, find if it's movement worthy and push to heap for top_k movement frames
64
+ for frameId, frame in enumerate( container.decode(video=0) ): # Decode all frames
65
+
66
+ img = frame.to_ndarray(format="bgr24") # Convert frame to NumPy array (BGR format)
67
+ assert len(img.shape) == 3, f"Instead it is: {img.shape}"
68
+
69
+ if prev_frame is not None:
70
+
71
+ # Compute frame difference in gray scale for efficiency
72
+ diff = cv2.absdiff(prev_frame, img)
73
+ gray_diff = cv2.cvtColor(diff, cv2.COLOR_BGR2GRAY)
74
+
75
+ movement_score = np.sum(gray_diff) # Sum of pixel differences
76
+ processed_frame = preprocess_images(frame.to_ndarray(format="rgb24") ,
77
+ 224,
78
+ 224
79
+ )
80
+
81
+ # Thresholding to detect movement (adjust based on video)
82
+ if len(important_frames) < top_k_frames: # Tune threshold for motion sensitivity
83
+
84
+ heapq.heappush(important_frames,
85
+ (movement_score, frameId, processed_frame)
86
+ )
87
+ else:
88
+ heapq.heappushpop(important_frames,
89
+ (movement_score, frameId, processed_frame)
90
+ )
91
+
92
+ prev_frame = img # Update previous frame
93
+
94
+
95
+ # sorting top_k frames in chronological order of their appearance. This is quickest LOC.
96
+ important_frames = [item[2] for item in sorted(important_frames, key = lambda x: x[1])]
97
+ return important_frames
98
+
99
+ @timer
100
+ def extract_audio(video_bytes):
101
+ """Extracts raw audio from a video file given as bytes without writing temp files."""
102
+
103
+
104
+ # Run FFmpeg to extract raw WAV audio without writing a file
105
+ process = subprocess.run(
106
+ ["ffmpeg", "-i", "pipe:0", "-ac", "1", "-ar", "16000", "-c:a", "pcm_s16le", "-f", "wav", "pipe:1"],
107
+ input=video_bytes,
108
+ stdout=subprocess.PIPE,
109
+ stderr=subprocess.DEVNULL
110
+ )
111
+
112
+ # Convert FFmpeg output to a BytesIO stream
113
+ audio_stream = BytesIO(process.stdout)
114
+
115
+ # Read the audio stream into a NumPy array
116
+ audio_array, sample_rate = sf.read(audio_stream, dtype="float32")
117
+
118
+ # Convert to PyTorch tensor (Whisper expects a torch.Tensor)
119
+ audio_tensor = torch.tensor(audio_array)
120
+
121
+ return audio_tensor
122
+
123
+ def batch_generator(array_list, batch_size=5):
124
+ """
125
+ Generator that yields batches of 5 NumPy arrays stacked along the first dimension.
126
+
127
+ Parameters:
128
+ array_list (list of np.ndarray): List of NumPy arrays of shape (H, W, C).
129
+ batch_size (int): Number of arrays per batch (default is 5).
130
+
131
+ Yields:
132
+ np.ndarray: A batch of shape (batch_size, H, W, C).
133
+ """
134
+ for i in range(0, len(array_list), batch_size):
135
+ batch = array_list[i:i + batch_size]
136
+ if len(batch) == batch_size:
137
+ yield np.stack(batch, axis=0)
138
+
139
+ @timer
140
+ def cosine_sim(emb1, emb2, threshold = 0.5):
141
+ cosine_sim = F.cosine_similarity(emb1, emb2)
142
+ counts = torch.count_nonzero(cosine_sim > threshold).numpy()
143
+ return (cosine_sim.mean(), counts)