BareGit

Refactor AST to use std::variant instead of inheritance

Author: MetroWind <chris.corsair@gmail.com>
Date: Sun Jan 11 11:39:58 2026 -0800
Commit: 66ae6af935739538faafb5b2d8291a1185a81383

Changes

diff --git a/include/macro_engine.h b/include/macro_engine.h
index ae4882a..399175f 100644
--- a/include/macro_engine.h
+++ b/include/macro_engine.h
@@ -41,7 +41,7 @@ public:
     void defineIntrinsic(const std::string& name, MacroCallback callback);
 
     std::string evaluate(const Node& node);
-    std::string evaluateMacro(const MacroNode& macro);
+    std::string evaluateMacro(const Macro& macro);
 
 private:
     std::map<std::string, MacroDefinition> macros_;
@@ -49,4 +49,4 @@ private:
 
 } // namespace macrodown
 
-#endif // MACRODOWN_MACRO_ENGINE_H
\ No newline at end of file
+#endif // MACRODOWN_MACRO_ENGINE_H
diff --git a/include/nodes.h b/include/nodes.h
index 3b44d2f..f15444b 100644
--- a/include/nodes.h
+++ b/include/nodes.h
@@ -4,69 +4,48 @@
 #include <string>
 #include <vector>
 #include <memory>
+#include <variant>
 
 namespace macrodown
 {
 
-enum class NodeType
-{
-    Text,
-    Macro,
-    Group
-};
+struct Node; // Forward declaration
 
-// Abstract Base Class
-struct Node
+struct Text
 {
-    virtual ~Node() = default;
-    virtual NodeType type() const = 0;
+    std::string content;
 };
 
-// Represents a raw text segment
-struct TextNode : public Node
+struct Macro
 {
-    std::string content;
-
-    TextNode(std::string c) : content(std::move(c)) {}
-    
-    NodeType type() const override
-    {
-        return NodeType::Text;
-    }
+    std::string name;
+    std::vector<std::unique_ptr<Node>> arguments;
+    bool is_special = false;
 };
 
-// Represents a collection of nodes (used for macro arguments)
-struct GroupNode : public Node
+struct Group
 {
     std::vector<std::unique_ptr<Node>> children;
 
-    NodeType type() const override
-    {
-        return NodeType::Group;
-    }
-
-    void addChild(std::unique_ptr<Node> node)
-    {
-        children.push_back(std::move(node));
-    }
+    void addChild(std::unique_ptr<Node> node);
 };
 
-// Represents a macro call: %name{arg1}{arg2}...
-struct MacroNode : public Node
+struct Node
 {
-    std::string name;
-    std::vector<std::unique_ptr<Node>> arguments;
-    bool is_special = false; // Intrinsic macros like %def
+    using Data = std::variant<Text, Macro, Group>;
+    Data data;
 
-    MacroNode(std::string n, bool special = false) 
-        : name(std::move(n)), is_special(special) {}
-    
-    NodeType type() const override
-    {
-        return NodeType::Macro;
-    }
+    Node(Text t) : data(std::move(t)) {}
+    Node(Macro m) : data(std::move(m)) {}
+    Node(Group g) : data(std::move(g)) {}
 };
 
+// Inline implementation to avoid circular dependency issues in headers
+inline void Group::addChild(std::unique_ptr<Node> node)
+{
+    children.push_back(std::move(node));
+}
+
 } // namespace macrodown
 
-#endif // MACRODOWN_NODES_H
+#endif // MACRODOWN_NODES_H
\ No newline at end of file
diff --git a/src/converter.cpp b/src/converter.cpp
index 8c3db2c..a6f6081 100644
--- a/src/converter.cpp
+++ b/src/converter.cpp
@@ -35,7 +35,6 @@ std::unique_ptr<Node> Converter::convert_block(const Block* block)
     if(!block) return nullptr;
 
     std::string macro_name;
-    std::vector<std::unique_ptr<Node>> args;
 
     switch(block->type)
     {
@@ -52,37 +51,37 @@ std::unique_ptr<Node> Converter::convert_block(const Block* block)
             return nullptr; // Ignore unknown blocks for now
     }
 
