Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/memos/graph_dbs/polardb.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,12 +810,19 @@ def edge_exists(
raise ValueError(
f"Invalid direction: {direction}. Must be 'OUTGOING', 'INCOMING', or 'ANY'."
)

# Escape user-supplied values for safe embedding in Cypher
src_esc = escape_sql_string(source_id)
tgt_esc = escape_sql_string(target_id)
uname_esc = escape_sql_string(user_name or "")

query = f"SELECT * FROM cypher('{self.db_name}_graph', $$"
query += f"\nMATCH {pattern}"
query += f"\nWHERE a.user_name = '{user_name}' AND b.user_name = '{user_name}'"
query += f"\nAND a.id = '{source_id}' AND b.id = '{target_id}'"
query += f"\nWHERE a.user_name = '{uname_esc}' AND b.user_name = '{uname_esc}'"
query += f"\nAND a.id = '{src_esc}' AND b.id = '{tgt_esc}'"
if type != "ANY":
query += f"\n AND type(r) = '{type}'"
type_esc = escape_sql_string(type)
query += f"\n AND type(r) = '{type_esc}'"

query += "\nRETURN r"
query += "\n$$) AS (r agtype)"
Expand Down
34 changes: 34 additions & 0 deletions tests/graph_dbs/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""
conftest.py for tests/graph_dbs/

Patches the memos logging subsystem early so that importing
memos.graph_dbs.polardb does not fail when log file paths or
optional dependencies (transformers, etc.) are not available.
"""

import logging
import sys
import types

from unittest.mock import MagicMock


# ── Patch memos.log before anything else imports it ──────────────────────

_mock_log = types.ModuleType("memos.log")
_mock_log.get_logger = lambda name: logging.getLogger(name)
_mock_log.ContextFilter = MagicMock()
sys.modules.setdefault("memos.log", _mock_log)

# ── Patch heavy optional deps that are not needed for unit tests ─────────

for mod_name in (
"transformers",
"nebula3",
"nebula3.gclient",
"nebula3.gclient.net",
"nebula3.gclient.net.ConnectionPool",
"neo4j",
"ollama",
):
sys.modules.setdefault(mod_name, MagicMock())
257 changes: 257 additions & 0 deletions tests/graph_dbs/test_polardb_edge_exists_cwe89.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
"""
Tests for CWE-89 fix: Cypher injection in PolarDBGraphDB.edge_exists()

Validates that source_id, target_id, user_name, and type parameters are
escaped via escape_sql_string() before being interpolated into the Cypher
query, preventing injection attacks.
"""

from pathlib import Path
from unittest.mock import MagicMock, patch

import pytest


@pytest.fixture
def polardb_instance():
"""Create a PolarDBGraphDB instance with mocked dependencies."""
with patch("memos.graph_dbs.polardb.PolarDBGraphDB.__init__", return_value=None):
from memos.graph_dbs.polardb import PolarDBGraphDB

db = PolarDBGraphDB.__new__(PolarDBGraphDB)
db.db_name = "test_db"
db.user_name = "default_user"
db.config = MagicMock()
db.config.user_name = "default_user"
yield db


def _make_mock_conn_cursor():
"""Create mocked connection and cursor with context-manager support."""
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.__enter__ = MagicMock(return_value=mock_conn)
mock_conn.__exit__ = MagicMock(return_value=False)
mock_cursor.__enter__ = MagicMock(return_value=mock_cursor)
mock_cursor.__exit__ = MagicMock(return_value=False)
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.return_value = None # no edge found
return mock_conn, mock_cursor


class TestEdgeExistsCWE89Fix:
"""Tests that edge_exists() escapes user-supplied values to prevent injection."""

def test_source_id_injection_escaped(self, polardb_instance):
"""A single-quote in source_id must be doubled, not passed raw."""
mock_conn, mock_cursor = _make_mock_conn_cursor()
malicious_source = "x' OR 1=1 --"

with patch.object(polardb_instance, "_get_connection") as mock_get_conn:
mock_get_conn.return_value = mock_conn

polardb_instance.edge_exists(
source_id=malicious_source,
target_id="legit_target",
user_name="user1",
)

query = mock_cursor.execute.call_args[0][0]
# The raw injection payload should NOT appear unescaped
assert "x' OR 1=1 --" not in query, (
"Raw injection payload in source_id must not appear in query"
)
# The escaped form (doubled single quotes) should be present
assert "x'' OR 1=1 --" in query, (
"source_id single quotes must be doubled via escape_sql_string()"
)

def test_target_id_injection_escaped(self, polardb_instance):
"""A single-quote in target_id must be doubled, not passed raw."""
mock_conn, mock_cursor = _make_mock_conn_cursor()
malicious_target = "y' RETURN n //"

with patch.object(polardb_instance, "_get_connection") as mock_get_conn:
mock_get_conn.return_value = mock_conn

polardb_instance.edge_exists(
source_id="legit_source",
target_id=malicious_target,
user_name="user1",
)

query = mock_cursor.execute.call_args[0][0]
assert "y' RETURN" not in query, (
"Raw injection payload in target_id must not appear in query"
)
assert "y'' RETURN" in query

def test_user_name_injection_escaped(self, polardb_instance):
"""A single-quote in user_name must be doubled, not passed raw."""
mock_conn, mock_cursor = _make_mock_conn_cursor()
malicious_user = "admin' OR 1=1 --"

with patch.object(polardb_instance, "_get_connection") as mock_get_conn:
mock_get_conn.return_value = mock_conn

polardb_instance.edge_exists(
source_id="src",
target_id="tgt",
user_name=malicious_user,
)

query = mock_cursor.execute.call_args[0][0]
assert "admin' OR 1=1 --" not in query, (
"Raw injection payload in user_name must not appear in query"
)
assert "admin'' OR 1=1 --" in query

def test_type_injection_escaped(self, polardb_instance):
"""A single-quote in type must be doubled, not passed raw."""
mock_conn, mock_cursor = _make_mock_conn_cursor()
malicious_type = "PARENT' OR 1=1 --"

with patch.object(polardb_instance, "_get_connection") as mock_get_conn:
mock_get_conn.return_value = mock_conn

polardb_instance.edge_exists(
source_id="src",
target_id="tgt",
type=malicious_type,
user_name="user1",
)

query = mock_cursor.execute.call_args[0][0]
assert "PARENT' OR 1=1 --" not in query, (
"Raw injection payload in type must not appear in query"
)
assert "PARENT'' OR 1=1 --" in query

def test_clean_values_no_double_escaping(self, polardb_instance):
"""Clean values without quotes should pass through unchanged."""
mock_conn, mock_cursor = _make_mock_conn_cursor()

with patch.object(polardb_instance, "_get_connection") as mock_get_conn:
mock_get_conn.return_value = mock_conn

polardb_instance.edge_exists(
source_id="abc-123",
target_id="def-456",
type="PARENT",
user_name="test_user",
)

query = mock_cursor.execute.call_args[0][0]
assert "a.id = 'abc-123'" in query
assert "b.id = 'def-456'" in query
assert "a.user_name = 'test_user'" in query
assert "type(r) = 'PARENT'" in query

def test_direction_any_works(self, polardb_instance):
"""Direction=ANY should use undirected match pattern."""
mock_conn, mock_cursor = _make_mock_conn_cursor()

with patch.object(polardb_instance, "_get_connection") as mock_get_conn:
mock_get_conn.return_value = mock_conn

polardb_instance.edge_exists(
source_id="src",
target_id="tgt",
direction="ANY",
user_name="user1",
)

query = mock_cursor.execute.call_args[0][0]
assert "(a:Memory)-[r]-(b:Memory)" in query

def test_direction_incoming_works(self, polardb_instance):
"""Direction=INCOMING should use incoming pattern."""
mock_conn, mock_cursor = _make_mock_conn_cursor()

with patch.object(polardb_instance, "_get_connection") as mock_get_conn:
mock_get_conn.return_value = mock_conn

polardb_instance.edge_exists(
source_id="src",
target_id="tgt",
direction="INCOMING",
user_name="user1",
)

query = mock_cursor.execute.call_args[0][0]
assert "(a:Memory)<-[r]-(b:Memory)" in query


class TestEdgeExistsSourceCodeSafety:
"""Verify the source code of edge_exists() directly to confirm no raw injection patterns."""

def _get_method_source(self) -> str:
"""Read the edge_exists() method source directly from the file.

Finds the @timed-decorated edge_exists (not edge_exists_old).
"""
polardb_path = (
Path(__file__).resolve().parents[2]
/ "src"
/ "memos"
/ "graph_dbs"
/ "polardb.py"
)
source = polardb_path.read_text()
lines = source.split("\n")
in_method = False
method_lines = []
indent = None
for i, line in enumerate(lines):
# Match the @timed edge_exists, not edge_exists_old
if (
"def edge_exists(" in line
and "edge_exists_old" not in line
and i > 0
and "@timed" in lines[i - 1]
):
in_method = True
indent = len(line) - len(line.lstrip())
method_lines.append(line)
continue
if in_method:
stripped = line.lstrip()
current_indent = len(line) - len(line.lstrip()) if stripped else indent + 4
if stripped.startswith("def ") and current_indent <= indent:
break
if stripped.startswith("@") and current_indent <= indent:
break
method_lines.append(line)
return "\n".join(method_lines)

def test_no_raw_user_name_interpolation(self):
"""edge_exists() should not have raw {user_name} in f-string."""
source = self._get_method_source()
assert source, "Should have found edge_exists() method source"
assert "'{user_name}'" not in source, (
"edge_exists() should NOT have raw f-string '{user_name}' interpolation"
)

def test_no_raw_source_id_interpolation(self):
"""edge_exists() should not have raw {source_id} in f-string."""
source = self._get_method_source()
assert source, "Should have found edge_exists() method source"
assert "'{source_id}'" not in source, (
"edge_exists() should NOT have raw f-string '{source_id}' interpolation"
)

def test_no_raw_target_id_interpolation(self):
"""edge_exists() should not have raw {target_id} in f-string."""
source = self._get_method_source()
assert source, "Should have found edge_exists() method source"
assert "'{target_id}'" not in source, (
"edge_exists() should NOT have raw f-string '{target_id}' interpolation"
)

def test_uses_escape_sql_string(self):
"""edge_exists() should call escape_sql_string for user-supplied values."""
source = self._get_method_source()
assert source, "Should have found edge_exists() method source"
assert "escape_sql_string" in source, (
"edge_exists() should use escape_sql_string() for escaping"
)