change_medoid_inference

#3
base_frame_medoid.npz CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7ccfba511e72cddd47db59a02147bfb03a8216b6bdfa4129d7c98b604cd048f7
3
  size 772
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9dad2de4ef28891c9cf509177d3ff24beb0eea068fe3d8159b4b1050d4f55139
3
  size 772
video_rating_siamesev2.onnx β†’ checkpoint__fl_batch_480_epoch_0.pt RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3f073fb71728915d53cde75d316c3196c07bda3fe79af5acab6596bc397146b6
3
- size 344064697
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:15a5a088b3fc06010de0b00ccbf00f0be0058c96d667c417c9a3188fcfb6e6dc
3
+ size 382240253
model_inference.py CHANGED
@@ -3,20 +3,21 @@ import torch
3
  from utils import (
4
  prompt_audio_summarization,
5
  timer,
6
- cosine_sim
 
7
  )
8
  from transformers import BartForConditionalGeneration, BartTokenizer
9
  import numpy as np
10
  import whisper
11
  from streamlit import session_state as sst
12
- import onnxruntime
13
 
14
 
15
  @timer
16
  def get_text_from_audio(audio_tensors) -> str:
17
  """Transcribe multiple audio tensors in parallel using Whisper's batch processing."""
18
  # Transcribe the in-memory audio
19
- audio_tensors = audio_tensors.to(sst['device'])
20
  result = audio_transcriber_model.transcribe(audio_tensors
21
  )
22
  all_transcription_segments = result["text"]
@@ -28,17 +29,23 @@ def summarize_from_text(raw_transcription):
28
  inputs = text_summarizer[0](prompt_audio_summarization + raw_transcription,
29
  return_tensors="pt",
30
  max_length=1024,
31
- truncation=True)\
32
- .to(sst['device'])
33
-
34
  summary_ids = text_summarizer[1].generate(**inputs,
35
  max_length=150,
36
  min_length=30,
37
  length_penalty=2.0,
38
- num_beams=4
 
39
  )
40
 
41
- return text_summarizer[0].decode(summary_ids[0], skip_special_tokens=True)
 
 
 
 
 
 
42
 
43
  @timer
44
  def rate_video_frames(video_frames):
@@ -47,39 +54,39 @@ def rate_video_frames(video_frames):
47
  """
48
 
49
  inp_frames = np.array(video_frames, dtype = np.float32).reshape(len(video_frames)//5, 5, 224,224,3)# 20,5,224,224,3
50
- inputs_dict = {"frames": inp_frames}
51
-
52
- video_frame_emb = video_rating_model.run(['emb'], inputs_dict)[0]
53
 
54
  overall_sim, count_upg = cosine_sim(emb1 = base_frame_emb,
55
- emb2 = torch.tensor(video_frame_emb),
56
- threshold=0.4
57
  )
58
 
59
  perc_of_upg = count_upg / (len(video_frames)//5)
60
 
61
- if perc_of_upg > 0.4:
62
- return f"Out of {len(video_frames)} important moments of this video, {count_upg*5} moments contain under or at least PG content. Hence this video is suitable for kids & family."
63
- else:
64
- return f"Out of {len(video_frames)} important moments of this video, {(len(video_frames)//5 - count_upg)*5} moments contain at least PG-13 content.Hence parental guidance is strongly suggested for this video."
 
65
 
66
  @st.cache_resource
67
  def load_models():
68
- sst['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'
69
- transcriber = whisper.load_model("base", device = sst['device'])
70
 
71
- model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn").to(sst['device'])
72
  tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
73
 
74
  base_frame_emb = torch.tensor(
75
  np.load('base_frame_medoid.npz')['arr'],
76
- dtype = torch.float32,
77
- device = sst['device']
78
  )
79
 
80
- session = onnxruntime.InferenceSession("video_rating_siamesev2.onnx",
81
- providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
82
- )
 
 
83
 
84
  return (
85
  transcriber, (tokenizer, model), session, base_frame_emb
 
3
  from utils import (
4
  prompt_audio_summarization,
5
  timer,
6
+ cosine_sim,
7
+ SiameseNetwork
8
  )
9
  from transformers import BartForConditionalGeneration, BartTokenizer
10
  import numpy as np
11
  import whisper
12
  from streamlit import session_state as sst
13
+ import math
14
 
15
 
16
  @timer
17
  def get_text_from_audio(audio_tensors) -> str:
18
  """Transcribe multiple audio tensors in parallel using Whisper's batch processing."""