-    auto macro = std::make_unique<MacroNode>(macro_name);
+    Macro macro;
+    macro.name = macro_name;
 
     // Handle Content
     if(block->children.empty())
     {
         // Leaf block: Parse literal content
-        // Wrap result in GroupNode
-        auto group = std::make_unique<GroupNode>();
+        Group group;
         auto inline_nodes = Parser::parse(block->literal_content);
         for(auto& n : inline_nodes)
         {
-            group->addChild(std::move(n));
+            group.addChild(std::move(n));
         }
-        macro->arguments.push_back(std::move(group));
+        macro.arguments.push_back(std::make_unique<Node>(std::move(group)));
     }
     else
     {
         // Container block: Recursively convert children
-        auto group = std::make_unique<GroupNode>();
+        Group group;
         for(const auto& child : block->children)
         {
             auto child_node = convert_block(child.get());
             if(child_node)
             {
-                group->addChild(std::move(child_node));
+                group.addChild(std::move(child_node));
             }
         }
-        macro->arguments.push_back(std::move(group));
+        macro.arguments.push_back(std::make_unique<Node>(std::move(group)));
     }
 
-    return macro;
+    return std::make_unique<Node>(std::move(macro));
 }
 
-} // namespace macrodown
\ No newline at end of file
+} // namespace macrodown
diff --git a/src/macro_engine.cpp b/src/macro_engine.cpp
index 06c00fe..4016a50 100644
--- a/src/macro_engine.cpp
+++ b/src/macro_engine.cpp
@@ -3,20 +3,25 @@
 #include <sstream>
 #include <stdexcept>
 #include <iostream>
+#include <variant>
 
