viktor-enzell commited on
Commit
2c48270
·
1 Parent(s): 0ffcd4e

Update model card example.

Browse files
Files changed (1) hide show
  1. README.md +28 -20
README.md CHANGED
@@ -20,30 +20,38 @@ Training of the acoustic model is the work of KBLab. See [VoxRex-C](https://hugg
20
  VoxRex-C is extended with a 4-gram language model estimated from a subset extracted from [The Swedish Culturomics Gigaword Corpus](https://spraakbanken.gu.se/resurser/gigaword) from Språkbanken. The subset contains 40M words from the social media genre between 2010 and 2015.
21
 
22
  ## How to use
23
- Audio should be downsampled to 16kHz.
24
 
25
  ```python
26
- import torch
27
- import torchaudio
28
  from datasets import load_dataset
29
- from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
30
- test_dataset = load_dataset("common_voice", "sv-SE", split="test[:2%]").
31
- processor = Wav2Vec2Processor.from_pretrained("KBLab/wav2vec2-large-voxrex-swedish")
32
- model = Wav2Vec2ForCTC.from_pretrained("KBLab/wav2vec2-large-voxrex-swedish")
33
- resampler = torchaudio.transforms.Resample(48_000, 16_000)
34
- # Preprocessing the datasets.
35
- # We need to read the aduio files as arrays
36
- def speech_file_to_array_fn(batch):
37
- speech_array, sampling_rate = torchaudio.load(batch["path"])
38
- batch["speech"] = resampler(speech_array).squeeze().numpy()
39
- return batch
40
- test_dataset = test_dataset.map(speech_file_to_array_fn)
41
- inputs = processor(test_dataset["speech"][:2], sampling_rate=16_000, return_tensors="pt", padding=True)
 
 
 
 
 
 
 
 
 
 
42
  with torch.no_grad():
43
- logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
44
- predicted_ids = torch.argmax(logits, dim=-1)
45
- print("Prediction:", processor.batch_decode(predicted_ids))
46
- print("Reference:", test_dataset["sentence"][:2])
47
  ```
48
 
49
  ## Training procedure
 
20
  VoxRex-C is extended with a 4-gram language model estimated from a subset extracted from [The Swedish Culturomics Gigaword Corpus](https://spraakbanken.gu.se/resurser/gigaword) from Språkbanken. The subset contains 40M words from the social media genre between 2010 and 2015.
21
 
22
  ## How to use
23
+ Example of transcribing 1% of the Common Voice test split, using GPU if available. The model expects 16kHz audio.
24
 
25
  ```python
26
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2ProcessorWithLM
 
27
  from datasets import load_dataset
28
+ import torch
29
+ import torchaudio.functional as F
30
+
31
+ # Import model and processor
32
+ model_name = 'viktor-enzell/wav2vec2-large-voxrex-swedish-4gram'
33
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
34
+ model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device);
35
+ processor = Wav2Vec2ProcessorWithLM.from_pretrained(model_name)
36
+
37
+ # Import and process speech data
38
+ common_voice = load_dataset('common_voice', 'sv-SE', split='test[:1%]')
39
+
40
+ def speech_file_to_array(sample):
41
+ # Convert speech file to array and downsample to 16 kHz
42
+ sampling_rate = sample['audio']['sampling_rate']
43
+ sample['speech'] = F.resample(torch.tensor(sample['audio']['array']), sampling_rate, 16_000)
44
+ return sample
45
+
46
+ common_voice = common_voice.map(speech_file_to_array)
47
+
48
+ # Run inference
49
+ inputs = processor(common_voice['speech'], sampling_rate=16_000, return_tensors='pt', padding=True).to(device)
50
+
51
  with torch.no_grad():
52
+ logits = model(**inputs).logits
53
+
54
+ transcripts = processor.batch_decode(logits.cpu().numpy()).text
 
55
  ```
56
 
57
  ## Training procedure