Joshua Lochner commited on
Commit
8e622c8
·
1 Parent(s): aacb405

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +282 -0
pipeline.py CHANGED
@@ -9,6 +9,280 @@ import os
9
  import numpy as np
10
  from PIL import Image
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
 
14
  class PreTrainedPipeline():
@@ -18,6 +292,13 @@ class PreTrainedPipeline():
18
 
19
  def __call__(self, inputs: "Image.Image")-> List[Dict[str, Any]]:
20
 
 
 
 
 
 
 
 
21
  # convert img to numpy array, resize and normalize to make the prediction
22
  img = np.array(inputs)
23
 
@@ -74,5 +355,6 @@ class PreTrainedPipeline():
74
  "label": f"LABEL_{cls}",
75
  "mask": mask_codes[f"mask_{cls}"],
76
  "score": 1.0,
 
77
  })
78
  return labels
 
9
  import numpy as np
10
  from PIL import Image
11
 
12
+ import youtube_transcript_api2
13
+ import json
14
+ import re
15
+ import requests
16
+ from transformers import (
17
+ AutoModelForSequenceClassification,
18
+ AutoTokenizer,
19
+ TextClassificationPipeline,
20
+ )
21
+ from typing import Any, Dict, List
22
+
23
+ CATEGORIES = [None, 'SPONSOR', 'SELFPROMO', 'INTERACTION']
24
+
25
+ PROFANITY_RAW = '[ __ ]' # How YouTube transcribes profanity
26
+ PROFANITY_CONVERTED = '*****' # Safer version for tokenizing
27
+
28
+ NUM_DECIMALS = 3
29
+
30
+ # https://www.fincher.org/Utilities/CountryLanguageList.shtml
31
+ # https://lingohub.com/developers/supported-locales/language-designators-with-regions
32
+ LANGUAGE_PREFERENCE_LIST = ['en-GB', 'en-US', 'en-CA', 'en-AU', 'en-NZ', 'en-ZA',
33
+ 'en-IE', 'en-IN', 'en-JM', 'en-BZ', 'en-TT', 'en-PH', 'en-ZW',
34
+ 'en']
35
+
36
+
37
+ def parse_transcript_json(json_data, granularity):
38
+ assert json_data['wireMagic'] == 'pb3'
39
+
40
+ assert granularity in ('word', 'chunk')
41
+
42
+ # TODO remove bracketed words?
43
+ # (kiss smacks)
44
+ # (upbeat music)
45
+ # [text goes here]
46
+
47
+ # Some manual transcripts aren't that well formatted... but do have punctuation
48
+ # https://www.youtube.com/watch?v=LR9FtWVjk2c
49
+
50
+ parsed_transcript = []
51
+
52
+ events = json_data['events']
53
+
54
+ for event_index, event in enumerate(events):
55
+ segments = event.get('segs')
56
+ if not segments:
57
+ continue
58
+
59
+ # This value is known (when phrase appears on screen)
60
+ start_ms = event['tStartMs']
61
+ total_characters = 0
62
+
63
+ new_segments = []
64
+ for seg in segments:
65
+ # Replace \n, \t, etc. with space
66
+ text = ' '.join(seg['utf8'].split())
67
+
68
+ # Remove zero-width spaces and strip trailing and leading whitespace
69
+ text = text.replace('\u200b', '').replace('\u200c', '').replace(
70
+ '\u200d', '').replace('\ufeff', '').strip()
71
+
72
+ # Alternatively,
73
+ # text = text.encode('ascii', 'ignore').decode()
74
+
75
+ # Needed for auto-generated transcripts
76
+ text = text.replace(PROFANITY_RAW, PROFANITY_CONVERTED)
77
+
78
+ if not text:
79
+ continue
80
+
81
+ offset_ms = seg.get('tOffsetMs', 0)
82
+
83
+ new_segments.append({
84
+ 'text': text,
85
+ 'start': round((start_ms + offset_ms)/1000, NUM_DECIMALS)
86
+ })
87
+
88
+ total_characters += len(text)
89
+
90
+ if not new_segments:
91
+ continue
92
+
93
+ if event_index < len(events) - 1:
94
+ next_start_ms = events[event_index + 1]['tStartMs']
95
+ total_event_duration_ms = min(
96
+ event.get('dDurationMs', float('inf')), next_start_ms - start_ms)
97
+ else:
98
+ total_event_duration_ms = event.get('dDurationMs', 0)
99
+
100
+ # Ensure duration is non-negative
101
+ total_event_duration_ms = max(total_event_duration_ms, 0)
102
+
103
+ avg_seconds_per_character = (
104
+ total_event_duration_ms/total_characters)/1000
105
+
106
+ num_char_count = 0
107
+ for seg_index, seg in enumerate(new_segments):
108
+ num_char_count += len(seg['text'])
109
+
110
+ # Estimate segment end
111
+ seg_end = seg['start'] + \
112
+ (num_char_count * avg_seconds_per_character)
113
+
114
+ if seg_index < len(new_segments) - 1:
115
+ # Do not allow longer than next
116
+ seg_end = min(seg_end, new_segments[seg_index+1]['start'])
117
+
118
+ seg['end'] = round(seg_end, NUM_DECIMALS)
119
+ parsed_transcript.append(seg)
120
+
121
+ final_parsed_transcript = []
122
+ for i in range(len(parsed_transcript)):
123
+
124
+ word_level = granularity == 'word'
125
+ if word_level:
126
+ split_text = parsed_transcript[i]['text'].split()
127
+ elif granularity == 'chunk':
128
+ # Split on space after punctuation
129
+ split_text = re.split(
130
+ r'(?<=[.!?,-;])\s+', parsed_transcript[i]['text'])
131
+ if len(split_text) == 1:
132
+ split_on_whitespace = parsed_transcript[i]['text'].split()
133
+
134
+ if len(split_on_whitespace) >= 8: # Too many words
135
+ # Rather split on whitespace instead of punctuation
136
+ split_text = split_on_whitespace
137
+ else:
138
+ word_level = True
139
+ else:
140
+ raise ValueError('Unknown granularity')
141
+
142
+ segment_end = parsed_transcript[i]['end']
143
+ if i < len(parsed_transcript) - 1:
144
+ segment_end = min(segment_end, parsed_transcript[i+1]['start'])
145
+
146
+ segment_duration = segment_end - parsed_transcript[i]['start']
147
+
148
+ num_chars_in_text = sum(map(len, split_text))
149
+
150
+ num_char_count = 0
151
+ current_offset = 0
152
+ for s in split_text:
153
+ num_char_count += len(s)
154
+
155
+ next_offset = (num_char_count/num_chars_in_text) * segment_duration
156
+
157
+ word_start = round(
158
+ parsed_transcript[i]['start'] + current_offset, NUM_DECIMALS)
159
+ word_end = round(
160
+ parsed_transcript[i]['start'] + next_offset, NUM_DECIMALS)
161
+
162
+ # Make the reasonable assumption that min wps is 1.5
163
+ final_parsed_transcript.append({
164
+ 'text': s,
165
+ 'start': word_start,
166
+ 'end': min(word_end, word_start + 1.5) if word_level else word_end
167
+ })
168
+ current_offset = next_offset
169
+
170
+ return final_parsed_transcript
171
+
172
+
173
+ def list_transcripts(video_id):
174
+ try:
175
+ return youtube_transcript_api2.YouTubeTranscriptApi.list_transcripts(video_id)
176
+ except json.decoder.JSONDecodeError:
177
+ return None
178
+
179
+
180
+ WORDS_TO_REMOVE = [
181
+ '[Music]'
182
+ '[Applause]'
183
+ '[Laughter]'
184
+ ]
185
+
186
+
187
+ def get_words(video_id, transcript_type='auto', fallback='manual', filter_words_to_remove=True, granularity='word'):
188
+ """Get parsed video transcript with caching system
189
+ returns None if not processed yet and process is False
190
+ """
191
+
192
+ raw_transcript_json = None
193
+ try:
194
+ transcript_list = list_transcripts(video_id)
195
+
196
+ if transcript_list is not None:
197
+ if transcript_type == 'manual':
198
+ ts = transcript_list.find_manually_created_transcript(
199
+ LANGUAGE_PREFERENCE_LIST)
200
+ else:
201
+ ts = transcript_list.find_generated_transcript(
202
+ LANGUAGE_PREFERENCE_LIST)
203
+ raw_transcript = ts._http_client.get(
204
+ f'{ts._url}&fmt=json3').content
205
+ if raw_transcript:
206
+ raw_transcript_json = json.loads(raw_transcript)
207
+
208
+ except (youtube_transcript_api2.TooManyRequests, youtube_transcript_api2.YouTubeRequestFailed):
209
+ raise # Cannot recover from these errors and do not mark as empty transcript
210
+
211
+ except requests.exceptions.RequestException: # Can recover
212
+ return get_words(video_id, transcript_type, fallback, granularity)
213
+
214
+ except youtube_transcript_api2.CouldNotRetrieveTranscript: # Retrying won't solve
215
+ pass # Mark as empty transcript
216
+
217
+ except json.decoder.JSONDecodeError:
218
+ return get_words(video_id, transcript_type, fallback, granularity)
219
+
220
+ if not raw_transcript_json and fallback is not None:
221
+ return get_words(video_id, transcript_type=fallback, fallback=None, granularity=granularity)
222
+
223
+ if raw_transcript_json:
224
+ processed_transcript = parse_transcript_json(
225
+ raw_transcript_json, granularity)
226
+ if filter_words_to_remove:
227
+ processed_transcript = list(
228
+ filter(lambda x: x['text'] not in WORDS_TO_REMOVE, processed_transcript))
229
+ else:
230
+ processed_transcript = raw_transcript_json # Either None or []
231
+
232
+ return processed_transcript
233
+
234
+
235
+ def word_start(word):
236
+ return word['start']
237
+
238
+
239
+ def word_end(word):
240
+ return word.get('end', word['start'])
241
+
242
+
243
+ def extract_segment(words, start, end, map_function=None):
244
+ """Extracts all words with time in [start, end]"""
245
+
246
+ a = max(binary_search_below(words, 0, len(words), start), 0)
247
+ b = min(binary_search_above(words, -1, len(words) - 1, end) + 1, len(words))
248
+
249
+ to_transform = map_function is not None and callable(map_function)
250
+
251
+ return [
252
+ map_function(words[i]) if to_transform else words[i] for i in range(a, b)
253
+ ]
254
+
255
+
256
+ def avg(*items):
257
+ return sum(items)/len(items)
258
+
259
+
260
+ def binary_search_below(transcript, start_index, end_index, time):
261
+ if start_index >= end_index:
262
+ return end_index
263
+
264
+ middle_index = (start_index + end_index) // 2
265
+ middle = transcript[middle_index]
266
+ middle_time = avg(word_start(middle), word_end(middle))
267
+
268
+ if time <= middle_time:
269
+ return binary_search_below(transcript, start_index, middle_index, time)
270
+ else:
271
+ return binary_search_below(transcript, middle_index + 1, end_index, time)
272
+
273
+
274
+ def binary_search_above(transcript, start_index, end_index, time):
275
+ if start_index >= end_index:
276
+ return end_index
277
+
278
+ middle_index = (start_index + end_index + 1) // 2
279
+ middle = transcript[middle_index]
280
+ middle_time = avg(word_start(middle), word_end(middle))
281
+
282
+ if time >= middle_time:
283
+ return binary_search_above(transcript, middle_index, end_index, time)
284
+ else:
285
+ return binary_search_above(transcript, start_index, middle_index - 1, time)
286
 
287
 
288
  class PreTrainedPipeline():
 
292
 
293
  def __call__(self, inputs: "Image.Image")-> List[Dict[str, Any]]:
294
 
295
+ # TEMP testing
296
+ # data = [{"video_id": "pqh4LfPeCYs", "start": 835.933, "end": 927.581, "category": "sponsor"}]
297
+ words = get_words("pqh4LfPeCYs")
298
+ segment = extract_segment(words, 835.933, 927.581)
299
+ # END TEMP
300
+
301
+
302
  # convert img to numpy array, resize and normalize to make the prediction
303
  img = np.array(inputs)
304
 
 
355
  "label": f"LABEL_{cls}",
356
  "mask": mask_codes[f"mask_{cls}"],
357
  "score": 1.0,
358
+ "words": segment
359
  })
360
  return labels