Spaces:
Runtime error
Runtime error
/// Single shard Client | |
use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient; | |
use crate::pb::generate::v1::*; | |
use crate::Result; | |
use grpc_metadata::InjectTelemetryContext; | |
use tonic::transport::{Channel, Uri}; | |
use tracing::instrument; | |
/// Text Generation Inference gRPC client | |
pub struct Client { | |
stub: TextGenerationServiceClient<Channel>, | |
} | |
impl Client { | |
/// Returns a client connected to the given url | |
pub async fn connect(uri: Uri) -> Result<Self> { | |
let channel = Channel::builder(uri).connect().await?; | |
Ok(Self { | |
stub: TextGenerationServiceClient::new(channel), | |
}) | |
} | |
/// Returns a client connected to the given unix socket | |
pub async fn connect_uds(path: String) -> Result<Self> { | |
let channel = Channel::from_shared("http://[::]:50051".to_string()) | |
.unwrap() | |
.connect_with_connector(tower::service_fn(move |_: Uri| { | |
tokio::net::UnixStream::connect(path.clone()) | |
})) | |
.await?; | |
Ok(Self { | |
stub: TextGenerationServiceClient::new(channel), | |
}) | |
} | |
/// Returns a list of uris or unix sockets of all shards | |
pub async fn service_discovery(&mut self) -> Result<Vec<String>> { | |
let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context(); | |
let response = self.stub.service_discovery(request).await?; | |
let urls = response | |
.into_inner() | |
.urls | |
.into_iter() | |
// Remove unix socket prefix | |
.map(|url| match url.strip_prefix("unix://") { | |
None => url, | |
Some(stripped_url) => stripped_url.to_string(), | |
}) | |
.collect(); | |
Ok(urls) | |
} | |
/// Get model info | |
pub async fn info(&mut self) -> Result<InfoResponse> { | |
let request = tonic::Request::new(InfoRequest {}).inject_context(); | |
let response = self.stub.info(request).await?.into_inner(); | |
Ok(response) | |
} | |
/// Get model health | |
pub async fn health(&mut self) -> Result<HealthResponse> { | |
let request = tonic::Request::new(HealthRequest {}).inject_context(); | |
let response = self.stub.health(request).await?.into_inner(); | |
Ok(response) | |
} | |
/// Clear the past generations cache | |
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> { | |
let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context(); | |
self.stub.clear_cache(request).await?; | |
Ok(()) | |
} | |
/// Filter a cached batch | |
pub async fn filter_batch( | |
&mut self, | |
batch_id: u64, | |
keep_requests: Vec<Request>, | |
) -> Result<Option<Batch>> { | |
let request = tonic::Request::new(FilterBatchRequest { | |
batch_id, | |
keep_requests, | |
}) | |
.inject_context(); | |
let filtered_batch = self.stub.filter_batch(request).await?.into_inner(); | |
Ok(filtered_batch.batch) | |
} | |
/// Generate one token for each request in the given batch | |
/// | |
/// Returns Generation for each request in batch | |
/// and the next cached batch | |
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> { | |
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context(); | |
let response = self.stub.prefill(request).await?.into_inner(); | |
Ok((response.generations, response.batch)) | |
} | |
/// Generate one token for each request in the given cached batches | |
/// | |
/// Returns Generation for each request in batches | |
/// and the next cached batch | |
pub async fn decode( | |
&mut self, | |
batches: Vec<Batch>, | |
) -> Result<(Vec<Generation>, Option<Batch>)> { | |
let request = tonic::Request::new(DecodeRequest { batches }).inject_context(); | |
let response = self.stub.decode(request).await?.into_inner(); | |
Ok((response.generations, response.batch)) | |
} | |
} | |