File size: 2,739 Bytes
45b4689
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import pytest
import numpy as np
from unittest.mock import MagicMock
from app.engine import PromptSearchEngine


@pytest.fixture
def mock_prompts():
    return ["prompt 1", "prompt 2", "prompt 3"]


@pytest.fixture
def mock_model():
    embedding_dim = 384  # Correct embedding dimensionality for SentenceTransformer
    model = MagicMock()
    model.encode = MagicMock(return_value=np.random.rand(3, embedding_dim))
    return model


@pytest.mark.unit
def test_engine_initialization(mock_prompts, mock_model):
    # Mock the vectorizer to use the mock model
    PromptSearchEngine.vectorizer = MagicMock()
    PromptSearchEngine.vectorizer.transform = MagicMock(return_value=mock_model.encode(mock_prompts))
    # Initialize the engine
    engine = PromptSearchEngine(mock_prompts)
    assert engine.prompts == mock_prompts
    assert engine.corpus_vectors.shape == (3, 384)  # Correct dimensionality


@pytest.mark.unit
def test_most_similar_valid_query(mock_prompts, mock_model):
    # Mock the vectorizer and its transform method
    embedding_dim = 384
    query_embedding = np.random.rand(1, embedding_dim)
    PromptSearchEngine.vectorizer = MagicMock()
    PromptSearchEngine.vectorizer.transform = MagicMock(return_value=query_embedding)
    # Initialize the engine
    engine = PromptSearchEngine(mock_prompts)
    engine.vectorizer = MagicMock()
    engine.vectorizer.transform = MagicMock(return_value=query_embedding)
    results = engine.most_similar("test query", n=2)
    assert len(results) == 2
    assert all(isinstance(score, float) and isinstance(prompt, str) for score, prompt in results)


@pytest.mark.unit
def test_most_similar_empty_query(mock_prompts):
    # Mock the vectorizer to raise a ValueError for empty input
    engine = PromptSearchEngine(mock_prompts)
    engine.vectorizer = MagicMock()
    engine.vectorizer.transform = MagicMock(side_effect=ValueError("Invalid query"))
    with pytest.raises(ValueError):
        engine.most_similar("", n=2)


@pytest.mark.unit
def test_most_similar_exceeding_n(mock_prompts, mock_model):
    # Initialize the engine
    PromptSearchEngine.vectorizer = MagicMock()
    engine = PromptSearchEngine(mock_prompts)
    # Call most_similar with n greater than the number of prompts
    results = engine.most_similar("test query", n=10)
    assert len(results) == len(mock_prompts)  # Should return at most the number of prompts


@pytest.mark.integration
def test_most_similar_integration(mock_prompts):
    engine = PromptSearchEngine(mock_prompts)
    results = engine.most_similar("prompt 1", n=2)
    assert len(results) == 2
    assert all(isinstance(score, float) and isinstance(prompt, str) for score, prompt in results)
    assert results[0][1] == "prompt 1"