-namespace macrodown {
+namespace macrodown
+{
 
-namespace {
+namespace
+{
 
 // Helper to split a string by delimiter
-std::vector<std::string> split(const std::string& s, char delimiter) {
+std::vector<std::string> split(const std::string& s, char delimiter)
+{
     std::vector<std::string> tokens;
     std::string token;
     std::istringstream tokenStream(s);
-    while (std::getline(tokenStream, token, delimiter)) {
+    while(std::getline(tokenStream, token, delimiter))
+    {
         // Trim whitespace
         size_t first = token.find_first_not_of(" \t");
-        if (first == std::string::npos) continue;
+        if(first == std::string::npos) continue;
         size_t last = token.find_last_not_of(" \t");
         tokens.push_back(token.substr(first, (last - first + 1)));
     }
@@ -24,9 +29,11 @@ std::vector<std::string> split(const std::string& s, char delimiter) {
 }
 
 // Helper to replace all occurrences of a substring
-std::string replace_all(std::string str, const std::string& from, const std::string& to) {
+std::string replace_all(std::string str, const std::string& from, const std::string& to)
+{
     size_t start_pos = 0;
-    while((start_pos = str.find(from, start_pos)) != std::string::npos) {
+    while((start_pos = str.find(from, start_pos)) != std::string::npos)
+    {
         str.replace(start_pos, from.length(), to);
         start_pos += to.length();
     }
@@ -35,26 +42,16 @@ std::string replace_all(std::string str, const std::string& from, const std::str
 
 } // namespace
 
-Evaluator::Evaluator() {
+Evaluator::Evaluator()
+{
     // Register intrinsic %def macro
-    // Syntax: %def[name]{args...}{body}
-    // args... is a comma-separated list of argument names
-    defineIntrinsic("def", [this](const std::vector<std::string>& args) -> std::string {
-        if (args.size() < 3) {
-            // Error handling: %def requires at least 3 arguments (name, args, body)
-            // But wait, the structure of MacroNode has 'arguments'.
-            // For %def[name]{args}{body}:
-            // The AST will look like: 
-            // Name: "def"
-            // Argument 0 (Bracket): "name" (Technically CommonMark doesn't distinguish [] from {},
-            // but our parser might. The user prompt says %def[name]{args}{body}.
-            // Let's assume our parser maps [name] to arg 0, {args} to arg 1, {body} to arg 2.
+    defineIntrinsic("def", [this](const std::vector<std::string>& args) -> std::string
+    {
+        if(args.size() < 3)
+        {
             return ""; 
         }
         
-        // We actually need to access the logic *outside* this callback if we want to change state.
-        // But lambda captures 'this', so we can modify macros_.
-        
         std::string name = args[0];
         std::string arg_list_str = args[1];
         std::string body = args[2];
@@ -62,41 +59,54 @@ Evaluator::Evaluator() {
         std::vector<std::string> arg_names = split(arg_list_str, ',');
         
         this->define(name, arg_names, body);
-        return ""; // Definitions expand to nothing
+        return ""; 
     });
 }
 
-void Evaluator::define(const std::string& name, const std::vector<std::string>& args, const std::string& body) {
+void Evaluator::define(const std::string& name, const std::vector<std::string>& args, const std::string& body)
+{
     macros_[name] = MacroDefinition(name, args, body);
 }
 
-void Evaluator::defineIntrinsic(const std::string& name, MacroCallback callback) {
+void Evaluator::defineIntrinsic(const std::string& name, MacroCallback callback)
+{
     macros_[name] = MacroDefinition(name, callback);
 }
 
-std::string Evaluator::evaluate(const Node& node) {
-    if (node.type() == NodeType::Text) {
-        return static_cast<const TextNode&>(node).content;
-    } else if (node.type() == NodeType::Macro) {
-        return evaluateMacro(static_cast<const MacroNode&>(node));
-    } else if (node.type() == NodeType::Group) {
-        std::string result;
-        const auto& group = static_cast<const GroupNode&>(node);
-        for (const auto& child : group.children) {
-            result += evaluate(*child);
+std::string Evaluator::evaluate(const Node& node)
+{
+    return std::visit([this](auto&& arg) -> std::string
+    {
+        using T = std::decay_t<decltype(arg)>;
+        if constexpr (std::is_same_v<T, Text>)
+        {
+            return arg.content;
         }
-        return result;
-    }
-    return "";
+        else if constexpr (std::is_same_v<T, Macro>)
+        {
+            return this->evaluateMacro(arg);
+        }
+        else if constexpr (std::is_same_v<T, Group>)
+        {
+            std::string result;
+            for(const auto& child : arg.children)
+            {
+                result += this->evaluate(*child);
+            }
+            return result;
+        }
+        return "";
+    }, node.data);
 }
 
-std::string Evaluator::evaluateMacro(const MacroNode& macro) {
+std::string Evaluator::evaluateMacro(const Macro& macro)
+{
     auto it = macros_.find(macro.name);
-    if (it == macros_.end()) {
-        // Undefined macro: return literal representation
-        // (This is a simplified behavior; usually we might want to warn)
+    if(it == macros_.end())
+    {
         std::string result = "%" + macro.name;
-        for (const auto& arg : macro.arguments) {
+        for(const auto& arg : macro.arguments)
+        {
             result += "{" + evaluate(*arg) + "}"; 
         }
         return result;
@@ -104,47 +114,36 @@ std::string Evaluator::evaluateMacro(const MacroNode& macro) {
 
     const MacroDefinition& def = it->second;
 
-    // Evaluate arguments
-    // Note: For %def, we might want raw arguments, but for general macros, 
-    // we evaluate arguments first?
-    // The design doc says: "The Evaluator recursively expands macros".
-    // Usually, strict evaluation means evaluating args first.
-    // But %def needs raw strings for name/arg_names.
-    // Let's assume we evaluate arguments to strings first, THEN pass to macro.
-    // Exception: If the macro expects specific raw syntax, we might need a flag.
-    // For now, we evaluate all arguments.
-    
     std::vector<std::string> evaluated_args;
-    for (const auto& arg : macro.arguments) {
+    for(const auto& arg : macro.arguments)
+    {
         evaluated_args.push_back(evaluate(*arg));
     }
 
-    if (def.is_intrinsic) {
+    if(def.is_intrinsic)
+    {
         return def.callback(evaluated_args);
-    } else {
-        // User defined macro
-        // 1. Check arg count
-        // (ignoring mismatch for now, or just filling empty)
-        
+    }
+    else
+    {
         std::string body = def.body;
         
-        // 2. Substitute arguments
-        for (size_t i = 0; i < def.arg_names.size(); ++i) {
+        for(size_t i = 0; i < def.arg_names.size(); ++i)
+        {
             std::string placeholder = "%" + def.arg_names[i];
             std::string value = (i < evaluated_args.size()) ? evaluated_args[i] : "";
             body = replace_all(body, placeholder, value);
         }
         
-        // 3. Parse and evaluate the body
-        // We use the static Parser::parse for this
         auto nodes = Parser::parse(body);
         
         std::string result;
-        for (const auto& n : nodes) {
+        for(const auto& n : nodes)
+        {
             result += evaluate(*n);
         }
         return result;
     }
 }
 
-} // namespace macrodown
\ No newline at end of file
+} // namespace macrodown
diff --git a/src/parser.cpp b/src/parser.cpp
index c2bd69a..ba0f62a 100644
--- a/src/parser.cpp
+++ b/src/parser.cpp
@@ -13,7 +13,7 @@ std::vector<std::unique_ptr<Node>> Parser::parse(const std::string& input)
     {
         if(!current_text.empty())
         {
-            nodes.push_back(std::make_unique<TextNode>(current_text));
+            nodes.push_back(std::make_unique<Node>(Text{current_text}));
             current_text.clear();
         }
     };
@@ -51,7 +51,8 @@ std::vector<std::unique_ptr<Node>> Parser::parse(const std::string& input)
                 continue;
             }
 
-            auto macro = std::make_unique<MacroNode>(name);
+            Macro macro;
+            macro.name = name;
 
             // Parse Arguments
             while(i < input.length())
@@ -83,20 +84,20 @@ std::vector<std::unique_ptr<Node>> Parser::parse(const std::string& input)
                         i++;
                     }
                     
-                    auto group = std::make_unique<GroupNode>();
+                    Group group;
                     std::vector<std::unique_ptr<Node>> sub_nodes = parse(arg_content);
                     for(auto& n : sub_nodes)
                     {
-                        group->addChild(std::move(n));
+                        group.addChild(std::move(n));
                     }
-                    macro->arguments.push_back(std::move(group));
+                    macro.arguments.push_back(std::make_unique<Node>(std::move(group)));
                 }
                 else
                 {
                     break; 
                 }
             }
-            nodes.push_back(std::move(macro));
+            nodes.push_back(std::make_unique<Node>(std::move(macro)));
             continue;
         }
         
@@ -109,11 +110,15 @@ std::vector<std::unique_ptr<Node>> Parser::parse(const std::string& input)
             {
                 flush_text();
                 std::string content = input.substr(start, end - start);
-                auto macro = std::make_unique<MacroNode>("code");
-                auto group = std::make_unique<GroupNode>();
-                group->addChild(std::make_unique<TextNode>(content));
-                macro->arguments.push_back(std::move(group));
-                nodes.push_back(std::move(macro));
+                
+                Macro macro;
+                macro.name = "code";
+                
+                Group group;
+                group.addChild(std::make_unique<Node>(Text{content}));
+                
+                macro.arguments.push_back(std::make_unique<Node>(std::move(group)));
+                nodes.push_back(std::make_unique<Node>(std::move(macro)));
                 i = end + 1;
                 continue;
             }
@@ -122,10 +127,7 @@ std::vector<std::unique_ptr<Node>> Parser::parse(const std::string& input)
         // Link: [text](url)
         if(c == '[')
         {
-            // Find matching ]
-            // Simplified: doesn't handle nested brackets well yet
             size_t label_start = i + 1;
-            // We need a helper for balanced search, but for now simple find
             size_t j = label_start;
             int bracket_bal = 1;
             while(j < input.length() && bracket_bal > 0)
@@ -138,7 +140,6 @@ std::vector<std::unique_ptr<Node>> Parser::parse(const std::string& input)
             if(j < input.length() && bracket_bal == 0)
             {
                 size_t close_bracket = j;
-                // Check for (url)
                 if(close_bracket + 1 < input.length() && input[close_bracket + 1] == '(')
                 {
                     size_t url_start = close_bracket + 2;
@@ -149,20 +150,21 @@ std::vector<std::unique_ptr<Node>> Parser::parse(const std::string& input)
                         std::string label = input.substr(label_start, close_bracket - label_start);
                         std::string url = input.substr(url_start, url_end - url_start);
                         
-                        auto macro = std::make_unique<MacroNode>("link");
+                        Macro macro;
+                        macro.name = "link";
                         
                         // Arg 1: URL
-                        auto group1 = std::make_unique<GroupNode>();
-                        group1->addChild(std::make_unique<TextNode>(url));
-                        macro->arguments.push_back(std::move(group1));
+                        Group group1;
+                        group1.addChild(std::make_unique<Node>(Text{url}));
+                        macro.arguments.push_back(std::make_unique<Node>(std::move(group1)));
                         
                         // Arg 2: Text (parsed)
-                        auto group2 = std::make_unique<GroupNode>();
+                        Group group2;
                         auto sub = parse(label);
-                        for(auto& n : sub) group2->addChild(std::move(n));
-                        macro->arguments.push_back(std::move(group2));
+                        for(auto& n : sub) group2.addChild(std::move(n));
+                        macro.arguments.push_back(std::make_unique<Node>(std::move(group2)));
                         
-                        nodes.push_back(std::move(macro));
+                        nodes.push_back(std::make_unique<Node>(std::move(macro)));
                         i = url_end + 1;
                         continue;
                     }
@@ -171,31 +173,28 @@ std::vector<std::unique_ptr<Node>> Parser::parse(const std::string& input)
         }
         
         // Emphasis: * or **
-        // Simplified: Greedy scan for matching *
         if(c == '*')
         {
             bool strong = (i + 1 < input.length() && input[i+1] == '*');
             size_t start_content = i + (strong ? 2 : 1);
             
-            // Find delimiter
             std::string delim = strong ? "**" : "*";
             size_t end = input.find(delim, start_content);
             
-            // Handle edge case: *foo **bar*** -> *foo *bar**
-            // Simplified: just find first match. 
-            // This is NOT correct CommonMark but sufficient for prototype.
-            
             if(end != std::string::npos)
             {
                 flush_text();
                 std::string content = input.substr(start_content, end - start_content);
-                auto macro = std::make_unique<MacroNode>(strong ? "strong" : "em");
-                auto group = std::make_unique<GroupNode>();
+                
+                Macro macro;
+                macro.name = strong ? "strong" : "em";
+                
+                Group group;
                 auto sub = parse(content);
-                for(auto& n : sub) group->addChild(std::move(n));
-                macro->arguments.push_back(std::move(group));
+                for(auto& n : sub) group.addChild(std::move(n));
+                macro.arguments.push_back(std::make_unique<Node>(std::move(group)));
                 
-                nodes.push_back(std::move(macro));
+                nodes.push_back(std::make_unique<Node>(std::move(macro)));
                 i = end + delim.length();
                 continue;
             }
@@ -209,4 +208,4 @@ std::vector<std::unique_ptr<Node>> Parser::parse(const std::string& input)
     return nodes;
 }
 
-} // namespace macrodown
+} // namespace macrodown
\ No newline at end of file
diff --git a/tests/test_macro_engine.cpp b/tests/test_macro_engine.cpp
index 10a295b..ce886f0 100644
--- a/tests/test_macro_engine.cpp
+++ b/tests/test_macro_engine.cpp
@@ -2,6 +2,7 @@
 #include "macro_engine.h"
 #include "parser.h"
 #include "nodes.h"
+#include <variant>
 
 using namespace macrodown;
 
@@ -16,8 +17,9 @@ TEST_F(MacroEngineTest, ParseText)
 {
     auto nodes = Parser::parse("Hello World");
     ASSERT_EQ(nodes.size(), 1);
-    EXPECT_EQ(nodes[0]->type(), NodeType::Text);
-    EXPECT_EQ(static_cast<TextNode*>(nodes[0].get())->content, "Hello World");
+    
+    ASSERT_TRUE(std::holds_alternative<Text>(nodes[0]->data));
+    EXPECT_EQ(std::get<Text>(nodes[0]->data).content, "Hello World");
 }
 
 // Test parsing of macro
@@ -25,17 +27,19 @@ TEST_F(MacroEngineTest, ParseMacro)
 {
     auto nodes = Parser::parse("%m{arg}");
     ASSERT_EQ(nodes.size(), 1);
-    EXPECT_EQ(nodes[0]->type(), NodeType::Macro);
     
-    auto* macro = static_cast<MacroNode*>(nodes[0].get());
-    EXPECT_EQ(macro->name, "m");
-    ASSERT_EQ(macro->arguments.size(), 1);
+    ASSERT_TRUE(std::holds_alternative<Macro>(nodes[0]->data));
+    const auto& macro = std::get<Macro>(nodes[0]->data);
+    EXPECT_EQ(macro.name, "m");
+    ASSERT_EQ(macro.arguments.size(), 1);
+    
+    // Argument should be a Group
+    ASSERT_TRUE(std::holds_alternative<Group>(macro.arguments[0]->data));
+    const auto& group = std::get<Group>(macro.arguments[0]->data);
+    ASSERT_EQ(group.children.size(), 1);
     
-    // Argument should be a GroupNode
-    EXPECT_EQ(macro->arguments[0]->type(), NodeType::Group);
-    auto* group = static_cast<GroupNode*>(macro->arguments[0].get());
-    ASSERT_EQ(group->children.size(), 1);
-    EXPECT_EQ(static_cast<TextNode*>(group->children[0].get())->content, "arg");
+    ASSERT_TRUE(std::holds_alternative<Text>(group.children[0]->data));
+    EXPECT_EQ(std::get<Text>(group.children[0]->data).content, "arg");
 }
 
 // Test intrinsic %def and expansion
@@ -89,4 +93,4 @@ TEST_F(MacroEngineTest, NestedMacros)
     }
     
     EXPECT_EQ(result, "<p>Hello <b>World</b></p>");
-}
\ No newline at end of file
+}