promptsearchengine / tests /test_engine.py
Jokica17's picture
Added tests for app module
45b4689
raw
history blame
2.64 kB
import pytest
import numpy as np
from unittest.mock import MagicMock
from app.engine import PromptSearchEngine
@pytest.fixture
def mock_prompts():
return ["prompt 1", "prompt 2", "prompt 3"]
@pytest.fixture
def mock_model():
model = MagicMock()
model.encode = MagicMock(return_value=np.array([
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9]
]))
return model
@pytest.mark.unit
def test_engine_initialization(mock_prompts, mock_model, monkeypatch):
# SentenceTransformer is mocked to return the mock model
monkeypatch.setattr("app.engine.SentenceTransformer", MagicMock(return_value=mock_model))
engine = PromptSearchEngine(mock_prompts)
# Verify that the engine initializes correctly with the mock prompts and vectors
assert engine.prompts == mock_prompts
assert engine.corpus_vectors.shape == (3, 3)
assert np.array_equal(
engine.corpus_vectors,
np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]])
)
@pytest.mark.unit
def test_most_similar_valid_query(mock_prompts, mock_model, monkeypatch):
monkeypatch.setattr("app.engine.SentenceTransformer", MagicMock(return_value=mock_model))
engine = PromptSearchEngine(mock_prompts)
# Mock the vectorizer's transform method to return a single query vector
engine.vectorizer.transform = MagicMock(return_value=np.array([[0.1, 0.2, 0.3]]))
results = engine.most_similar("test query", n=2)
assert len(results) == 2
assert all(isinstance(score, float) and isinstance(prompt, str) for score, prompt in results)
@pytest.mark.unit
def test_most_similar_exceeding_n(mock_prompts, mock_model, monkeypatch):
monkeypatch.setattr("app.engine.SentenceTransformer", MagicMock(return_value=mock_model))
engine = PromptSearchEngine(mock_prompts)
engine.vectorizer.transform = MagicMock(return_value=np.array([[0.1, 0.2, 0.3]]))
# Call most_similar with n greater than the number of prompts
results = engine.most_similar("test query", n=10)
assert len(results) == len(mock_prompts) # Should return at most the number of prompts
assert all(isinstance(score, float) and isinstance(prompt, str) for score, prompt in results)
@pytest.mark.integration
def test_most_similar_integration(mock_prompts):
engine = PromptSearchEngine(mock_prompts)
results = engine.most_similar("prompt 1", n=2)
# Verify that the results include the expected number of matches and correct types
assert len(results) == 2
assert all(isinstance(score, float) and isinstance(prompt, str) for score, prompt in results)
assert results[0][1] == "prompt 1"