Spaces:
Running
Running
def get_model(model_name, adapted_component, adaptor_class, num_steers, rank, | |
epsilon, init_var, low_resource_mode): | |
if model_name.startswith("EleutherAI/gpt-neo") or \ | |
model_name.startswith("gpt2"): | |
from lm_steer.models.model_gpt_neo import Switching_GPTNeoModel | |
model = Switching_GPTNeoModel( | |
model_name, adapted_component, adaptor_class, num_steers, rank, | |
epsilon, init_var, low_resource_mode) | |
return model, model.tokenizer | |
elif model_name.startswith("lora-gpt2"): | |
from lm_steer.models.model_lora_gpt_neo import LORA_GPTNeoModel | |
model = LORA_GPTNeoModel(model_name, rank, epsilon) | |
return model, model.tokenizer | |
elif model_name.startswith("embedding_tuning"): | |
from lm_steer.models.model_embedding_tuning_gpt_neo import \ | |
EmbeddingTuning_GPTNeoModel | |
model = EmbeddingTuning_GPTNeoModel(model_name) | |
return model, model.tokenizer | |
elif model_name.startswith("prefix-gpt2"): | |
from lm_steer.models.model_prefix_gpt_neo import PREFIX_GPTNeoModel | |
model = PREFIX_GPTNeoModel(model_name) | |
return model, model.tokenizer | |
elif model_name.startswith("EleutherAI/pythia"): | |
from lm_steer.models.model_gpt_neox import Switching_GPTNeoXModel | |
model = Switching_GPTNeoXModel( | |
model_name, adapted_component, adaptor_class, num_steers, rank, | |
epsilon, init_var, low_resource_mode) | |
return model, model.tokenizer | |
elif model_name.startswith("EleutherAI/gpt-j"): | |
from lm_steer.models.model_gpt_j import Switching_GPTJModel | |
model = Switching_GPTJModel( | |
model_name, adapted_component, adaptor_class, num_steers, rank, | |
epsilon, init_var, low_resource_mode) | |
return model, model.tokenizer | |
elif model_name.startswith("microsoft/DialoGPT"): | |
from lm_steer.models.model_dialogpt import Switching_DialoGPTModel | |
model = Switching_DialoGPTModel( | |
model_name, adapted_component, adaptor_class, num_steers, rank, | |
epsilon, init_var, low_resource_mode) | |
return model, model.tokenizer | |
else: | |
raise NotImplementedError() | |