import pytest import numpy as np from unittest.mock import MagicMock from app.engine import PromptSearchEngine @pytest.fixture def mock_prompts(): return ["prompt 1", "prompt 2", "prompt 3"] @pytest.fixture def mock_model(): model = MagicMock() model.encode = MagicMock(return_value=np.array([ [0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9] ])) return model @pytest.mark.unit def test_engine_initialization(mock_prompts, mock_model, monkeypatch): # SentenceTransformer is mocked to return the mock model monkeypatch.setattr("app.engine.SentenceTransformer", MagicMock(return_value=mock_model)) engine = PromptSearchEngine(mock_prompts) # Verify that the engine initializes correctly with the mock prompts and vectors assert engine.prompts == mock_prompts assert engine.corpus_vectors.shape == (3, 3) assert np.array_equal( engine.corpus_vectors, np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]) ) @pytest.mark.unit def test_most_similar_valid_query(mock_prompts, mock_model, monkeypatch): monkeypatch.setattr("app.engine.SentenceTransformer", MagicMock(return_value=mock_model)) engine = PromptSearchEngine(mock_prompts) # Mock the vectorizer's transform method to return a single query vector engine.vectorizer.transform = MagicMock(return_value=np.array([[0.1, 0.2, 0.3]])) results = engine.most_similar("test query", n=2) assert len(results) == 2 assert all(isinstance(score, float) and isinstance(prompt, str) for score, prompt in results) @pytest.mark.unit def test_most_similar_exceeding_n(mock_prompts, mock_model, monkeypatch): monkeypatch.setattr("app.engine.SentenceTransformer", MagicMock(return_value=mock_model)) engine = PromptSearchEngine(mock_prompts) engine.vectorizer.transform = MagicMock(return_value=np.array([[0.1, 0.2, 0.3]])) # Call most_similar with n greater than the number of prompts results = engine.most_similar("test query", n=10) assert len(results) == len(mock_prompts) # Should return at most the number of prompts assert all(isinstance(score, float) and isinstance(prompt, str) for score, prompt in results) @pytest.mark.integration def test_most_similar_integration(mock_prompts): engine = PromptSearchEngine(mock_prompts) results = engine.most_similar("prompt 1", n=2) # Verify that the results include the expected number of matches and correct types assert len(results) == 2 assert all(isinstance(score, float) and isinstance(prompt, str) for score, prompt in results) assert results[0][1] == "prompt 1"