Spaces:
Runtime error
Runtime error
/// Text Generation Inference benchmarking tool | |
/// | |
/// Inspired by the great Oha app: https://github.com/hatoo/oha | |
/// and: https://github.com/orhun/rust-tui-template | |
use clap::Parser; | |
use std::path::Path; | |
use text_generation_client::ShardedClient; | |
use tokenizers::{FromPretrainedParameters, Tokenizer}; | |
use tracing_subscriber::layer::SubscriberExt; | |
use tracing_subscriber::util::SubscriberInitExt; | |
use tracing_subscriber::EnvFilter; | |
/// App Configuration | |
struct Args { | |
tokenizer_name: String, | |
revision: String, | |
batch_size: Option<Vec<u32>>, | |
sequence_length: u32, | |
decode_length: u32, | |
runs: usize, | |
warmups: usize, | |
master_shard_uds_path: String, | |
} | |
fn main() -> Result<(), Box<dyn std::error::Error>> { | |
// Get args | |
let args = Args::parse(); | |
// Pattern match configuration | |
let Args { | |
tokenizer_name, | |
revision, | |
batch_size, | |
sequence_length, | |
decode_length, | |
runs, | |
warmups, | |
master_shard_uds_path, | |
} = args; | |
let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]); | |
init_logging(); | |
// Tokenizer instance | |
// This will only be used to validate payloads | |
tracing::info!("Loading tokenizer"); | |
let local_path = Path::new(&tokenizer_name); | |
let tokenizer = | |
if local_path.exists() && local_path.is_dir() && local_path.join("tokenizer.json").exists() | |
{ | |
// Load local tokenizer | |
tracing::info!("Found local tokenizer"); | |
Tokenizer::from_file(local_path.join("tokenizer.json")).unwrap() | |
} else { | |
tracing::info!("Downloading tokenizer"); | |
// Parse Huggingface hub token | |
let auth_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok(); | |
// Download and instantiate tokenizer | |
// We need to download it outside of the Tokio runtime | |
let params = FromPretrainedParameters { | |
revision, | |
auth_token, | |
..Default::default() | |
}; | |
Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).unwrap() | |
}; | |
tracing::info!("Tokenizer loaded"); | |
// Launch Tokio runtime | |
tokio::runtime::Builder::new_multi_thread() | |
.enable_all() | |
.build() | |
.unwrap() | |
.block_on(async { | |
// Instantiate sharded client from the master unix socket | |
tracing::info!("Connect to model server"); | |
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) | |
.await | |
.expect("Could not connect to server"); | |
// Clear the cache; useful if the webserver rebooted | |
sharded_client | |
.clear_cache(None) | |
.await | |
.expect("Unable to clear cache"); | |
tracing::info!("Connected"); | |
// Run app | |
text_generation_benchmark::run( | |
tokenizer_name, | |
tokenizer, | |
batch_size, | |
sequence_length, | |
decode_length, | |
runs, | |
warmups, | |
sharded_client, | |
) | |
.await | |
.unwrap(); | |
}); | |
Ok(()) | |
} | |
/// Init logging using LOG_LEVEL | |
fn init_logging() { | |
// STDOUT/STDERR layer | |
let fmt_layer = tracing_subscriber::fmt::layer() | |
.with_file(true) | |
.with_line_number(true); | |
// Filter events with LOG_LEVEL | |
let env_filter = | |
EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info")); | |
tracing_subscriber::registry() | |
.with(env_filter) | |
.with(fmt_layer) | |
.init(); | |
} | |