Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 31 additions & 23 deletions src/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,9 +1007,31 @@ async def run_async_server(self, transport="stdio", host="127.0.0.1", port=9001,
finally:
await self.close_pool()

# Sync server start logic
def start(self, transport="stdio", host="127.0.0.1", port=9001, path="/mcp"):
exit_code = 0

# --- Main Execution Block ---
if __name__ == "__main__":
try:
# 2. Use anyio.run to manage the event loop and call the main async server logic
anyio.run(
partial(self.run_async_server,
transport=transport,
host=host,
port=port,
path=path)
)
logger.info("Server finished gracefully.")

except KeyboardInterrupt:
logger.info("Server execution interrupted by user.")
except Exception as e:
logger.critical(f"Server failed to start or crashed: {e}", exc_info=True)
exit_code = 1
finally:
logger.info(f"Server exiting with code {exit_code}.")


def get_arg_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="MariaDB MCP Server")
parser.add_argument('--transport', type=str, default='stdio', choices=['stdio', 'sse', 'http'],
help='MCP transport protocol (stdio, sse, or http)')
Expand All @@ -1019,27 +1041,13 @@ async def run_async_server(self, transport="stdio", host="127.0.0.1", port=9001,
help='Port for SSE or HTTP transport')
parser.add_argument('--path', type=str, default='/mcp',
help='Path for HTTP transport (default: /mcp)')
args = parser.parse_args()
return parser


# --- Main Execution Block ---
if __name__ == "__main__":
args = get_arg_parser().parse_args()

# 1. Create the server instance
server = MariaDBServer()
exit_code = 0

try:
# 2. Use anyio.run to manage the event loop and call the main async server logic
anyio.run(
partial(server.run_async_server,
transport=args.transport,
host=args.host,
port=args.port,
path=args.path)
)
logger.info("Server finished gracefully.")

except KeyboardInterrupt:
logger.info("Server execution interrupted by user.")
except Exception as e:
logger.critical(f"Server failed to start or crashed: {e}", exc_info=True)
exit_code = 1
finally:
logger.info(f"Server exiting with code {exit_code}.")
server.start(args.transport, args.host, args.port, args.path)
51 changes: 51 additions & 0 deletions src/tests/test_custom_resource.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""
Tests that the MariaDBServer exposes its FastMCP instance so callers can
register additional resources before starting the server.

This mirrors the pattern used by mcp_fmd_server.py, which adds a
``schema://context`` resource to the ``mcp`` object after construction.
"""

import sys
import unittest
from pathlib import Path

import anyio
from fastmcp.client import Client

sys.path.insert(0, str(Path(__file__).parent.parent))

from server import MariaDBServer

_SCHEMA_CONTENT = "# Test Schema\nThis is test schema documentation."


class TestCustomResource(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
self.server = MariaDBServer()

# Register a custom resource on the exposed FastMCP instance —
# this is the pattern the FMD wrapper (mcp_fmd_server.py) uses.
@self.server.mcp.resource("schema://context")
def schema_context() -> str:
"""FMD AI Schema documentation."""
return _SCHEMA_CONTENT

async def test_custom_resource_is_listed(self):
"""The schema://context resource should appear in the resource list."""
async with Client(self.server.mcp) as client:
resources = await client.list_resources()
uris = [str(r.uri) for r in resources]
self.assertIn("schema://context", uris)

async def test_custom_resource_content(self):
"""Reading schema://context should return the registered content."""
async with Client(self.server.mcp) as client:
result = await client.read_resource("schema://context")
# result is a list of resource content objects; grab the first text
text = result[0].text if result else ""
self.assertEqual(text, _SCHEMA_CONTENT)


if __name__ == "__main__":
unittest.main()