|
|
|
from langfuse import Langfuse |
|
from langfuse.decorators import observe, langfuse_context |
|
|
|
from config.config import settings |
|
import os |
|
|
|
|
|
os.environ["LANGFUSE_PUBLIC_KEY"] = "pk-lf-9f2c32d2-266f-421d-9b87-51377f0a268c" |
|
os.environ["LANGFUSE_SECRET_KEY"] = "sk-lf-229e10c5-6210-4a4b-a432-0f17bc66e56c" |
|
os.environ["LANGFUSE_HOST"] = "https://chris4k-langfuse-template-space.hf.space" |
|
|
|
try: |
|
langfuse = Langfuse() |
|
except Exception as e: |
|
print("Langfuse Offline") |
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from llama_cpp import Llama |
|
from typing import Optional, Dict |
|
import logging |
|
from functools import lru_cache |
|
from config.config import GenerationConfig, ModelConfig |
|
|
|
|
|
class ModelManager: |
|
def __init__(self, device: Optional[str] = None): |
|
self.logger = logging.getLogger(__name__) |
|
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
|
self.models: Dict[str, Any] = {} |
|
self.tokenizers: Dict[str, Any] = {} |
|
|
|
@observe() |
|
def load_model(self, model_id: str, model_path: str, model_type: str, config: ModelConfig) -> None: |
|
"""Load a model with specified configuration.""" |
|
try: |
|
|
|
if model_type == "llama": |
|
self.tokenizers[model_id] = AutoTokenizer.from_pretrained( |
|
model_path, |
|
padding_side='left', |
|
trust_remote_code=True, |
|
**config.tokenizer_kwargs |
|
) |
|
if self.tokenizers[model_id].pad_token is None: |
|
self.tokenizers[model_id].pad_token = self.tokenizers[model_id].eos_token |
|
|
|
self.models[model_id] = AutoModelForCausalLM.from_pretrained( |
|
model_path, |
|
device_map="auto", |
|
trust_remote_code=True, |
|
**config.model_kwargs |
|
) |
|
elif model_type == "gguf": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.models[model_id] = self._load_quantized_model( |
|
model_path, |
|
**config.quantization_kwargs |
|
) |
|
except Exception as e: |
|
self.logger.error(f"Failed to load model {model_id}: {str(e)}") |
|
raise |
|
|
|
@observe() |
|
def unload_model(self, model_id: str) -> None: |
|
"""Unload a model and free resources.""" |
|
if model_id in self.models: |
|
del self.models[model_id] |
|
if model_id in self.tokenizers: |
|
del self.tokenizers[model_id] |
|
torch.cuda.empty_cache() |
|
|
|
def _load_quantized_model(self, model_path: str, **kwargs) -> Llama: |
|
"""Load a quantized GGUF model.""" |
|
try: |
|
n_gpu_layers = -1 if torch.cuda.is_available() else 0 |
|
model = Llama( |
|
model_path=model_path, |
|
n_ctx=kwargs.get('n_ctx', 2048), |
|
n_batch=kwargs.get('n_batch', 512), |
|
n_gpu_layers=kwargs.get('n_gpu_layers', n_gpu_layers), |
|
verbose=kwargs.get('verbose', False) |
|
) |
|
return model |
|
except Exception as e: |
|
self.logger.error(f"Failed to load GGUF model: {str(e)}") |
|
raise |
|
|
|
|
|
|
|
from functools import lru_cache |
|
from typing import Tuple, Any |
|
|
|
|
|
class ResponseCache: |
|
def __init__(self, cache_size: int = 1000): |
|
self.cache_size = cache_size |
|
self._initialize_cache() |
|
|
|
def _initialize_cache(self): |
|
@lru_cache(maxsize=self.cache_size) |
|
def cached_response(prompt: str, config_hash: str) -> Tuple[str, float]: |
|
pass |
|
self.get_cached_response = cached_response |
|
|
|
def cache_response(self, prompt: str, config: GenerationConfig, response: str, score: float) -> None: |
|
config_hash = hash(str(config.__dict__)) |
|
self.get_cached_response(prompt, str(config_hash)) |
|
|
|
def get_response(self, prompt: str, config: GenerationConfig) -> Optional[Tuple[str, float]]: |
|
config_hash = hash(str(config.__dict__)) |
|
return self.get_cached_response(prompt, str(config_hash)) |
|
|
|
|
|
|
|
from typing import List, Dict |
|
import asyncio |
|
|
|
|
|
class BatchProcessor: |
|
def __init__(self, max_batch_size: int = 32, max_wait_time: float = 0.1): |
|
self.max_batch_size = max_batch_size |
|
self.max_wait_time = max_wait_time |
|
self.pending_requests: List[Dict] = [] |
|
self.lock = asyncio.Lock() |
|
|
|
async def add_request(self, request: Dict) -> Any: |
|
async with self.lock: |
|
self.pending_requests.append(request) |
|
if len(self.pending_requests) >= self.max_batch_size: |
|
return await self._process_batch() |
|
else: |
|
await asyncio.sleep(self.max_wait_time) |
|
if self.pending_requests: |
|
return await self._process_batch() |
|
|
|
async def _process_batch(self) -> List[Any]: |
|
batch = self.pending_requests[:self.max_batch_size] |
|
self.pending_requests = self.pending_requests[self.max_batch_size:] |
|
|
|
return batch |
|
|
|
|
|
|
|
|
|
from abc import ABC, abstractmethod |
|
from typing import AsyncGenerator, Dict, Any, Optional, List, Tuple |
|
from dataclasses import dataclass |
|
from logging import getLogger |
|
|
|
|
|
from config.config import GenerationConfig, ModelConfig |
|
|
|
class BaseGenerator(ABC): |
|
"""Base class for all generator implementations.""" |
|
|
|
def __init__( |
|
self, |
|
model_name: str, |
|
device: Optional[str] = None, |
|
default_generation_config: Optional[GenerationConfig] = None, |
|
model_config: Optional[ModelConfig] = None, |
|
cache_size: int = 1000, |
|
max_batch_size: int = 32 |
|
): |
|
self.logger = getLogger(__name__) |
|
self.model_manager = ModelManager(device) |
|
self.cache = ResponseCache(cache_size) |
|
self.batch_processor = BatchProcessor(max_batch_size) |
|
self.health_check = HealthCheck() |
|
|
|
|
|
self.default_config = default_generation_config or GenerationConfig() |
|
self.model_config = model_config or ModelConfig() |
|
|
|
@abstractmethod |
|
async def generate_stream( |
|
self, |
|
prompt: str, |
|
config: Optional[GenerationConfig] = None |
|
) -> AsyncGenerator[str, None]: |
|
pass |
|
|
|
@abstractmethod |
|
def _get_generation_kwargs(self, config: GenerationConfig) -> Dict[str, Any]: |
|
pass |
|
|
|
@abstractmethod |
|
def generate( |
|
self, |
|
prompt: str, |
|
model_kwargs: Dict[str, Any], |
|
strategy: str = "default", |
|
**kwargs |
|
) -> str: |
|
pass |
|
|
|
|
|
|
|
|
|
from abc import ABC, abstractmethod |
|
from typing import List, Tuple |
|
|
|
@observe() |
|
class GenerationStrategy(ABC): |
|
"""Base class for generation strategies.""" |
|
|
|
@abstractmethod |
|
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str: |
|
pass |
|
|
|
|
|
class DefaultStrategy(GenerationStrategy): |
|
|
|
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str: |
|
input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device) |
|
output = generator.model.generate(input_ids, **model_kwargs) |
|
return generator.tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
|
@observe() |
|
class MajorityVotingStrategy(GenerationStrategy): |
|
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: |
|
outputs = [] |
|
for _ in range(num_samples): |
|
input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device) |
|
output = generator.model.generate(input_ids, **model_kwargs) |
|
outputs.append(generator.tokenizer.decode(output[0], skip_special_tokens=True)) |
|
return max(set(outputs), key=outputs.count) |
|
|
|
@observe() |
|
class BestOfN(GenerationStrategy): |
|
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: |
|
scored_outputs = [] |
|
for _ in range(num_samples): |
|
input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) |
|
output = self.llama_model.generate(input_ids, **model_kwargs) |
|
response = self.llama_tokenizer.decode(output[0], skip_special_tokens=True) |
|
score = self.prm_model(**self.llama_tokenizer(response, return_tensors="pt").to(self.device)).logits.mean().item() |
|
scored_outputs.append((response, score)) |
|
return max(scored_outputs, key=lambda x: x[1])[0] |
|
|
|
@observe() |
|
class BeamSearch(GenerationStrategy): |
|
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: |
|
input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) |
|
outputs = self.llama_model.generate( |
|
input_ids, |
|
num_beams=num_samples, |
|
num_return_sequences=num_samples, |
|
**model_kwargs |
|
) |
|
return [self.llama_tokenizer.decode(output, skip_special_tokens=True) for output in outputs] |
|
|
|
@observe() |
|
class DVT(GenerationStrategy): |
|
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: |
|
results = [] |
|
for _ in range(breadth): |
|
input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) |
|
output = self.llama_model.generate(input_ids, **model_kwargs) |
|
response = self.llama_tokenizer.decode(output[0], skip_special_tokens=True) |
|
score = self.prm_model(**self.llama_tokenizer(response, return_tensors="pt").to(self.device)).logits.mean().item() |
|
results.append((response, score)) |
|
|
|
for _ in range(depth - 1): |
|
best_responses = sorted(results, key=lambda x: x[1], reverse=True)[:breadth] |
|
for response, _ in best_responses: |
|
input_ids = self.llama_tokenizer(response, return_tensors="pt").input_ids.to(self.device) |
|
output = self.llama_model.generate(input_ids, **model_kwargs) |
|
extended_response = self.llama_tokenizer.decode(output[0], skip_special_tokens=True) |
|
score = self.prm_model(**self.llama_tokenizer(extended_response, return_tensors="pt").to(self.device)).logits.mean().item() |
|
results.append((extended_response, score)) |
|
return max(results, key=lambda x: x[1])[0] |
|
|
|
@observe() |
|
class COT(GenerationStrategy): |
|
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: |
|
|
|
|
|
return "Not implemented yet" |
|
|
|
@observe() |
|
class ReAct(GenerationStrategy): |
|
def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: |
|
|
|
return "Not implemented yet" |
|
|
|
|
|
|
|
from typing import Protocol, List, Tuple |
|
from transformers import AutoTokenizer |
|
|
|
@observe() |
|
class PromptTemplate(Protocol): |
|
"""Protocol for prompt templates.""" |
|
def format(self, context: str, user_input: str, chat_history: List[Tuple[str, str]], **kwargs) -> str: |
|
pass |
|
|
|
@observe() |
|
class LlamaPromptTemplate: |
|
def format(self, context: str, user_input: str, chat_history: List[Tuple[str, str]], max_history_turns: int = 1) -> str: |
|
system_message = f"Please assist based on the following context: {context}" |
|
prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>" |
|
|
|
for user_msg, assistant_msg in chat_history[-max_history_turns:]: |
|
prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_msg}<|eot_id|>" |
|
prompt += f"<|start_header_id|>assistant<|end_header_id|>\n\n{assistant_msg}<|eot_id|>" |
|
|
|
prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_input}<|eot_id|>" |
|
prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n" |
|
return prompt |
|
|
|
@observe() |
|
class TransformersPromptTemplate: |
|
def __init__(self, model_path: str): |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
|
def format(self, context: str, user_input: str, chat_history: List[Tuple[str, str]], **kwargs) -> str: |
|
messages = [ |
|
{ |
|
"role": "system", |
|
"content": f"Please assist based on the following context: {context}", |
|
} |
|
] |
|
|
|
for user_msg, assistant_msg in chat_history: |
|
messages.extend([ |
|
{"role": "user", "content": user_msg}, |
|
{"role": "assistant", "content": assistant_msg} |
|
]) |
|
|
|
messages.append({"role": "user", "content": user_input}) |
|
|
|
tokenized_chat = self.tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=True |
|
) |
|
return tokenized_chat |
|
|
|
|
|
import psutil |
|
from dataclasses import dataclass |
|
from typing import Dict, Any |
|
|
|
@dataclass |
|
class HealthStatus: |
|
status: str |
|
gpu_memory: Dict[str, float] |
|
cpu_usage: float |
|
ram_usage: float |
|
model_status: Dict[str, str] |
|
|
|
class HealthCheck: |
|
@staticmethod |
|
def check_gpu_memory() -> Dict[str, float]: |
|
if torch.cuda.is_available(): |
|
return { |
|
f"gpu_{i}": torch.cuda.memory_allocated(i) / 1024**3 |
|
for i in range(torch.cuda.device_count()) |
|
} |
|
return {} |
|
|
|
@staticmethod |
|
def check_system_resources() -> HealthStatus: |
|
return HealthStatus( |
|
status="healthy", |
|
gpu_memory=HealthCheck.check_gpu_memory(), |
|
cpu_usage=psutil.cpu_percent(), |
|
ram_usage=psutil.virtual_memory().percent, |
|
|
|
model_status={} |
|
) |
|
|
|
|
|
|
|
from config.config import GenerationConfig, ModelConfig |
|
|
|
@observe() |
|
class LlamaGenerator(BaseGenerator): |
|
def __init__( |
|
self, |
|
llama_model_name: str, |
|
prm_model_path: str, |
|
device: Optional[str] = None, |
|
default_generation_config: Optional[GenerationConfig] = None, |
|
model_config: Optional[ModelConfig] = None, |
|
cache_size: int = 1000, |
|
max_batch_size: int = 32, |
|
|
|
|
|
|
|
): |
|
|
|
@observe() |
|
def load_model(self, model_name: str): |
|
|
|
from transformers import AutoModelForCausalLM |
|
return AutoModelForCausalLM.from_pretrained(model_name) |
|
|
|
@observe() |
|
def load_tokenizer(self, model_name: str): |
|
|
|
from transformers import AutoTokenizer |
|
return AutoTokenizer.from_pretrained(model_name) |
|
|
|
self.tokenizer = load_tokenizer(llama_model_name) |
|
|
|
super().__init__( |
|
llama_model_name, |
|
device, |
|
default_generation_config, |
|
model_config, |
|
cache_size, |
|
max_batch_size |
|
) |
|
|
|
|
|
self.model_manager.load_model( |
|
"llama", |
|
llama_model_name, |
|
"llama", |
|
self.model_config |
|
) |
|
self.model_manager.load_model( |
|
"prm", |
|
prm_model_path, |
|
"gguf", |
|
self.model_config |
|
) |
|
|
|
self.prompt_builder = LlamaPromptTemplate() |
|
self._init_strategies() |
|
|
|
def _init_strategies(self): |
|
self.strategies = { |
|
"default": DefaultStrategy(), |
|
"majority_voting": MajorityVotingStrategy(), |
|
"best_of_n": BestOfN(), |
|
"beam_search": BeamSearch(), |
|
"dvts": DVT(), |
|
} |
|
|
|
def _get_generation_kwargs(self, config: GenerationConfig) -> Dict[str, Any]: |
|
"""Get generation kwargs based on config.""" |
|
return { |
|
key: getattr(config, key) |
|
for key in [ |
|
"max_new_tokens", |
|
"temperature", |
|
"top_p", |
|
"top_k", |
|
"repetition_penalty", |
|
"length_penalty", |
|
"do_sample" |
|
] |
|
if hasattr(config, key) |
|
} |
|
|
|
@observe() |
|
def generate_stream (self): |
|
return " NOt implememnted yet " |
|
|
|
@observe() |
|
def generate( |
|
self, |
|
prompt: str, |
|
model_kwargs: Dict[str, Any], |
|
strategy: str = "default", |
|
**kwargs |
|
) -> str: |
|
""" |
|
Generate text based on a given strategy. |
|
|
|
Args: |
|
prompt (str): Input prompt for text generation. |
|
model_kwargs (Dict[str, Any]): Additional arguments for model generation. |
|
strategy (str): The generation strategy to use (default: "default"). |
|
**kwargs: Additional arguments passed to the strategy. |
|
|
|
Returns: |
|
str: Generated text response. |
|
|
|
Raises: |
|
ValueError: If the specified strategy is not available. |
|
""" |
|
|
|
if strategy not in self.strategies: |
|
raise ValueError(f"Unknown strategy: {strategy}. Available strategies are: {list(self.strategies.keys())}") |
|
|
|
|
|
kwargs.pop("generator", None) |
|
|
|
|
|
return self.strategies[strategy].generate( |
|
generator=self, |
|
prompt=prompt, |
|
model_kwargs=model_kwargs, |
|
**kwargs |
|
) |
|
|
|
@observe() |
|
def generate_with_context( |
|
self, |
|
context: str, |
|
user_input: str, |
|
chat_history: List[Tuple[str, str]], |
|
model_kwargs: Dict[str, Any], |
|
max_history_turns: int = 3, |
|
strategy: str = "default", |
|
num_samples: int = 5, |
|
depth: int = 3, |
|
breadth: int = 2, |
|
|
|
) -> str: |
|
"""Generate a response using context and chat history. |
|
|
|
Args: |
|
context (str): Context for the conversation |
|
user_input (str): Current user input |
|
chat_history (List[Tuple[str, str]]): List of (user, assistant) message pairs |
|
model_kwargs (dict): Additional arguments for model.generate() |
|
max_history_turns (int): Maximum number of history turns to include |
|
strategy (str): Generation strategy |
|
num_samples (int): Number of samples for applicable strategies |
|
depth (int): Depth for DVTS strategy |
|
breadth (int): Breadth for DVTS strategy |
|
|
|
Returns: |
|
str: Generated response |
|
""" |
|
prompt = self.prompt_builder.format( |
|
context, |
|
user_input, |
|
chat_history, |
|
max_history_turns |
|
) |
|
return self.generate( |
|
generator=self, |
|
prompt=prompt, |
|
model_kwargs=model_kwargs, |
|
strategy=strategy, |
|
num_samples=num_samples, |
|
depth=depth, |
|
breadth=breadth |
|
) |
|
|
|
|
|
|
|
def check_health(self) -> HealthStatus: |
|
"""Check the health status of the generator.""" |
|
return self.health_check.check_system_resources() |
|
|
|
|
|
|
|
|
|
|
|
from fastapi import FastAPI, HTTPException, BackgroundTasks, WebSocket, Depends |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.responses import StreamingResponse |
|
from pydantic import BaseModel, Field, ConfigDict |
|
from typing import List, Optional, Dict, Any, AsyncGenerator |
|
import asyncio |
|
import uuid |
|
from datetime import datetime |
|
import json |
|
from huggingface_hub import hf_hub_download |
|
from contextlib import asynccontextmanager |
|
|
|
|
|
|
|
class ChatMessage(BaseModel): |
|
"""A single message in the chat history.""" |
|
role: str = Field( |
|
..., |
|
description="Role of the message sender", |
|
examples=["user", "assistant"] |
|
) |
|
content: str = Field(..., description="Content of the message") |
|
|
|
model_config = ConfigDict( |
|
json_schema_extra={ |
|
"example": { |
|
"role": "user", |
|
"content": "What is the capital of France?" |
|
} |
|
} |
|
) |
|
|
|
class GenerationConfig(BaseModel): |
|
"""Configuration for text generation.""" |
|
temperature: float = Field( |
|
0.7, |
|
ge=0.0, |
|
le=2.0, |
|
description="Controls randomness in the output. Higher values (e.g., 0.8) make the output more random, lower values (e.g., 0.2) make it more focused and deterministic." |
|
) |
|
max_new_tokens: int = Field( |
|
100, |
|
ge=1, |
|
le=2048, |
|
description="Maximum number of tokens to generate" |
|
) |
|
top_p: float = Field( |
|
0.9, |
|
ge=0.0, |
|
le=1.0, |
|
description="Nucleus sampling parameter. Only tokens with cumulative probability < top_p are considered." |
|
) |
|
top_k: int = Field( |
|
50, |
|
ge=0, |
|
description="Only consider the top k tokens for text generation" |
|
) |
|
strategy: str = Field( |
|
"default", |
|
description="Generation strategy to use", |
|
examples=["default", "majority_voting", "best_of_n", "beam_search", "dvts"] |
|
) |
|
num_samples: int = Field( |
|
5, |
|
ge=1, |
|
le=10, |
|
description="Number of samples to generate (used in majority_voting and best_of_n strategies)" |
|
) |
|
|
|
class GenerationRequest(BaseModel): |
|
"""Request model for text generation.""" |
|
context: Optional[str] = Field( |
|
None, |
|
description="Additional context to guide the generation", |
|
examples=["You are a helpful assistant skilled in Python programming"] |
|
) |
|
messages: List[ChatMessage] = Field( |
|
..., |
|
description="Chat history including the current message", |
|
min_items=1 |
|
) |
|
config: Optional[GenerationConfig] = Field( |
|
None, |
|
description="Generation configuration parameters" |
|
) |
|
stream: bool = Field( |
|
False, |
|
description="Whether to stream the response token by token" |
|
) |
|
|
|
model_config = ConfigDict( |
|
json_schema_extra={ |
|
"example": { |
|
"context": "You are a helpful assistant", |
|
"messages": [ |
|
{"role": "user", "content": "What is the capital of France?"} |
|
], |
|
"config": { |
|
"temperature": 0.7, |
|
"max_new_tokens": 100 |
|
}, |
|
"stream": False |
|
} |
|
} |
|
) |
|
|
|
class GenerationResponse(BaseModel): |
|
"""Response model for text generation.""" |
|
id: str = Field(..., description="Unique generation ID") |
|
content: str = Field(..., description="Generated text content") |
|
created_at: datetime = Field( |
|
default_factory=datetime.now, |
|
description="Timestamp of generation" |
|
) |
|
|
|
|
|
|
|
async def get_prm_model_path(): |
|
"""Download and cache the PRM model.""" |
|
return await asyncio.to_thread( |
|
hf_hub_download, |
|
repo_id="tensorblock/Llama3.1-8B-PRM-Mistral-Data-GGUF", |
|
filename="Llama3.1-8B-PRM-Mistral-Data-Q4_K_M.gguf" |
|
) |
|
|
|
|
|
generator = None |
|
|
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
"""Lifecycle management for the FastAPI application.""" |
|
|
|
global generator |
|
try: |
|
prm_model_path = await get_prm_model_path() |
|
generator = LlamaGenerator( |
|
llama_model_name="meta-llama/Llama-3.2-1B-Instruct", |
|
prm_model_path=prm_model_path, |
|
default_generation_config=GenerationConfig( |
|
max_new_tokens=100, |
|
temperature=0.7 |
|
) |
|
) |
|
yield |
|
finally: |
|
|
|
if generator: |
|
await asyncio.to_thread(generator.cleanup) |
|
|
|
|
|
app = FastAPI( |
|
title="Inference Deluxe Service", |
|
description=""" |
|
A service for generating text using LLaMA models with various generation strategies. |
|
|
|
Generation Strategies: |
|
- default: Standard autoregressive generation |
|
- majority_voting: Generates multiple responses and selects the most common one |
|
- best_of_n: Generates multiple responses and selects the best based on a scoring metric |
|
- beam_search: Uses beam search for more coherent text generation |
|
- dvts: Dynamic vocabulary tree search for efficient generation |
|
""", |
|
version="1.0.0", |
|
lifespan=lifespan |
|
) |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
async def get_generator(): |
|
"""Dependency to get the generator instance.""" |
|
if not generator: |
|
raise HTTPException( |
|
status_code=503, |
|
detail="Generator not initialized" |
|
) |
|
return generator |
|
|
|
@app.post( |
|
"/generate", |
|
response_model=GenerationResponse, |
|
tags=["generation"], |
|
summary="Generate text response", |
|
response_description="Generated text with unique identifier" |
|
) |
|
async def generate( |
|
request: GenerationRequest, |
|
generator: Any = Depends(get_generator) |
|
): |
|
""" |
|
Generate a text response based on the provided context and chat history. |
|
""" |
|
try: |
|
chat_history = [(msg.role, msg.content) for msg in request.messages[:-1]] |
|
user_input = request.messages[-1].content |
|
|
|
|
|
config = request.config or GenerationConfig() |
|
model_kwargs = { |
|
"temperature": config.temperature if hasattr(config, "temperature") else 0.7, |
|
"max_new_tokens": config.max_new_tokens if hasattr(config, "max_new_tokens") else 100, |
|
|
|
} |
|
|
|
|
|
response = await asyncio.to_thread( |
|
generator.generate_with_context, |
|
context=request.context or "", |
|
user_input=user_input, |
|
chat_history=chat_history, |
|
model_kwargs=model_kwargs, |
|
max_history_turns=config.max_history_turns if hasattr(config, "max_history_turns") else 3, |
|
strategy=config.strategy if hasattr(config, "strategy") else "default", |
|
num_samples=config.num_samples if hasattr(config, "num_samples") else 5, |
|
depth=config.depth if hasattr(config, "depth") else 3, |
|
breadth=config.breadth if hasattr(config, "breadth") else 2, |
|
) |
|
|
|
return GenerationResponse( |
|
id=str(uuid.uuid4()), |
|
content=response |
|
) |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.websocket("/generate/stream") |
|
async def generate_stream( |
|
websocket: WebSocket, |
|
generator: Any = Depends(get_generator) |
|
): |
|
""" |
|
Stream generated text tokens over a WebSocket connection. |
|
|
|
The stream sends JSON messages with the following format: |
|
- During generation: {"token": "generated_token", "finished": false} |
|
- End of generation: {"token": "", "finished": true} |
|
- Error: {"error": "error_message"} |
|
""" |
|
await websocket.accept() |
|
|
|
try: |
|
while True: |
|
request_data = await websocket.receive_text() |
|
request = GenerationRequest.parse_raw(request_data) |
|
|
|
chat_history = [(msg.role, msg.content) for msg in request.messages[:-1]] |
|
user_input = request.messages[-1].content |
|
|
|
config = request.config or GenerationConfig() |
|
|
|
async for token in generator.generate_stream( |
|
prompt=generator.prompt_builder.format( |
|
context=request.context or "", |
|
user_input=user_input, |
|
chat_history=chat_history |
|
), |
|
config=config |
|
): |
|
await websocket.send_text(json.dumps({ |
|
"token": token, |
|
"finished": False |
|
})) |
|
|
|
await websocket.send_text(json.dumps({ |
|
"token": "", |
|
"finished": True |
|
})) |
|
|
|
except Exception as e: |
|
await websocket.send_text(json.dumps({ |
|
"error": str(e) |
|
})) |
|
finally: |
|
await websocket.close() |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|
|
|