BareGit
#include "agent.hpp"

#include <spdlog/spdlog.h>

Agent::Agent(std::unique_ptr<LlmClient> client, std::unique_ptr<Memory> memory,
             ToolRegistry& tool_registry)
    : client_(std::move(client)), memory_(std::move(memory)),
      tool_registry_(tool_registry)
{
}

mw::E<void> Agent::allowTool(const std::string& tool_name)
{
    if(tool_registry_.getTool(tool_name) == nullptr)
    {
        return std::unexpected(mw::runtimeError("Tool not found in registry"));
    }
    allowed_tools_.push_back(tool_name);
    return {};
}

void Agent::activateSkill(const Skill& skill)
{
    current_skill_ = skill;
    allowed_tools_ = skill.allowed_tools;
    memory_->addMessage(SystemMessage{skill.system_prompt});
}

Task<mw::E<std::string>> Agent::run(const std::string& user_input)
{
    memory_->addMessage(UserMessage{user_input});

    while(true)
    {
        auto history = memory_->getHistory();
        nlohmann::json tools_schema = nlohmann::json::array();

        for(const auto& tool_name : allowed_tools_)
        {
            if(Tool* tool = tool_registry_.getTool(tool_name))
            {
                nlohmann::json tool_json = {
                    {"type", "function"},
                    {"function",
                     {{"name", tool->name()},
                      {"description", tool->description()},
                      {"parameters", tool->parametersSchema()}}}};
                tools_schema.push_back(tool_json);
            }
        }

        auto response_res =
            co_await client_->generateResponse(history, tools_schema);
        if(!response_res.has_value())
        {
            co_return std::unexpected(response_res.error());
        }

        Message response_msg = response_res.value();
        memory_->addMessage(response_msg);

        auto& assistant_msg = std::get<AssistantMessage>(response_msg);

        if(!assistant_msg.tool_calls.empty())
        {
            for(const auto& call : assistant_msg.tool_calls)
            {
                if(Tool* tool = tool_registry_.getTool(call.name))
                {
                    auto result = co_await tool->execute(call.arguments);
                    if(result.has_value())
                    {
                        memory_->addMessage(
                            ToolResultMessage{call.id, result.value()});
                    }
                    else
                    {
                        memory_->addMessage(ToolResultMessage{
                            call.id, "Error: " + mw::errorMsg(result.error())});
                    }
                }
                else
                {
                    memory_->addMessage(
                        ToolResultMessage{call.id, "Error: Tool not found"});
                }
            }
        }
        else
        {
            if(assistant_msg.content)
            {
                co_return *assistant_msg.content;
            }
            co_return "";
        }
    }
}