#include "agent.hpp"
#include <spdlog/spdlog.h>
Agent::Agent(std::unique_ptr<LlmClient> client, std::unique_ptr<Memory> memory,
ToolRegistry& tool_registry)
: client_(std::move(client)), memory_(std::move(memory)),
tool_registry_(tool_registry)
{
}
mw::E<void> Agent::allowTool(const std::string& tool_name)
{
if(tool_registry_.getTool(tool_name) == nullptr)
{
return std::unexpected(mw::runtimeError("Tool not found in registry"));
}
allowed_tools_.push_back(tool_name);
return {};
}
void Agent::activateSkill(const Skill& skill)
{
current_skill_ = skill;
allowed_tools_ = skill.allowed_tools;
memory_->addMessage(SystemMessage{skill.system_prompt});
}
Task<mw::E<std::string>> Agent::run(const std::string& user_input)
{
memory_->addMessage(UserMessage{user_input});
while(true)
{
auto history = memory_->getHistory();
nlohmann::json tools_schema = nlohmann::json::array();
for(const auto& tool_name : allowed_tools_)
{
if(Tool* tool = tool_registry_.getTool(tool_name))
{
nlohmann::json tool_json = {
{"type", "function"},
{"function",
{{"name", tool->name()},
{"description", tool->description()},
{"parameters", tool->parametersSchema()}}}};
tools_schema.push_back(tool_json);
}
}
auto response_res =
co_await client_->generateResponse(history, tools_schema);
if(!response_res.has_value())
{
co_return std::unexpected(response_res.error());
}
Message response_msg = response_res.value();
memory_->addMessage(response_msg);
auto& assistant_msg = std::get<AssistantMessage>(response_msg);
if(!assistant_msg.tool_calls.empty())
{
for(const auto& call : assistant_msg.tool_calls)
{
if(Tool* tool = tool_registry_.getTool(call.name))
{
auto result = co_await tool->execute(call.arguments);
if(result.has_value())
{
memory_->addMessage(
ToolResultMessage{call.id, result.value()});
}
else
{
memory_->addMessage(ToolResultMessage{
call.id, "Error: " + mw::errorMsg(result.error())});
}
}
else
{
memory_->addMessage(
ToolResultMessage{call.id, "Error: Tool not found"});
}
}
}
else
{
if(assistant_msg.content)
{
co_return *assistant_msg.content;
}
co_return "";
}
}
}