Spaces:
Runtime error
Runtime error
Commit
·
ce348f5
1
Parent(s):
f01f4ac
Update stitched_model.py
Browse files- stitched_model.py +2 -2
stitched_model.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
import torch
|
2 |
from torch import nn
|
3 |
-
from transformers import
|
4 |
|
5 |
class CombinedModel(nn.Module):
|
6 |
def __init__(self, stt_model_name, nmt_model_name,device = "cuda"):
|
7 |
super(CombinedModel, self).__init__()
|
8 |
|
9 |
-
self.stt_processor =
|
10 |
self.stt_model = Wav2Vec2ForCTC.from_pretrained(stt_model_name)
|
11 |
self.nmt_tokenizer = AutoTokenizer.from_pretrained(nmt_model_name)
|
12 |
self.nmt_model = AutoModelForSeq2SeqLM.from_pretrained(nmt_model_name)
|
|
|
1 |
import torch
|
2 |
from torch import nn
|
3 |
+
from transformers import AutoProcessor, Wav2Vec2ForCTC, AutoTokenizer, AutoModelForSeq2SeqLM
|
4 |
|
5 |
class CombinedModel(nn.Module):
|
6 |
def __init__(self, stt_model_name, nmt_model_name,device = "cuda"):
|
7 |
super(CombinedModel, self).__init__()
|
8 |
|
9 |
+
self.stt_processor = AutoProcessor.from_pretrained(stt_model_name)
|
10 |
self.stt_model = Wav2Vec2ForCTC.from_pretrained(stt_model_name)
|
11 |
self.nmt_tokenizer = AutoTokenizer.from_pretrained(nmt_model_name)
|
12 |
self.nmt_model = AutoModelForSeq2SeqLM.from_pretrained(nmt_model_name)
|