change_medoid_inference
#3
by
AmithAdiraju1694
- opened
- base_frame_medoid.npz +1 -1
- video_rating_siamesev2.onnx β checkpoint__fl_batch_480_epoch_0.pt +2 -2
- model_inference.py +32 -25
- pages.py +3 -3
- utils.py +31 -0
base_frame_medoid.npz
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 772
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9dad2de4ef28891c9cf509177d3ff24beb0eea068fe3d8159b4b1050d4f55139
|
3 |
size 772
|
video_rating_siamesev2.onnx β checkpoint__fl_batch_480_epoch_0.pt
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:15a5a088b3fc06010de0b00ccbf00f0be0058c96d667c417c9a3188fcfb6e6dc
|
3 |
+
size 382240253
|
model_inference.py
CHANGED
@@ -3,20 +3,21 @@ import torch
|
|
3 |
from utils import (
|
4 |
prompt_audio_summarization,
|
5 |
timer,
|
6 |
-
cosine_sim
|
|
|
7 |
)
|
8 |
from transformers import BartForConditionalGeneration, BartTokenizer
|
9 |
import numpy as np
|
10 |
import whisper
|
11 |
from streamlit import session_state as sst
|
12 |
-
import
|
13 |
|
14 |
|
15 |
@timer
|
16 |
def get_text_from_audio(audio_tensors) -> str:
|
17 |
"""Transcribe multiple audio tensors in parallel using Whisper's batch processing."""
|
18 |
# Transcribe the in-memory audio
|
19 |
-
audio_tensors = audio_tensors
|
20 |
result = audio_transcriber_model.transcribe(audio_tensors
|
21 |
)
|
22 |
all_transcription_segments = result["text"]
|
@@ -28,17 +29,23 @@ def summarize_from_text(raw_transcription):
|
|
28 |
inputs = text_summarizer[0](prompt_audio_summarization + raw_transcription,
|
29 |
return_tensors="pt",
|
30 |
max_length=1024,
|
31 |
-
truncation=True)
|
32 |
-
|
33 |
-
|
34 |
summary_ids = text_summarizer[1].generate(**inputs,
|
35 |
max_length=150,
|
36 |
min_length=30,
|
37 |
length_penalty=2.0,
|
38 |
-
num_beams=4
|
|
|
39 |
)
|
40 |
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
@timer
|
44 |
def rate_video_frames(video_frames):
|
@@ -47,39 +54,39 @@ def rate_video_frames(video_frames):
|
|
47 |
"""
|
48 |
|
49 |
inp_frames = np.array(video_frames, dtype = np.float32).reshape(len(video_frames)//5, 5, 224,224,3)# 20,5,224,224,3
|
50 |
-
|
51 |
-
|
52 |
-
video_frame_emb = video_rating_model.run(['emb'], inputs_dict)[0]
|
53 |
|
54 |
overall_sim, count_upg = cosine_sim(emb1 = base_frame_emb,
|
55 |
-
emb2 =
|
56 |
-
threshold=0.
|
57 |
)
|
58 |
|
59 |
perc_of_upg = count_upg / (len(video_frames)//5)
|
60 |
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
65 |
|
66 |
@st.cache_resource
|
67 |
def load_models():
|
68 |
-
|
69 |
-
transcriber = whisper.load_model("base", device = sst['device'])
|
70 |
|
71 |
-
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
|
72 |
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
|
73 |
|
74 |
base_frame_emb = torch.tensor(
|
75 |
np.load('base_frame_medoid.npz')['arr'],
|
76 |
-
dtype = torch.float32
|
77 |
-
device = sst['device']
|
78 |
)
|
79 |
|
80 |
-
session =
|
81 |
-
|
82 |
-
|
|
|
|
|
83 |
|
84 |
return (
|
85 |
transcriber, (tokenizer, model), session, base_frame_emb
|
|
|
3 |
from utils import (
|
4 |
prompt_audio_summarization,
|
5 |
timer,
|
6 |
+
cosine_sim,
|
7 |
+
SiameseNetwork
|
8 |
)
|
9 |
from transformers import BartForConditionalGeneration, BartTokenizer
|
10 |
import numpy as np
|
11 |
import whisper
|
12 |
from streamlit import session_state as sst
|
13 |
+
import math
|
14 |
|
15 |
|
16 |
@timer
|
17 |
def get_text_from_audio(audio_tensors) -> str:
|
18 |
"""Transcribe multiple audio tensors in parallel using Whisper's batch processing."""
|
19 |
# Transcribe the in-memory audio
|
20 |
+
audio_tensors = audio_tensors
|
21 |
result = audio_transcriber_model.transcribe(audio_tensors
|
22 |
)
|
23 |
all_transcription_segments = result["text"]
|
|
|
29 |
inputs = text_summarizer[0](prompt_audio_summarization + raw_transcription,
|
30 |
return_tensors="pt",
|
31 |
max_length=1024,
|
32 |
+
truncation=True)
|
33 |
+
|
|
|
34 |
summary_ids = text_summarizer[1].generate(**inputs,
|
35 |
max_length=150,
|
36 |
min_length=30,
|
37 |
length_penalty=2.0,
|
38 |
+
num_beams=4,
|
39 |
+
early_stopping = True
|
40 |
)
|
41 |
|
42 |
+
prediction_string = text_summarizer[0].decode(summary_ids[0], skip_special_tokens=True)
|
43 |
+
|
44 |
+
|
45 |
+
if prompt_audio_summarization[:15] == prediction_string[:15]:
|
46 |
+
prediction_string = prediction_string[50: ]
|
47 |
+
|
48 |
+
return prediction_string
|
49 |
|
50 |
@timer
|
51 |
def rate_video_frames(video_frames):
|
|
|
54 |
"""
|
55 |
|
56 |
inp_frames = np.array(video_frames, dtype = np.float32).reshape(len(video_frames)//5, 5, 224,224,3)# 20,5,224,224,3
|
57 |
+
with torch.no_grad():
|
58 |
+
video_frame_emb = video_rating_model(torch.tensor(inp_frames) )
|
|
|
59 |
|
60 |
overall_sim, count_upg = cosine_sim(emb1 = base_frame_emb,
|
61 |
+
emb2 = video_frame_emb,
|
62 |
+
threshold=0.95
|
63 |
)
|
64 |
|
65 |
perc_of_upg = count_upg / (len(video_frames)//5)
|
66 |
|
67 |
+
perc_of_upg = math.floor(perc_of_upg*100)
|
68 |
+
non_upg_perc = math.ceil(100 - perc_of_upg)
|
69 |
+
|
70 |
+
response_string = f"{perc_of_upg} % of important moments from this video contain Under or PG content, rest of {non_upg_perc} % moments contain atleast PG-13, R or even NC-17 content."
|
71 |
+
return response_string
|
72 |
|
73 |
@st.cache_resource
|
74 |
def load_models():
|
75 |
+
transcriber = whisper.load_model("base")
|
|
|
76 |
|
77 |
+
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
|
78 |
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
|
79 |
|
80 |
base_frame_emb = torch.tensor(
|
81 |
np.load('base_frame_medoid.npz')['arr'],
|
82 |
+
dtype = torch.float32
|
|
|
83 |
)
|
84 |
|
85 |
+
session = SiameseNetwork()
|
86 |
+
checkpoint = torch.load('./checkpoint__fl_batch_480_epoch_0.pt', weights_only=True)
|
87 |
+
session.load_state_dict(checkpoint['model_state_dict'])
|
88 |
+
_ = session.eval()
|
89 |
+
|
90 |
|
91 |
return (
|
92 |
transcriber, (tokenizer, model), session, base_frame_emb
|
pages.py
CHANGED
@@ -76,7 +76,7 @@ async def model_inference_page():
|
|
76 |
try:
|
77 |
video_rating_scale = rate_video_frames(important_frames)
|
78 |
except Exception as e:
|
79 |
-
video_rating_scale =
|
80 |
|
81 |
st.toast("Done")
|
82 |
st.header("Movie Scale Rating of Your Video: ", divider = True)
|
@@ -91,7 +91,7 @@ async def model_inference_page():
|
|
91 |
try:
|
92 |
video_summary_text = get_text_from_audio(sst["audio_transcript"])
|
93 |
except Exception as e:
|
94 |
-
video_summary_text =
|
95 |
st.toast("Done")
|
96 |
|
97 |
if video_summary_text[:5] != "Sorry":
|
@@ -99,7 +99,7 @@ async def model_inference_page():
|
|
99 |
try:
|
100 |
video_summary_text = summarize_from_text(video_summary_text)
|
101 |
except Exception as e:
|
102 |
-
video_summary_text =
|
103 |
st.toast("Done")
|
104 |
|
105 |
st.header("Audio Transcript summary of your video: ", divider = True)
|
|
|
76 |
try:
|
77 |
video_rating_scale = rate_video_frames(important_frames)
|
78 |
except Exception as e:
|
79 |
+
video_rating_scale = "Sorry, we couldn't generate rating of your video because of this error "
|
80 |
|
81 |
st.toast("Done")
|
82 |
st.header("Movie Scale Rating of Your Video: ", divider = True)
|
|
|
91 |
try:
|
92 |
video_summary_text = get_text_from_audio(sst["audio_transcript"])
|
93 |
except Exception as e:
|
94 |
+
video_summary_text = "Sorry, we couldn't extract text from audio of this file because of this error"
|
95 |
st.toast("Done")
|
96 |
|
97 |
if video_summary_text[:5] != "Sorry":
|
|
|
99 |
try:
|
100 |
video_summary_text = summarize_from_text(video_summary_text)
|
101 |
except Exception as e:
|
102 |
+
video_summary_text = "Sorry, we couldn't summarize text from audio of this file"
|
103 |
st.toast("Done")
|
104 |
|
105 |
st.header("Audio Transcript summary of your video: ", divider = True)
|
utils.py
CHANGED
@@ -12,6 +12,8 @@ import time
|
|
12 |
|
13 |
from io import BytesIO
|
14 |
import torch
|
|
|
|
|
15 |
import soundfile as sf
|
16 |
import subprocess
|
17 |
from typing import List
|
@@ -19,6 +21,34 @@ from typing import List
|
|
19 |
prompt_audio_summarization = "This is a video transcript, tell me what is this about: "
|
20 |
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
def timer(func):
|
23 |
def wrapper(*args, **kwargs):
|
24 |
start = time.time()
|
@@ -138,3 +168,4 @@ def cosine_sim(emb1, emb2, threshold = 0.5):
|
|
138 |
cosine_sim = F.cosine_similarity(emb1, emb2)
|
139 |
counts = torch.count_nonzero(cosine_sim > threshold).numpy()
|
140 |
return (cosine_sim.mean(), counts)
|
|
|
|
12 |
|
13 |
from io import BytesIO
|
14 |
import torch
|
15 |
+
import torchvision.models as models
|
16 |
+
import torch.nn as nn
|
17 |
import soundfile as sf
|
18 |
import subprocess
|
19 |
from typing import List
|
|
|
21 |
prompt_audio_summarization = "This is a video transcript, tell me what is this about: "
|
22 |
|
23 |
|
24 |
+
class SiameseNetwork(nn.Module):
|
25 |
+
def __init__(self, model_name="vit_b_16"):
|
26 |
+
super(SiameseNetwork, self).__init__()
|
27 |
+
|
28 |
+
self.encoder = models.vit_b_16(weights="IMAGENET1K_V1") # Pretrained ViT
|
29 |
+
self.encoder.heads = nn.Identity() # Remove classification head
|
30 |
+
|
31 |
+
self.fc = nn.Linear(768, 128) # Reduce to 128-d embedding
|
32 |
+
|
33 |
+
def forward(self, frames):
|
34 |
+
|
35 |
+
B,num_frames,H,W,C = frames.shape # (Batch,num_frames, H, W, C)
|
36 |
+
|
37 |
+
# Flatten frames into batch dimension for ViT
|
38 |
+
frames = frames.permute(0,1,4,2,3).reshape(B * num_frames, C,H,W)
|
39 |
+
|
40 |
+
# Extract frame-level embeddings
|
41 |
+
emb = self.encoder(frames)
|
42 |
+
|
43 |
+
# Reshape back to (B, T, 768) and average over T
|
44 |
+
#TODO: Change this to use LSTM instead of averaging
|
45 |
+
emb = emb.reshape(B, num_frames, -1).mean(dim=1) # (B, 768)
|
46 |
+
|
47 |
+
# Pass through fully connected layer
|
48 |
+
emb = self.fc(emb)
|
49 |
+
|
50 |
+
return emb
|
51 |
+
|
52 |
def timer(func):
|
53 |
def wrapper(*args, **kwargs):
|
54 |
start = time.time()
|
|
|
168 |
cosine_sim = F.cosine_similarity(emb1, emb2)
|
169 |
counts = torch.count_nonzero(cosine_sim > threshold).numpy()
|
170 |
return (cosine_sim.mean(), counts)
|
171 |
+
|