Spaces:
Runtime error
Runtime error
File size: 1,547 Bytes
69f3483 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
from typing import Literal, Optional
import fire
from packaging.version import Version
from ..pip_utils import is_installed, run_pip, version
import platform
def get_cuda_version_from_torch() -> Optional[Literal["11", "12"]]:
try:
import torch
except ImportError:
return None
return torch.version.cuda.split(".")[0]
def install(cu: Optional[Literal["11", "12"]] = get_cuda_version_from_torch()):
if cu is None or cu not in ["11", "12"]:
print("Could not detect CUDA version. Please specify manually.")
return
print("Installing TensorRT requirements...")
if is_installed("tensorrt"):
if version("tensorrt") < Version("9.0.0"):
run_pip("uninstall -y tensorrt")
cudnn_name = f"nvidia-cudnn-cu{cu}==8.9.4.25"
if not is_installed("tensorrt"):
run_pip(f"install {cudnn_name} --no-cache-dir")
run_pip(
"install --pre --extra-index-url https://pypi.nvidia.com tensorrt==9.0.1.post11.dev4 --no-cache-dir"
)
if not is_installed("polygraphy"):
run_pip(
"install polygraphy==0.47.1 --extra-index-url https://pypi.ngc.nvidia.com"
)
if not is_installed("onnx_graphsurgeon"):
run_pip(
"install onnx-graphsurgeon==0.3.26 --extra-index-url https://pypi.ngc.nvidia.com"
)
# if platform.system() == 'Windows' and not is_installed("pywin32"):
# run_pip(
# "install pywin32"
# )
pass
if __name__ == "__main__":
fire.Fire(install)
|