Summary
Whisper is a powerful speech recognition platform developed by OpenAI. This model has been specially optimized for converting sign language input features into german text.
Applications
The model is based on 'primeline/whisper-large-v3-german' and used (in combination with google mediapipe) to translate a video of german sign language into text. This model decodes a sequence of input features, where each input feature represents keypoints extracted from a video (body hands, upper body and face), into text.
We keep the decoder frozen, while training the encoder.
Evaluations - Word error rate
TBD
Training data
TBD
Training process
import torch
from transformers import WhisperForConditionalGeneration, AutoProcessor, AutoTokenizer, AutoConfig
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# See custom config in model.py for configuration options.
# First load the config using AutoConfig
config = AutoConfig.from_pretrained(
"mrprimenotes/sign-whisper-german",
trust_remote_code=True,
use_first_embeddings=True,
#embedding_stride=2,
#conv_dropout=0.1,
skip_connections=True,
conv_preprocessing_layers=[
{ # When changing conv_preprocessing_layers make sure their final output has the shape b x 1280 x seq.
"in_channels": 128,
"out_channels": 1280,
"kernel_size": 3,
"stride": 1,
"padding": 1,
"activation": "gelu",
"bias": True
},
{
"in_channels": 1280,
"out_channels": 1280,
"kernel_size": 3,
"stride": 1,
"padding": 1,
"activation": "gelu",
"bias": True
}
]
)
tokenizer = AutoTokenizer.from_pretrained("mrprimenotes/sign-whisper-german")
model = AutoModel.from_pretrained(
pretrained_model_name_or_path="mrprimenotes/sign-whisper-german",
config=config,
use_safetensors=True,
trust_remote_code=True,
ignore_mismatched_sizes=True,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
device_map='auto'
).to(device)
# You can see raw model outputs as follows:
# output = model(input_features, labels=labels)
# e.g.
# output.loss
# output.shape --> b x sq
train_dataset = YourSignDataset(...)
val_dataset = YourSignDataset(...)
# Freeze the decoder for our purpose
model.freeze_decoder()
# Define training arguments
training_args = TrainingArguments(
hub_model_id="mrprimenotes/sign-whisper-german_trained",
push_to_hub=True,
num_train_epochs=2,
per_device_train_batch_size=256,
per_device_eval_batch_size=386,
learning_rate=2e-5
warmup_steps=200,
weight_decay=0.01,
# Logging settings
logging_steps=500,
logging_strategy="steps",
# Evaluation
metric_for_best_model="eval_loss",
greater_is_better=False,
evaluation_strategy="steps",
eval_steps=1000,
# Saving
save_strategy="steps",
save_steps=2000,
save_total_limit=4,
resume_from_checkpoint=True,
load_best_model_at_end=True,
fp16=torch.cuda.is_available(),
)
# Initialize trainer with tokenizer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
tokenizer=tokenizer,
)
# Train the model
trainer.train()
Use model for inference (with generate)
from transformers import TextStreamer
streamer = TextStreamer(tokenizer, skip_special_tokens=False) #only needed for streaming
# input preprocessing / feature extraction (TBD)
# input_features = ...
# Generate
generated_ids = model.generate(
input_features,
max_new_tokens=128,
return_timestamps=False, #timestamps are not supported
streamer=streamer #only needed for streaming
)
tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
- Downloads last month
- 260
Model tree for mrprimenotes/sign-whisper-german
Base model
primeline/whisper-large-v3-german