Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import asyncio | |
import anyio | |
import pytest | |
from speaches.config import Config, WhisperConfig | |
from tests.conftest import DEFAULT_WHISPER_MODEL, AclientFactory | |
MODEL = DEFAULT_WHISPER_MODEL # just to make the test more readable | |
async def test_model_unloaded_after_ttl(aclient_factory: AclientFactory) -> None: | |
ttl = 5 | |
config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False) | |
async with aclient_factory(config) as aclient: | |
res = (await aclient.get("/api/ps")).json() | |
assert len(res["models"]) == 0 | |
await aclient.post(f"/api/ps/{MODEL}") | |
res = (await aclient.get("/api/ps")).json() | |
assert len(res["models"]) == 1 | |
await asyncio.sleep(ttl + 1) # wait for the model to be unloaded | |
res = (await aclient.get("/api/ps")).json() | |
assert len(res["models"]) == 0 | |
async def test_ttl_resets_after_usage(aclient_factory: AclientFactory) -> None: | |
ttl = 5 | |
config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False) | |
async with aclient_factory(config) as aclient: | |
await aclient.post(f"/api/ps/{MODEL}") | |
res = (await aclient.get("/api/ps")).json() | |
assert len(res["models"]) == 1 | |
await asyncio.sleep(ttl - 2) # sleep for less than the ttl. The model should not be unloaded | |
res = (await aclient.get("/api/ps")).json() | |
assert len(res["models"]) == 1 | |
async with await anyio.open_file("audio.wav", "rb") as f: | |
data = await f.read() | |
res = ( | |
await aclient.post( | |
"/v1/audio/transcriptions", | |
files={"file": ("audio.wav", data, "audio/wav")}, | |
data={"model": MODEL}, | |
) | |
).json() | |
res = (await aclient.get("/api/ps")).json() | |
assert len(res["models"]) == 1 | |
await asyncio.sleep(ttl - 2) # sleep for less than the ttl. The model should not be unloaded | |
res = (await aclient.get("/api/ps")).json() | |
assert len(res["models"]) == 1 | |
await asyncio.sleep(3) # sleep for a bit more. The model should be unloaded | |
res = (await aclient.get("/api/ps")).json() | |
assert len(res["models"]) == 0 | |
# test the model can be used again after being unloaded | |
# this just ensures the model can be loaded again after being unloaded | |
res = ( | |
await aclient.post( | |
"/v1/audio/transcriptions", | |
files={"file": ("audio.wav", data, "audio/wav")}, | |
data={"model": MODEL}, | |
) | |
).json() | |
async def test_model_cant_be_unloaded_when_used(aclient_factory: AclientFactory) -> None: | |
ttl = 0 | |
config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False) | |
async with aclient_factory(config) as aclient: | |
async with await anyio.open_file("audio.wav", "rb") as f: | |
data = await f.read() | |
task = asyncio.create_task( | |
aclient.post( | |
"/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": MODEL} | |
) | |
) | |
await asyncio.sleep(0.1) # wait for the server to start processing the request | |
res = await aclient.delete(f"/api/ps/{MODEL}") | |
assert res.status_code == 409 | |
await task | |
res = (await aclient.get("/api/ps")).json() | |
assert len(res["models"]) == 0 | |
async def test_model_cant_be_loaded_twice(aclient_factory: AclientFactory) -> None: | |
ttl = -1 | |
config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False) | |
async with aclient_factory(config) as aclient: | |
res = await aclient.post(f"/api/ps/{MODEL}") | |
assert res.status_code == 201 | |
res = await aclient.post(f"/api/ps/{MODEL}") | |
assert res.status_code == 409 | |
res = (await aclient.get("/api/ps")).json() | |
assert len(res["models"]) == 1 | |
async def test_model_is_unloaded_after_request_when_ttl_is_zero(aclient_factory: AclientFactory) -> None: | |
ttl = 0 | |
config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False) | |
async with aclient_factory(config) as aclient: | |
async with await anyio.open_file("audio.wav", "rb") as f: | |
data = await f.read() | |
res = await aclient.post( | |
"/v1/audio/transcriptions", | |
files={"file": ("audio.wav", data, "audio/wav")}, | |
data={"model": "Systran/faster-whisper-tiny.en"}, | |
) | |
res = (await aclient.get("/api/ps")).json() | |
assert len(res["models"]) == 0 | |