promptsearchengine / tests /test_vectorizer.py
Jokica17's picture
Added tests for app module
45b4689
raw
history blame
2.74 kB
import pytest
import numpy as np
from unittest.mock import MagicMock
from app.engine import PromptSearchEngine
@pytest.fixture
def mock_prompts():
return ["prompt 1", "prompt 2", "prompt 3"]
@pytest.fixture
def mock_model():
embedding_dim = 384 # Correct embedding dimensionality for SentenceTransformer
model = MagicMock()
model.encode = MagicMock(return_value=np.random.rand(3, embedding_dim))
return model
@pytest.mark.unit
def test_engine_initialization(mock_prompts, mock_model):
# Mock the vectorizer to use the mock model
PromptSearchEngine.vectorizer = MagicMock()
PromptSearchEngine.vectorizer.transform = MagicMock(return_value=mock_model.encode(mock_prompts))
# Initialize the engine
engine = PromptSearchEngine(mock_prompts)
assert engine.prompts == mock_prompts
assert engine.corpus_vectors.shape == (3, 384) # Correct dimensionality
@pytest.mark.unit
def test_most_similar_valid_query(mock_prompts, mock_model):
# Mock the vectorizer and its transform method
embedding_dim = 384
query_embedding = np.random.rand(1, embedding_dim)
PromptSearchEngine.vectorizer = MagicMock()
PromptSearchEngine.vectorizer.transform = MagicMock(return_value=query_embedding)
# Initialize the engine
engine = PromptSearchEngine(mock_prompts)
engine.vectorizer = MagicMock()
engine.vectorizer.transform = MagicMock(return_value=query_embedding)
results = engine.most_similar("test query", n=2)
assert len(results) == 2
assert all(isinstance(score, float) and isinstance(prompt, str) for score, prompt in results)
@pytest.mark.unit
def test_most_similar_empty_query(mock_prompts):
# Mock the vectorizer to raise a ValueError for empty input
engine = PromptSearchEngine(mock_prompts)
engine.vectorizer = MagicMock()
engine.vectorizer.transform = MagicMock(side_effect=ValueError("Invalid query"))
with pytest.raises(ValueError):
engine.most_similar("", n=2)
@pytest.mark.unit
def test_most_similar_exceeding_n(mock_prompts, mock_model):
# Initialize the engine
PromptSearchEngine.vectorizer = MagicMock()
engine = PromptSearchEngine(mock_prompts)
# Call most_similar with n greater than the number of prompts
results = engine.most_similar("test query", n=10)
assert len(results) == len(mock_prompts) # Should return at most the number of prompts
@pytest.mark.integration
def test_most_similar_integration(mock_prompts):
engine = PromptSearchEngine(mock_prompts)
results = engine.most_similar("prompt 1", n=2)
assert len(results) == 2
assert all(isinstance(score, float) and isinstance(prompt, str) for score, prompt in results)
assert results[0][1] == "prompt 1"