How to Perform Inference with GPT-2 Mini Model in Flutter Using ONNX Runtime

#101
by junssashu - opened

I am currently working on a Flutter project where I need to implement a chatbot using the GPT-2 Mini model with ONNX Runtime. I have successfully loaded the GPT-2 Mini model into my Flutter app, but I am facing challenges with performing inference to obtain logical outputs.

Here is the corrected text with the ">" quote at the beginning of each line for Markdown syntax:

import 'dart:io';
import 'dart:math';
import 'dart:typed_data';
import 'package:flutter/services.dart';
import 'package:onnxruntime/onnxruntime.dart';
import 'package:path_provider/path_provider.dart';
import 'dart:convert';
import 'package:flutter/services.dart' show rootBundle;

class Phi3ChatModel {

  late OrtSession _session;
  late OrtSessionOptions _sessionOptions;
  late Map<String, dynamic> _tokenizerConfig;
  late Map<String, dynamic> _generationConfig;
  late Map<String, dynamic> _modelConfig;
  late Map<String, dynamic> _tokenizer;
  late Map<String, dynamic> _vocab;
  late Map<String, dynamic> _specialTokensMap;
  late int _vocabSize;
  late String _bosToken;
  late String _eosToken;
  late String _unkToken;
  late int _maxLength;

  Phi3ChatModel() : _sessionOptions = OrtSessionOptions() {
    OrtEnv.instance.init();
  }

  Future<void> initModel() async {
    try {
      print("Initializing GPT-2 model...");
      _sessionOptions = OrtSessionOptions()
        ..setInterOpNumThreads(1)
        ..setIntraOpNumThreads(1)
        ..setSessionGraphOptimizationLevel(GraphOptimizationLevel.ortEnableAll);

      final appDocDir = await getApplicationDocumentsDirectory();

      // Load ONNX model
      final modelFile = File('${appDocDir.path}/decoder_model.onnx');
      if (!await modelFile.exists()) {
        final data = await rootBundle.load('assets/models/gpt2/decoder_model.onnx');
        await modelFile.writeAsBytes(data.buffer.asUint8List());
      }
      _session = OrtSession.fromFile(modelFile, _sessionOptions);

      // Load configuration files
      _tokenizerConfig = await _loadJsonFile('assets/models/gpt2/tokenizer_config.json');
      _generationConfig = await _loadJsonFile('assets/models/gpt2/generation_config.json');
      _modelConfig = await _loadJsonFile('assets/models/gpt2/config.json');
      _tokenizer = await _loadJsonFile('assets/models/gpt2/tokenizer.json');
      _specialTokensMap = await _loadJsonFile('assets/models/gpt2/special_tokens_map.json');
      _vocab = await _loadJsonFile('assets/models/gpt2/vocab.json');

      // Initialize model parameters
      _vocabSize = _modelConfig['vocab_size'];
      _bosToken = _specialTokensMap['bos_token'];
      _eosToken = _specialTokensMap['eos_token'];
      _unkToken = _specialTokensMap['unk_token'];
      _maxLength = _modelConfig['n_positions'];

      print("GPT-2 model initialized successfully.");
    } catch (e) {
      print("Error initializing GPT-2 model: $e");
      rethrow;
    }
  }

  Future<Map<String, dynamic>> _loadJsonFile(String path) async {
    try {
      print("Loading JSON file $path...");
      String jsonString = await rootBundle.loadString(path);
      print("JSON file $path loaded successfully.");
      return json.decode(jsonString);
    } catch (e) {
      print("Error loading JSON file $path: $e");
      rethrow;
    }
  }

  List<int> encode(String text) {
    print("Encoding text: \"$text\"");
    List<int> tokens = [];
    for (String word in text.split(' ')) {
      if (_tokenizer['model']['vocab'].containsKey(word)) {
        tokens.add(_tokenizer['model']['vocab'][word]);
      } else {
        // TODO: handle unknown words
      }
    }
    print("Encoded text: $tokens");
    return tokens;
  }

