digitalhuman / Silero.js
atlury's picture
Upload 6 files
16bb6c5 verified
raw
history blame
6.59 kB
import * as ort from 'onnxruntime-web';
class OnnxWrapper {
private session: ort.InferenceSession;
private _state: number[][];
private _context: number[];
private _last_sr: number;
private _last_batch_size: number;
private sample_rates: number[];
private sessionReady: Promise<void>;
constructor(path: string, force_onnx_cpu: boolean = true) {
console.log(`Initializing OnnxWrapper with path: ${path}`);
this.sessionReady = this.initSession(path, force_onnx_cpu);
this.resetStates();
this.sample_rates = [8000, 16000];
}
async ready(): Promise<void> {
console.log('Waiting for OnnxWrapper session to be ready');
await this.sessionReady;
console.log('OnnxWrapper session is ready');
}
private async initSession(path: string, force_onnx_cpu: boolean) {
console.log(`Initializing ONNX session with force_onnx_cpu: ${force_onnx_cpu}`);
const options: ort.InferenceSession.SessionOptions = {
executionProviders: force_onnx_cpu ? ['wasm'] : ['webgl', 'wasm'],
graphOptimizationLevel: 'all',
executionMode: 'sequential',
enableCpuMemArena: true,
enableMemPattern: true,
extra: {
session: {
intra_op_num_threads: 1,
inter_op_num_threads: 1,
}
}
};
this.session = await ort.InferenceSession.create(path, options);
console.log('ONNX session created successfully');
}
private _validate_input(x: number[][], sr: number): [number[][], number] {
if (!Array.isArray(x[0])) {
x = [x as unknown as number[]];
}
if (x.length > 2) {
throw new Error(`Too many dimensions for input audio chunk ${x.length}`);
}
if (sr !== 16000 && (sr % 16000 === 0)) {
const step = Math.floor(sr / 16000);
x = x.map(row => row.filter((_, i) => i % step === 0));
sr = 16000;
}
if (!this.sample_rates.includes(sr)) {
throw new Error(`Supported sampling rates: ${this.sample_rates} (or multiply of 16000)`);
}
if (sr / x[0].length > 31.25) {
throw new Error("Input audio chunk is too short");
}
return [x, sr];
}
resetStates(batch_size: number = 1): void {
console.log(`Resetting states with batch_size: ${batch_size}`);
this._state = Array(2).fill(0).map(() => Array(batch_size * 128).fill(0));
this._context = [];
this._last_sr = 0;
this._last_batch_size = 0;
}
async call(x: number[][], sr: number): Promise<number[][]> {
console.log(`Calling model with input shape: [${x.length}, ${x[0].length}], sample rate: ${sr}`);
await this.ready();
[x, sr] = this._validate_input(x, sr);
const num_samples = sr === 16000 ? 512 : 256;
if (x[0].length !== num_samples) {
throw new Error(`Provided number of samples is ${x[0].length} (Supported values: 256 for 8000 sample rate, 512 for 16000)`);
}
const batch_size = x.length;
const context_size = sr === 16000 ? 64 : 32;
if (!this._last_batch_size) {
this.resetStates(batch_size);
}
if (this._last_sr && this._last_sr !== sr) {
this.resetStates(batch_size);
}
if (this._last_batch_size && this._last_batch_size !== batch_size) {
this.resetStates(batch_size);
}
if (this._context.length === 0) {
this._context = Array(batch_size * context_size).fill(0);
}
x = x.map((row, i) => [...this._context.slice(i * context_size, (i + 1) * context_size), ...row]);
if (sr === 8000 || sr === 16000) {
const inputTensor = new ort.Tensor('float32', x.flat(), [batch_size, x[0].length]);
const stateTensor = new ort.Tensor('float32', this._state.flat(), [2, batch_size, 128]);
const srTensor = new ort.Tensor('int64', [sr], []);
const feeds: Record<string, ort.Tensor> = {
input: inputTensor,
state: stateTensor,
sr: srTensor
};
const results = await this.session.run(feeds);
const outputData = results.output.data as Float32Array;
const stateData = results.stateN.data as Float32Array;
this._state = Array(2).fill(0).map((_, i) =>
Array.from(stateData.slice(i * batch_size * 128, (i + 1) * batch_size * 128))
);
const outputShape = results.output.dims as number[];
const out = Array(outputShape[0]).fill(0).map((_, i) =>
Array.from(outputData.slice(i * outputShape[1], (i + 1) * outputShape[1]))
);
this._context = x.map(row => row.slice(-context_size)).flat();
this._last_sr = sr;
this._last_batch_size = batch_size;
console.log(`Model call completed, output shape: [${out.length}, ${out[0].length}]`);
return out;
} else {
throw new Error(`Unsupported sample rate: ${sr}. Supported rates are 8000 and 16000.`);
}
}
async audio_forward(x: number[][], sr: number): Promise<number[][]> {
console.log(`Running audio_forward with input shape: [${x.length}, ${x[0].length}], sample rate: ${sr}`);
const outs: number[][][] = [];
[x, sr] = this._validate_input(x, sr);
this.resetStates();
const num_samples = sr === 16000 ? 512 : 256;
if (x[0].length % num_samples !== 0) {
const pad_num = num_samples - (x[0].length % num_samples);
x = x.map(row => [...row, ...Array(pad_num).fill(0)]);
}
for (let i = 0; i < x[0].length; i += num_samples) {
const wavs_batch = x.map(row => row.slice(i, i + num_samples));
const out_chunk = await this.call(wavs_batch, sr);
outs.push(out_chunk);
}
console.log(`audio_forward completed, output shape: [${outs.length}, ${outs[0].length}]`);
return outs.reduce((acc, curr) => acc.map((row, i) => [...row, ...curr[i]]));
}
close(): void {
console.log('Closing OnnxWrapper session');
this.session.release();
}
}
export default OnnxWrapper;