Spaces:
Runtime error
Runtime error
mod health; | |
/// Text Generation Inference Webserver | |
mod infer; | |
mod queue; | |
pub mod server; | |
mod validation; | |
use infer::Infer; | |
use queue::{Entry, Queue}; | |
use serde::{Deserialize, Serialize}; | |
use utoipa::ToSchema; | |
use validation::Validation; | |
/// Hub type | |
pub struct HubModelInfo { | |
pub model_id: String, | |
pub sha: Option<String>, | |
pub pipeline_tag: Option<String>, | |
} | |
pub struct Info { | |
/// Model info | |
pub model_id: String, | |
pub model_sha: Option<String>, | |
pub model_dtype: String, | |
pub model_device_type: String, | |
pub model_pipeline_tag: Option<String>, | |
/// Router Parameters | |
pub max_concurrent_requests: usize, | |
pub max_best_of: usize, | |
pub max_stop_sequences: usize, | |
pub max_input_length: usize, | |
pub max_total_tokens: usize, | |
pub waiting_served_ratio: f32, | |
pub max_batch_total_tokens: u32, | |
pub max_waiting_tokens: usize, | |
pub validation_workers: usize, | |
/// Router Info | |
pub version: &'static str, | |
pub sha: Option<&'static str>, | |
pub docker_label: Option<&'static str>, | |
} | |
pub(crate) struct GenerateParameters { | |
pub best_of: Option<usize>, | |
pub temperature: Option<f32>, | |
pub repetition_penalty: Option<f32>, | |
pub top_k: Option<i32>, | |
pub top_p: Option<f32>, | |
pub typical_p: Option<f32>, | |
pub do_sample: bool, | |
pub max_new_tokens: u32, | |
pub return_full_text: Option<bool>, | |
))] | |
pub stop: Vec<String>, | |
pub truncate: Option<usize>, | |
pub watermark: bool, | |
pub details: bool, | |
pub seed: Option<u64>, | |
} | |
fn default_max_new_tokens() -> u32 { | |
20 | |
} | |
fn default_parameters() -> GenerateParameters { | |
GenerateParameters { | |
best_of: None, | |
temperature: None, | |
repetition_penalty: None, | |
top_k: None, | |
top_p: None, | |
typical_p: None, | |
do_sample: false, | |
max_new_tokens: default_max_new_tokens(), | |
return_full_text: None, | |
stop: Vec::new(), | |
truncate: None, | |
watermark: false, | |
details: false, | |
seed: None, | |
} | |
} | |
pub(crate) struct GenerateRequest { | |
pub inputs: String, | |
pub parameters: GenerateParameters, | |
} | |
pub(crate) struct CompatGenerateRequest { | |
pub inputs: String, | |
pub parameters: GenerateParameters, | |
pub stream: bool, | |
} | |
impl From<CompatGenerateRequest> for GenerateRequest { | |
fn from(req: CompatGenerateRequest) -> Self { | |
Self { | |
inputs: req.inputs, | |
parameters: req.parameters, | |
} | |
} | |
} | |
pub struct PrefillToken { | |
id: u32, | |
text: String, | |
logprob: f32, | |
} | |
pub struct Token { | |
id: u32, | |
text: String, | |
logprob: f32, | |
special: bool, | |
} | |
pub(crate) enum FinishReason { | |
Length, | |
EndOfSequenceToken, | |
StopSequence, | |
} | |
pub(crate) struct BestOfSequence { | |
pub generated_text: String, | |
pub finish_reason: FinishReason, | |
pub generated_tokens: u32, | |
pub seed: Option<u64>, | |
pub prefill: Vec<PrefillToken>, | |
pub tokens: Vec<Token>, | |
} | |
pub(crate) struct Details { | |
pub finish_reason: FinishReason, | |
pub generated_tokens: u32, | |
pub seed: Option<u64>, | |
pub prefill: Vec<PrefillToken>, | |
pub tokens: Vec<Token>, | |
pub best_of_sequences: Option<Vec<BestOfSequence>>, | |
} | |
pub(crate) struct GenerateResponse { | |
pub generated_text: String, | |
pub details: Option<Details>, | |
} | |
pub(crate) struct StreamDetails { | |
pub finish_reason: FinishReason, | |
pub generated_tokens: u32, | |
pub seed: Option<u64>, | |
} | |
pub(crate) struct StreamResponse { | |
pub token: Token, | |
pub generated_text: Option<String>, | |
pub details: Option<StreamDetails>, | |
} | |
pub(crate) struct ErrorResponse { | |
pub error: String, | |
pub error_type: String, | |
} | |
mod tests { | |
use std::io::Write; | |
use tokenizers::Tokenizer; | |
pub(crate) async fn get_tokenizer() -> Tokenizer { | |
if !std::path::Path::new("tokenizer.json").exists() { | |
let content = reqwest::get("https://huggingface.co./gpt2/raw/main/tokenizer.json") | |
.await | |
.unwrap() | |
.bytes() | |
.await | |
.unwrap(); | |
let mut file = std::fs::File::create("tokenizer.json").unwrap(); | |
file.write_all(&content).unwrap(); | |
} | |
Tokenizer::from_file("tokenizer.json").unwrap() | |
} | |
} | |