|
from langchain.chat_models.base import BaseChatModel |
|
from langchain_community.chat_models.fake import FakeListChatModel |
|
|
|
from gpt_engineer.core.ai import AI |
|
|
|
|
|
def mock_create_chat_model(self) -> BaseChatModel: |
|
return FakeListChatModel(responses=["response1", "response2", "response3"]) |
|
|
|
|
|
def test_start(monkeypatch): |
|
monkeypatch.setattr(AI, "_create_chat_model", mock_create_chat_model) |
|
|
|
ai = AI("gpt-4") |
|
|
|
|
|
response_messages = ai.start("system prompt", "user prompt", step_name="step name") |
|
|
|
|
|
assert response_messages[-1].content == "response1" |
|
|
|
|
|
def test_next(monkeypatch): |
|
|
|
monkeypatch.setattr(AI, "_create_chat_model", mock_create_chat_model) |
|
|
|
ai = AI("gpt-4") |
|
response_messages = ai.start("system prompt", "user prompt", step_name="step name") |
|
|
|
|
|
response_messages = ai.next( |
|
response_messages, "next user prompt", step_name="step name" |
|
) |
|
|
|
|
|
assert response_messages[-1].content == "response2" |
|
|
|
|
|
def test_token_logging(monkeypatch): |
|
|
|
monkeypatch.setattr(AI, "_create_chat_model", mock_create_chat_model) |
|
|
|
ai = AI("gpt-4") |
|
|
|
|
|
response_messages = ai.start("system prompt", "user prompt", step_name="step name") |
|
usageCostAfterStart = ai.token_usage_log.usage_cost() |
|
ai.next(response_messages, "next user prompt", step_name="step name") |
|
usageCostAfterNext = ai.token_usage_log.usage_cost() |
|
|
|
|
|
assert usageCostAfterStart > 0 |
|
assert usageCostAfterNext > usageCostAfterStart |
|
|