Spaces:
Runtime error
Runtime error
import pytest | |
from text_generation import Client, AsyncClient | |
from text_generation.errors import NotFoundError, ValidationError | |
from text_generation.types import FinishReason, PrefillToken, Token | |
def test_generate(flan_t5_xxl_url, hf_headers): | |
client = Client(flan_t5_xxl_url, hf_headers) | |
response = client.generate("test", max_new_tokens=1) | |
assert response.generated_text == "" | |
assert response.details.finish_reason == FinishReason.Length | |
assert response.details.generated_tokens == 1 | |
assert response.details.seed is None | |
assert len(response.details.prefill) == 1 | |
assert response.details.prefill[0] == PrefillToken(id=0, text="<pad>", logprob=None) | |
assert len(response.details.tokens) == 1 | |
assert response.details.tokens[0] == Token( | |
id=3, text="", logprob=-1.984375, special=False | |
) | |
def test_generate_best_of(flan_t5_xxl_url, hf_headers): | |
client = Client(flan_t5_xxl_url, hf_headers) | |
response = client.generate("test", max_new_tokens=1, best_of=2, do_sample=True) | |
assert response.details.seed is not None | |
assert response.details.best_of_sequences is not None | |
assert len(response.details.best_of_sequences) == 1 | |
assert response.details.best_of_sequences[0].seed is not None | |
def test_generate_not_found(fake_url, hf_headers): | |
client = Client(fake_url, hf_headers) | |
with pytest.raises(NotFoundError): | |
client.generate("test") | |
def test_generate_validation_error(flan_t5_xxl_url, hf_headers): | |
client = Client(flan_t5_xxl_url, hf_headers) | |
with pytest.raises(ValidationError): | |
client.generate("test", max_new_tokens=10_000) | |
def test_generate_stream(flan_t5_xxl_url, hf_headers): | |
client = Client(flan_t5_xxl_url, hf_headers) | |
responses = [ | |
response for response in client.generate_stream("test", max_new_tokens=1) | |
] | |
assert len(responses) == 1 | |
response = responses[0] | |
assert response.generated_text == "" | |
assert response.details.finish_reason == FinishReason.Length | |
assert response.details.generated_tokens == 1 | |
assert response.details.seed is None | |
def test_generate_stream_not_found(fake_url, hf_headers): | |
client = Client(fake_url, hf_headers) | |
with pytest.raises(NotFoundError): | |
list(client.generate_stream("test")) | |
def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers): | |
client = Client(flan_t5_xxl_url, hf_headers) | |
with pytest.raises(ValidationError): | |
list(client.generate_stream("test", max_new_tokens=10_000)) | |
async def test_generate_async(flan_t5_xxl_url, hf_headers): | |
client = AsyncClient(flan_t5_xxl_url, hf_headers) | |
response = await client.generate("test", max_new_tokens=1) | |
assert response.generated_text == "" | |
assert response.details.finish_reason == FinishReason.Length | |
assert response.details.generated_tokens == 1 | |
assert response.details.seed is None | |
assert len(response.details.prefill) == 1 | |
assert response.details.prefill[0] == PrefillToken(id=0, text="<pad>", logprob=None) | |
assert len(response.details.tokens) == 1 | |
assert response.details.tokens[0] == Token( | |
id=3, text="", logprob=-1.984375, special=False | |
) | |
async def test_generate_async_best_of(flan_t5_xxl_url, hf_headers): | |
client = AsyncClient(flan_t5_xxl_url, hf_headers) | |
response = await client.generate( | |
"test", max_new_tokens=1, best_of=2, do_sample=True | |
) | |
assert response.details.seed is not None | |
assert response.details.best_of_sequences is not None | |
assert len(response.details.best_of_sequences) == 1 | |
assert response.details.best_of_sequences[0].seed is not None | |
async def test_generate_async_not_found(fake_url, hf_headers): | |
client = AsyncClient(fake_url, hf_headers) | |
with pytest.raises(NotFoundError): | |
await client.generate("test") | |
async def test_generate_async_validation_error(flan_t5_xxl_url, hf_headers): | |
client = AsyncClient(flan_t5_xxl_url, hf_headers) | |
with pytest.raises(ValidationError): | |
await client.generate("test", max_new_tokens=10_000) | |
async def test_generate_stream_async(flan_t5_xxl_url, hf_headers): | |
client = AsyncClient(flan_t5_xxl_url, hf_headers) | |
responses = [ | |
response async for response in client.generate_stream("test", max_new_tokens=1) | |
] | |
assert len(responses) == 1 | |
response = responses[0] | |
assert response.generated_text == "" | |
assert response.details.finish_reason == FinishReason.Length | |
assert response.details.generated_tokens == 1 | |
assert response.details.seed is None | |
async def test_generate_stream_async_not_found(fake_url, hf_headers): | |
client = AsyncClient(fake_url, hf_headers) | |
with pytest.raises(NotFoundError): | |
async for _ in client.generate_stream("test"): | |
pass | |
async def test_generate_stream_async_validation_error(flan_t5_xxl_url, hf_headers): | |
client = AsyncClient(flan_t5_xxl_url, hf_headers) | |
with pytest.raises(ValidationError): | |
async for _ in client.generate_stream("test", max_new_tokens=10_000): | |
pass | |