metadata
license: apache-2.0
language:
- de
tags:
- sign-language
- whisper
- german
- safetensors
library_name: transformers
model-index:
- name: whisper-large-v3-turbo-german
results:
- task:
type: automatic-speech-recognition
name: Speech Recognition
dataset:
name: German ASR Data-Mix
type: flozi00/asr-german-mixed
metrics:
- type: wer
value: TBD
datasets:
- flozi00/asr-german-mixed
base_model:
- primeline/whisper-large-v3-german
Summary
Whisper is a powerful speech recognition platform developed by OpenAI. This model has been specially optimized for converting sign language input features into german text.
Applications
The model is based on 'primeline/whisper-large-v3-german' and used (in combination with google mediapipe) to translate a video of german sign language into text. This model decodes a sequence of input features, where each input feature represents keypoints extracted from a video (body hands, upper body and face), into text.
We keep the decoder frozen, while training the encoder.
Evaluations - Word error rate
TBD
Training data
TBD
Training process
!!! Make sure to install Transformers 4.46.0 !!!
import torch
from transformers import WhisperForConditionalGeneration, AutoProcessor, AutoTokenizer, AutoConfig, TextStreamer, Trainer
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# First load the config using AutoConfig
# See custom config in model.py for configuration options.
config = AutoConfig.from_pretrained(
"mrprimenotes/sign-whisper-german",
trust_remote_code=True,
use_first_embeddings=True,
#embedding_stride=2,
#conv_dropout=0.1,
skip_connections=True,
conv_preprocessing_layers=[
{ # When changing conv_preprocessing_layers make sure their final output has the shape b x 1280 x seq.
"in_channels": 128,
"out_channels": 1280,
"kernel_size": 3,
"stride": 1,
"padding": 1,
"activation": "gelu",
"bias": True
},
{
"in_channels": 1280,
"out_channels": 1280,
"kernel_size": 3,
"stride": 1,
"padding": 1,
"activation": "gelu",
"bias": True
}
]
)
tokenizer = AutoTokenizer.from_pretrained("mrprimenotes/sign-whisper-german")
model = AutoModel.from_pretrained(
pretrained_model_name_or_path="mrprimenotes/sign-whisper-german",
config=config,
use_safetensors=True,
trust_remote_code=True,
ignore_mismatched_sizes=True,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
device_map='auto'
).to(device)
# You can see raw model outputs as follows:
# output = model(input_features, labels=labels)
# e.g.
# output.loss
# output.shape --> b x sq
# Load your dataset (e.g. mrprimenotes/sign-whisper-german-example)
train_dataset = YourSignDataset(...)
val_dataset = YourSignDataset(...)
# Freeze the decoder for our purpose
model.freeze_decoder()
# Define training arguments
training_args = TrainingArguments(
hub_model_id="mrprimenotes/sign-whisper-german_trained",
push_to_hub=True,
num_train_epochs=2,
per_device_train_batch_size=256,
per_device_eval_batch_size=386,
learning_rate=2e-5
warmup_steps=200,
weight_decay=0.01,
# Logging settings
logging_steps=500,
logging_strategy="steps",
# Evaluation
metric_for_best_model="eval_loss",
greater_is_better=False,
evaluation_strategy="steps",
eval_steps=1000,
# Saving
save_strategy="steps",
save_steps=2000,
save_total_limit=4,
resume_from_checkpoint=True,
load_best_model_at_end=True,
fp16=torch.cuda.is_available(),
)
# Initialize trainer with tokenizer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
tokenizer=tokenizer,
)
# Train the model
trainer.train()
Use model for inference (with generate)
!!! Make sure to install Transformers 4.46.0 !!!
from transformers import TextStreamer
streamer = TextStreamer(tokenizer, skip_special_tokens=False) #only needed for streaming
# input preprocessing / feature extraction (TBD)
# input_features = ...
# Generate
generated_ids = model.generate(
input_features,
max_new_tokens=128,
return_timestamps=False, #timestamps are not supported
streamer=streamer #only needed for streaming
)
tokenizer.batch_decode(generated_ids, skip_special_tokens=False)