Spaces:
Running
Running
import pytest | |
import numpy as np | |
from unittest.mock import MagicMock | |
from app.engine import PromptSearchEngine | |
def mock_prompts(): | |
return ["prompt 1", "prompt 2", "prompt 3"] | |
def mock_model(): | |
model = MagicMock() | |
model.encode = MagicMock(return_value=np.array([ | |
[0.1, 0.2, 0.3], | |
[0.4, 0.5, 0.6], | |
[0.7, 0.8, 0.9] | |
])) | |
return model | |
def test_engine_initialization(mock_prompts, mock_model, monkeypatch): | |
# SentenceTransformer is mocked to return the mock model | |
monkeypatch.setattr("app.engine.SentenceTransformer", MagicMock(return_value=mock_model)) | |
engine = PromptSearchEngine(mock_prompts) | |
# Verify that the engine initializes correctly with the mock prompts and vectors | |
assert engine.prompts == mock_prompts | |
assert engine.corpus_vectors.shape == (3, 3) | |
assert np.array_equal( | |
engine.corpus_vectors, | |
np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]) | |
) | |
def test_most_similar_valid_query(mock_prompts, mock_model, monkeypatch): | |
monkeypatch.setattr("app.engine.SentenceTransformer", MagicMock(return_value=mock_model)) | |
engine = PromptSearchEngine(mock_prompts) | |
# Mock the vectorizer's transform method to return a single query vector | |
engine.vectorizer.transform = MagicMock(return_value=np.array([[0.1, 0.2, 0.3]])) | |
results = engine.most_similar("test query", n=2) | |
assert len(results) == 2 | |
assert all(isinstance(score, float) and isinstance(prompt, str) for score, prompt in results) | |
def test_most_similar_exceeding_n(mock_prompts, mock_model, monkeypatch): | |
monkeypatch.setattr("app.engine.SentenceTransformer", MagicMock(return_value=mock_model)) | |
engine = PromptSearchEngine(mock_prompts) | |
engine.vectorizer.transform = MagicMock(return_value=np.array([[0.1, 0.2, 0.3]])) | |
# Call most_similar with n greater than the number of prompts | |
results = engine.most_similar("test query", n=10) | |
assert len(results) == len(mock_prompts) # Should return at most the number of prompts | |
assert all(isinstance(score, float) and isinstance(prompt, str) for score, prompt in results) | |
def test_most_similar_integration(mock_prompts): | |
engine = PromptSearchEngine(mock_prompts) | |
results = engine.most_similar("prompt 1", n=2) | |
# Verify that the results include the expected number of matches and correct types | |
assert len(results) == 2 | |
assert all(isinstance(score, float) and isinstance(prompt, str) for score, prompt in results) | |
assert results[0][1] == "prompt 1" | |