import * as ort from 'onnxruntime-web'; class OnnxWrapper { private session: ort.InferenceSession; private _state: number[][]; private _context: number[]; private _last_sr: number; private _last_batch_size: number; private sample_rates: number[]; private sessionReady: Promise; constructor(path: string, force_onnx_cpu: boolean = true) { console.log(`Initializing OnnxWrapper with path: ${path}`); this.sessionReady = this.initSession(path, force_onnx_cpu); this.resetStates(); this.sample_rates = [8000, 16000]; } async ready(): Promise { console.log('Waiting for OnnxWrapper session to be ready'); await this.sessionReady; console.log('OnnxWrapper session is ready'); } private async initSession(path: string, force_onnx_cpu: boolean) { console.log(`Initializing ONNX session with force_onnx_cpu: ${force_onnx_cpu}`); const options: ort.InferenceSession.SessionOptions = { executionProviders: force_onnx_cpu ? ['wasm'] : ['webgl', 'wasm'], graphOptimizationLevel: 'all', executionMode: 'sequential', enableCpuMemArena: true, enableMemPattern: true, extra: { session: { intra_op_num_threads: 1, inter_op_num_threads: 1, } } }; this.session = await ort.InferenceSession.create(path, options); console.log('ONNX session created successfully'); } private _validate_input(x: number[][], sr: number): [number[][], number] { if (!Array.isArray(x[0])) { x = [x as unknown as number[]]; } if (x.length > 2) { throw new Error(`Too many dimensions for input audio chunk ${x.length}`); } if (sr !== 16000 && (sr % 16000 === 0)) { const step = Math.floor(sr / 16000); x = x.map(row => row.filter((_, i) => i % step === 0)); sr = 16000; } if (!this.sample_rates.includes(sr)) { throw new Error(`Supported sampling rates: ${this.sample_rates} (or multiply of 16000)`); } if (sr / x[0].length > 31.25) { throw new Error("Input audio chunk is too short"); } return [x, sr]; } resetStates(batch_size: number = 1): void { console.log(`Resetting states with batch_size: ${batch_size}`); this._state = Array(2).fill(0).map(() => Array(batch_size * 128).fill(0)); this._context = []; this._last_sr = 0; this._last_batch_size = 0; } async call(x: number[][], sr: number): Promise { console.log(`Calling model with input shape: [${x.length}, ${x[0].length}], sample rate: ${sr}`); await this.ready(); [x, sr] = this._validate_input(x, sr); const num_samples = sr === 16000 ? 512 : 256; if (x[0].length !== num_samples) { throw new Error(`Provided number of samples is ${x[0].length} (Supported values: 256 for 8000 sample rate, 512 for 16000)`); } const batch_size = x.length; const context_size = sr === 16000 ? 64 : 32; if (!this._last_batch_size) { this.resetStates(batch_size); } if (this._last_sr && this._last_sr !== sr) { this.resetStates(batch_size); } if (this._last_batch_size && this._last_batch_size !== batch_size) { this.resetStates(batch_size); } if (this._context.length === 0) { this._context = Array(batch_size * context_size).fill(0); } x = x.map((row, i) => [...this._context.slice(i * context_size, (i + 1) * context_size), ...row]); if (sr === 8000 || sr === 16000) { const inputTensor = new ort.Tensor('float32', x.flat(), [batch_size, x[0].length]); const stateTensor = new ort.Tensor('float32', this._state.flat(), [2, batch_size, 128]); const srTensor = new ort.Tensor('int64', [sr], []); const feeds: Record = { input: inputTensor, state: stateTensor, sr: srTensor }; const results = await this.session.run(feeds); const outputData = results.output.data as Float32Array; const stateData = results.stateN.data as Float32Array; this._state = Array(2).fill(0).map((_, i) => Array.from(stateData.slice(i * batch_size * 128, (i + 1) * batch_size * 128)) ); const outputShape = results.output.dims as number[]; const out = Array(outputShape[0]).fill(0).map((_, i) => Array.from(outputData.slice(i * outputShape[1], (i + 1) * outputShape[1])) ); this._context = x.map(row => row.slice(-context_size)).flat(); this._last_sr = sr; this._last_batch_size = batch_size; console.log(`Model call completed, output shape: [${out.length}, ${out[0].length}]`); return out; } else { throw new Error(`Unsupported sample rate: ${sr}. Supported rates are 8000 and 16000.`); } } async audio_forward(x: number[][], sr: number): Promise { console.log(`Running audio_forward with input shape: [${x.length}, ${x[0].length}], sample rate: ${sr}`); const outs: number[][][] = []; [x, sr] = this._validate_input(x, sr); this.resetStates(); const num_samples = sr === 16000 ? 512 : 256; if (x[0].length % num_samples !== 0) { const pad_num = num_samples - (x[0].length % num_samples); x = x.map(row => [...row, ...Array(pad_num).fill(0)]); } for (let i = 0; i < x[0].length; i += num_samples) { const wavs_batch = x.map(row => row.slice(i, i + num_samples)); const out_chunk = await this.call(wavs_batch, sr); outs.push(out_chunk); } console.log(`audio_forward completed, output shape: [${outs.length}, ${outs[0].length}]`); return outs.reduce((acc, curr) => acc.map((row, i) => [...row, ...curr[i]])); } close(): void { console.log('Closing OnnxWrapper session'); this.session.release(); } } export default OnnxWrapper;