Spaces:
Runtime error
Runtime error
import os | |
import sys | |
import typer | |
from pathlib import Path | |
from loguru import logger | |
from typing import Optional | |
app = typer.Typer() | |
def serve( | |
model_id: str, | |
revision: Optional[str] = None, | |
sharded: bool = False, | |
quantize: bool = False, | |
uds_path: Path = "/tmp/text-generation-server", | |
logger_level: str = "INFO", | |
json_output: bool = False, | |
otlp_endpoint: Optional[str] = None, | |
): | |
if sharded: | |
assert ( | |
os.getenv("RANK", None) is not None | |
), "RANK must be set when sharded is True" | |
assert ( | |
os.getenv("WORLD_SIZE", None) is not None | |
), "WORLD_SIZE must be set when sharded is True" | |
assert ( | |
os.getenv("MASTER_ADDR", None) is not None | |
), "MASTER_ADDR must be set when sharded is True" | |
assert ( | |
os.getenv("MASTER_PORT", None) is not None | |
), "MASTER_PORT must be set when sharded is True" | |
# Remove default handler | |
logger.remove() | |
logger.add( | |
sys.stdout, | |
format="{message}", | |
filter="text_generation_server", | |
level=logger_level, | |
serialize=json_output, | |
backtrace=True, | |
diagnose=False, | |
) | |
# Import here after the logger is added to log potential import exceptions | |
from text_generation_server import server | |
from text_generation_server.tracing import setup_tracing | |
# Setup OpenTelemetry distributed tracing | |
if otlp_endpoint is not None: | |
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint) | |
server.serve(model_id, revision, sharded, quantize, uds_path) | |
def download_weights( | |
model_id: str, | |
revision: Optional[str] = None, | |
extension: str = ".safetensors", | |
logger_level: str = "INFO", | |
json_output: bool = False, | |
): | |
# Remove default handler | |
logger.remove() | |
logger.add( | |
sys.stdout, | |
format="{message}", | |
filter="text_generation_server", | |
level=logger_level, | |
serialize=json_output, | |
backtrace=True, | |
diagnose=False, | |
) | |
# Import here after the logger is added to log potential import exceptions | |
from text_generation_server import utils | |
# Test if files were already download | |
try: | |
utils.weight_files(model_id, revision, extension) | |
logger.info( | |
"Files are already present in the local cache. " "Skipping download." | |
) | |
return | |
# Local files not found | |
except utils.LocalEntryNotFoundError: | |
pass | |
# Download weights directly | |
try: | |
filenames = utils.weight_hub_files(model_id, revision, extension) | |
utils.download_weights(filenames, model_id, revision) | |
except utils.EntryNotFoundError as e: | |
if not extension == ".safetensors": | |
raise e | |
logger.warning( | |
f"No safetensors weights found for model {model_id} at revision {revision}. " | |
f"Converting PyTorch weights instead." | |
) | |
# Try to see if there are pytorch weights | |
pt_filenames = utils.weight_hub_files(model_id, revision, ".bin") | |
# Download pytorch weights | |
local_pt_files = utils.download_weights(pt_filenames, model_id, revision) | |
local_st_files = [ | |
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" | |
for p in local_pt_files | |
] | |
# Convert pytorch weights to safetensors | |
utils.convert_files(local_pt_files, local_st_files) | |
if __name__ == "__main__": | |
app() | |