Spaces:
Running
Running
Added tests for app module
Browse files- pytest.ini +5 -0
- tests/test_api.py +62 -0
- tests/test_engine.py +66 -0
- tests/test_scorer.py +85 -0
- tests/test_vectorizer.py +73 -0
pytest.ini
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[pytest]
|
2 |
+
pythonpath = .
|
3 |
+
markers =
|
4 |
+
unit: Marks unit tests
|
5 |
+
integration: Marks integration tests
|
tests/test_api.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
from fastapi.testclient import TestClient
|
3 |
+
from app.api import app
|
4 |
+
|
5 |
+
|
6 |
+
# Mock the most_similar function for testing
|
7 |
+
def mock_most_similar(query, n):
|
8 |
+
if query == "error":
|
9 |
+
raise Exception("Test exception")
|
10 |
+
return [[0.9, "Result 1"], [0.8, "Result 2"]][:n]
|
11 |
+
|
12 |
+
|
13 |
+
# Initialize the FastAPI test client
|
14 |
+
client = TestClient(app)
|
15 |
+
|
16 |
+
|
17 |
+
@pytest.fixture(autouse=True)
|
18 |
+
def setup_mock_search_engine():
|
19 |
+
# Replace the app's search engine with a mock that uses the mock_most_similar function
|
20 |
+
app.state.search_engine = type(
|
21 |
+
"MockSearchEngine",
|
22 |
+
(object,),
|
23 |
+
{"most_similar": mock_most_similar}
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
@pytest.mark.unit
|
28 |
+
def test_search_valid_input():
|
29 |
+
response = client.get("/search", params={"query": "test", "n": 2})
|
30 |
+
assert response.status_code == 200
|
31 |
+
assert response.json() == {
|
32 |
+
"query": "test",
|
33 |
+
"results": [[0.9, "Result 1"], [0.8, "Result 2"]]
|
34 |
+
}
|
35 |
+
|
36 |
+
|
37 |
+
@pytest.mark.unit
|
38 |
+
def test_search_empty_query():
|
39 |
+
response = client.get("/search", params={"query": "", "n": 2})
|
40 |
+
assert response.status_code == 400
|
41 |
+
assert response.json()["detail"] == "Query cannot be empty."
|
42 |
+
|
43 |
+
|
44 |
+
@pytest.mark.unit
|
45 |
+
def test_search_invalid_n():
|
46 |
+
response = client.get("/search", params={"query": "test", "n": 0})
|
47 |
+
assert response.status_code == 422
|
48 |
+
|
49 |
+
|
50 |
+
@pytest.mark.unit
|
51 |
+
def test_search_engine_error():
|
52 |
+
response = client.get("/search", params={"query": "error", "n": 2})
|
53 |
+
assert response.status_code == 500
|
54 |
+
assert "An unexpected error occurred: Test exception" in response.json()["detail"]
|
55 |
+
|
56 |
+
|
57 |
+
@pytest.mark.unit
|
58 |
+
def test_search_no_engine():
|
59 |
+
app.state.search_engine = None # Simulate uninitialized search engine
|
60 |
+
response = client.get("/search", params={"query": "test", "n": 2})
|
61 |
+
assert response.status_code == 500
|
62 |
+
assert response.json()["detail"] == "Search engine not initialized."
|
tests/test_engine.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
import numpy as np
|
3 |
+
from unittest.mock import MagicMock
|
4 |
+
from app.engine import PromptSearchEngine
|
5 |
+
|
6 |
+
|
7 |
+
@pytest.fixture
|
8 |
+
def mock_prompts():
|
9 |
+
return ["prompt 1", "prompt 2", "prompt 3"]
|
10 |
+
|
11 |
+
|
12 |
+
@pytest.fixture
|
13 |
+
def mock_model():
|
14 |
+
model = MagicMock()
|
15 |
+
model.encode = MagicMock(return_value=np.array([
|
16 |
+
[0.1, 0.2, 0.3],
|
17 |
+
[0.4, 0.5, 0.6],
|
18 |
+
[0.7, 0.8, 0.9]
|
19 |
+
]))
|
20 |
+
return model
|
21 |
+
|
22 |
+
|
23 |
+
@pytest.mark.unit
|
24 |
+
def test_engine_initialization(mock_prompts, mock_model, monkeypatch):
|
25 |
+
# SentenceTransformer is mocked to return the mock model
|
26 |
+
monkeypatch.setattr("app.engine.SentenceTransformer", MagicMock(return_value=mock_model))
|
27 |
+
engine = PromptSearchEngine(mock_prompts)
|
28 |
+
# Verify that the engine initializes correctly with the mock prompts and vectors
|
29 |
+
assert engine.prompts == mock_prompts
|
30 |
+
assert engine.corpus_vectors.shape == (3, 3)
|
31 |
+
assert np.array_equal(
|
32 |
+
engine.corpus_vectors,
|
33 |
+
np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]])
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
+
@pytest.mark.unit
|
38 |
+
def test_most_similar_valid_query(mock_prompts, mock_model, monkeypatch):
|
39 |
+
monkeypatch.setattr("app.engine.SentenceTransformer", MagicMock(return_value=mock_model))
|
40 |
+
engine = PromptSearchEngine(mock_prompts)
|
41 |
+
# Mock the vectorizer's transform method to return a single query vector
|
42 |
+
engine.vectorizer.transform = MagicMock(return_value=np.array([[0.1, 0.2, 0.3]]))
|
43 |
+
results = engine.most_similar("test query", n=2)
|
44 |
+
assert len(results) == 2
|
45 |
+
assert all(isinstance(score, float) and isinstance(prompt, str) for score, prompt in results)
|
46 |
+
|
47 |
+
|
48 |
+
@pytest.mark.unit
|
49 |
+
def test_most_similar_exceeding_n(mock_prompts, mock_model, monkeypatch):
|
50 |
+
monkeypatch.setattr("app.engine.SentenceTransformer", MagicMock(return_value=mock_model))
|
51 |
+
engine = PromptSearchEngine(mock_prompts)
|
52 |
+
engine.vectorizer.transform = MagicMock(return_value=np.array([[0.1, 0.2, 0.3]]))
|
53 |
+
# Call most_similar with n greater than the number of prompts
|
54 |
+
results = engine.most_similar("test query", n=10)
|
55 |
+
assert len(results) == len(mock_prompts) # Should return at most the number of prompts
|
56 |
+
assert all(isinstance(score, float) and isinstance(prompt, str) for score, prompt in results)
|
57 |
+
|
58 |
+
|
59 |
+
@pytest.mark.integration
|
60 |
+
def test_most_similar_integration(mock_prompts):
|
61 |
+
engine = PromptSearchEngine(mock_prompts)
|
62 |
+
results = engine.most_similar("prompt 1", n=2)
|
63 |
+
# Verify that the results include the expected number of matches and correct types
|
64 |
+
assert len(results) == 2
|
65 |
+
assert all(isinstance(score, float) and isinstance(prompt, str) for score, prompt in results)
|
66 |
+
assert results[0][1] == "prompt 1"
|
tests/test_scorer.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
import numpy as np
|
3 |
+
from app.scorer import cosine_similarity, DimensionalityMismatchError, ZeroVectorError, EmptyInputError
|
4 |
+
|
5 |
+
|
6 |
+
@pytest.fixture
|
7 |
+
def valid_input():
|
8 |
+
query_vector = np.array([[1, 0]])
|
9 |
+
corpus_vectors = np.array([[1, 0], [0, 1], [1, 1]])
|
10 |
+
return query_vector, corpus_vectors
|
11 |
+
|
12 |
+
|
13 |
+
@pytest.fixture
|
14 |
+
def zero_query_vector():
|
15 |
+
query_vector = np.array([[0, 0]])
|
16 |
+
corpus_vectors = np.array([[1, 0], [0, 1]])
|
17 |
+
return query_vector, corpus_vectors
|
18 |
+
|
19 |
+
|
20 |
+
@pytest.fixture
|
21 |
+
def corpus_with_zero_vector():
|
22 |
+
query_vector = np.array([[1, 1]])
|
23 |
+
corpus_vectors = np.array([[1, 0], [0, 1], [0, 0]])
|
24 |
+
return query_vector, corpus_vectors
|
25 |
+
|
26 |
+
|
27 |
+
@pytest.fixture
|
28 |
+
def dimensionality_mismatch():
|
29 |
+
query_vector = np.array([[1, 0]])
|
30 |
+
corpus_vectors = np.array([[1, 0, 0], [0, 1, 0]])
|
31 |
+
return query_vector, corpus_vectors
|
32 |
+
|
33 |
+
|
34 |
+
@pytest.fixture
|
35 |
+
def empty_input():
|
36 |
+
query_vector = np.array([[]])
|
37 |
+
corpus_vectors = np.array([[]])
|
38 |
+
return query_vector, corpus_vectors
|
39 |
+
|
40 |
+
|
41 |
+
@pytest.mark.unit
|
42 |
+
def test_cosine_similarity_valid_input(valid_input):
|
43 |
+
query_vector, corpus_vectors = valid_input
|
44 |
+
similarities = cosine_similarity(query_vector, corpus_vectors)
|
45 |
+
assert isinstance(similarities, np.ndarray)
|
46 |
+
assert similarities.shape == (3,)
|
47 |
+
assert similarities[0] == pytest.approx(1.0) # Same direction
|
48 |
+
assert similarities[1] == pytest.approx(0.0) # Orthogonal
|
49 |
+
assert similarities[2] == pytest.approx(1 / np.sqrt(2)) # Diagonal similarity
|
50 |
+
|
51 |
+
|
52 |
+
@pytest.mark.unit
|
53 |
+
def test_cosine_similarity_zero_query_vector(zero_query_vector):
|
54 |
+
query_vector, corpus_vectors = zero_query_vector
|
55 |
+
with pytest.raises(ZeroVectorError):
|
56 |
+
cosine_similarity(query_vector, corpus_vectors)
|
57 |
+
|
58 |
+
|
59 |
+
@pytest.mark.unit
|
60 |
+
def test_cosine_similarity_corpus_with_zero_vector(corpus_with_zero_vector):
|
61 |
+
query_vector, corpus_vectors = corpus_with_zero_vector
|
62 |
+
with pytest.raises(ZeroVectorError):
|
63 |
+
cosine_similarity(query_vector, corpus_vectors)
|
64 |
+
|
65 |
+
|
66 |
+
@pytest.mark.unit
|
67 |
+
def test_cosine_similarity_dimensionality_mismatch(dimensionality_mismatch):
|
68 |
+
query_vector, corpus_vectors = dimensionality_mismatch
|
69 |
+
with pytest.raises(DimensionalityMismatchError):
|
70 |
+
cosine_similarity(query_vector, corpus_vectors)
|
71 |
+
|
72 |
+
|
73 |
+
@pytest.mark.unit
|
74 |
+
def test_cosine_similarity_empty_inputs(empty_input):
|
75 |
+
query_vector, corpus_vectors = empty_input
|
76 |
+
with pytest.raises(EmptyInputError):
|
77 |
+
cosine_similarity(query_vector, corpus_vectors)
|
78 |
+
|
79 |
+
|
80 |
+
@pytest.mark.integration
|
81 |
+
def test_cosine_similarity_output_range(valid_input):
|
82 |
+
query_vector, corpus_vectors = valid_input
|
83 |
+
similarities = cosine_similarity(query_vector, corpus_vectors)
|
84 |
+
assert np.all(similarities >= -1)
|
85 |
+
assert np.all(similarities <= 1)
|
tests/test_vectorizer.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
import numpy as np
|
3 |
+
from unittest.mock import MagicMock
|
4 |
+
from app.engine import PromptSearchEngine
|
5 |
+
|
6 |
+
|
7 |
+
@pytest.fixture
|
8 |
+
def mock_prompts():
|
9 |
+
return ["prompt 1", "prompt 2", "prompt 3"]
|
10 |
+
|
11 |
+
|
12 |
+
@pytest.fixture
|
13 |
+
def mock_model():
|
14 |
+
embedding_dim = 384 # Correct embedding dimensionality for SentenceTransformer
|
15 |
+
model = MagicMock()
|
16 |
+
model.encode = MagicMock(return_value=np.random.rand(3, embedding_dim))
|
17 |
+
return model
|
18 |
+
|
19 |
+
|
20 |
+
@pytest.mark.unit
|
21 |
+
def test_engine_initialization(mock_prompts, mock_model):
|
22 |
+
# Mock the vectorizer to use the mock model
|
23 |
+
PromptSearchEngine.vectorizer = MagicMock()
|
24 |
+
PromptSearchEngine.vectorizer.transform = MagicMock(return_value=mock_model.encode(mock_prompts))
|
25 |
+
# Initialize the engine
|
26 |
+
engine = PromptSearchEngine(mock_prompts)
|
27 |
+
assert engine.prompts == mock_prompts
|
28 |
+
assert engine.corpus_vectors.shape == (3, 384) # Correct dimensionality
|
29 |
+
|
30 |
+
|
31 |
+
@pytest.mark.unit
|
32 |
+
def test_most_similar_valid_query(mock_prompts, mock_model):
|
33 |
+
# Mock the vectorizer and its transform method
|
34 |
+
embedding_dim = 384
|
35 |
+
query_embedding = np.random.rand(1, embedding_dim)
|
36 |
+
PromptSearchEngine.vectorizer = MagicMock()
|
37 |
+
PromptSearchEngine.vectorizer.transform = MagicMock(return_value=query_embedding)
|
38 |
+
# Initialize the engine
|
39 |
+
engine = PromptSearchEngine(mock_prompts)
|
40 |
+
engine.vectorizer = MagicMock()
|
41 |
+
engine.vectorizer.transform = MagicMock(return_value=query_embedding)
|
42 |
+
results = engine.most_similar("test query", n=2)
|
43 |
+
assert len(results) == 2
|
44 |
+
assert all(isinstance(score, float) and isinstance(prompt, str) for score, prompt in results)
|
45 |
+
|
46 |
+
|
47 |
+
@pytest.mark.unit
|
48 |
+
def test_most_similar_empty_query(mock_prompts):
|
49 |
+
# Mock the vectorizer to raise a ValueError for empty input
|
50 |
+
engine = PromptSearchEngine(mock_prompts)
|
51 |
+
engine.vectorizer = MagicMock()
|
52 |
+
engine.vectorizer.transform = MagicMock(side_effect=ValueError("Invalid query"))
|
53 |
+
with pytest.raises(ValueError):
|
54 |
+
engine.most_similar("", n=2)
|
55 |
+
|
56 |
+
|
57 |
+
@pytest.mark.unit
|
58 |
+
def test_most_similar_exceeding_n(mock_prompts, mock_model):
|
59 |
+
# Initialize the engine
|
60 |
+
PromptSearchEngine.vectorizer = MagicMock()
|
61 |
+
engine = PromptSearchEngine(mock_prompts)
|
62 |
+
# Call most_similar with n greater than the number of prompts
|
63 |
+
results = engine.most_similar("test query", n=10)
|
64 |
+
assert len(results) == len(mock_prompts) # Should return at most the number of prompts
|
65 |
+
|
66 |
+
|
67 |
+
@pytest.mark.integration
|
68 |
+
def test_most_similar_integration(mock_prompts):
|
69 |
+
engine = PromptSearchEngine(mock_prompts)
|
70 |
+
results = engine.most_similar("prompt 1", n=2)
|
71 |
+
assert len(results) == 2
|
72 |
+
assert all(isinstance(score, float) and isinstance(prompt, str) for score, prompt in results)
|
73 |
+
assert results[0][1] == "prompt 1"
|