from transformers import PretrainedConfig, PreTrainedModel import torch import torch.nn as nn class ONNXBaseConfig(PretrainedConfig): model_type = "onnx-base" def __init__(self, model_path=None, **kwargs): self.model_path = model_path super().__init__(**kwargs) model_directory = './new_model' config = ONNXBaseConfig(model_path='model.onnx') config.save_pretrained(save_directory=model_directory) class ONNXBaseModel(PreTrainedModel): config_class = ONNXBaseConfig def __init__(self, config): super().__init__(config) self.dummy_param = nn.Parameter(torch.zeros(0)) def forward(self, inputs): return torch.zeros_like(inputs) def save_pretrained(self, save_directory: str, **kwargs): super().save_pretrained(save_directory=save_directory, **kwargs) onnx_file_path = save_directory + '/model.onnx' dummy_input = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32) torch.onnx.export(self, dummy_input, onnx_file_path, input_names=['input'], output_names=['output'], dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}) # Initialize model model = ONNXBaseModel(config) # Save model model.save_pretrained(save_directory=model_directory) model = model.from_pretrained(model_directory) # Test model dummy_input = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32) output_tensor = model(dummy_input) print(output_tensor) # Test the onnx model onnx_file_path = model_directory + '/model.onnx' import onnx import onnxruntime as ort ort_session = ort.InferenceSession(onnx_file_path) outputs = ort_session.run(None, {'input': dummy_input.numpy()}) print("Model output:", outputs)