BareGit
import unittest
from unittest.mock import MagicMock, patch
import json
import io
import sys
from src.trilium_client import TriliumClient
from src.handlers import ToolHandlers
from src.server import MCPServer

class TestTriliumMCP(unittest.TestCase):
    def setUp(self):
        self.mock_client = MagicMock()
        self.handlers = ToolHandlers(self.mock_client)
        self.server = MCPServer(self.handlers)

    def testInitialize(self):
        # Mock stdin with initialize request
        stdin_content = json.dumps({
            "jsonrpc": "2.0",
            "id": 1,
            "method": "initialize",
            "params": {}
        }) + "\n"
        
        with patch('sys.stdin', io.StringIO(stdin_content)):
            with patch('sys.stdout', io.StringIO()) as mock_stdout:
                self.server.run()
                output = mock_stdout.getvalue().strip()
                response = json.loads(output)
                
                self.assertEqual(response["id"], 1)
                self.assertIn("capabilities", response["result"])
                self.assertIn("tools", response["result"]["capabilities"])

    def testListTools(self):
        # Mock stdin with tools/list request
        stdin_content = json.dumps({
            "jsonrpc": "2.0",
            "id": 2,
            "method": "tools/list",
            "params": {}
        }) + "\n"
        
        with patch('sys.stdin', io.StringIO(stdin_content)):
            with patch('sys.stdout', io.StringIO()) as mock_stdout:
                self.server.run()
                output = mock_stdout.getvalue().strip()
                response = json.loads(output)
                
                self.assertEqual(response["id"], 2)
                self.assertGreater(len(response["result"]["tools"]), 0)
                # Verify search_notes tool exists
                tool_names = [t["name"] for t in response["result"]["tools"]]
                self.assertIn("search_notes", tool_names)

    def testCallSearchNotes(self):
        # Setup mock return value for client
        self.mock_client.searchNotes.return_value = [{"noteId": "abc", "title": "Test Note"}]
        
        stdin_content = json.dumps({
            "jsonrpc": "2.0",
            "id": 3,
            "method": "tools/call",
            "params": {
                "name": "search_notes",
                "arguments": {"query": "test"}
            }
        }) + "\n"
        
        with patch('sys.stdin', io.StringIO(stdin_content)):
            with patch('sys.stdout', io.StringIO()) as mock_stdout:
                self.server.run()
                output = mock_stdout.getvalue().strip()
                response = json.loads(output)
                
                self.assertEqual(response["id"], 3)
                content = response["result"]["content"][0]["text"]
                self.assertIn("abc", content)
                self.assertIn("Test Note", content)

    def testInvalidMethod(self):
        stdin_content = json.dumps({
            "jsonrpc": "2.0",
            "id": 4,
            "method": "non_existent_method",
            "params": {}
        }) + "\n"
        
        with patch('sys.stdin', io.StringIO(stdin_content)):
            with patch('sys.stdout', io.StringIO()) as mock_stdout:
                self.server.run()
                output = mock_stdout.getvalue().strip()
                response = json.loads(output)
                
                self.assertEqual(response["id"], 4)
                self.assertEqual(response["error"]["code"], -32601)

if __name__ == '__main__':
    unittest.main()