Update modeling_emotion_classifier.py
Browse files
modeling_emotion_classifier.py
CHANGED
@@ -1,8 +1,20 @@
|
|
1 |
-
|
2 |
-
import torch.nn as nn
|
3 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
from .configuration_emotion_classifier import EmotionClassifierConfig
|
5 |
|
|
|
|
|
6 |
class EmotionClassifierHuBERT(PreTrainedModel):
|
7 |
config_class = EmotionClassifierConfig
|
8 |
|
@@ -47,7 +59,7 @@ class EmotionClassifierHuBERT(PreTrainedModel):
|
|
47 |
mirror = kwargs.pop("mirror", None)
|
48 |
|
49 |
# Load config if we don't provide a configuration
|
50 |
-
if not isinstance(config,
|
51 |
config_path = config if config is not None else pretrained_model_name_or_path
|
52 |
config, model_kwargs = cls.config_class.from_pretrained(
|
53 |
config_path,
|
@@ -156,4 +168,4 @@ class EmotionClassifierHuBERT(PreTrainedModel):
|
|
156 |
loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys}
|
157 |
return model, loading_info
|
158 |
|
159 |
-
return model
|
|
|
1 |
+
import os
|
|
|
2 |
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from transformers import PreTrainedModel, HubertConfig, HubertModel
|
5 |
+
from transformers.file_utils import (
|
6 |
+
WEIGHTS_NAME,
|
7 |
+
TF2_WEIGHTS_NAME,
|
8 |
+
TF_WEIGHTS_NAME,
|
9 |
+
cached_path,
|
10 |
+
hf_bucket_url,
|
11 |
+
is_remote_url,
|
12 |
+
)
|
13 |
+
from transformers.utils import logging
|
14 |
from .configuration_emotion_classifier import EmotionClassifierConfig
|
15 |
|
16 |
+
logger = logging.get_logger(__name__)
|
17 |
+
|
18 |
class EmotionClassifierHuBERT(PreTrainedModel):
|
19 |
config_class = EmotionClassifierConfig
|
20 |
|
|
|
59 |
mirror = kwargs.pop("mirror", None)
|
60 |
|
61 |
# Load config if we don't provide a configuration
|
62 |
+
if not isinstance(config, EmotionClassifierConfig):
|
63 |
config_path = config if config is not None else pretrained_model_name_or_path
|
64 |
config, model_kwargs = cls.config_class.from_pretrained(
|
65 |
config_path,
|
|
|
168 |
loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys}
|
169 |
return model, loading_info
|
170 |
|
171 |
+
return model
|