  String decode(List<int> tokens) {
    print("Decoding tokens: $tokens");
    String text = '';
    for (int token in tokens) {
      if (_tokenizer['model']['vocab'].containsValue(token)) {
        String word = _tokenizer['model']['vocab'].keys.firstWhere((key) => _tokenizer['model']['vocab'][key] == token);
        text += '$word ';
      } else {
        // TODO: handle unknown tokens
      }
    }
    print("Decoded text: \"$text\"");
    return text.trim();
  }

  void processOutputTokens(Object outputTokensObject, List<int> generatedTokens) {
    if (outputTokensObject is List) {
      if (outputTokensObject[0] is List) {
        for (var item in outputTokensObject) {
          processOutputTokens(item, generatedTokens);
        }
      }
      if (outputTokensObject[0] is double) {
        var probs = _softmax(outputTokensObject as List<double>);
        var sample = _sampleFromProbs(probs);
        generatedTokens.add(sample);
      }
    }
  }

  Future<String> generateText(String prompt, {int maxNewTokens = 50}) async {
    try {
      print("Generating text with prompt: \"$prompt\"...");

      // Encode the prompt
      List<int> inputIds = encode(prompt);
      inputIds.insert(0, _modelConfig['bos_token_id']); // Beginning of sequence token
      inputIds.add(_modelConfig['eos_token_id']);

      // Initialize an empty list to store the generated tokens
      List<int> generatedTokens = [];

      var inputTensor = OrtValueTensor.createTensorWithDataList(
        inputIds,
        [1, inputIds.length],
      );

      // Create attention mask (all ones since there is no padding)
      var attentionMask = OrtValueTensor.createTensorWithDataList(
        Int64List.fromList(List.filled(inputIds.length, 1)),
        [1, inputIds.length],
      );

      var ortInput = {
        'input_ids': inputTensor,
        'attention_mask': attentionMask
      };

      final outputs = _session.run(
        OrtRunOptions(),
        ortInput
      );

      var out1 = outputs[0]?.value as List;

      processOutputTokens(out1, generatedTokens);

      inputTensor.release();
      attentionMask.release();

      // Decode the generated tokens
      String result = decode(generatedTokens);
      print("Generated text: \"$result\"");
      return result;
    } catch (e) {
      print("Error generating text: $e");
      return "Error: Unable to generate text.";
    }
  }


// Helper function to calculate softmax
  List<double> _softmax(List<double> logits) {
    print("Calculating softmax for logits: $logits");
    double maxLogit = logits.reduce((a, b) => a > b ? a : b);
    List<double> expLogits = logits.map((logit) => exp(logit - maxLogit)).toList();
    double sum = expLogits.reduce((a, b) => a + b);
    List<double> result = expLogits.map((expLogit) => expLogit / sum).toList();
    print("Softmax result: $result");
    return result;
  }

// Helper function to sample from probabilities
  int _sampleFromProbs(List<double> probs) {
    print("Sampling from probabilities: $probs");
    Random rand = Random();
    double cumulativeProb = 0.0;
    for (int i = 0; i < probs.length; i++) {
      cumulativeProb += probs[i];
      if (rand.nextDouble() < cumulativeProb) {
        int result = i;
        print("Sampled token: $result");
        return result;
      }
    }
    int result = probs.length - 1; // Fallback to last token
    print("Sampled token: $result");
    return result;
  }

  Future<String> predict(String inputData) async {
    print("----------------------------- predicting start -");
    final inputTensor = OrtValueTensor.createTensorWithDataList([inputData], [1]);
    print("----------------------------- predicting input tensor created -");
    final inputs = {'input': inputTensor};
    final outputs = await _session.runAsync(OrtRunOptions(), inputs);
    print("----------------------------- predicting infered -");
    inputTensor.release();
    print("----------------------------- predicting memory released  -");

    final response = outputs?[0]?.value as List<String>;
    outputs?.forEach((element) => element?.release());

    print("----------------------------- predicting done -");
    return response.first;
  }

