Spaces:
Running
on
T4
Running
on
T4
import re | |
import torch | |
import torchaudio | |
from transformers import Wav2Vec2ForCTC, Wav2Vec2FeatureExtractor, Wav2Vec2CTCTokenizer, Wav2Vec2Processor | |
from tortoise.utils.audio import load_audio | |
def max_alignment(s1, s2, skip_character='~', record=None): | |
""" | |
A clever function that aligns s1 to s2 as best it can. Wherever a character from s1 is not found in s2, a '~' is | |
used to replace that character. | |
Finally got to use my DP skills! | |
""" | |
if record is None: | |
record = {} | |
assert skip_character not in s1, f"Found the skip character {skip_character} in the provided string, {s1}" | |
if len(s1) == 0: | |
return '' | |
if len(s2) == 0: | |
return skip_character * len(s1) | |
if s1 == s2: | |
return s1 | |
if s1[0] == s2[0]: | |
return s1[0] + max_alignment(s1[1:], s2[1:], skip_character, record) | |
take_s1_key = (len(s1), len(s2) - 1) | |
if take_s1_key in record: | |
take_s1, take_s1_score = record[take_s1_key] | |
else: | |
take_s1 = max_alignment(s1, s2[1:], skip_character, record) | |
take_s1_score = len(take_s1.replace(skip_character, '')) | |
record[take_s1_key] = (take_s1, take_s1_score) | |
take_s2_key = (len(s1) - 1, len(s2)) | |
if take_s2_key in record: | |
take_s2, take_s2_score = record[take_s2_key] | |
else: | |
take_s2 = max_alignment(s1[1:], s2, skip_character, record) | |
take_s2_score = len(take_s2.replace(skip_character, '')) | |
record[take_s2_key] = (take_s2, take_s2_score) | |
return take_s1 if take_s1_score > take_s2_score else skip_character + take_s2 | |
class Wav2VecAlignment: | |
""" | |
Uses wav2vec2 to perform audio<->text alignment. | |
""" | |
def __init__(self, device='cuda' if not torch.backends.mps.is_available() else 'mps'): | |
self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu() | |
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"facebook/wav2vec2-large-960h") | |
self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron-symbols') | |
self.device = device | |
def align(self, audio, expected_text, audio_sample_rate=24000): | |
orig_len = audio.shape[-1] | |
with torch.no_grad(): | |
self.model = self.model.to(self.device) | |
audio = audio.to(self.device) | |
audio = torchaudio.functional.resample(audio, audio_sample_rate, 16000) | |
clip_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7) | |
logits = self.model(clip_norm).logits | |
self.model = self.model.cpu() | |
logits = logits[0] | |
pred_string = self.tokenizer.decode(logits.argmax(-1).tolist()) | |
fixed_expectation = max_alignment(expected_text.lower(), pred_string) | |
w2v_compression = orig_len // logits.shape[0] | |
expected_tokens = self.tokenizer.encode(fixed_expectation) | |
expected_chars = list(fixed_expectation) | |
if len(expected_tokens) == 1: | |
return [0] # The alignment is simple; there is only one token. | |
expected_tokens.pop(0) # The first token is a given. | |
expected_chars.pop(0) | |
alignments = [0] | |
def pop_till_you_win(): | |
if len(expected_tokens) == 0: | |
return None | |
popped = expected_tokens.pop(0) | |
popped_char = expected_chars.pop(0) | |
while popped_char == '~': | |
alignments.append(-1) | |
if len(expected_tokens) == 0: | |
return None | |
popped = expected_tokens.pop(0) | |
popped_char = expected_chars.pop(0) | |
return popped | |
next_expected_token = pop_till_you_win() | |
for i, logit in enumerate(logits): | |
top = logit.argmax() | |
if next_expected_token == top: | |
alignments.append(i * w2v_compression) | |
if len(expected_tokens) > 0: | |
next_expected_token = pop_till_you_win() | |
else: | |
break | |
pop_till_you_win() | |
if not (len(expected_tokens) == 0 and len(alignments) == len(expected_text)): | |
torch.save([audio, expected_text], 'alignment_debug.pth') | |
assert False, "Something went wrong with the alignment algorithm. I've dumped a file, 'alignment_debug.pth' to" \ | |
"your current working directory. Please report this along with the file so it can get fixed." | |
# Now fix up alignments. Anything with -1 should be interpolated. | |
alignments.append(orig_len) # This'll get removed but makes the algorithm below more readable. | |
for i in range(len(alignments)): | |
if alignments[i] == -1: | |
for j in range(i+1, len(alignments)): | |
if alignments[j] != -1: | |
next_found_token = j | |
break | |
for j in range(i, next_found_token): | |
gap = alignments[next_found_token] - alignments[i-1] | |
alignments[j] = (j-i+1) * gap // (next_found_token-i+1) + alignments[i-1] | |
return alignments[:-1] | |
def redact(self, audio, expected_text, audio_sample_rate=24000): | |
if '[' not in expected_text: | |
return audio | |
splitted = expected_text.split('[') | |
fully_split = [splitted[0]] | |
for spl in splitted[1:]: | |
assert ']' in spl, 'Every "[" character must be paired with a "]" with no nesting.' | |
fully_split.extend(spl.split(']')) | |
# At this point, fully_split is a list of strings, with every other string being something that should be redacted. | |
non_redacted_intervals = [] | |
last_point = 0 | |
for i in range(len(fully_split)): | |
if i % 2 == 0 and fully_split[i] != "": # Check for empty string fixes index error | |
end_interval = max(0, last_point + len(fully_split[i]) - 1) | |
non_redacted_intervals.append((last_point, end_interval)) | |
last_point += len(fully_split[i]) | |
bare_text = ''.join(fully_split) | |
alignments = self.align(audio, bare_text, audio_sample_rate) | |
output_audio = [] | |
for nri in non_redacted_intervals: | |
start, stop = nri | |
output_audio.append(audio[:, alignments[start]:alignments[stop]]) | |
return torch.cat(output_audio, dim=-1) | |