Amith Adiraju
commited on
Commit
·
ef39fae
1
Parent(s):
d3acc32
Changed onnx to checkpoint based inference, as onnx file is corrupted.
Browse filesReplaced medoid array with right embeddings and increased PG classification threshold to 0.95.
Modified rating status message.
fixed issues with audio text summarization.
Signed-off-by: Amith Adiraju <[email protected]>
- base_frame_medoid.npz +1 -1
- video_rating_siamesev2.onnx → checkpoint__fl_batch_480_epoch_0.pt +2 -2
- model_inference.py +32 -25
- pages.py +3 -3
- utils.py +31 -0
base_frame_medoid.npz
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 772
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9dad2de4ef28891c9cf509177d3ff24beb0eea068fe3d8159b4b1050d4f55139
|
3 |
size 772
|
video_rating_siamesev2.onnx → checkpoint__fl_batch_480_epoch_0.pt
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:15a5a088b3fc06010de0b00ccbf00f0be0058c96d667c417c9a3188fcfb6e6dc
|
3 |
+
size 382240253
|
model_inference.py
CHANGED
@@ -3,20 +3,21 @@ import torch
|
|
3 |
from utils import (
|
4 |
prompt_audio_summarization,
|
5 |
timer,
|
6 |
-
cosine_sim
|
|
|
7 |
)
|
8 |
from transformers import BartForConditionalGeneration, BartTokenizer
|
9 |
import numpy as np
|
10 |
import whisper
|
11 |
from streamlit import session_state as sst
|
12 |
-
import
|
13 |
|
14 |
|
15 |
@timer
|
16 |
def get_text_from_audio(audio_tensors) -> str:
|
17 |
"""Transcribe multiple audio tensors in parallel using Whisper's batch processing."""
|
18 |
# Transcribe the in-memory audio
|
19 |
-
audio_tensors = audio_tensors
|
20 |
result = audio_transcriber_model.transcribe(audio_tensors
|
21 |
)
|
22 |
all_transcription_segments = result["text"]
|
@@ -28,17 +29,23 @@ def summarize_from_text(raw_transcription):
|
|
28 |
inputs = text_summarizer[0](prompt_audio_summarization + raw_transcription,
|
29 |
return_tensors="pt",
|
30 |
max_length=1024,
|
31 |
-
truncation=True)
|
32 |
-
|
33 |
-
|
34 |
summary_ids = text_summarizer[1].generate(**inputs,
|
35 |
max_length=150,
|
36 |
min_length=30,
|
37 |
length_penalty=2.0,
|
38 |
-
num_beams=4
|
|
|
39 |
)
|
40 |
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
@timer
|
44 |
def rate_video_frames(video_frames):
|
@@ -47,39 +54,39 @@ def rate_video_frames(video_frames):
|
|
47 |
"""
|
48 |
|
49 |
inp_frames = np.array(video_frames, dtype = np.float32).reshape(len(video_frames)//5, 5, 224,224,3)# 20,5,224,224,3
|
50 |
-
|
51 |
-
|
52 |
-
video_frame_emb = video_rating_model.run(['emb'], inputs_dict)[0]
|
53 |
|
54 |
overall_sim, count_upg = cosine_sim(emb1 = base_frame_emb,
|
55 |
-
emb2 =
|
56 |
-
threshold=0.
|
57 |
)
|
58 |
|
59 |
perc_of_upg = count_upg / (len(video_frames)//5)
|
60 |
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
65 |
|
66 |
@st.cache_resource
|
67 |
def load_models():
|
68 |
-
|
69 |
-
transcriber = whisper.load_model("base", device = sst['device'])
|
70 |
|
71 |
-
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
|
72 |
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
|
73 |
|
74 |
base_frame_emb = torch.tensor(
|
75 |
np.load('base_frame_medoid.npz')['arr'],
|
76 |
-
dtype = torch.float32
|
77 |
-
device = sst['device']
|
78 |
)
|
79 |
|
80 |
-
session =
|
81 |
-
|
82 |
-
|
|
|
|
|
83 |
|
84 |
return (
|
85 |
transcriber, (tokenizer, model), session, base_frame_emb
|
|
|
3 |
from utils import (
|
4 |
prompt_audio_summarization,
|
5 |
timer,
|
6 |
+
cosine_sim,
|
7 |
+
SiameseNetwork
|
8 |
)
|
9 |
from transformers import BartForConditionalGeneration, BartTokenizer
|
10 |
import numpy as np
|
11 |
import whisper
|
12 |
from streamlit import session_state as sst
|
13 |
+
import math
|
14 |
|
15 |
|
16 |
@timer
|
17 |
def get_text_from_audio(audio_tensors) -> str:
|
18 |
"""Transcribe multiple audio tensors in parallel using Whisper's batch processing."""
|
19 |
# Transcribe the in-memory audio
|
20 |
+
audio_tensors = audio_tensors
|
21 |
result = audio_transcriber_model.transcribe(audio_tensors
|
22 |
)
|
23 |
all_transcription_segments = result["text"]
|
|
|
29 |
inputs = text_summarizer[0](prompt_audio_summarization + raw_transcription,
|
30 |
return_tensors="pt",
|
31 |
max_length=1024,
|
32 |
+
truncation=True)
|
33 |
+
|
|
|
34 |
summary_ids = text_summarizer[1].generate(**inputs,
|
35 |
max_length=150,
|
36 |
min_length=30,
|
37 |
length_penalty=2.0,
|
38 |
+
num_beams=4,
|
39 |
+
early_stopping = True
|
40 |
)
|
41 |
|
42 |
+
prediction_string = text_summarizer[0].decode(summary_ids[0], skip_special_tokens=True)
|
43 |
+
|
44 |
+
|
45 |
+
if prompt_audio_summarization[:15] == prediction_string[:15]:
|
46 |
+
prediction_string = prediction_string[50: ]
|
47 |
+
|
48 |
+
return prediction_string
|
49 |
|
50 |
@timer
|
51 |
def rate_video_frames(video_frames):
|
|
|
54 |
"""
|
55 |
|
56 |
inp_frames = np.array(video_frames, dtype = np.float32).reshape(len(video_frames)//5, 5, 224,224,3)# 20,5,224,224,3
|
57 |
+
with torch.no_grad():
|
58 |
+
video_frame_emb = video_rating_model(torch.tensor(inp_frames) )
|
|
|
59 |
|
60 |
overall_sim, count_upg = cosine_sim(emb1 = base_frame_emb,
|
61 |
+
emb2 = video_frame_emb,
|
62 |
+
threshold=0.95
|
63 |
)
|
64 |
|
65 |
perc_of_upg = count_upg / (len(video_frames)//5)
|
66 |
|
67 |
+
perc_of_upg = math.floor(perc_of_upg*100)
|
68 |
+
non_upg_perc = math.ceil(100 - perc_of_upg)
|
69 |
+
|
70 |
+
response_string = f"{perc_of_upg} % of important moments from this video contain Under or PG content, rest of {non_upg_perc} % moments contain atleast PG-13, R or even NC-17 content."
|
71 |
+
return response_string
|
72 |
|
73 |
@st.cache_resource
|
74 |
def load_models():
|
75 |
+
transcriber = whisper.load_model("base")
|
|
|
76 |
|
77 |
+
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
|
78 |
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
|
79 |
|
80 |
base_frame_emb = torch.tensor(
|
81 |
np.load('base_frame_medoid.npz')['arr'],
|
82 |
+
dtype = torch.float32
|
|
|
83 |
)
|
84 |
|
85 |
+
session = SiameseNetwork()
|
86 |
+
checkpoint = torch.load('./checkpoint__fl_batch_480_epoch_0.pt', weights_only=True)
|
87 |
+
session.load_state_dict(checkpoint['model_state_dict'])
|
88 |
+
_ = session.eval()
|
89 |
+
|
90 |
|
91 |
return (
|
92 |
transcriber, (tokenizer, model), session, base_frame_emb
|
pages.py
CHANGED
@@ -76,7 +76,7 @@ async def model_inference_page():
|
|
76 |
try:
|
77 |
video_rating_scale = rate_video_frames(important_frames)
|
78 |
except Exception as e:
|
79 |
-
video_rating_scale =
|
80 |
|
81 |
st.toast("Done")
|
82 |
st.header("Movie Scale Rating of Your Video: ", divider = True)
|
@@ -91,7 +91,7 @@ async def model_inference_page():
|
|
91 |
try:
|
92 |
video_summary_text = get_text_from_audio(sst["audio_transcript"])
|
93 |
except Exception as e:
|
94 |
-
video_summary_text =
|
95 |
st.toast("Done")
|
96 |
|
97 |
if video_summary_text[:5] != "Sorry":
|
@@ -99,7 +99,7 @@ async def model_inference_page():
|
|
99 |
try:
|
100 |
video_summary_text = summarize_from_text(video_summary_text)
|
101 |
except Exception as e:
|
102 |
-
video_summary_text =
|
103 |
st.toast("Done")
|
104 |
|
105 |
st.header("Audio Transcript summary of your video: ", divider = True)
|
|
|
76 |
try:
|
77 |
video_rating_scale = rate_video_frames(important_frames)
|
78 |
except Exception as e:
|
79 |
+
video_rating_scale = "Sorry, we couldn't generate rating of your video because of this error "
|
80 |
|
81 |
st.toast("Done")
|
82 |
st.header("Movie Scale Rating of Your Video: ", divider = True)
|
|
|
91 |
try:
|
92 |
video_summary_text = get_text_from_audio(sst["audio_transcript"])
|
93 |
except Exception as e:
|
94 |
+
video_summary_text = "Sorry, we couldn't extract text from audio of this file because of this error"
|
95 |
st.toast("Done")
|
96 |
|
97 |
if video_summary_text[:5] != "Sorry":
|
|
|
99 |
try:
|
100 |
video_summary_text = summarize_from_text(video_summary_text)
|
101 |
except Exception as e:
|
102 |
+
video_summary_text = "Sorry, we couldn't summarize text from audio of this file"
|
103 |
st.toast("Done")
|
104 |
|
105 |
st.header("Audio Transcript summary of your video: ", divider = True)
|
utils.py
CHANGED
@@ -12,6 +12,8 @@ import time
|
|
12 |
|
13 |
from io import BytesIO
|
14 |
import torch
|
|
|
|
|
15 |
import soundfile as sf
|
16 |
import subprocess
|
17 |
from typing import List
|
@@ -19,6 +21,34 @@ from typing import List
|
|
19 |
prompt_audio_summarization = "This is a video transcript, tell me what is this about: "
|
20 |
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
def timer(func):
|
23 |
def wrapper(*args, **kwargs):
|
24 |
start = time.time()
|
@@ -138,3 +168,4 @@ def cosine_sim(emb1, emb2, threshold = 0.5):
|
|
138 |
cosine_sim = F.cosine_similarity(emb1, emb2)
|
139 |
counts = torch.count_nonzero(cosine_sim > threshold).numpy()
|
140 |
return (cosine_sim.mean(), counts)
|
|
|
|
12 |
|
13 |
from io import BytesIO
|
14 |
import torch
|
15 |
+
import torchvision.models as models
|
16 |
+
import torch.nn as nn
|
17 |
import soundfile as sf
|
18 |
import subprocess
|
19 |
from typing import List
|
|
|
21 |
prompt_audio_summarization = "This is a video transcript, tell me what is this about: "
|
22 |
|
23 |
|
24 |
+
class SiameseNetwork(nn.Module):
|
25 |
+
def __init__(self, model_name="vit_b_16"):
|
26 |
+
super(SiameseNetwork, self).__init__()
|
27 |
+
|
28 |
+
self.encoder = models.vit_b_16(weights="IMAGENET1K_V1") # Pretrained ViT
|
29 |
+
self.encoder.heads = nn.Identity() # Remove classification head
|
30 |
+
|
31 |
+
self.fc = nn.Linear(768, 128) # Reduce to 128-d embedding
|
32 |
+
|
33 |
+
def forward(self, frames):
|
34 |
+
|
35 |
+
B,num_frames,H,W,C = frames.shape # (Batch,num_frames, H, W, C)
|
36 |
+
|
37 |
+
# Flatten frames into batch dimension for ViT
|
38 |
+
frames = frames.permute(0,1,4,2,3).reshape(B * num_frames, C,H,W)
|
39 |
+
|
40 |
+
# Extract frame-level embeddings
|
41 |
+
emb = self.encoder(frames)
|
42 |
+
|
43 |
+
# Reshape back to (B, T, 768) and average over T
|
44 |
+
#TODO: Change this to use LSTM instead of averaging
|
45 |
+
emb = emb.reshape(B, num_frames, -1).mean(dim=1) # (B, 768)
|
46 |
+
|
47 |
+
# Pass through fully connected layer
|
48 |
+
emb = self.fc(emb)
|
49 |
+
|
50 |
+
return emb
|
51 |
+
|
52 |
def timer(func):
|
53 |
def wrapper(*args, **kwargs):
|
54 |
start = time.time()
|
|
|
168 |
cosine_sim = F.cosine_similarity(emb1, emb2)
|
169 |
counts = torch.count_nonzero(cosine_sim > threshold).numpy()
|
170 |
return (cosine_sim.mean(), counts)
|
171 |
+
|