  void release() {
    print("Releasing resources...");
    _sessionOptions.release();
    _session.release();
    OrtEnv.instance.release();
    print("Resources released.");
  }
}
import 'package:flutter/material.dart';
import 'package:onnxruntime_example/features/phi3_chat_model.dart';

class ONNXChatScreen extends StatefulWidget {
 const ONNXChatScreen({super.key});

 

@override
	
 _ONNXChatScreenState createState() => _ONNXChatScreenState();
}

class _ONNXChatScreenState extends State<ONNXChatScreen> {
 final TextEditingController _controller = TextEditingController();
 late Phi3ChatModel _chatbotModel;
 List<String> _messages = [];
 bool _isModelInitialized = false;

 

@override
	
 void initState() {
   super.initState();
   _initializeModel();
 }

 Future<void> _initializeModel() async {
   _chatbotModel = Phi3ChatModel();
   try {
     await _chatbotModel.initModel();
     setState(() {
       _isModelInitialized = true;
     });
   } catch (e) {
     print('Erreur lors de l\'initialisation du modèle : $e');
   }
 }

 

@override
	
 void dispose() {
   _chatbotModel.release();
   super.dispose();
 }

 void _sendMessage() async {
   final message = _controller.text;
   if (message.isEmpty) return;

   try {
     final response = await _chatbotModel.generateText(message);
     setState(() {
       _messages.add('You: $message');
       _messages.add('Bot: $response');
       _controller.clear();
     });
   } catch (e) {
     print('Erreur lors de l\'envoi du message : $e');
   }
 }

 

@override
	
 Widget build(BuildContext context) {
   return Scaffold(
     appBar: AppBar(title: const Text('ONNX Chatbot')),
     body: !_isModelInitialized
         ? const Center(child: CircularProgressIndicator())
         : Column(
       children: <Widget>[
         Expanded(
           child: ListView.builder(
             itemCount: _messages.length,
             itemBuilder: (context, index) => ListTile(
               title: Text(_messages[index]),
             ),
           ),
         ),
         Padding(
           padding: const EdgeInsets.all(8.0),
           child: Row(
             children: <Widget>[
               Expanded(
                 child: TextField(controller: _controller),
               ),
               IconButton(
                 icon: const Icon(Icons.send),
                 onPressed: _sendMessage,
               ),
             ],
           ),
         ),
       ],
     ),
   );
 }
}

the output i get
screenshot model gpt2.png

I would suggest use GPT2-medium moel for your chatbot. Thanks.

I actually tried it but the model file was too big for the asset file to load so I went with the GPT2-mini version.
thanks but in the demos they show how it infers directly on the transformers so the problem is with my inference method and I'd like some help with that.

My bad, I'm not very expert.
I just test different SLMs capabilities like their reasoning and role modeling ability, in search of Ideal Language model to run locally as a Chatbot. I have tested many bots like "ChatGPT-2-medium" and "nisten/Biggie-SmoLlm-0.15B-Base" - with slight adjustments in temperature (0.7-0.9)
But check out my chat with ChatGPT-AI: https://chatgpt.com/share/480c1e78-9b94-4172-b63e-c0e5093aa9d4
which contains the solution, updated and upgraded code for your chatbot.
For more to know how I test SLMs for free or anything else contact me at Telegram (UserName: @Saitama_AU), Link: https://t.me/Saitama_AU .
Thanks...

Here's updated code, for upgraded code please see my ChatGPT Chat...

import 'dart:io';
import 'dart:math';
import 'dart:typed_data';
import 'package:flutter/services.dart';
import 'package:onnxruntime/onnxruntime.dart';
import 'package:path_provider/path_provider.dart';
import 'dart:convert';
import 'package:flutter/services.dart' show rootBundle;

