import streamlit as st import torch from utils import ( prompt_audio_summarization, timer, cosine_sim, SiameseNetwork ) from transformers import BartForConditionalGeneration, BartTokenizer import numpy as np import whisper from streamlit import session_state as sst import math @timer def get_text_from_audio(audio_tensors) -> str: """Transcribe multiple audio tensors in parallel using Whisper's batch processing.""" # Transcribe the in-memory audio audio_tensors = audio_tensors result = audio_transcriber_model.transcribe(audio_tensors ) all_transcription_segments = result["text"] return all_transcription_segments @timer def summarize_from_text(raw_transcription): inputs = text_summarizer[0](prompt_audio_summarization + raw_transcription, return_tensors="pt", max_length=1024, truncation=True) summary_ids = text_summarizer[1].generate(**inputs, max_length=150, min_length=30, length_penalty=2.0, num_beams=4, early_stopping = True ) prediction_string = text_summarizer[0].decode(summary_ids[0], skip_special_tokens=True) if prompt_audio_summarization[:15] == prediction_string[:15]: prediction_string = prediction_string[50: ] return prediction_string @timer def rate_video_frames(video_frames): """ Classifies video frames into another category. """ inp_frames = np.array(video_frames, dtype = np.float32).reshape(len(video_frames)//5, 5, 224,224,3)# 20,5,224,224,3 with torch.no_grad(): video_frame_emb = video_rating_model(torch.tensor(inp_frames) ) overall_sim, count_upg = cosine_sim(emb1 = base_frame_emb, emb2 = video_frame_emb, threshold=0.95 ) perc_of_upg = count_upg / (len(video_frames)//5) perc_of_upg = math.floor(perc_of_upg*100) non_upg_perc = math.ceil(100 - perc_of_upg) response_string = f"{perc_of_upg} % of important moments from this video contain Under or PG content, rest of {non_upg_perc} % moments contain atleast PG-13, R or even NC-17 content." return response_string @st.cache_resource def load_models(): transcriber = whisper.load_model("base") model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") base_frame_emb = torch.tensor( np.load('base_frame_medoid.npz')['arr'], dtype = torch.float32 ) session = SiameseNetwork() checkpoint = torch.load('./checkpoint__fl_batch_480_epoch_0.pt', weights_only=True) session.load_state_dict(checkpoint['model_state_dict']) _ = session.eval() return ( transcriber, (tokenizer, model), session, base_frame_emb ) audio_transcriber_model, text_summarizer, video_rating_model,base_frame_emb = load_models()