#include <gtest/gtest.h>
#include "agent.hpp"
#include "llm_client.hpp"
#include "memory.hpp"
#include "tool.hpp"
class MockLlmClient : public LlmClient
{
public:
std::vector<Message> responses;
size_t call_count = 0;
Task<mw::E<Message>> generateResponse(const std::vector<Message>&,
const nlohmann::json&) override
{
if(call_count < responses.size())
{
auto res = responses[call_count++];
co_return res;
}
co_return AssistantMessage{"default_mock_end", {}};
}
};
class DummyTool : public Tool
{
public:
std::string name() const override
{
return "dummy";
}
std::string description() const override
{
return "dummy desc";
}
nlohmann::json parametersSchema() const override
{
return nlohmann::json::object();
}
Task<mw::E<std::string>> execute(const nlohmann::json&) override
{
co_return "dummy_result";
}
};
TEST(AgentTest, BasicRunAndReturn)
{
auto client = std::make_unique<MockLlmClient>();
client->responses.push_back(AssistantMessage{"Hello from Agent", {}});
auto memory = std::make_unique<InMemoryMemory>();
ToolRegistry registry;
Agent agent(std::move(client), std::move(memory), registry);
auto result = agent.run("Hi").get();
ASSERT_TRUE(result.has_value());
EXPECT_EQ(result.value(), "Hello from Agent");
}
TEST(AgentTest, ToolCallingFlow)
{
auto client = std::make_unique<MockLlmClient>();
AssistantMessage tool_call_msg;
tool_call_msg.tool_calls.push_back(
{"call_1", "dummy", nlohmann::json::object()});
client->responses.push_back(tool_call_msg);
client->responses.push_back(AssistantMessage{"Final Result", {}});
auto memory = std::make_unique<InMemoryMemory>();
ToolRegistry registry;
registry.registerTool(std::make_unique<DummyTool>());
Agent agent(std::move(client), std::move(memory), registry);
agent.allowTool("dummy");
auto result = agent.run("Hi").get();
ASSERT_TRUE(result.has_value());
EXPECT_EQ(result.value(), "Final Result");
}
TEST(AgentTest, ActivateSkill)
{
auto client = std::make_unique<MockLlmClient>();
client->responses.push_back(AssistantMessage{"Skill Active", {}});
auto memory = std::make_unique<InMemoryMemory>();
auto* mem_ptr = memory.get();
ToolRegistry registry;
Agent agent(std::move(client), std::move(memory), registry);
Skill test_skill{"test_skill", "You are a tester", {}};
agent.activateSkill(test_skill);
auto history = mem_ptr->getHistory();
ASSERT_EQ(history.size(), 1);
EXPECT_TRUE(std::holds_alternative<SystemMessage>(history[0]));
EXPECT_EQ(std::get<SystemMessage>(history[0]).content, "You are a tester");
}