Spaces:
Running
Running
File size: 2,644 Bytes
45b4689 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import pytest
import numpy as np
from unittest.mock import MagicMock
from app.engine import PromptSearchEngine
@pytest.fixture
def mock_prompts():
return ["prompt 1", "prompt 2", "prompt 3"]
@pytest.fixture
def mock_model():
model = MagicMock()
model.encode = MagicMock(return_value=np.array([
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9]
]))
return model
@pytest.mark.unit
def test_engine_initialization(mock_prompts, mock_model, monkeypatch):
# SentenceTransformer is mocked to return the mock model
monkeypatch.setattr("app.engine.SentenceTransformer", MagicMock(return_value=mock_model))
engine = PromptSearchEngine(mock_prompts)
# Verify that the engine initializes correctly with the mock prompts and vectors
assert engine.prompts == mock_prompts
assert engine.corpus_vectors.shape == (3, 3)
assert np.array_equal(
engine.corpus_vectors,
np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]])
)
@pytest.mark.unit
def test_most_similar_valid_query(mock_prompts, mock_model, monkeypatch):
monkeypatch.setattr("app.engine.SentenceTransformer", MagicMock(return_value=mock_model))
engine = PromptSearchEngine(mock_prompts)
# Mock the vectorizer's transform method to return a single query vector
engine.vectorizer.transform = MagicMock(return_value=np.array([[0.1, 0.2, 0.3]]))
results = engine.most_similar("test query", n=2)
assert len(results) == 2
assert all(isinstance(score, float) and isinstance(prompt, str) for score, prompt in results)
@pytest.mark.unit
def test_most_similar_exceeding_n(mock_prompts, mock_model, monkeypatch):
monkeypatch.setattr("app.engine.SentenceTransformer", MagicMock(return_value=mock_model))
engine = PromptSearchEngine(mock_prompts)
engine.vectorizer.transform = MagicMock(return_value=np.array([[0.1, 0.2, 0.3]]))
# Call most_similar with n greater than the number of prompts
results = engine.most_similar("test query", n=10)
assert len(results) == len(mock_prompts) # Should return at most the number of prompts
assert all(isinstance(score, float) and isinstance(prompt, str) for score, prompt in results)
@pytest.mark.integration
def test_most_similar_integration(mock_prompts):
engine = PromptSearchEngine(mock_prompts)
results = engine.most_similar("prompt 1", n=2)
# Verify that the results include the expected number of matches and correct types
assert len(results) == 2
assert all(isinstance(score, float) and isinstance(prompt, str) for score, prompt in results)
assert results[0][1] == "prompt 1"
|