5roop commited on
Commit
e049326
1 Parent(s): c669a62

Update README.md

Browse files

Add versions, correct use example.

Files changed (1) hide show
  1. README.md +12 -18
README.md CHANGED
@@ -35,40 +35,34 @@ Nikola Ljubešić, Danijel Koržinek, Peter Rupnik, Ivo-Pavao Jazbec. ParlaSpeec
35
 
36
  ## Usage in `transformers`
37
 
38
- So far untested approach that worked before:
 
39
 
40
  ```python
41
- from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
42
  import soundfile as sf
43
  import torch
44
  import os
45
-
46
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
47
-
48
  # load model and tokenizer
49
- processor = Wav2Vec2Processor.from_pretrained(
50
  "5roop/wav2vec2-xls-r-parlaspeech-hr-lm")
51
  model = Wav2Vec2ForCTC.from_pretrained("5roop/wav2vec2-xls-r-parlaspeech-hr-lm")
52
-
53
-
54
  # download the example wav files:
55
- os.system("wget https://huggingface.co/5roop/wav2vec2-xls-r-parlaspeech-hr-lm/raw/main/00020570a.flac.wav")
56
-
57
  # read the wav file
58
  speech, sample_rate = sf.read("00020570a.flac.wav")
59
- input_values = processor(speech, sampling_rate=sample_rate, return_tensors="pt").input_values.to(device)
 
 
 
 
60
 
61
  # remove the raw wav file
62
  os.system("rm 00020570a.flac.wav")
 
63
 
64
- # retrieve logits
65
- logits = model.to(device)(input_values).logits
66
-
67
- # take argmax and decode
68
- predicted_ids = torch.argmax(logits, dim=-1)
69
- transcription = processor.decode(predicted_ids[0]).lower()
70
-
71
- # transcription: 'veliki broj poslovnih subjekata posluje sa minusom velik dio'
72
  ```
73
 
74
 
 
35
 
36
  ## Usage in `transformers`
37
 
38
+ Tested with `transformers==4.18.0`, `torch==1.11.0`, and `SoundFile==0.10.3.post1`.
39
+
40
 
41
  ```python
42
+ from transformers import Wav2Vec2ProcessorWithLM, Wav2Vec2ForCTC
43
  import soundfile as sf
44
  import torch
45
  import os
 
46
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
47
  # load model and tokenizer
48
+ processor = Wav2Vec2ProcessorWithLM.from_pretrained(
49
  "5roop/wav2vec2-xls-r-parlaspeech-hr-lm")
50
  model = Wav2Vec2ForCTC.from_pretrained("5roop/wav2vec2-xls-r-parlaspeech-hr-lm")
 
 
51
  # download the example wav files:
52
+ os.system("wget https://huggingface.co/classla/wav2vec2-large-slavic-parlaspeech-hr/raw/main/00020570a.flac.wav")
 
53
  # read the wav file
54
  speech, sample_rate = sf.read("00020570a.flac.wav")
55
+ input_values = processor(speech, sampling_rate=sample_rate, return_tensors="pt").input_values.cuda()
56
+ inputs = processor(speech, sampling_rate=sample_rate, return_tensors="pt")
57
+ with torch.no_grad():
58
+ logits = model(**inputs).logits
59
+ transcription = processor.batch_decode(logits.numpy()).text[0]
60
 
61
  # remove the raw wav file
62
  os.system("rm 00020570a.flac.wav")
63
+ transcription
64
 
65
+ # transcription: 'velik broj poslovnih subjekata posluje sa minusom velik dio'
 
 
 
 
 
 
 
66
  ```
67
 
68