diff --git a/src/server.py b/src/server.py index eac59ef..7d98d6a 100644 --- a/src/server.py +++ b/src/server.py @@ -1007,9 +1007,31 @@ async def run_async_server(self, transport="stdio", host="127.0.0.1", port=9001, finally: await self.close_pool() + # Sync server start logic + def start(self, transport="stdio", host="127.0.0.1", port=9001, path="/mcp"): + exit_code = 0 -# --- Main Execution Block --- -if __name__ == "__main__": + try: + # 2. Use anyio.run to manage the event loop and call the main async server logic + anyio.run( + partial(self.run_async_server, + transport=transport, + host=host, + port=port, + path=path) + ) + logger.info("Server finished gracefully.") + + except KeyboardInterrupt: + logger.info("Server execution interrupted by user.") + except Exception as e: + logger.critical(f"Server failed to start or crashed: {e}", exc_info=True) + exit_code = 1 + finally: + logger.info(f"Server exiting with code {exit_code}.") + + +def get_arg_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="MariaDB MCP Server") parser.add_argument('--transport', type=str, default='stdio', choices=['stdio', 'sse', 'http'], help='MCP transport protocol (stdio, sse, or http)') @@ -1019,27 +1041,13 @@ async def run_async_server(self, transport="stdio", host="127.0.0.1", port=9001, help='Port for SSE or HTTP transport') parser.add_argument('--path', type=str, default='/mcp', help='Path for HTTP transport (default: /mcp)') - args = parser.parse_args() + return parser + + +# --- Main Execution Block --- +if __name__ == "__main__": + args = get_arg_parser().parse_args() # 1. Create the server instance server = MariaDBServer() - exit_code = 0 - - try: - # 2. Use anyio.run to manage the event loop and call the main async server logic - anyio.run( - partial(server.run_async_server, - transport=args.transport, - host=args.host, - port=args.port, - path=args.path) - ) - logger.info("Server finished gracefully.") - - except KeyboardInterrupt: - logger.info("Server execution interrupted by user.") - except Exception as e: - logger.critical(f"Server failed to start or crashed: {e}", exc_info=True) - exit_code = 1 - finally: - logger.info(f"Server exiting with code {exit_code}.") \ No newline at end of file + server.start(args.transport, args.host, args.port, args.path) \ No newline at end of file diff --git a/src/tests/test_custom_resource.py b/src/tests/test_custom_resource.py new file mode 100644 index 0000000..2b5a4b5 --- /dev/null +++ b/src/tests/test_custom_resource.py @@ -0,0 +1,51 @@ +""" +Tests that the MariaDBServer exposes its FastMCP instance so callers can +register additional resources before starting the server. + +This mirrors the pattern used by mcp_fmd_server.py, which adds a +``schema://context`` resource to the ``mcp`` object after construction. +""" + +import sys +import unittest +from pathlib import Path + +import anyio +from fastmcp.client import Client + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from server import MariaDBServer + +_SCHEMA_CONTENT = "# Test Schema\nThis is test schema documentation." + + +class TestCustomResource(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + self.server = MariaDBServer() + + # Register a custom resource on the exposed FastMCP instance — + # this is the pattern the FMD wrapper (mcp_fmd_server.py) uses. + @self.server.mcp.resource("schema://context") + def schema_context() -> str: + """FMD AI Schema documentation.""" + return _SCHEMA_CONTENT + + async def test_custom_resource_is_listed(self): + """The schema://context resource should appear in the resource list.""" + async with Client(self.server.mcp) as client: + resources = await client.list_resources() + uris = [str(r.uri) for r in resources] + self.assertIn("schema://context", uris) + + async def test_custom_resource_content(self): + """Reading schema://context should return the registered content.""" + async with Client(self.server.mcp) as client: + result = await client.read_resource("schema://context") + # result is a list of resource content objects; grab the first text + text = result[0].text if result else "" + self.assertEqual(text, _SCHEMA_CONTENT) + + +if __name__ == "__main__": + unittest.main()