Amith Adiraju commited on
Commit
ef39fae
·
1 Parent(s): d3acc32

Changed onnx to checkpoint based inference, as onnx file is corrupted.

Browse files

Replaced medoid array with right embeddings and increased PG classification threshold to 0.95.
Modified rating status message.
fixed issues with audio text summarization.

Signed-off-by: Amith Adiraju <[email protected]>

base_frame_medoid.npz CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7ccfba511e72cddd47db59a02147bfb03a8216b6bdfa4129d7c98b604cd048f7
3
  size 772
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9dad2de4ef28891c9cf509177d3ff24beb0eea068fe3d8159b4b1050d4f55139
3
  size 772
video_rating_siamesev2.onnx → checkpoint__fl_batch_480_epoch_0.pt RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3f073fb71728915d53cde75d316c3196c07bda3fe79af5acab6596bc397146b6
3
- size 344064697
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:15a5a088b3fc06010de0b00ccbf00f0be0058c96d667c417c9a3188fcfb6e6dc
3
+ size 382240253
model_inference.py CHANGED
@@ -3,20 +3,21 @@ import torch
3
  from utils import (
4
  prompt_audio_summarization,
5
  timer,
6
- cosine_sim
 
7
  )
8
  from transformers import BartForConditionalGeneration, BartTokenizer
9
  import numpy as np
10
  import whisper
11
  from streamlit import session_state as sst
12
- import onnxruntime
13
 
14
 
15
  @timer
16
  def get_text_from_audio(audio_tensors) -> str:
17
  """Transcribe multiple audio tensors in parallel using Whisper's batch processing."""
18
  # Transcribe the in-memory audio
19
- audio_tensors = audio_tensors.to(sst['device'])
20
  result = audio_transcriber_model.transcribe(audio_tensors
21
  )
22
  all_transcription_segments = result["text"]
@@ -28,17 +29,23 @@ def summarize_from_text(raw_transcription):
28
  inputs = text_summarizer[0](prompt_audio_summarization + raw_transcription,
29
  return_tensors="pt",
30
  max_length=1024,
31
- truncation=True)\
32
- .to(sst['device'])
33
-
34
  summary_ids = text_summarizer[1].generate(**inputs,
35
  max_length=150,
36
  min_length=30,
37
  length_penalty=2.0,
38
- num_beams=4
 
39
  )
40
 
41
- return text_summarizer[0].decode(summary_ids[0], skip_special_tokens=True)
 
 
 
 
 
 
42
 
43
  @timer
44
  def rate_video_frames(video_frames):
@@ -47,39 +54,39 @@ def rate_video_frames(video_frames):
47
  """
48
 
49
  inp_frames = np.array(video_frames, dtype = np.float32).reshape(len(video_frames)//5, 5, 224,224,3)# 20,5,224,224,3
50
- inputs_dict = {"frames": inp_frames}
51
-
52
- video_frame_emb = video_rating_model.run(['emb'], inputs_dict)[0]
53
 
54
  overall_sim, count_upg = cosine_sim(emb1 = base_frame_emb,
55
- emb2 = torch.tensor(video_frame_emb),
56
- threshold=0.4
57
  )
58
 
59
  perc_of_upg = count_upg / (len(video_frames)//5)
60
 
61
- if perc_of_upg > 0.4:
62
- return f"Out of {len(video_frames)} important moments of this video, {count_upg*5} moments contain under or at least PG content. Hence this video is suitable for kids & family."
63
- else:
64
- return f"Out of {len(video_frames)} important moments of this video, {(len(video_frames)//5 - count_upg)*5} moments contain at least PG-13 content.Hence parental guidance is strongly suggested for this video."
 
65
 
66
  @st.cache_resource
67
  def load_models():
68
- sst['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'
69
- transcriber = whisper.load_model("base", device = sst['device'])
70
 
71
- model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn").to(sst['device'])
72
  tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
73
 
74
  base_frame_emb = torch.tensor(
75
  np.load('base_frame_medoid.npz')['arr'],
76
- dtype = torch.float32,
77
- device = sst['device']
78
  )
79
 
80
- session = onnxruntime.InferenceSession("video_rating_siamesev2.onnx",
81
- providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
82
- )
 
 
83
 
84
  return (
85
  transcriber, (tokenizer, model), session, base_frame_emb
 
3
  from utils import (
4
  prompt_audio_summarization,
5
  timer,
6
+ cosine_sim,
7
+ SiameseNetwork
8
  )
9
  from transformers import BartForConditionalGeneration, BartTokenizer
10
  import numpy as np
11
  import whisper
12
  from streamlit import session_state as sst
13
+ import math
14
 
15
 
16
  @timer
17
  def get_text_from_audio(audio_tensors) -> str:
18
  """Transcribe multiple audio tensors in parallel using Whisper's batch processing."""
