Spaces:
Runtime error
Runtime error
/// Multi shard Client | |
use crate::Result; | |
use crate::{Batch, Client, Generation, HealthResponse, Request, ShardInfo}; | |
use futures::future::join_all; | |
use tonic::transport::Uri; | |
use tracing::instrument; | |
/// Text Generation Inference gRPC multi client | |
pub struct ShardedClient { | |
clients: Vec<Client>, | |
} | |
impl ShardedClient { | |
fn new(clients: Vec<Client>) -> Self { | |
Self { clients } | |
} | |
/// Create a new ShardedClient from a master client. The master client will communicate with | |
/// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method. | |
async fn from_master_client(mut master_client: Client) -> Result<Self> { | |
// Get all uris/unix sockets from the master client | |
let uris = master_client.service_discovery().await?; | |
let futures = uris.into_iter().map(Client::connect_uds); | |
let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect(); | |
Ok(Self::new(clients?)) | |
} | |
/// Returns a client connected to the given uri | |
pub async fn connect(uri: Uri) -> Result<Self> { | |
let master_client = Client::connect(uri).await?; | |
Self::from_master_client(master_client).await | |
} | |
/// Returns a client connected to the given unix socket | |
pub async fn connect_uds(path: String) -> Result<Self> { | |
let master_client = Client::connect_uds(path).await?; | |
Self::from_master_client(master_client).await | |
} | |
/// Get the model info | |
pub async fn info(&mut self) -> Result<ShardInfo> { | |
let futures: Vec<_> = self | |
.clients | |
.iter_mut() | |
.map(|client| client.info()) | |
.collect(); | |
join_all(futures).await.pop().unwrap() | |
} | |
/// GRPC health check | |
pub async fn health(&mut self) -> Result<HealthResponse> { | |
let futures: Vec<_> = self | |
.clients | |
.iter_mut() | |
.map(|client| client.health()) | |
.collect(); | |
join_all(futures).await.pop().unwrap() | |
} | |
/// Clear the past generations cache | |
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> { | |
let futures: Vec<_> = self | |
.clients | |
.iter_mut() | |
.map(|client| client.clear_cache(batch_id)) | |
.collect(); | |
join_all(futures).await.into_iter().collect() | |
} | |
/// Filter a cached batch | |
pub async fn filter_batch( | |
&mut self, | |
batch_id: u64, | |
keep_requests: Vec<Request>, | |
) -> Result<Option<Batch>> { | |
let futures: Vec<_> = self | |
.clients | |
.iter_mut() | |
.map(|client| Box::pin(client.filter_batch(batch_id, keep_requests.clone()))) | |
.collect(); | |
// all shards return the same message | |
join_all(futures).await.pop().unwrap() | |
} | |
/// Generate one token for each request in the given batch | |
/// | |
/// Returns Generation for each request in batch | |
/// and the next cached batch | |
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> { | |
let futures: Vec<_> = self | |
.clients | |
.iter_mut() | |
.map(|client| Box::pin(client.prefill(batch.clone()))) | |
.collect(); | |
// all shards return the same message | |
join_all(futures).await.pop().unwrap() | |
} | |
/// Generate one token for each request in the given cached batches | |
/// | |
/// Returns Generation for each request in batches | |
/// and the next cached batch | |
pub async fn decode( | |
&mut self, | |
batches: Vec<Batch>, | |
) -> Result<(Vec<Generation>, Option<Batch>)> { | |
let futures: Vec<_> = self | |
.clients | |
.iter_mut() | |
.map(|client| Box::pin(client.decode(batches.clone()))) | |
.collect(); | |
// all shards return the same message | |
join_all(futures).await.pop().unwrap() | |
} | |
} | |