|
from transformers import PreTrainedModel |
|
from .configuration_test import TestConfig |
|
import torch.nn as nn |
|
from transformers import AutoModelForMaskedLM, AutoConfig |
|
import librosa |
|
from huggingface_hub import hf_hub_download |
|
import os |
|
|
|
|
|
class TestModel(PreTrainedModel): |
|
config_class = TestConfig |
|
|
|
def __init__(self, config: TestConfig): |
|
super().__init__(config) |
|
self.input_dim = config.input_dim |
|
self.model1 = nn.Linear(config.input_dim, config.output_dim) |
|
self.model2 = AutoModelForMaskedLM.from_config( |
|
AutoConfig.from_pretrained("albert/albert-base-v2") |
|
) |
|
self.path = config.name_or_path |
|
|
|
def get_audio_duration(self): |
|
audio_path = hf_hub_download( |
|
repo_id = self.path, |
|
filename = "output1.wav", |
|
repo_type = "model", |
|
local_dir = os.path.dirname(os.path.abspath(__file__)) |
|
) |
|
arr, sr = librosa.load(audio_path) |
|
return librosa.get_duration(y=arr, sr=sr) |
|
|
|
def forward(self, tensor): |
|
return self.model1(tensor) |
|
|