19
  # Transcribe the in-memory audio
20
+ audio_tensors = audio_tensors
21
  result = audio_transcriber_model.transcribe(audio_tensors
22
  )
23
  all_transcription_segments = result["text"]
 
29
  inputs = text_summarizer[0](prompt_audio_summarization + raw_transcription,
30
  return_tensors="pt",
31
  max_length=1024,
32
+ truncation=True)
33
+
 
34
  summary_ids = text_summarizer[1].generate(**inputs,
35
  max_length=150,
36
  min_length=30,
37
  length_penalty=2.0,
38
+ num_beams=4,
39
+ early_stopping = True
40
  )
41
 
42
+ prediction_string = text_summarizer[0].decode(summary_ids[0], skip_special_tokens=True)
43
+
44
+
45
+ if prompt_audio_summarization[:15] == prediction_string[:15]:
46
+ prediction_string = prediction_string[50: ]
47
+
48
+ return prediction_string
49
 
50
  @timer
51
  def rate_video_frames(video_frames):
 
54
  """
55
 
56
  inp_frames = np.array(video_frames, dtype = np.float32).reshape(len(video_frames)//5, 5, 224,224,3)# 20,5,224,224,3
57
+ with torch.no_grad():
58
+ video_frame_emb = video_rating_model(torch.tensor(inp_frames) )
 
59
 
60
  overall_sim, count_upg = cosine_sim(emb1 = base_frame_emb,
61
+ emb2 = video_frame_emb,
62
+ threshold=0.95
63
  )
64
 
65
  perc_of_upg = count_upg / (len(video_frames)//5)
66
 
67
+ perc_of_upg = math.floor(perc_of_upg*100)
68
+ non_upg_perc = math.ceil(100 - perc_of_upg)
69
+
70
+ response_string = f"{perc_of_upg} % of important moments from this video contain Under or PG content, rest of {non_upg_perc} % moments contain atleast PG-13, R or even NC-17 content."
71
+ return response_string
72
 
73
  @st.cache_resource
74
  def load_models():
75
+ transcriber = whisper.load_model("base")
 
76
 
77
+ model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
78
  tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
79
 
80
  base_frame_emb = torch.tensor(
81
  np.load('base_frame_medoid.npz')['arr'],
82
+ dtype = torch.float32
 
83
  )
84
 
85
+ session = SiameseNetwork()
86
+ checkpoint = torch.load('./checkpoint__fl_batch_480_epoch_0.pt', weights_only=True)
87
+ session.load_state_dict(checkpoint['model_state_dict'])
88
+ _ = session.eval()
89
+
90
 
91
  return (
92
  transcriber, (tokenizer, model), session, base_frame_emb
pages.py CHANGED
@@ -76,7 +76,7 @@ async def model_inference_page():
76
  try:
77
  video_rating_scale = rate_video_frames(important_frames)
78
  except Exception as e:
79
- video_rating_scale = f"Sorry, we couldn't generate rating of your video because of this error: {e} "
80
 
81
  st.toast("Done")
82
  st.header("Movie Scale Rating of Your Video: ", divider = True)
@@ -91,7 +91,7 @@ async def model_inference_page():
91
  try:
92
  video_summary_text = get_text_from_audio(sst["audio_transcript"])
93
  except Exception as e:
94
- video_summary_text = f"Sorry, we couldn't extract text from audio of this file because of this error: {e} "
95
  st.toast("Done")
96
 
97
  if video_summary_text[:5] != "Sorry":
@@ -99,7 +99,7 @@ async def model_inference_page():
99
  try:
100
  video_summary_text = summarize_from_text(video_summary_text)
101
  except Exception as e:
102
- video_summary_text = f"Sorry, we couldn't summarize text from audio of this file because of this error: {e} "
103
  st.toast("Done")
104
 
105
  st.header("Audio Transcript summary of your video: ", divider = True)
 
76
  try:
77
  video_rating_scale = rate_video_frames(important_frames)
78
  except Exception as e:
79
+ video_rating_scale = "Sorry, we couldn't generate rating of your video because of this error "
80
 
81
  st.toast("Done")
82
  st.header("Movie Scale Rating of Your Video: ", divider = True)
 
91
  try:
92
  video_summary_text = get_text_from_audio(sst["audio_transcript"])
93
  except Exception as e:
94
+ video_summary_text = "Sorry, we couldn't extract text from audio of this file because of this error"
95
  st.toast("Done")
96
 
97
  if video_summary_text[:5] != "Sorry":
 
99
  try:
100
  video_summary_text = summarize_from_text(video_summary_text)
101
  except Exception as e:
102
+ video_summary_text = "Sorry, we couldn't summarize text from audio of this file"
103
  st.toast("Done")
104
 
105
  st.header("Audio Transcript summary of your video: ", divider = True)
utils.py CHANGED
@@ -12,6 +12,8 @@ import time
12
 
13
  from io import BytesIO
14
  import torch
 
 
15
  import soundfile as sf
16
  import subprocess
17
  from typing import List
@@ -19,6 +21,34 @@ from typing import List
19
  prompt_audio_summarization = "This is a video transcript, tell me what is this about: "
20
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def timer(func):
23
  def wrapper(*args, **kwargs):
24
  start = time.time()
@@ -138,3 +168,4 @@ def cosine_sim(emb1, emb2, threshold = 0.5):
138
  cosine_sim = F.cosine_similarity(emb1, emb2)
139
  counts = torch.count_nonzero(cosine_sim > threshold).numpy()
140
  return (cosine_sim.mean(), counts)
 
 
12
 
13
  from io import BytesIO
14
  import torch
15
+ import torchvision.models as models
16
+ import torch.nn as nn
17
  import soundfile as sf
18
  import subprocess
19
  from typing import List
 
21
  prompt_audio_summarization = "This is a video transcript, tell me what is this about: "
22
 
23
 
24
+ class SiameseNetwork(nn.Module):
25
+ def __init__(self, model_name="vit_b_16"):
26
+ super(SiameseNetwork, self).__init__()
27
+
28
+ self.encoder = models.vit_b_16(weights="IMAGENET1K_V1") # Pretrained ViT
29
+ self.encoder.heads = nn.Identity() # Remove classification head
30
+
31
+ self.fc = nn.Linear(768, 128) # Reduce to 128-d embedding
32
+
33
+ def forward(self, frames):
34
+
35
+ B,num_frames,H,W,C = frames.shape # (Batch,num_frames, H, W, C)
36
+
37
+ # Flatten frames into batch dimension for ViT
38
+ frames = frames.permute(0,1,4,2,3).reshape(B * num_frames, C,H,W)
39
+
40
+ # Extract frame-level embeddings
41
+ emb = self.encoder(frames)
42
+
43
+ # Reshape back to (B, T, 768) and average over T
44
+ #TODO: Change this to use LSTM instead of averaging
45
+ emb = emb.reshape(B, num_frames, -1).mean(dim=1) # (B, 768)
46
+
47
+ # Pass through fully connected layer
48
+ emb = self.fc(emb)
49
+
50
+ return emb
51
+
52
  def timer(func):
53
  def wrapper(*args, **kwargs):
54
  start = time.time()
 
168
  cosine_sim = F.cosine_similarity(emb1, emb2)
169
  counts = torch.count_nonzero(cosine_sim > threshold).numpy()
170
  return (cosine_sim.mean(), counts)
171
+