File size: 887 Bytes
e4f3acf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
#!/usr/bin/env python
# coding=utf-8
"""Automatically get correct model type.
"""

from lmflow.models.hf_decoder_model import HFDecoderModel
from lmflow.models.text_regression_model import TextRegressionModel
from lmflow.models.hf_encoder_decoder_model import HFEncoderDecoderModel

class AutoModel:

    @classmethod
    def get_model(self, model_args, *args, **kwargs):
        arch_type = model_args.arch_type
        if arch_type == "decoder_only":
            return HFDecoderModel(model_args, *args, **kwargs)
        elif arch_type == "text_regression":
            return TextRegressionModel(model_args, *args, **kwargs)
        elif arch_type == "encoder_decoder":
            return HFEncoderDecoderModel(model_args, *args, **kwargs)
        else:
            raise NotImplementedError(
                f"model architecture type \"{arch_type}\" is not supported"
            )