19
  # Transcribe the in-memory audio
20
+ audio_tensors = audio_tensors
21
  result = audio_transcriber_model.transcribe(audio_tensors
22
  )
23
  all_transcription_segments = result["text"]
 
29
  inputs = text_summarizer[0](prompt_audio_summarization + raw_transcription,
30
  return_tensors="pt",
31
  max_length=1024,
32
+ truncation=True)
33
+
 
34
  summary_ids = text_summarizer[1].generate(**inputs,
35
  max_length=150,
36
  min_length=30,
37
  length_penalty=2.0,
38
+ num_beams=4,
39
+ early_stopping = True
40
  )
41
 
42
+ prediction_string = text_summarizer[0].decode(summary_ids[0], skip_special_tokens=True)
43
+
44
+
45
+ if prompt_audio_summarization[:15] == prediction_string[:15]:
46
+ prediction_string = prediction_string[50: ]
47
+
48
+ return prediction_string
49
 
50
  @timer
51
  def rate_video_frames(video_frames):
 
54
  """
55
 
56
  inp_frames = np.array(video_frames, dtype = np.float32).reshape(len(video_frames)//5, 5, 224,224,3)# 20,5,224,224,3
57
+ with torch.no_grad():
58
+ video_frame_emb = video_rating_model(torch.tensor(inp_frames) )
 
59
 
60
  overall_sim, count_upg = cosine_sim(emb1 = base_frame_emb,
61
+ emb2 = video_frame_emb,
62
+ threshold=0.95
63
  )
64
 
65
  perc_of_upg = count_upg / (len(video_frames)//5)
66
 
67
+ perc_of_upg = math.floor(perc_of_upg*100)
68
+ non_upg_perc = math.ceil(100 - perc_of_upg)
69
+
70
+ response_string = f"{perc_of_upg} % of important moments from this video contain Under or PG content, rest of {non_upg_perc} % moments contain atleast PG-13, R or even NC-17 content."
71
+ return response_string
72
 
73
  @st.cache_resource
74
  def load_models():
75
+ transcriber = whisper.load_model("base")
 
76
 
77
+ model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
78
  tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
79
 
80
  base_frame_emb = torch.tensor(
81
  np.load('base_frame_medoid.npz')['arr'],
82
+ dtype = torch.float32
 
83
  )
84
 
85
+ session = SiameseNetwork()
86
+ checkpoint = torch.load('./checkpoint__fl_batch_480_epoch_0.pt', weights_only=True)
87
+ session.load_state_dict(checkpoint['model_state_dict'])
88
+ _ = session.eval()
89
+
90
 
91
  return (
92
  transcriber, (tokenizer, model), session, base_frame_emb
pages.py CHANGED
@@ -76,7 +76,7 @@ async def model_inference_page():
76
  try:
77
  video_rating_scale = rate_video_frames(important_frames)
78
  except Exception as e:
79
- video_rating_scale = f"Sorry, we couldn't generate rating of your video because of this error: {e} "
80
 
81
  st.toast("Done")
82
  st.header("Movie Scale Rating of Your Video: ", divider = True)
@@ -91,7 +91,7 @@ async def model_inference_page():
91
  try:
92
  video_summary_text = get_text_from_audio(sst["audio_transcript"])
93
  except Exception as e:
94
- video_summary_text = f"Sorry, we couldn't extract text from audio of this file because of this error: {e} "
95
  st.toast("Done")
96
 
97
  if video_summary_text[:5] != "Sorry":
@@ -99,7 +99,7 @@ async def model_inference_page():
99
  try:
100
  video_summary_text = summarize_from_text(video_summary_text)
101
  except Exception as e:
102
- video_summary_text = f"Sorry, we couldn't summarize text from audio of this file because of this error: {e} "
103
  st.toast("Done")
104
 
105
  st.header("Audio Transcript summary of your video: ", divider = True)
 
76
  try:
77
  video_rating_scale = rate_video_frames(important_frames)
78
  except Exception as e:
79
+ video_rating_scale = "Sorry, we couldn't generate rating of your video because of this error "
80
 
81
  st.toast("Done")
82
  st.header("Movie Scale Rating of Your Video: ", divider = True)
 
91
  try:
92
  video_summary_text = get_text_from_audio(sst["audio_transcript"])
93
  except Exception as e:
94
+ video_summary_text = "Sorry, we couldn't extract text from audio of this file because of this error"
95
  st.toast("Done")
96
 
97
  if video_summary_text[:5] != "Sorry":
 
99
  try:
100
  video_summary_text = summarize_from_text(video_summary_text)
101
  except Exception as e:
102
+ video_summary_text = "Sorry, we couldn't summarize text from audio of this file"
103
  st.toast("Done")
104
 
105
  st.header("Audio Transcript summary of your video: ", divider = True)
utils.py CHANGED
@@ -12,6 +12,8 @@ import time
12
 
13
  from io import BytesIO
14
  import torch
 
 
15
  import soundfile as sf
16
  import subprocess
17
  from typing import List
@@ -19,6 +21,34 @@ from typing import List
19
  prompt_audio_summarization = "This is a video transcript, tell me what is this about: "
20
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def timer(func):
23
  def wrapper(*args, **kwargs):
24
  start = time.time()
@@ -138,3 +168,4 @@ def cosine_sim(emb1, emb2, threshold = 0.5):
138
  cosine_sim = F.cosine_similarity(emb1, emb2)
139
  counts = torch.count_nonzero(cosine_sim > threshold).numpy()
140
  return (cosine_sim.mean(), counts)
 
 
12
 
13
  from io import BytesIO
14
  import torch
15
+ import torchvision.models as models
16
+ import torch.nn as nn
17
  import soundfile as sf
18
  import subprocess
19
  from typing import List
 
21
  prompt_audio_summarization = "This is a video transcript, tell me what is this about: "
22
 
23
 
24
+ class SiameseNetwork(nn.Module):
25
+ def __init__(self, model_name="vit_b_16"):
26
+ super(SiameseNetwork, self).__init__()
27
+
28
+ self.encoder = models.vit_b_16(weights="IMAGENET1K_V1") # Pretrained ViT
29
+ self.encoder.heads = nn.Identity() # Remove classification head
30
+
31
+ self.fc = nn.Linear(768, 128) # Reduce to 128-d embedding
32
+
33
+ def forward(self, frames):
34
+
35
+ B,num_frames,H,W,C = frames.shape # (Batch,num_frames, H, W, C)
36
+
37
+ # Flatten frames into batch dimension for ViT
38
+ frames = frames.permute(0,1,4,2,3).reshape(B * num_frames, C,H,W)
39
+
40
+ # Extract frame-level embeddings
41
+ emb = self.encoder(frames)
42
+
43
+ # Reshape back to (B, T, 768) and average over T
44
+ #TODO: Change this to use LSTM instead of averaging
45
+ emb = emb.reshape(B, num_frames, -1).mean(dim=1) # (B, 768)
46
+
47
+ # Pass through fully connected layer
48
+ emb = self.fc(emb)
49
+
50
+ return emb
51
+
52
  def timer(func):
53
  def wrapper(*args, **kwargs):
54
  start = time.time()
 
168
  cosine_sim = F.cosine_similarity(emb1, emb2)
169
  counts = torch.count_nonzero(cosine_sim > threshold).numpy()
170
  return (cosine_sim.mean(), counts)
171
+