Video_Summary_Beta / model_inference.py
Amith Adiraju
Changed onnx to checkpoint based inference, as onnx file is corrupted.
ef39fae
raw
history blame
3.48 kB
import streamlit as st
import torch
from utils import (
prompt_audio_summarization,
timer,
cosine_sim,
SiameseNetwork
)
from transformers import BartForConditionalGeneration, BartTokenizer
import numpy as np
import whisper
from streamlit import session_state as sst
import math
@timer
def get_text_from_audio(audio_tensors) -> str:
"""Transcribe multiple audio tensors in parallel using Whisper's batch processing."""
# Transcribe the in-memory audio
audio_tensors = audio_tensors
result = audio_transcriber_model.transcribe(audio_tensors
)
all_transcription_segments = result["text"]
return all_transcription_segments
@timer
def summarize_from_text(raw_transcription):
inputs = text_summarizer[0](prompt_audio_summarization + raw_transcription,
return_tensors="pt",
max_length=1024,
truncation=True)
summary_ids = text_summarizer[1].generate(**inputs,
max_length=150,
min_length=30,
length_penalty=2.0,
num_beams=4,
early_stopping = True
)
prediction_string = text_summarizer[0].decode(summary_ids[0], skip_special_tokens=True)
if prompt_audio_summarization[:15] == prediction_string[:15]:
prediction_string = prediction_string[50: ]
return prediction_string
@timer
def rate_video_frames(video_frames):
"""
Classifies video frames into another category.
"""
inp_frames = np.array(video_frames, dtype = np.float32).reshape(len(video_frames)//5, 5, 224,224,3)# 20,5,224,224,3
with torch.no_grad():
video_frame_emb = video_rating_model(torch.tensor(inp_frames) )
overall_sim, count_upg = cosine_sim(emb1 = base_frame_emb,
emb2 = video_frame_emb,
threshold=0.95
)
perc_of_upg = count_upg / (len(video_frames)//5)
perc_of_upg = math.floor(perc_of_upg*100)
non_upg_perc = math.ceil(100 - perc_of_upg)
response_string = f"{perc_of_upg} % of important moments from this video contain Under or PG content, rest of {non_upg_perc} % moments contain atleast PG-13, R or even NC-17 content."
return response_string
@st.cache_resource
def load_models():
transcriber = whisper.load_model("base")
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
base_frame_emb = torch.tensor(
np.load('base_frame_medoid.npz')['arr'],
dtype = torch.float32
)
session = SiameseNetwork()
checkpoint = torch.load('./checkpoint__fl_batch_480_epoch_0.pt', weights_only=True)
session.load_state_dict(checkpoint['model_state_dict'])
_ = session.eval()
return (
transcriber, (tokenizer, model), session, base_frame_emb
)
audio_transcriber_model, text_summarizer, video_rating_model,base_frame_emb = load_models()