import unittest
from unittest.mock import MagicMock, patch
import json
import io
import sys
from src.trilium_client import TriliumClient
from src.handlers import ToolHandlers
from src.server import MCPServer
class TestTriliumMCP(unittest.TestCase):
def setUp(self):
self.mock_client = MagicMock()
self.handlers = ToolHandlers(self.mock_client)
self.server = MCPServer(self.handlers)
def testInitialize(self):
# Mock stdin with initialize request
stdin_content = json.dumps({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {}
}) + "\n"
with patch('sys.stdin', io.StringIO(stdin_content)):
with patch('sys.stdout', io.StringIO()) as mock_stdout:
self.server.run()
output = mock_stdout.getvalue().strip()
response = json.loads(output)
self.assertEqual(response["id"], 1)
self.assertIn("capabilities", response["result"])
self.assertIn("tools", response["result"]["capabilities"])
def testListTools(self):
# Mock stdin with tools/list request
stdin_content = json.dumps({
"jsonrpc": "2.0",
"id": 2,
"method": "tools/list",
"params": {}
}) + "\n"
with patch('sys.stdin', io.StringIO(stdin_content)):
with patch('sys.stdout', io.StringIO()) as mock_stdout:
self.server.run()
output = mock_stdout.getvalue().strip()
response = json.loads(output)
self.assertEqual(response["id"], 2)
self.assertGreater(len(response["result"]["tools"]), 0)
# Verify search_notes tool exists
tool_names = [t["name"] for t in response["result"]["tools"]]
self.assertIn("search_notes", tool_names)
def testCallSearchNotes(self):
# Setup mock return value for client
self.mock_client.searchNotes.return_value = [{"noteId": "abc", "title": "Test Note"}]
stdin_content = json.dumps({
"jsonrpc": "2.0",
"id": 3,
"method": "tools/call",
"params": {
"name": "search_notes",
"arguments": {"query": "test"}
}
}) + "\n"
with patch('sys.stdin', io.StringIO(stdin_content)):
with patch('sys.stdout', io.StringIO()) as mock_stdout:
self.server.run()
output = mock_stdout.getvalue().strip()
response = json.loads(output)
self.assertEqual(response["id"], 3)
content = response["result"]["content"][0]["text"]
self.assertIn("abc", content)
self.assertIn("Test Note", content)
def testInvalidMethod(self):
stdin_content = json.dumps({
"jsonrpc": "2.0",
"id": 4,
"method": "non_existent_method",
"params": {}
}) + "\n"
with patch('sys.stdin', io.StringIO(stdin_content)):
with patch('sys.stdout', io.StringIO()) as mock_stdout:
self.server.run()
output = mock_stdout.getvalue().strip()
response = json.loads(output)
self.assertEqual(response["id"], 4)
self.assertEqual(response["error"]["code"], -32601)
if __name__ == '__main__':
unittest.main()