|
import json |
|
from functools import lru_cache |
|
import youtube_transcript_api2 |
|
import json |
|
import re |
|
import requests |
|
from transformers import ( |
|
AutoModelForSequenceClassification, |
|
AutoTokenizer, |
|
TextClassificationPipeline, |
|
) |
|
from typing import Any, Dict, List |
|
import os |
|
import numpy as np |
|
|
|
CATEGORIES = [None, 'SPONSOR', 'SELFPROMO', 'INTERACTION'] |
|
|
|
PROFANITY_RAW = '[ __ ]' |
|
PROFANITY_CONVERTED = '*****' |
|
|
|
NUM_DECIMALS = 3 |
|
|
|
|
|
|
|
LANGUAGE_PREFERENCE_LIST = ['en-GB', 'en-US', 'en-CA', 'en-AU', 'en-NZ', 'en-ZA', |
|
'en-IE', 'en-IN', 'en-JM', 'en-BZ', 'en-TT', 'en-PH', 'en-ZW', |
|
'en'] |
|
|
|
|
|
def parse_transcript_json(json_data, granularity): |
|
assert json_data['wireMagic'] == 'pb3' |
|
|
|
assert granularity in ('word', 'chunk') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parsed_transcript = [] |
|
|
|
events = json_data['events'] |
|
|
|
for event_index, event in enumerate(events): |
|
segments = event.get('segs') |
|
if not segments: |
|
continue |
|
|
|
|
|
start_ms = event['tStartMs'] |
|
total_characters = 0 |
|
|
|
new_segments = [] |
|
for seg in segments: |
|
|
|
text = ' '.join(seg['utf8'].split()) |
|
|
|
|
|
text = text.replace('\u200b', '').replace('\u200c', '').replace( |
|
'\u200d', '').replace('\ufeff', '').strip() |
|
|
|
|
|
|
|
|
|
|
|
text = text.replace(PROFANITY_RAW, PROFANITY_CONVERTED) |
|
|
|
if not text: |
|
continue |
|
|
|
offset_ms = seg.get('tOffsetMs', 0) |
|
|
|
new_segments.append({ |
|
'text': text, |
|
'start': round((start_ms + offset_ms)/1000, NUM_DECIMALS) |
|
}) |
|
|
|
total_characters += len(text) |
|
|
|
if not new_segments: |
|
continue |
|
|
|
if event_index < len(events) - 1: |
|
next_start_ms = events[event_index + 1]['tStartMs'] |
|
total_event_duration_ms = min( |
|
event.get('dDurationMs', float('inf')), next_start_ms - start_ms) |
|
else: |
|
total_event_duration_ms = event.get('dDurationMs', 0) |
|
|
|
|
|
total_event_duration_ms = max(total_event_duration_ms, 0) |
|
|
|
avg_seconds_per_character = ( |
|
total_event_duration_ms/total_characters)/1000 |
|
|
|
num_char_count = 0 |
|
for seg_index, seg in enumerate(new_segments): |
|
num_char_count += len(seg['text']) |
|
|
|
|
|
seg_end = seg['start'] + \ |
|
(num_char_count * avg_seconds_per_character) |
|
|
|
if seg_index < len(new_segments) - 1: |
|
|
|
seg_end = min(seg_end, new_segments[seg_index+1]['start']) |
|
|
|
seg['end'] = round(seg_end, NUM_DECIMALS) |
|
parsed_transcript.append(seg) |
|
|
|
final_parsed_transcript = [] |
|
for i in range(len(parsed_transcript)): |
|
|
|
word_level = granularity == 'word' |
|
if word_level: |
|
split_text = parsed_transcript[i]['text'].split() |
|
elif granularity == 'chunk': |
|
|
|
split_text = re.split( |
|
r'(?<=[.!?,-;])\s+', parsed_transcript[i]['text']) |
|
if len(split_text) == 1: |
|
split_on_whitespace = parsed_transcript[i]['text'].split() |
|
|
|
if len(split_on_whitespace) >= 8: |
|
|
|
split_text = split_on_whitespace |
|
else: |
|
word_level = True |
|
else: |
|
raise ValueError('Unknown granularity') |
|
|
|
segment_end = parsed_transcript[i]['end'] |
|
if i < len(parsed_transcript) - 1: |
|
segment_end = min(segment_end, parsed_transcript[i+1]['start']) |
|
|
|
segment_duration = segment_end - parsed_transcript[i]['start'] |
|
|
|
num_chars_in_text = sum(map(len, split_text)) |
|
|
|
num_char_count = 0 |
|
current_offset = 0 |
|
for s in split_text: |
|
num_char_count += len(s) |
|
|
|
next_offset = (num_char_count/num_chars_in_text) * segment_duration |
|
|
|
word_start = round( |
|
parsed_transcript[i]['start'] + current_offset, NUM_DECIMALS) |
|
word_end = round( |
|
parsed_transcript[i]['start'] + next_offset, NUM_DECIMALS) |
|
|
|
|
|
final_parsed_transcript.append({ |
|
'text': s, |
|
'start': word_start, |
|
'end': min(word_end, word_start + 1.5) if word_level else word_end |
|
}) |
|
current_offset = next_offset |
|
|
|
return final_parsed_transcript |
|
|
|
|
|
def list_transcripts(video_id): |
|
try: |
|
return youtube_transcript_api2.YouTubeTranscriptApi.list_transcripts(video_id) |
|
except json.decoder.JSONDecodeError: |
|
return None |
|
|
|
|
|
WORDS_TO_REMOVE = [ |
|
'[Music]' |
|
'[Applause]' |
|
'[Laughter]' |
|
] |
|
|
|
|
|
@lru_cache(maxsize=16) |
|
def get_words(video_id, transcript_type='auto', fallback='manual', filter_words_to_remove=True, granularity='word'): |
|
"""Get parsed video transcript with caching system |
|
returns None if not processed yet and process is False |
|
""" |
|
|
|
raw_transcript_json = None |
|
try: |
|
transcript_list = list_transcripts(video_id) |
|
|
|
if transcript_list is not None: |
|
if transcript_type == 'manual': |
|
ts = transcript_list.find_manually_created_transcript( |
|
LANGUAGE_PREFERENCE_LIST) |
|
else: |
|
ts = transcript_list.find_generated_transcript( |
|
LANGUAGE_PREFERENCE_LIST) |
|
raw_transcript = ts._http_client.get( |
|
f'{ts._url}&fmt=json3').content |
|
if raw_transcript: |
|
raw_transcript_json = json.loads(raw_transcript) |
|
|
|
except (youtube_transcript_api2.TooManyRequests, youtube_transcript_api2.YouTubeRequestFailed): |
|
raise |
|
|
|
except requests.exceptions.RequestException: |
|
return get_words(video_id, transcript_type, fallback, granularity) |
|
|
|
except youtube_transcript_api2.CouldNotRetrieveTranscript: |
|
pass |
|
|
|
except json.decoder.JSONDecodeError: |
|
return get_words(video_id, transcript_type, fallback, granularity) |
|
|
|
if not raw_transcript_json and fallback is not None: |
|
return get_words(video_id, transcript_type=fallback, fallback=None, granularity=granularity) |
|
|
|
if raw_transcript_json: |
|
processed_transcript = parse_transcript_json( |
|
raw_transcript_json, granularity) |
|
if filter_words_to_remove: |
|
processed_transcript = list( |
|
filter(lambda x: x['text'] not in WORDS_TO_REMOVE, processed_transcript)) |
|
else: |
|
processed_transcript = raw_transcript_json |
|
|
|
return processed_transcript |
|
|
|
|
|
def word_start(word): |
|
return word['start'] |
|
|
|
|
|
def word_end(word): |
|
return word.get('end', word['start']) |
|
|
|
|
|
def extract_segment(words, start, end, map_function=None): |
|
"""Extracts all words with time in [start, end]""" |
|
|
|
a = max(binary_search_below(words, 0, len(words), start), 0) |
|
b = min(binary_search_above(words, -1, len(words) - 1, end) + 1, len(words)) |
|
|
|
to_transform = map_function is not None and callable(map_function) |
|
|
|
return [ |
|
map_function(words[i]) if to_transform else words[i] for i in range(a, b) |
|
] |
|
|
|
|
|
def avg(*items): |
|
return sum(items)/len(items) |
|
|
|
|
|
def binary_search_below(transcript, start_index, end_index, time): |
|
if start_index >= end_index: |
|
return end_index |
|
|
|
middle_index = (start_index + end_index) // 2 |
|
middle = transcript[middle_index] |
|
middle_time = avg(word_start(middle), word_end(middle)) |
|
|
|
if time <= middle_time: |
|
return binary_search_below(transcript, start_index, middle_index, time) |
|
else: |
|
return binary_search_below(transcript, middle_index + 1, end_index, time) |
|
|
|
|
|
def binary_search_above(transcript, start_index, end_index, time): |
|
if start_index >= end_index: |
|
return end_index |
|
|
|
middle_index = (start_index + end_index + 1) // 2 |
|
middle = transcript[middle_index] |
|
middle_time = avg(word_start(middle), word_end(middle)) |
|
|
|
if time >= middle_time: |
|
return binary_search_above(transcript, middle_index, end_index, time) |
|
else: |
|
return binary_search_above(transcript, start_index, middle_index - 1, time) |
|
|
|
|
|
class PreTrainedPipeline(): |
|
def __init__(self, path: str): |
|
path2 = os.path.join(path, 'model') |
|
self.model2 = AutoModelForSequenceClassification.from_pretrained(path2) |
|
self.tokenizer2 = AutoTokenizer.from_pretrained(path2) |
|
self.pipeline2 = SponsorBlockClassificationPipeline( |
|
model=self.model2, tokenizer=self.tokenizer2) |
|
|
|
def __call__(self, inputs: str)-> List[Dict[str, Any]]: |
|
|
|
if ' ' not in inputs and inputs.count(',') >= 2: |
|
split_info = inputs.split(',', 1) |
|
times = np.reshape(np.array(split_info[1].split(',')), (-1, 2)) |
|
data = [] |
|
for start, end in times: |
|
data.append({ |
|
'video_id': split_info[0], |
|
'start': float(start), |
|
'end': float(end) |
|
}) |
|
else: |
|
data = inputs |
|
|
|
return self.pipeline2(data) |
|
|
|
|
|
|
|
class SponsorBlockClassificationPipeline(TextClassificationPipeline): |
|
def __init__(self, model, tokenizer): |
|
super().__init__(model=model, tokenizer=tokenizer, return_all_scores=True) |
|
|
|
def preprocess(self, data, **tokenizer_kwargs): |
|
if isinstance(data, str): |
|
text = data |
|
else: |
|
words = get_words(data['video_id']) |
|
segment_words = extract_segment(words, data['start'], data['end']) |
|
text = ' '.join(x['text'] for x in segment_words) |
|
|
|
return self.tokenizer( |
|
text, return_tensors=self.framework, **tokenizer_kwargs) |
|
|
|
|
|
def postprocess(self, model_outputs, function_to_apply=None, return_all_scores=False): |
|
results = super().postprocess(model_outputs, function_to_apply, return_all_scores) |
|
|
|
for result in results: |
|
result['label_text'] = CATEGORIES[result['label']] |
|
|
|
return results |
|
|