from transformers import PretrainedConfig, PreTrainedModel, AutoConfig, AutoModel from transformers.pipelines import PIPELINE_REGISTRY from huggingface_hub import hf_hub_download import onnxruntime as ort import torch import os # 1. register AutoConfig class ONNXBaseConfig(PretrainedConfig): model_type = 'onnx-base' AutoConfig.register('onnx-base', ONNXBaseConfig) # 2. register AutoModel class ONNXBaseModel(PreTrainedModel): config_class = ONNXBaseConfig def __init__(self, config, base_path=None): super().__init__(config) if base_path: model_path = base_path + '/' + config.model_path if os.path.exists(model_path): self.session = ort.InferenceSession(model_path) def forward(self, input=None, **kwargs): outs = self.session.run(None, {'input': input}) return outs def save_pretrained(self, save_directory: str, **kwargs): super().save_pretrained(save_directory=save_directory, **kwargs) onnx_file_path = save_directory + '/model.onnx' dummy_input = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32) torch.onnx.export(self, dummy_input, onnx_file_path, input_names=['input'], output_names=['output'], dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) if config.model_path is None: config.model_path = 'model.onnx' is_local = os.path.isdir(pretrained_model_name_or_path) if is_local: base_path = pretrained_model_name_or_path else: config_path = hf_hub_download(repo_id=pretrained_model_name_or_path, filename='config.json') base_path = os.path.dirname(config_path) hf_hub_download(repo_id=pretrained_model_name_or_path, filename=config.model_path) return cls(config, base_path=base_path) @property def device(self): device = 'cuda' if torch.cuda.is_available() else 'cpu' return torch.device(device) AutoModel.register(ONNXBaseConfig, ONNXBaseModel) # 2. register Pipeline from transformers.pipelines import Pipeline class ONNXBasePipeline(Pipeline): def __init__(self, model, **kwargs): self.device_id = kwargs['device'] super().__init__(model=model, **kwargs) def _sanitize_parameters(self, **kwargs): return {}, {}, {} def preprocess(self, input): return {'input': input} def _forward(self, model_input): with torch.no_grad(): outputs = self.model(**model_input) return outputs def postprocess(self, model_outputs): return model_outputs PIPELINE_REGISTRY.register_pipeline( task='onnx-base', pipeline_class=ONNXBasePipeline, pt_model=ONNXBaseModel )