class Phi3ChatModel {

late OrtSession _session;
late OrtSessionOptions _sessionOptions;
late Map<String, dynamic> _tokenizerConfig;
late Map<String, dynamic> _generationConfig;
late Map<String, dynamic> _modelConfig;
late Map<String, dynamic> _tokenizer;
late Map<String, dynamic> _vocab;
late Map<String, dynamic> _specialTokensMap;
late int _vocabSize;
late String _bosToken;
late String _eosToken;
late String _unkToken;
late int _maxLength;

Phi3ChatModel() : _sessionOptions = OrtSessionOptions() {
OrtEnv.instance.init();
}

Future initModel() async {
try {
print("Initializing GPT-2 model...");
_sessionOptions = OrtSessionOptions()
..setInterOpNumThreads(1)
..setIntraOpNumThreads(1)
..setSessionGraphOptimizationLevel(GraphOptimizationLevel.ortEnableAll);

  final appDocDir = await getApplicationDocumentsDirectory();

  // Load ONNX model
  final modelFile = File('${appDocDir.path}/decoder_model.onnx');
  if (!await modelFile.exists()) {
    final data = await rootBundle.load('assets/models/gpt2/decoder_model.onnx');
    await modelFile.writeAsBytes(data.buffer.asUint8List());
  }
  _session = OrtSession.fromFile(modelFile, _sessionOptions);

  // Load configuration files
  _tokenizerConfig = await _loadJsonFile('assets/models/gpt2/tokenizer_config.json');
  _generationConfig = await _loadJsonFile('assets/models/gpt2/generation_config.json');
  _modelConfig = await _loadJsonFile('assets/models/gpt2/config.json');
  _tokenizer = await _loadJsonFile('assets/models/gpt2/tokenizer.json');
  _specialTokensMap = await _loadJsonFile('assets/models/gpt2/special_tokens_map.json');
  _vocab = await _loadJsonFile('assets/models/gpt2/vocab.json');

  // Initialize model parameters
  _vocabSize = _modelConfig['vocab_size'];
  _bosToken = _specialTokensMap['bos_token'];
  _eosToken = _specialTokensMap['eos_token'];
  _unkToken = _specialTokensMap['unk_token'];
  _maxLength = _modelConfig['n_positions'];

  print("GPT-2 model initialized successfully.");
} catch (e) {
  print("Error initializing GPT-2 model: $e");
  rethrow;
}

}

Future<Map<String, dynamic>> _loadJsonFile(String path) async {
try {
print("Loading JSON file $path...");
String jsonString = await rootBundle.loadString(path);
print("JSON file $path loaded successfully.");
return json.decode(jsonString);
} catch (e) {
print("Error loading JSON file $path: $e");
rethrow;
}
}

List encode(String text) {
print("Encoding text: "$text"");
List tokens = [];
List words = _tokenizer['model']['pre_tokenizer']['splitter']
.split(RegExp(text));
for (String word in words) {
if (_tokenizer['model']['vocab'].containsKey(word)) {
tokens.add(_tokenizer['model']['vocab'][word]);
} else {
tokens.add(_tokenizer['model']['vocab'][_unkToken]); // Handle unknown words
}
}
print("Encoded text: $tokens");
return tokens;
}

String decode(List tokens) {
print("Decoding tokens: $tokens");
String text = '';
for (int token in tokens) {
if (_vocab.containsValue(token)) {
String word = _vocab.keys.firstWhere((key) => _vocab[key] == token);
text += '$word ';
} else {
text += _unkToken + ' '; // Handle unknown tokens
}
}
print("Decoded text: "$text"");
return text.trim();
}

void processOutputTokens(Object outputTokensObject, List generatedTokens) {
if (outputTokensObject is List) {
if (outputTokensObject[0] is List) {
for (var item in outputTokensObject) {
processOutputTokens(item, generatedTokens);
}
} else if (outputTokensObject[0] is double) {
var probs = _softmax(outputTokensObject.cast());
var sample = _sampleFromProbs(probs);
generatedTokens.add(sample);
}
}
}

Future generateText(String prompt, {int maxNewTokens = 50}) async {
try {
print("Generating text with prompt: "$prompt"...");

  // Encode the prompt
  List<int> inputIds = encode(prompt);
  inputIds.insert(0, _modelConfig['bos_token_id']); // Beginning of sequence token
  inputIds.add(_modelConfig['eos_token_id']);

  // Initialize an empty list to store the generated tokens
  List<int> generatedTokens = [];

  var inputTensor = OrtValueTensor.createTensorWithDataList(
    inputIds,
    [1, inputIds.length],
  );

  // Create attention mask (all ones since there is no padding)
  var attentionMask = OrtValueTensor.createTensorWithDataList(
    Int64List.fromList(List.filled(inputIds.length, 1)),
    [1, inputIds.length],
  );

  var ortInput = {
    'input_ids': inputTensor,
    'attention_mask': attentionMask
  };

  final outputs = _session.run(
    OrtRunOptions(),
    ortInput
  );

  var out1 = outputs[0]?.value as List;

  processOutputTokens(out1, generatedTokens);

  inputTensor.release();
  attentionMask.release();

  // Decode the generated tokens
  String result = decode(generatedTokens);
  print("Generated text: \"$result\"");
  return result;
} catch (e) {
  print("Error generating text: $e");
  return "Error: Unable to generate text.";
}

}

// Helper function to calculate softmax
List _softmax(List logits) {
print("Calculating softmax for logits: $logits");
double maxLogit = logits.reduce((a, b) => a > b ? a : b);
List expLogits = logits.map((logit) => exp(logit - maxLogit)).toList();
double sum = expLogits.reduce((a, b) => a + b);
List result = expLogits.map((expLogit) => expLogit / sum).toList();
print("Softmax result: $result");
return result;
}

// Helper function to sample from probabilities
int _sampleFromProbs(List probs) {
print("Sampling from probabilities: $probs");
Random rand = Random();
double cumulativeProb = 0.0;
for (int i = 0; i < probs.length; i++) {
cumulativeProb += probs[i];
if (rand.nextDouble() < cumulativeProb) {
int result = i;
print("Sampled token: $result");
return result;
}
}
int result = probs.length - 1; // Fallback to last token
print("Sampled token: $result");
return result;
}

void release() {
print("Releasing resources...");
_sessionOptions.release();
_session.release();
OrtEnv.instance.release();
print("Resources released.");
}
}

