/// Multi shard Client use crate::Result; use crate::{Batch, Client, Generation, HealthResponse, Request, ShardInfo}; use futures::future::join_all; use tonic::transport::Uri; use tracing::instrument; #[derive(Debug, Clone)] /// Text Generation Inference gRPC multi client pub struct ShardedClient { clients: Vec, } impl ShardedClient { fn new(clients: Vec) -> Self { Self { clients } } /// Create a new ShardedClient from a master client. The master client will communicate with /// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method. async fn from_master_client(mut master_client: Client) -> Result { // Get all uris/unix sockets from the master client let uris = master_client.service_discovery().await?; let futures = uris.into_iter().map(Client::connect_uds); let clients: Result> = join_all(futures).await.into_iter().collect(); Ok(Self::new(clients?)) } /// Returns a client connected to the given uri pub async fn connect(uri: Uri) -> Result { let master_client = Client::connect(uri).await?; Self::from_master_client(master_client).await } /// Returns a client connected to the given unix socket pub async fn connect_uds(path: String) -> Result { let master_client = Client::connect_uds(path).await?; Self::from_master_client(master_client).await } /// Get the model info #[instrument(skip(self))] pub async fn info(&mut self) -> Result { let futures: Vec<_> = self .clients .iter_mut() .map(|client| client.info()) .collect(); join_all(futures).await.pop().unwrap() } /// GRPC health check #[instrument(skip(self))] pub async fn health(&mut self) -> Result { let futures: Vec<_> = self .clients .iter_mut() .map(|client| client.health()) .collect(); join_all(futures).await.pop().unwrap() } /// Clear the past generations cache #[instrument(skip(self))] pub async fn clear_cache(&mut self, batch_id: Option) -> Result<()> { let futures: Vec<_> = self .clients .iter_mut() .map(|client| client.clear_cache(batch_id)) .collect(); join_all(futures).await.into_iter().collect() } /// Filter a cached batch #[instrument(skip(self))] pub async fn filter_batch( &mut self, batch_id: u64, keep_requests: Vec, ) -> Result> { let futures: Vec<_> = self .clients .iter_mut() .map(|client| Box::pin(client.filter_batch(batch_id, keep_requests.clone()))) .collect(); // all shards return the same message join_all(futures).await.pop().unwrap() } /// Generate one token for each request in the given batch /// /// Returns Generation for each request in batch /// and the next cached batch #[instrument(skip_all, fields(id = &batch.id, size = &batch.size))] pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec, Option)> { let futures: Vec<_> = self .clients .iter_mut() .map(|client| Box::pin(client.prefill(batch.clone()))) .collect(); // all shards return the same message join_all(futures).await.pop().unwrap() } /// Generate one token for each request in the given cached batches /// /// Returns Generation for each request in batches /// and the next cached batch #[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::()))] pub async fn decode( &mut self, batches: Vec, ) -> Result<(Vec, Option)> { let futures: Vec<_> = self .clients .iter_mut() .map(|client| Box::pin(client.decode(batches.clone()))) .collect(); // all shards return the same message join_all(futures).await.pop().unwrap() } }