|
import streamlit as st |
|
import torch |
|
from utils import ( |
|
prompt_audio_summarization, |
|
timer, |
|
cosine_sim, |
|
SiameseNetwork |
|
) |
|
from transformers import BartForConditionalGeneration, BartTokenizer |
|
import numpy as np |
|
import whisper |
|
from streamlit import session_state as sst |
|
import math |
|
|
|
|
|
@timer |
|
def get_text_from_audio(audio_tensors) -> str: |
|
"""Transcribe multiple audio tensors in parallel using Whisper's batch processing.""" |
|
|
|
audio_tensors = audio_tensors |
|
result = audio_transcriber_model.transcribe(audio_tensors |
|
) |
|
all_transcription_segments = result["text"] |
|
return all_transcription_segments |
|
|
|
@timer |
|
def summarize_from_text(raw_transcription): |
|
|
|
inputs = text_summarizer[0](prompt_audio_summarization + raw_transcription, |
|
return_tensors="pt", |
|
max_length=1024, |
|
truncation=True) |
|
|
|
summary_ids = text_summarizer[1].generate(**inputs, |
|
max_length=150, |
|
min_length=30, |
|
length_penalty=2.0, |
|
num_beams=4, |
|
early_stopping = True |
|
) |
|
|
|
prediction_string = text_summarizer[0].decode(summary_ids[0], skip_special_tokens=True) |
|
|
|
|
|
if prompt_audio_summarization[:15] == prediction_string[:15]: |
|
prediction_string = prediction_string[50: ] |
|
|
|
return prediction_string |
|
|
|
@timer |
|
def rate_video_frames(video_frames): |
|
""" |
|
Classifies video frames into another category. |
|
""" |
|
|
|
inp_frames = np.array(video_frames, dtype = np.float32).reshape(len(video_frames)//5, 5, 224,224,3) |
|
with torch.no_grad(): |
|
video_frame_emb = video_rating_model(torch.tensor(inp_frames) ) |
|
|
|
overall_sim, count_upg = cosine_sim(emb1 = base_frame_emb, |
|
emb2 = video_frame_emb, |
|
threshold=0.95 |
|
) |
|
|
|
perc_of_upg = count_upg / (len(video_frames)//5) |
|
|
|
perc_of_upg = math.floor(perc_of_upg*100) |
|
non_upg_perc = math.ceil(100 - perc_of_upg) |
|
|
|
response_string = f"{perc_of_upg} % of important moments from this video contain Under or PG content, rest of {non_upg_perc} % moments contain atleast PG-13, R or even NC-17 content." |
|
return response_string |
|
|
|
@st.cache_resource |
|
def load_models(): |
|
transcriber = whisper.load_model("base") |
|
|
|
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") |
|
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") |
|
|
|
base_frame_emb = torch.tensor( |
|
np.load('base_frame_medoid.npz')['arr'], |
|
dtype = torch.float32 |
|
) |
|
|
|
session = SiameseNetwork() |
|
checkpoint = torch.load('./checkpoint__fl_batch_480_epoch_0.pt', weights_only=True) |
|
session.load_state_dict(checkpoint['model_state_dict']) |
|
_ = session.eval() |
|
|
|
|
|
return ( |
|
transcriber, (tokenizer, model), session, base_frame_emb |
|
) |
|
|
|
audio_transcriber_model, text_summarizer, video_rating_model,base_frame_emb = load_models() |