allandclive commited on
Commit
ce348f5
·
1 Parent(s): f01f4ac

Update stitched_model.py

Browse files
Files changed (1) hide show
  1. stitched_model.py +2 -2
stitched_model.py CHANGED
@@ -1,12 +1,12 @@
1
  import torch
2
  from torch import nn
3
- from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, AutoTokenizer, AutoModelForSeq2SeqLM
4
 
5
  class CombinedModel(nn.Module):
6
  def __init__(self, stt_model_name, nmt_model_name,device = "cuda"):
7
  super(CombinedModel, self).__init__()
8
 
9
- self.stt_processor = Wav2Vec2Processor.from_pretrained(stt_model_name)
10
  self.stt_model = Wav2Vec2ForCTC.from_pretrained(stt_model_name)
11
  self.nmt_tokenizer = AutoTokenizer.from_pretrained(nmt_model_name)
12
  self.nmt_model = AutoModelForSeq2SeqLM.from_pretrained(nmt_model_name)
 
1
  import torch
2
  from torch import nn
3
+ from transformers import AutoProcessor, Wav2Vec2ForCTC, AutoTokenizer, AutoModelForSeq2SeqLM
4
 
5
  class CombinedModel(nn.Module):
6
  def __init__(self, stt_model_name, nmt_model_name,device = "cuda"):
7
  super(CombinedModel, self).__init__()
8
 
9
+ self.stt_processor = AutoProcessor.from_pretrained(stt_model_name)
10
  self.stt_model = Wav2Vec2ForCTC.from_pretrained(stt_model_name)
11
  self.nmt_tokenizer = AutoTokenizer.from_pretrained(nmt_model_name)
12
  self.nmt_model = AutoModelForSeq2SeqLM.from_pretrained(nmt_model_name)