import 'package:flutter/material.dart';

class ONNXChatScreen extends StatefulWidget {
const ONNXChatScreen({super.key});

@override
_ONNXChatScreenState createState() => _ONNXChatScreenState();
}

class _ONNXChatScreenState extends State {
final TextEditingController _controller = TextEditingController();
late Phi3ChatModel _chatbotModel;
List _messages = [];
bool _isModelInitialized = false;

@override
void initState() {
super.initState();
_initializeModel();
}

Future _initializeModel() async {
_chatbotModel = Phi3ChatModel();
try {
await _chatbotModel.initModel();
setState(() {
_isModelInitialized = true;
});
} catch (e) {
print('Error during model initialization: $e');
}
}

@override
void dispose() {
_chatbotModel.release();
super.dispose();
}

void _sendMessage() async {
final message = _controller.text;
if (message.isEmpty) return;

try {
final response = await _chatbotModel.generateText(message);
setState(() {
_messages.add('You: $message');
_messages.add('Bot: $response');
_controller.clear();
});
} catch (e) {
print('Error sending message: $e');
}
}

@override
Widget build(BuildContext context) {
return Scaffold(
appBar: AppBar(title: const Text('ONNX Chatbot')),
body: !_isModelInitialized
? const Center(child: CircularProgressIndicator())
: Column(
children: [
Expanded(
child: ListView.builder(
itemCount: _messages.length,
itemBuilder: (context, index) => ListTile(
title: Text(_messages[index]),
),
),
),
Padding(
padding: const EdgeInsets.all(8.0),
child: Row(
children: [
Expanded(
child: TextField(controller: _controller),
),
IconButton(
icon: const Icon(Icons.send),
onPressed: _sendMessage,
),
],
),
),
],
),
);
}
}

Sign up or log in to comment