speaches / tests /model_manager_test.py
Fedir Zadniprovskyi
rename to `speaches`
ba81a8e
raw
history blame
4.67 kB
import asyncio
import anyio
import pytest
from speaches.config import Config, WhisperConfig
from tests.conftest import DEFAULT_WHISPER_MODEL, AclientFactory
MODEL = DEFAULT_WHISPER_MODEL # just to make the test more readable
@pytest.mark.asyncio
async def test_model_unloaded_after_ttl(aclient_factory: AclientFactory) -> None:
ttl = 5
config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False)
async with aclient_factory(config) as aclient:
res = (await aclient.get("/api/ps")).json()
assert len(res["models"]) == 0
await aclient.post(f"/api/ps/{MODEL}")
res = (await aclient.get("/api/ps")).json()
assert len(res["models"]) == 1
await asyncio.sleep(ttl + 1) # wait for the model to be unloaded
res = (await aclient.get("/api/ps")).json()
assert len(res["models"]) == 0
@pytest.mark.asyncio
async def test_ttl_resets_after_usage(aclient_factory: AclientFactory) -> None:
ttl = 5
config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False)
async with aclient_factory(config) as aclient:
await aclient.post(f"/api/ps/{MODEL}")
res = (await aclient.get("/api/ps")).json()
assert len(res["models"]) == 1
await asyncio.sleep(ttl - 2) # sleep for less than the ttl. The model should not be unloaded
res = (await aclient.get("/api/ps")).json()
assert len(res["models"]) == 1
async with await anyio.open_file("audio.wav", "rb") as f:
data = await f.read()
res = (
await aclient.post(
"/v1/audio/transcriptions",
files={"file": ("audio.wav", data, "audio/wav")},
data={"model": MODEL},
)
).json()
res = (await aclient.get("/api/ps")).json()
assert len(res["models"]) == 1
await asyncio.sleep(ttl - 2) # sleep for less than the ttl. The model should not be unloaded
res = (await aclient.get("/api/ps")).json()
assert len(res["models"]) == 1
await asyncio.sleep(3) # sleep for a bit more. The model should be unloaded
res = (await aclient.get("/api/ps")).json()
assert len(res["models"]) == 0
# test the model can be used again after being unloaded
# this just ensures the model can be loaded again after being unloaded
res = (
await aclient.post(
"/v1/audio/transcriptions",
files={"file": ("audio.wav", data, "audio/wav")},
data={"model": MODEL},
)
).json()
@pytest.mark.asyncio
async def test_model_cant_be_unloaded_when_used(aclient_factory: AclientFactory) -> None:
ttl = 0
config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False)
async with aclient_factory(config) as aclient:
async with await anyio.open_file("audio.wav", "rb") as f:
data = await f.read()
task = asyncio.create_task(
aclient.post(
"/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": MODEL}
)
)
await asyncio.sleep(0.1) # wait for the server to start processing the request
res = await aclient.delete(f"/api/ps/{MODEL}")
assert res.status_code == 409
await task
res = (await aclient.get("/api/ps")).json()
assert len(res["models"]) == 0
@pytest.mark.asyncio
async def test_model_cant_be_loaded_twice(aclient_factory: AclientFactory) -> None:
ttl = -1
config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False)
async with aclient_factory(config) as aclient:
res = await aclient.post(f"/api/ps/{MODEL}")
assert res.status_code == 201
res = await aclient.post(f"/api/ps/{MODEL}")
assert res.status_code == 409
res = (await aclient.get("/api/ps")).json()
assert len(res["models"]) == 1
@pytest.mark.asyncio
async def test_model_is_unloaded_after_request_when_ttl_is_zero(aclient_factory: AclientFactory) -> None:
ttl = 0
config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False)
async with aclient_factory(config) as aclient:
async with await anyio.open_file("audio.wav", "rb") as f:
data = await f.read()
res = await aclient.post(
"/v1/audio/transcriptions",
files={"file": ("audio.wav", data, "audio/wav")},
data={"model": "Systran/faster-whisper-tiny.en"},
)
res = (await aclient.get("/api/ps")).json()
assert len(res["models"]) == 0