From c779552e5cadc6037fc3c311ff2d1b5db853f39f Mon Sep 17 00:00:00 2001 From: "vipul.mittal" Date: Thu, 26 Mar 2026 10:09:57 +0530 Subject: [PATCH] adding missing unit testcases --- tests/configuration/__init__.py | 0 tests/configuration/test_loader.py | 99 ++++++++ tests/core/dataset/test_dataset_processor.py | 139 +++++++++++ tests/core/dataset/test_file_handler.py | 199 ++++++++++++++++ .../core/dataset/test_huggingface_handler.py | 218 ++++++++++++++++++ tests/core/graph/functions/__init__.py | 0 .../graph/functions/test_edge_condition.py | 42 ++++ .../graph/functions/test_lambda_function.py | 77 +++++++ .../graph/functions/test_node_processor.py | 92 ++++++++ .../graph/langgraph/test_graph_builder.py | 162 +++++++++++++ tests/core/graph/nodes/test_connector_node.py | 63 +++++ tests/core/graph/nodes/test_lambda_node.py | 180 +++++++++++++++ tests/core/graph/nodes/test_multi_llm_node.py | 160 +++++++++++++ tests/core/graph/nodes/test_node_utils.py | 90 ++++++++ .../graph/nodes/test_weighted_sampler_node.py | 102 ++++++++ tests/core/graph/test_backend_factory.py | 63 +++++ tests/core/graph/test_sygra_message.py | 30 +++ tests/core/graph/test_sygra_state.py | 31 +++ tests/data_mapper/__init__.py | 0 tests/data_mapper/test_mapper.py | 150 ++++++++++++ tests/utils/test_decorators.py | 95 ++++++++ tests/utils/test_graph_utils.py | 96 ++++++++ tests/utils/test_model_utils.py | 177 ++++++++++++++ 23 files changed, 2265 insertions(+) create mode 100644 tests/configuration/__init__.py create mode 100644 tests/configuration/test_loader.py create mode 100644 tests/core/dataset/test_dataset_processor.py create mode 100644 tests/core/dataset/test_file_handler.py create mode 100644 tests/core/dataset/test_huggingface_handler.py create mode 100644 tests/core/graph/functions/__init__.py create mode 100644 tests/core/graph/functions/test_edge_condition.py create mode 100644 tests/core/graph/functions/test_lambda_function.py create mode 100644 tests/core/graph/functions/test_node_processor.py create mode 100644 tests/core/graph/langgraph/test_graph_builder.py create mode 100644 tests/core/graph/nodes/test_connector_node.py create mode 100644 tests/core/graph/nodes/test_lambda_node.py create mode 100644 tests/core/graph/nodes/test_multi_llm_node.py create mode 100644 tests/core/graph/nodes/test_node_utils.py create mode 100644 tests/core/graph/nodes/test_weighted_sampler_node.py create mode 100644 tests/core/graph/test_backend_factory.py create mode 100644 tests/core/graph/test_sygra_message.py create mode 100644 tests/core/graph/test_sygra_state.py create mode 100644 tests/data_mapper/__init__.py create mode 100644 tests/data_mapper/test_mapper.py create mode 100644 tests/utils/test_decorators.py create mode 100644 tests/utils/test_graph_utils.py create mode 100644 tests/utils/test_model_utils.py diff --git a/tests/configuration/__init__.py b/tests/configuration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/configuration/test_loader.py b/tests/configuration/test_loader.py new file mode 100644 index 00000000..0a67d2a4 --- /dev/null +++ b/tests/configuration/test_loader.py @@ -0,0 +1,99 @@ +import sys +import tempfile +import unittest +from pathlib import Path +from unittest.mock import MagicMock, patch + +import yaml + +sys.path.append(str(Path(__file__).parent.parent.parent)) + +from sygra.configuration.loader import ConfigLoader + + +class TestConfigLoaderLoad(unittest.TestCase): + def test_load_returns_dict_unchanged_when_given_dict(self): + loader = ConfigLoader() + config = {"task_name": "test", "nodes": {}} + result = loader.load(config) + self.assertEqual(result, config) + + def test_load_reads_yaml_file(self): + config_data = {"task_name": "my_task", "nodes": {"n1": {"node_type": "llm"}}} + with tempfile.NamedTemporaryFile(suffix=".yaml", mode="w", delete=False) as f: + yaml.dump(config_data, f) + tmp_path = f.name + + loader = ConfigLoader() + result = loader.load(tmp_path) + self.assertEqual(result["task_name"], "my_task") + + def test_load_raises_file_not_found_for_missing_file(self): + loader = ConfigLoader() + with self.assertRaises(FileNotFoundError): + loader.load("/nonexistent/path/config.yaml") + + def test_load_raises_type_error_for_non_dict_yaml(self): + with tempfile.NamedTemporaryFile(suffix=".yaml", mode="w", delete=False) as f: + yaml.dump(["item1", "item2"], f) + tmp_path = f.name + + loader = ConfigLoader() + with self.assertRaises(TypeError): + loader.load(tmp_path) + + def test_load_accepts_path_object(self): + config_data = {"task_name": "path_test"} + with tempfile.NamedTemporaryFile(suffix=".yaml", mode="w", delete=False) as f: + yaml.dump(config_data, f) + tmp_path = Path(f.name) + + loader = ConfigLoader() + result = loader.load(tmp_path) + self.assertEqual(result["task_name"], "path_test") + + +class TestConfigLoaderLoadAndCreate(unittest.TestCase): + def test_load_and_create_returns_workflow_with_correct_flags(self): + config_data = {"task_name": "my_workflow", "nodes": {}} + with tempfile.NamedTemporaryFile(suffix=".yaml", mode="w", delete=False) as f: + yaml.dump(config_data, f) + tmp_path = f.name + + loader = ConfigLoader() + workflow = loader.load_and_create(tmp_path) + + self.assertTrue(workflow._supports_subgraphs) + self.assertTrue(workflow._supports_multimodal) + self.assertTrue(workflow._supports_resumable) + self.assertTrue(workflow._supports_quality) + self.assertTrue(workflow._supports_oasst) + + def test_load_and_create_sets_name_from_parent_directory(self): + config_data = {"task_name": "my_workflow"} + with tempfile.TemporaryDirectory() as tmpdir: + task_dir = Path(tmpdir) / "my_task_name" + task_dir.mkdir() + config_file = task_dir / "graph_config.yaml" + config_file.write_text(yaml.dump(config_data)) + + loader = ConfigLoader() + workflow = loader.load_and_create(str(config_file)) + + self.assertEqual(workflow.name, "my_task_name") + + def test_load_and_create_with_dict_sets_name_from_task_name(self): + config = {"task_name": "dict_task", "nodes": {}} + loader = ConfigLoader() + workflow = loader.load_and_create(config) + self.assertEqual(workflow.name, "dict_task") + + def test_load_and_create_with_dict_without_task_name_uses_default(self): + config = {"nodes": {}} + loader = ConfigLoader() + workflow = loader.load_and_create(config) + self.assertEqual(workflow.name, "loaded_workflow") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/dataset/test_dataset_processor.py b/tests/core/dataset/test_dataset_processor.py new file mode 100644 index 00000000..5e5a21d1 --- /dev/null +++ b/tests/core/dataset/test_dataset_processor.py @@ -0,0 +1,139 @@ +import sys +import unittest +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +sys.path.append(str(Path(__file__).parent.parent.parent.parent)) + + +def _make_processor(**kwargs): + """Helper to create a DatasetProcessor with all heavy deps mocked.""" + defaults = dict( + input_dataset=[{"id": "1"}], + graph=MagicMock(), + graph_config=MagicMock(), + output_file="/tmp/tasks/my_task/output.jsonl", + num_records_total=10, + batch_size=10, + checkpoint_interval=10, + ) + defaults.update(kwargs) + + graph_config = defaults["graph_config"] + graph_config.config = {"task_name": "test_task"} + graph_config.oasst_mapper = None + + with patch("sygra.core.dataset.dataset_processor.tqdm") as mock_tqdm, \ + patch("sygra.core.dataset.dataset_processor.ResumableExecutionManager"): + mock_tqdm.tqdm.return_value = MagicMock() + from sygra.core.dataset.dataset_processor import DatasetProcessor + processor = DatasetProcessor(**defaults) + return processor + + +class TestDatasetProcessorInit(unittest.TestCase): + def test_raises_when_checkpoint_not_multiple_of_batch(self): + with patch("sygra.core.dataset.dataset_processor.tqdm"), \ + patch("sygra.core.dataset.dataset_processor.ResumableExecutionManager"): + from sygra.core.dataset.dataset_processor import DatasetProcessor + with self.assertRaises(AssertionError): + DatasetProcessor( + input_dataset=[{"id": "1"}], + graph=MagicMock(), + graph_config=MagicMock(), + output_file="/tmp/output.jsonl", + num_records_total=10, + batch_size=30, + checkpoint_interval=100, + ) + + def test_valid_checkpoint_multiple_of_batch(self): + proc = _make_processor(batch_size=10, checkpoint_interval=100) + self.assertEqual(proc.batch_size, 10) + self.assertEqual(proc.checkpoint_interval, 100) + + +class TestDetermineDatasetType(unittest.TestCase): + def setUp(self): + from sygra.core.dataset.dataset_processor import DatasetProcessor + self.DatasetProcessor = DatasetProcessor + + def test_list_returns_in_memory(self): + result = self.DatasetProcessor._determine_dataset_type([{"a": 1}]) + self.assertEqual(result, "in_memory") + + def test_streaming_attribute_true_returns_streaming(self): + mock_ds = MagicMock() + mock_ds.is_streaming = True + result = self.DatasetProcessor._determine_dataset_type(mock_ds) + self.assertEqual(result, "streaming") + + def test_iterable_dataset_returns_streaming(self): + import datasets + mock_ds = MagicMock(spec=datasets.IterableDataset) + result = self.DatasetProcessor._determine_dataset_type(mock_ds) + self.assertEqual(result, "streaming") + + def test_default_returns_in_memory(self): + mock_ds = MagicMock(spec=object) + del mock_ds.is_streaming + result = self.DatasetProcessor._determine_dataset_type(mock_ds) + self.assertEqual(result, "in_memory") + + +class TestExtractTaskName(unittest.TestCase): + def test_extracts_from_tasks_path(self): + proc = _make_processor(output_file="/data/tasks/my_task/output.jsonl") + result = proc._extract_task_name() + self.assertEqual(result, "my_task") + + def test_fallback_when_no_tasks_segment(self): + proc = _make_processor(output_file="/data/output/result.jsonl") + result = proc._extract_task_name() + self.assertIn("task_", result) + + +class TestIsErrorCodeInOutput(unittest.TestCase): + def setUp(self): + from sygra.core.dataset.dataset_processor import DatasetProcessor + self.DatasetProcessor = DatasetProcessor + + def test_returns_true_when_error_prefix_found(self): + output = {"key": "###SERVER_ERROR### something bad happened"} + self.assertTrue(self.DatasetProcessor.is_error_code_in_output(output)) + + def test_returns_false_when_no_error_prefix(self): + output = {"key": "all good", "count": 5} + self.assertFalse(self.DatasetProcessor.is_error_code_in_output(output)) + + def test_returns_false_for_non_string_values(self): + output = {"count": 42, "data": [1, 2, 3]} + self.assertFalse(self.DatasetProcessor.is_error_code_in_output(output)) + + def test_returns_false_for_empty_output(self): + self.assertFalse(self.DatasetProcessor.is_error_code_in_output({})) + + +class TestGetRecord(unittest.TestCase): + def test_assigns_uuid_when_record_has_no_id(self): + proc = _make_processor(input_dataset=[{"value": "hello"}]) + proc.resumable = False + record = proc._get_record() + self.assertIn("id", record) + self.assertTrue(len(record["id"]) > 0) + + def test_keeps_existing_id(self): + proc = _make_processor(input_dataset=[{"id": "existing-id", "value": "hello"}]) + proc.resumable = False + record = proc._get_record() + self.assertEqual(record["id"], "existing-id") + + def test_raises_stop_iteration_when_exhausted(self): + proc = _make_processor(input_dataset=[]) + proc.resumable = False + with self.assertRaises(StopIteration): + proc._get_record() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/dataset/test_file_handler.py b/tests/core/dataset/test_file_handler.py new file mode 100644 index 00000000..1ed40cdb --- /dev/null +++ b/tests/core/dataset/test_file_handler.py @@ -0,0 +1,199 @@ +import json +import sys +import tempfile +import unittest +from datetime import datetime +from pathlib import Path +from unittest.mock import MagicMock, patch + +import numpy as np + +sys.path.append(str(Path(__file__).parent.parent.parent.parent)) + +from sygra.core.dataset.file_handler import FileHandler + + +class TestFileHandlerRead(unittest.TestCase): + def setUp(self): + self.source_config = MagicMock() + self.source_config.file_path = "/some/path/data.json" + self.source_config.encoding = "utf-8" + self.source_config.shard = None + self.handler = FileHandler(source_config=self.source_config) + + def test_read_raises_when_no_path_and_no_source_config(self): + handler = FileHandler(source_config=None) + with self.assertRaises(ValueError): + handler.read() + + def test_read_raises_when_source_config_has_no_file_path(self): + self.source_config.file_path = None + with self.assertRaises(ValueError): + self.handler.read() + + def test_read_raises_for_unsupported_extension(self): + with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as f: + f.write(b"hello") + tmp_path = f.name + with self.assertRaises(ValueError): + self.handler.read(tmp_path) + + def test_read_jsonl(self): + records = [{"a": 1}, {"b": 2}] + with tempfile.NamedTemporaryFile(suffix=".jsonl", mode="w", delete=False) as f: + for rec in records: + f.write(json.dumps(rec) + "\n") + tmp_path = f.name + result = self.handler.read(tmp_path) + self.assertEqual(result, records) + + def test_read_json(self): + records = [{"x": 10}, {"y": 20}] + with tempfile.NamedTemporaryFile(suffix=".json", mode="w", delete=False) as f: + json.dump(records, f) + tmp_path = f.name + result = self.handler.read(tmp_path) + self.assertEqual(result, records) + + def test_read_json_raises_for_non_list(self): + with tempfile.NamedTemporaryFile(suffix=".json", mode="w", delete=False) as f: + json.dump({"key": "value"}, f) + tmp_path = f.name + with self.assertRaises(ValueError): + self.handler.read(tmp_path) + + @patch("sygra.core.dataset.file_handler.pd.read_parquet") + def test_read_parquet(self, mock_read_parquet): + mock_df = MagicMock() + mock_df.to_dict.return_value = [{"col": 1}] + mock_read_parquet.return_value = mock_df + result = self.handler.read("/some/file.parquet") + mock_read_parquet.assert_called_once() + self.assertEqual(result, [{"col": 1}]) + + @patch("sygra.core.dataset.file_handler.pd.read_csv") + def test_read_csv(self, mock_read_csv): + mock_df = MagicMock() + mock_df.to_dict.return_value = [{"col": "val"}] + mock_read_csv.return_value = mock_df + result = self.handler.read("/some/file.csv") + mock_read_csv.assert_called_once() + self.assertEqual(result, [{"col": "val"}]) + + def test_read_uses_source_config_path_when_no_arg(self): + records = [{"z": 99}] + with tempfile.NamedTemporaryFile(suffix=".json", mode="w", delete=False) as f: + json.dump(records, f) + tmp_path = f.name + self.source_config.file_path = tmp_path + result = self.handler.read() + self.assertEqual(result, records) + + +class TestFileHandlerWrite(unittest.TestCase): + def setUp(self): + self.output_config = MagicMock() + self.output_config.encoding = "utf-8" + self.handler = FileHandler(source_config=None, output_config=self.output_config) + + def test_write_json(self): + data = [{"key": "value"}] + with tempfile.TemporaryDirectory() as tmpdir: + path = str(Path(tmpdir) / "output.json") + self.handler.write(data, path) + with open(path, "r") as f: + result = json.load(f) + self.assertEqual(result, data) + + def test_write_jsonl(self): + data = [{"a": 1}, {"b": 2}] + with tempfile.TemporaryDirectory() as tmpdir: + path = str(Path(tmpdir) / "output.jsonl") + self.handler.write(data, path) + with open(path, "r") as f: + lines = [json.loads(l) for l in f] + self.assertEqual(lines, data) + + @patch("sygra.core.dataset.file_handler.pd.DataFrame") + def test_write_parquet(self, mock_df_cls): + mock_df = MagicMock() + mock_df_cls.return_value = mock_df + data = [{"col": 1}] + self.handler.write(data, "/some/file.parquet") + mock_df.to_parquet.assert_called_once() + + def test_write_creates_parent_directories(self): + data = [{"val": 1}] + with tempfile.TemporaryDirectory() as tmpdir: + path = str(Path(tmpdir) / "nested" / "deep" / "output.json") + self.handler.write(data, path) + self.assertTrue(Path(path).exists()) + + def test_write_serializes_datetime(self): + dt = datetime(2024, 1, 15, 12, 0, 0) + data = [{"ts": dt}] + with tempfile.TemporaryDirectory() as tmpdir: + path = str(Path(tmpdir) / "output.json") + self.handler.write(data, path) + with open(path, "r") as f: + result = json.load(f) + self.assertEqual(result[0]["ts"], dt.isoformat()) + + def test_write_serializes_numpy_array(self): + arr = np.array([1, 2, 3]) + data = [{"arr": arr}] + with tempfile.TemporaryDirectory() as tmpdir: + path = str(Path(tmpdir) / "output.json") + self.handler.write(data, path) + with open(path, "r") as f: + result = json.load(f) + self.assertEqual(result[0]["arr"], [1, 2, 3]) + + +class TestFileHandlerGetFiles(unittest.TestCase): + def test_get_files_raises_when_no_source_config(self): + handler = FileHandler(source_config=None) + with self.assertRaises(ValueError): + handler.get_files() + + def test_get_files_raises_when_no_file_path(self): + source_config = MagicMock() + source_config.file_path = None + handler = FileHandler(source_config=source_config) + with self.assertRaises(ValueError): + handler.get_files() + + def test_get_files_returns_matching_files(self): + with tempfile.TemporaryDirectory() as tmpdir: + (Path(tmpdir) / "data.json").write_text("[]") + (Path(tmpdir) / "data.jsonl").write_text("") + (Path(tmpdir) / "other.txt").write_text("skip") + + source_config = MagicMock() + source_config.file_path = str(Path(tmpdir) / "dummy.json") + source_config.shard = None + handler = FileHandler(source_config=source_config) + + files = handler.get_files() + exts = {Path(f).suffix for f in files} + self.assertTrue(exts.issubset({".json", ".jsonl", ".parquet"})) + self.assertNotIn(".txt", exts) + + def test_get_files_uses_shard_regex(self): + with tempfile.TemporaryDirectory() as tmpdir: + (Path(tmpdir) / "train-001.jsonl").write_text("") + (Path(tmpdir) / "test-001.jsonl").write_text("") + + source_config = MagicMock() + source_config.file_path = str(Path(tmpdir) / "dummy.jsonl") + shard = MagicMock() + shard.regex = "train-" + source_config.shard = shard + handler = FileHandler(source_config=source_config) + + files = handler.get_files() + self.assertTrue(all("train-" in f for f in files)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/dataset/test_huggingface_handler.py b/tests/core/dataset/test_huggingface_handler.py new file mode 100644 index 00000000..fa146444 --- /dev/null +++ b/tests/core/dataset/test_huggingface_handler.py @@ -0,0 +1,218 @@ +import sys +import unittest +from pathlib import Path +from unittest.mock import MagicMock, patch + +sys.path.append(str(Path(__file__).parent.parent.parent.parent)) + +from sygra.core.dataset.huggingface_handler import HuggingFaceHandler +from sygra.core.dataset.dataset_config import DataSourceConfig, OutputConfig + + +def _make_source_config(**kwargs): + defaults = dict( + repo_id="user/dataset", + config_name="default", + split="train", + streaming=False, + token=None, + shard=None, + encoding="utf-8", + ) + defaults.update(kwargs) + sc = MagicMock(spec=DataSourceConfig) + for k, v in defaults.items(): + setattr(sc, k, v) + return sc + + +def _make_output_config(**kwargs): + defaults = dict( + repo_id="user/output", + config_name="default", + split="train", + private=False, + token=None, + ) + defaults.update(kwargs) + oc = MagicMock(spec=OutputConfig) + for k, v in defaults.items(): + setattr(oc, k, v) + return oc + + +class TestHuggingFaceHandlerInit(unittest.TestCase): + @patch("sygra.core.dataset.huggingface_handler.HfFileSystem") + def test_init_stores_configs(self, mock_fs_cls): + sc = _make_source_config() + oc = _make_output_config() + handler = HuggingFaceHandler(source_config=sc, output_config=oc) + self.assertEqual(handler.source_config, sc) + self.assertEqual(handler.output_config, oc) + + @patch("sygra.core.dataset.huggingface_handler.HfFileSystem") + def test_init_creates_hf_filesystem(self, mock_fs_cls): + sc = _make_source_config(token="my_token") + handler = HuggingFaceHandler(source_config=sc) + mock_fs_cls.assert_called_once_with(token="my_token") + + +class TestHuggingFaceHandlerRead(unittest.TestCase): + @patch("sygra.core.dataset.huggingface_handler.HfFileSystem") + def test_read_raises_when_no_source_config(self, mock_fs_cls): + handler = HuggingFaceHandler(source_config=None) + with self.assertRaises((ValueError, RuntimeError)): + handler.read() + + @patch("sygra.core.dataset.huggingface_handler.HfFileSystem") + def test_read_calls_read_shard_when_path_and_shard_set(self, mock_fs_cls): + sc = _make_source_config(shard=MagicMock()) + handler = HuggingFaceHandler(source_config=sc) + with patch.object(handler, "_read_shard", return_value=[]) as mock_shard: + handler.read(path="some/path") + mock_shard.assert_called_once_with("some/path") + + @patch("sygra.core.dataset.huggingface_handler.HfFileSystem") + def test_read_calls_read_dataset_when_no_path(self, mock_fs_cls): + sc = _make_source_config() + handler = HuggingFaceHandler(source_config=sc) + with patch.object(handler, "_read_dataset", return_value=[]) as mock_dataset: + handler.read() + mock_dataset.assert_called_once() + + +class TestHuggingFaceHandlerWrite(unittest.TestCase): + @patch("sygra.core.dataset.huggingface_handler.HfFileSystem") + def test_write_raises_when_no_output_config(self, mock_fs_cls): + sc = _make_source_config() + handler = HuggingFaceHandler(source_config=sc, output_config=None) + with self.assertRaises((ValueError, RuntimeError)): + handler.write([{"a": 1}]) + + @patch("sygra.core.dataset.huggingface_handler.HfFileSystem") + @patch("sygra.core.dataset.huggingface_handler.HfApi") + def test_create_repo_raises_when_no_output_config(self, mock_api_cls, mock_fs_cls): + sc = _make_source_config() + handler = HuggingFaceHandler(source_config=sc, output_config=None) + with self.assertRaises(ValueError): + handler._create_repo() + + @patch("sygra.core.dataset.huggingface_handler.HfFileSystem") + @patch("sygra.core.dataset.huggingface_handler.HfApi") + def test_create_repo_raises_when_repo_id_empty(self, mock_api_cls, mock_fs_cls): + sc = _make_source_config() + oc = _make_output_config(repo_id="") + handler = HuggingFaceHandler(source_config=sc, output_config=oc) + with self.assertRaises(ValueError): + handler._create_repo() + + +class TestHuggingFaceHandlerGetFiles(unittest.TestCase): + @patch("sygra.core.dataset.huggingface_handler.HfFileSystem") + def test_get_files_raises_when_no_source_config(self, mock_fs_cls): + handler = HuggingFaceHandler(source_config=None) + with self.assertRaises((ValueError, RuntimeError)): + handler.get_files() + + @patch("sygra.core.dataset.huggingface_handler.HfFileSystem") + def test_get_files_returns_all_files_when_no_shard(self, mock_fs_cls): + sc = _make_source_config(shard=None) + handler = HuggingFaceHandler(source_config=sc) + handler.fs.glob.return_value = ["file1.parquet", "file2.parquet"] + result = handler.get_files() + self.assertEqual(result, ["file1.parquet", "file2.parquet"]) + + @patch("sygra.core.dataset.huggingface_handler.HfFileSystem") + def test_get_files_filters_by_shard_index(self, mock_fs_cls): + shard = MagicMock() + shard.regex = "/*.parquet" + shard.index = {0, 2} + sc = _make_source_config(shard=shard) + handler = HuggingFaceHandler(source_config=sc) + handler.fs.glob.return_value = ["file0.parquet", "file1.parquet", "file2.parquet"] + result = handler.get_files() + self.assertEqual(len(result), 2) + self.assertIn("file0.parquet", result) + self.assertIn("file2.parquet", result) + + +class TestHuggingFaceHandlerDecodeBase64Media(unittest.TestCase): + @patch("sygra.core.dataset.huggingface_handler.HfFileSystem") + def test_decodes_valid_data_url(self, mock_fs_cls): + import base64 + sc = _make_source_config() + handler = HuggingFaceHandler(source_config=sc) + data = b"hello bytes" + encoded = base64.b64encode(data).decode("utf-8") + data_url = f"data:image/png;base64,{encoded}" + result = handler._decode_base64_media(data_url) + self.assertEqual(len(result), 1) + self.assertEqual(result[0]["bytes"], data) + + @patch("sygra.core.dataset.huggingface_handler.HfFileSystem") + def test_returns_none_for_invalid_item(self, mock_fs_cls): + sc = _make_source_config() + handler = HuggingFaceHandler(source_config=sc) + result = handler._decode_base64_media(["not_a_data_url"]) + self.assertEqual(len(result), 1) + self.assertIsNone(result[0]) + + +class TestHuggingFaceHandlerDetectMediaColumns(unittest.TestCase): + @patch("sygra.core.dataset.huggingface_handler.HfFileSystem") + @patch("sygra.core.dataset.huggingface_handler.image_utils") + @patch("sygra.core.dataset.huggingface_handler.audio_utils") + def test_detects_image_str_column(self, mock_audio, mock_image, mock_fs_cls): + import pandas as pd + mock_image.is_data_url.return_value = True + mock_image.is_image_file_path.return_value = False + mock_audio.is_data_url.return_value = False + mock_audio.is_audio_file_path.return_value = False + + sc = _make_source_config() + handler = HuggingFaceHandler(source_config=sc) + df = pd.DataFrame({"img": ["data:image/png;base64,abc"]}) + result = handler._detect_media_columns(df) + self.assertEqual(len(result["image_str"]), 1) + + @patch("sygra.core.dataset.huggingface_handler.HfFileSystem") + @patch("sygra.core.dataset.huggingface_handler.image_utils") + @patch("sygra.core.dataset.huggingface_handler.audio_utils") + def test_detects_no_media_for_text_column(self, mock_audio, mock_image, mock_fs_cls): + import pandas as pd + mock_image.is_data_url.return_value = False + mock_image.is_image_file_path.return_value = False + mock_audio.is_data_url.return_value = False + mock_audio.is_audio_file_path.return_value = False + + sc = _make_source_config() + handler = HuggingFaceHandler(source_config=sc) + df = pd.DataFrame({"text": ["hello world"]}) + result = handler._detect_media_columns(df) + self.assertEqual(result["image_str"], []) + self.assertEqual(result["audio_str"], []) + + +class TestHuggingFaceHandlerStoreDatasetMetadata(unittest.TestCase): + @patch("sygra.core.dataset.huggingface_handler.HfFileSystem") + def test_stores_fingerprint_when_available(self, mock_fs_cls): + sc = _make_source_config() + handler = HuggingFaceHandler(source_config=sc) + mock_ds = MagicMock() + mock_ds._fingerprint = "abc123" + mock_ds.info = None + handler._store_dataset_metadata(mock_ds) + self.assertEqual(handler.dataset_hash, "abc123") + + @patch("sygra.core.dataset.huggingface_handler.HfFileSystem") + def test_handles_dataset_without_fingerprint(self, mock_fs_cls): + sc = _make_source_config() + handler = HuggingFaceHandler(source_config=sc) + mock_ds = MagicMock(spec=[]) + # No _fingerprint attribute + handler._store_dataset_metadata(mock_ds) + self.assertIsNone(handler.dataset_hash) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/graph/functions/__init__.py b/tests/core/graph/functions/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/core/graph/functions/test_edge_condition.py b/tests/core/graph/functions/test_edge_condition.py new file mode 100644 index 00000000..bdbc9e76 --- /dev/null +++ b/tests/core/graph/functions/test_edge_condition.py @@ -0,0 +1,42 @@ +import sys +import unittest +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent.parent.parent.parent)) + +from sygra.core.graph.functions.edge_condition import EdgeCondition + + +class TestEdgeConditionABC(unittest.TestCase): + def test_cannot_instantiate_directly(self): + with self.assertRaises(TypeError): + EdgeCondition() + + def test_concrete_subclass_can_be_instantiated(self): + class MyCondition(EdgeCondition): + @staticmethod + def apply(state) -> str: + return "path_a" + + cond = MyCondition() + self.assertIsInstance(cond, EdgeCondition) + + def test_apply_is_called_on_concrete_subclass(self): + class MyCondition(EdgeCondition): + @staticmethod + def apply(state) -> str: + return "path_b" + + result = MyCondition.apply({"key": "val"}) + self.assertEqual(result, "path_b") + + def test_incomplete_subclass_cannot_be_instantiated(self): + class IncompleteCondition(EdgeCondition): + pass + + with self.assertRaises(TypeError): + IncompleteCondition() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/graph/functions/test_lambda_function.py b/tests/core/graph/functions/test_lambda_function.py new file mode 100644 index 00000000..2b0e56fa --- /dev/null +++ b/tests/core/graph/functions/test_lambda_function.py @@ -0,0 +1,77 @@ +import sys +import unittest +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent.parent.parent.parent)) + +from sygra.core.graph.functions.lambda_function import AsyncLambdaFunction, LambdaFunction + + +class TestLambdaFunctionABC(unittest.TestCase): + def test_cannot_instantiate_directly(self): + with self.assertRaises(TypeError): + LambdaFunction() + + def test_concrete_subclass_can_be_instantiated(self): + class MyLambda(LambdaFunction): + @staticmethod + def apply(config, state): + return state + + fn = MyLambda() + self.assertIsInstance(fn, LambdaFunction) + + def test_apply_is_callable_on_concrete_subclass(self): + class MyLambda(LambdaFunction): + @staticmethod + def apply(config, state): + return {**state, "processed": True} + + result = MyLambda.apply({}, {"x": 1}) + self.assertTrue(result["processed"]) + + def test_incomplete_subclass_cannot_be_instantiated(self): + class IncompleteLambda(LambdaFunction): + pass + + with self.assertRaises(TypeError): + IncompleteLambda() + + +class TestAsyncLambdaFunctionABC(unittest.TestCase): + def test_cannot_instantiate_directly(self): + with self.assertRaises(TypeError): + AsyncLambdaFunction() + + def test_concrete_subclass_can_be_instantiated(self): + class MyAsyncLambda(AsyncLambdaFunction): + @staticmethod + async def apply(config, state): + return state + + fn = MyAsyncLambda() + self.assertIsInstance(fn, AsyncLambdaFunction) + + def test_apply_is_async(self): + import asyncio + + class MyAsyncLambda(AsyncLambdaFunction): + @staticmethod + async def apply(config, state): + return {**state, "async_processed": True} + + result = asyncio.get_event_loop().run_until_complete( + MyAsyncLambda.apply({}, {"x": 1}) + ) + self.assertTrue(result["async_processed"]) + + def test_incomplete_subclass_cannot_be_instantiated(self): + class IncompleteAsync(AsyncLambdaFunction): + pass + + with self.assertRaises(TypeError): + IncompleteAsync() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/graph/functions/test_node_processor.py b/tests/core/graph/functions/test_node_processor.py new file mode 100644 index 00000000..2a8b3db4 --- /dev/null +++ b/tests/core/graph/functions/test_node_processor.py @@ -0,0 +1,92 @@ +import sys +import unittest +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent.parent.parent.parent)) + +from sygra.core.graph.functions.node_processor import ( + NodePostProcessor, + NodePostProcessorWithState, + NodePreProcessor, +) + + +class TestNodePreProcessorABC(unittest.TestCase): + def test_cannot_instantiate_directly(self): + with self.assertRaises(TypeError): + NodePreProcessor() + + def test_concrete_subclass_can_be_instantiated(self): + class MyPreProcessor(NodePreProcessor): + def apply(self, state): + return state + + proc = MyPreProcessor() + self.assertIsInstance(proc, NodePreProcessor) + + def test_apply_receives_state_and_returns_state(self): + class MyPreProcessor(NodePreProcessor): + def apply(self, state): + state["preprocessed"] = True + return state + + proc = MyPreProcessor() + result = proc.apply({"x": 1}) + self.assertTrue(result["preprocessed"]) + + +class TestNodePostProcessorABC(unittest.TestCase): + def test_cannot_instantiate_directly(self): + with self.assertRaises(TypeError): + NodePostProcessor() + + def test_concrete_subclass_can_be_instantiated(self): + class MyPostProcessor(NodePostProcessor): + def apply(self, resp): + return {"result": resp.message} + + proc = MyPostProcessor() + self.assertIsInstance(proc, NodePostProcessor) + + def test_apply_receives_sygra_message(self): + from sygra.core.graph.sygra_message import SygraMessage + + class MyPostProcessor(NodePostProcessor): + def apply(self, resp): + return {"content": resp.message} + + proc = MyPostProcessor() + msg = SygraMessage("hello") + result = proc.apply(msg) + self.assertEqual(result["content"], "hello") + + +class TestNodePostProcessorWithStateABC(unittest.TestCase): + def test_cannot_instantiate_directly(self): + with self.assertRaises(TypeError): + NodePostProcessorWithState() + + def test_concrete_subclass_can_be_instantiated(self): + class MyPostProcessorWithState(NodePostProcessorWithState): + def apply(self, resp, state): + return {**state, "response": resp.message} + + proc = MyPostProcessorWithState() + self.assertIsInstance(proc, NodePostProcessorWithState) + + def test_apply_receives_response_and_state(self): + from sygra.core.graph.sygra_message import SygraMessage + + class MyPostProcessorWithState(NodePostProcessorWithState): + def apply(self, resp, state): + return {**state, "response": resp.message} + + proc = MyPostProcessorWithState() + msg = SygraMessage("world") + result = proc.apply(msg, {"existing": "val"}) + self.assertEqual(result["response"], "world") + self.assertEqual(result["existing"], "val") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/graph/langgraph/test_graph_builder.py b/tests/core/graph/langgraph/test_graph_builder.py new file mode 100644 index 00000000..bbdbf34c --- /dev/null +++ b/tests/core/graph/langgraph/test_graph_builder.py @@ -0,0 +1,162 @@ +import sys +import unittest +from pathlib import Path +from unittest.mock import MagicMock, patch + +sys.path.append(str(Path(__file__).parent.parent.parent.parent.parent)) + +from langgraph.constants import END, START + +from sygra.core.graph.langgraph.graph_builder import LangGraphBuilder + + +class TestLangGraphBuilderInit(unittest.TestCase): + def test_init_stores_graph_config(self): + graph_config = MagicMock() + builder = LangGraphBuilder(graph_config) + self.assertEqual(builder.graph_config, graph_config) + + def test_init_sets_workflow_to_none(self): + builder = LangGraphBuilder(MagicMock()) + self.assertIsNone(builder.workflow) + + def test_special_nodes_map_contains_start_and_end(self): + self.assertEqual(LangGraphBuilder.SPECIAL_NODES_MAP["START"], START) + self.assertEqual(LangGraphBuilder.SPECIAL_NODES_MAP["END"], END) + + +class TestLangGraphBuilderGetNode(unittest.TestCase): + def test_get_node_returns_start_for_start_key(self): + result = LangGraphBuilder.get_node("START") + self.assertEqual(result, START) + + def test_get_node_returns_end_for_end_key(self): + result = LangGraphBuilder.get_node("END") + self.assertEqual(result, END) + + def test_get_node_returns_value_unchanged_for_regular_name(self): + result = LangGraphBuilder.get_node("my_node") + self.assertEqual(result, "my_node") + + +class TestLangGraphBuilderConvertToGraph(unittest.TestCase): + def test_returns_none_for_none_input(self): + builder = LangGraphBuilder(MagicMock()) + result = builder.convert_to_graph(None) + self.assertIsNone(result) + + def test_returns_start_for_special_start_node(self): + builder = LangGraphBuilder(MagicMock()) + node = MagicMock() + node.get_name.return_value = "START" + node.is_special_type.return_value = True + result = builder.convert_to_graph(node) + self.assertEqual(result, START) + + def test_returns_end_for_special_end_node(self): + builder = LangGraphBuilder(MagicMock()) + node = MagicMock() + node.get_name.return_value = "END" + node.is_special_type.return_value = True + result = builder.convert_to_graph(node) + self.assertEqual(result, END) + + def test_returns_node_name_for_regular_node(self): + builder = LangGraphBuilder(MagicMock()) + node = MagicMock() + node.get_name.return_value = "my_node" + node.is_special_type.return_value = False + result = builder.convert_to_graph(node) + self.assertEqual(result, "my_node") + + +class TestLangGraphBuilderCompile(unittest.TestCase): + def test_compile_raises_when_build_not_called(self): + builder = LangGraphBuilder(MagicMock()) + with self.assertRaises(RuntimeError): + builder.compile() + + def test_compile_returns_compiled_graph_after_build(self): + graph_config = MagicMock() + graph_config.get_nodes.return_value = {} + graph_config.get_edges.return_value = [] + graph_config.state_variables = set() + graph_config.sub_graphs = {} + + mock_workflow = MagicMock() + mock_compiled = MagicMock() + mock_workflow.compile.return_value = mock_compiled + + with patch("sygra.core.graph.langgraph.graph_builder.backend_factory") as mock_bf, \ + patch("sygra.core.graph.langgraph.graph_builder.EdgeFactory") as mock_ef: + mock_bf.build_workflow.return_value = mock_workflow + mock_ef.return_value.get_edges.return_value = [] + + builder = LangGraphBuilder(graph_config) + builder.build() + result = builder.compile() + + mock_workflow.compile.assert_called_once() + self.assertEqual(result, mock_compiled) + + +class TestLangGraphBuilderAddNodes(unittest.TestCase): + def test_add_nodes_adds_active_nodes(self): + graph_config = MagicMock() + active_node = MagicMock() + active_node.is_active.return_value = True + active_node.to_backend.return_value = MagicMock() + graph_config.get_nodes.return_value = {"active_node": active_node} + + workflow = MagicMock() + builder = LangGraphBuilder(graph_config) + builder.add_nodes(workflow) + + workflow.add_node.assert_called_once_with("active_node", active_node.to_backend()) + + def test_add_nodes_skips_inactive_nodes(self): + graph_config = MagicMock() + inactive_node = MagicMock() + inactive_node.is_active.return_value = False + graph_config.get_nodes.return_value = {"inactive_node": inactive_node} + + workflow = MagicMock() + builder = LangGraphBuilder(graph_config) + builder.add_nodes(workflow) + + workflow.add_node.assert_not_called() + + +class TestLangGraphBuilderUpdateStateVariables(unittest.TestCase): + def test_adds_node_state_variables_to_graph_config(self): + graph_config = MagicMock() + graph_config.state_variables = set() + + node = MagicMock() + node.is_active.return_value = True + node.get_state_variables.return_value = ["var_a", "var_b"] + graph_config.get_nodes.return_value = {"node1": node} + + builder = LangGraphBuilder(graph_config) + builder._update_state_variables() + + self.assertIn("var_a", graph_config.state_variables) + self.assertIn("var_b", graph_config.state_variables) + + def test_does_not_add_duplicate_state_variables(self): + graph_config = MagicMock() + graph_config.state_variables = {"var_a"} + + node = MagicMock() + node.is_active.return_value = True + node.get_state_variables.return_value = ["var_a"] + graph_config.get_nodes.return_value = {"node1": node} + + builder = LangGraphBuilder(graph_config) + builder._update_state_variables() + + self.assertEqual(graph_config.state_variables.count("var_a") if hasattr(graph_config.state_variables, 'count') else 1, 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/graph/nodes/test_connector_node.py b/tests/core/graph/nodes/test_connector_node.py new file mode 100644 index 00000000..db209f81 --- /dev/null +++ b/tests/core/graph/nodes/test_connector_node.py @@ -0,0 +1,63 @@ +import sys +import unittest +from pathlib import Path +from unittest.mock import MagicMock, patch + +sys.path.append(str(Path(__file__).parent.parent.parent.parent.parent)) + +from sygra.core.graph.nodes.base_node import NodeState, NodeType +from sygra.core.graph.nodes.connector_node import ConnectorNode + + +class TestConnectorNodeInit(unittest.TestCase): + @patch("sygra.core.graph.nodes.connector_node.utils") + def test_init_sets_node_state_active(self, mock_utils): + mock_utils.backend_factory = MagicMock() + node = ConnectorNode("connector_1") + self.assertEqual(node.node_state, NodeState.ACTIVE.value) + + @patch("sygra.core.graph.nodes.connector_node.utils") + def test_init_sets_node_type_connector(self, mock_utils): + mock_utils.backend_factory = MagicMock() + node = ConnectorNode("connector_1") + self.assertEqual(node.node_type, NodeType.CONNECTOR.value) + + @patch("sygra.core.graph.nodes.connector_node.utils") + def test_init_stores_node_name(self, mock_utils): + mock_utils.backend_factory = MagicMock() + node = ConnectorNode("my_connector") + self.assertEqual(node.name, "my_connector") + + +class TestConnectorNodeIsValid(unittest.TestCase): + @patch("sygra.core.graph.nodes.connector_node.utils") + def test_is_valid_returns_true(self, mock_utils): + mock_utils.backend_factory = MagicMock() + node = ConnectorNode("connector_1") + self.assertTrue(node.is_valid()) + + +class TestConnectorNodeIsActive(unittest.TestCase): + @patch("sygra.core.graph.nodes.connector_node.utils") + def test_is_active_returns_true(self, mock_utils): + mock_utils.backend_factory = MagicMock() + node = ConnectorNode("connector_1") + self.assertTrue(node.is_active()) + + +class TestConnectorNodeToBackend(unittest.TestCase): + @patch("sygra.core.graph.nodes.connector_node.utils") + def test_to_backend_calls_create_connector_runnable(self, mock_utils): + mock_utils.backend_factory = MagicMock() + expected = MagicMock() + mock_utils.backend_factory.create_connector_runnable.return_value = expected + + node = ConnectorNode("connector_1") + result = node.to_backend() + + mock_utils.backend_factory.create_connector_runnable.assert_called_once_with() + self.assertEqual(result, expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/graph/nodes/test_lambda_node.py b/tests/core/graph/nodes/test_lambda_node.py new file mode 100644 index 00000000..2e6d02b6 --- /dev/null +++ b/tests/core/graph/nodes/test_lambda_node.py @@ -0,0 +1,180 @@ +import sys +import unittest +from pathlib import Path +from unittest.mock import MagicMock, patch + +sys.path.append(str(Path(__file__).parent.parent.parent.parent.parent)) + +from sygra.core.graph.nodes.lambda_node import LambdaNode + + +def _make_config(lambda_path="some.module.my_func", extra=None): + config = {"node_type": "lambda", "lambda": lambda_path} + if extra: + config.update(extra) + return config + + +class TestLambdaNodeInit(unittest.TestCase): + @patch("sygra.core.graph.nodes.lambda_node.utils") + def test_init_sync_function_sets_func_type_sync(self, mock_utils): + sync_func = MagicMock() + mock_utils.get_func_from_str.return_value = sync_func + mock_utils.backend_factory = MagicMock() + + node = LambdaNode("my_node", _make_config()) + + mock_utils.get_func_from_str.assert_called_once_with("some.module.my_func") + self.assertEqual(node.func_type, "sync") + self.assertEqual(node.func_to_execute, sync_func) + + @patch("sygra.core.graph.nodes.lambda_node.utils") + def test_init_async_function_sets_func_type_async(self, mock_utils): + async def async_func(config, state): + pass + + mock_utils.get_func_from_str.return_value = async_func + mock_utils.backend_factory = MagicMock() + + node = LambdaNode("my_node", _make_config()) + + self.assertEqual(node.func_type, "async") + + @patch("sygra.core.graph.nodes.lambda_node.utils") + def test_init_with_class_uses_apply_method(self, mock_utils): + class MyClass: + @staticmethod + def apply(config, state): + return state + + mock_utils.get_func_from_str.return_value = MyClass + mock_utils.backend_factory = MagicMock() + + node = LambdaNode("my_node", _make_config()) + + self.assertEqual(node.func_to_execute, MyClass.apply) + self.assertEqual(node.func_type, "sync") + + +class TestLambdaNodeValidate(unittest.TestCase): + @patch("sygra.core.graph.nodes.lambda_node.utils") + def test_validate_node_raises_when_lambda_key_missing(self, mock_utils): + mock_utils.validate_required_keys.side_effect = ValueError("Missing key: lambda") + mock_utils.backend_factory = MagicMock() + + with self.assertRaises((ValueError, Exception)): + LambdaNode("my_node", {"node_type": "lambda"}) + + +class TestLambdaNodeToBackend(unittest.TestCase): + @patch("sygra.core.graph.nodes.lambda_node.utils") + def test_to_backend_sync_calls_create_lambda_runnable_with_async_false(self, mock_utils): + sync_func = MagicMock() + mock_utils.get_func_from_str.return_value = sync_func + mock_utils.backend_factory = MagicMock() + expected = MagicMock() + mock_utils.backend_factory.create_lambda_runnable.return_value = expected + + node = LambdaNode("my_node", _make_config()) + result = node.to_backend() + + mock_utils.backend_factory.create_lambda_runnable.assert_called_once_with( + node._sync_exec_wrapper, async_func=False + ) + self.assertEqual(result, expected) + + @patch("sygra.core.graph.nodes.lambda_node.utils") + def test_to_backend_async_calls_create_lambda_runnable(self, mock_utils): + async def async_func(config, state): + pass + + mock_utils.get_func_from_str.return_value = async_func + mock_utils.backend_factory = MagicMock() + expected = MagicMock() + mock_utils.backend_factory.create_lambda_runnable.return_value = expected + + node = LambdaNode("my_node", _make_config()) + result = node.to_backend() + + mock_utils.backend_factory.create_lambda_runnable.assert_called_once_with( + node._async_exec_wrapper + ) + self.assertEqual(result, expected) + + +class TestLambdaNodeSyncExecWrapper(unittest.TestCase): + @patch("sygra.core.graph.nodes.lambda_node.utils") + def test_sync_exec_wrapper_calls_func_and_records_metadata(self, mock_utils): + state = {"input": "hello"} + expected_result = {"output": "world"} + sync_func = MagicMock(return_value=expected_result) + mock_utils.get_func_from_str.return_value = sync_func + mock_utils.backend_factory = MagicMock() + + node = LambdaNode("my_node", _make_config()) + + with patch.object(node, "_record_execution_metadata") as mock_record: + result = node._sync_exec_wrapper(state) + + sync_func.assert_called_once_with(node.node_config, state) + self.assertEqual(result, expected_result) + mock_record.assert_called_once() + self.assertTrue(mock_record.call_args[0][1]) + + @patch("sygra.core.graph.nodes.lambda_node.utils") + def test_sync_exec_wrapper_records_failure_on_exception(self, mock_utils): + sync_func = MagicMock(side_effect=RuntimeError("boom")) + mock_utils.get_func_from_str.return_value = sync_func + mock_utils.backend_factory = MagicMock() + + node = LambdaNode("my_node", _make_config()) + + with patch.object(node, "_record_execution_metadata") as mock_record: + with self.assertRaises(RuntimeError): + node._sync_exec_wrapper({"x": 1}) + + mock_record.assert_called_once() + self.assertFalse(mock_record.call_args[0][1]) + + +class TestLambdaNodeAsyncExecWrapper(unittest.IsolatedAsyncioTestCase): + @patch("sygra.core.graph.nodes.lambda_node.utils") + async def test_async_exec_wrapper_calls_func_and_records_metadata(self, mock_utils): + state = {"input": "hello"} + expected_result = {"output": "world"} + + async def async_func(config, st): + return expected_result + + mock_utils.get_func_from_str.return_value = async_func + mock_utils.backend_factory = MagicMock() + + node = LambdaNode("my_node", _make_config()) + + with patch.object(node, "_record_execution_metadata") as mock_record: + result = await node._async_exec_wrapper(state) + + self.assertEqual(result, expected_result) + mock_record.assert_called_once() + self.assertTrue(mock_record.call_args[0][1]) + + @patch("sygra.core.graph.nodes.lambda_node.utils") + async def test_async_exec_wrapper_records_failure_on_exception(self, mock_utils): + async def async_func(config, state): + raise ValueError("async boom") + + mock_utils.get_func_from_str.return_value = async_func + mock_utils.backend_factory = MagicMock() + + node = LambdaNode("my_node", _make_config()) + + with patch.object(node, "_record_execution_metadata") as mock_record: + with self.assertRaises(ValueError): + await node._async_exec_wrapper({"x": 1}) + + mock_record.assert_called_once() + self.assertFalse(mock_record.call_args[0][1]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/graph/nodes/test_multi_llm_node.py b/tests/core/graph/nodes/test_multi_llm_node.py new file mode 100644 index 00000000..c48af6b3 --- /dev/null +++ b/tests/core/graph/nodes/test_multi_llm_node.py @@ -0,0 +1,160 @@ +import sys +import unittest +from pathlib import Path +from unittest.mock import MagicMock, patch + +sys.path.append(str(Path(__file__).parent.parent.parent.parent.parent)) + +from sygra.core.graph.nodes.multi_llm_node import MultiLLMNode + + +def _make_config(models=None, extra=None): + config = { + "node_type": "multi_llm", + "prompt": [{"role": "system", "content": "You are helpful."}], + "models": models or { + "model_a": {"name": "a", "model": "gpt-a", "model_type": "openai"}, + "model_b": {"name": "b", "model": "gpt-b", "model_type": "openai"}, + }, + } + if extra: + config.update(extra) + return config + + +class TestMultiLLMNodeInit(unittest.TestCase): + @patch("sygra.core.graph.nodes.multi_llm_node.LLMNode") + @patch("sygra.core.graph.nodes.multi_llm_node.utils") + def test_init_creates_llm_node_per_model(self, mock_utils, mock_llm_node_cls): + mock_utils.backend_factory = MagicMock() + mock_llm_node_cls.return_value = MagicMock() + + node = MultiLLMNode("multi_node", _make_config()) + + self.assertEqual(mock_llm_node_cls.call_count, 2) + self.assertIn("model_a", node.llm_dict) + self.assertIn("model_b", node.llm_dict) + + @patch("sygra.core.graph.nodes.multi_llm_node.LLMNode") + @patch("sygra.core.graph.nodes.multi_llm_node.utils") + def test_init_uses_default_post_process_when_not_in_config(self, mock_utils, mock_llm_node_cls): + mock_utils.backend_factory = MagicMock() + mock_llm_node_cls.return_value = MagicMock() + + node = MultiLLMNode("multi_node", _make_config()) + + self.assertEqual(node.multi_llm_post_process, node._default_multi_llm_post_process) + + @patch("sygra.core.graph.nodes.multi_llm_node.LLMNode") + @patch("sygra.core.graph.nodes.multi_llm_node.utils") + def test_init_with_custom_post_process_function(self, mock_utils, mock_llm_node_cls): + mock_utils.backend_factory = MagicMock() + mock_llm_node_cls.return_value = MagicMock() + custom_func = MagicMock() + mock_utils.get_func_from_str.return_value = custom_func + + config = _make_config(extra={"multi_llm_post_process": "some.module.custom_func"}) + node = MultiLLMNode("multi_node", config) + + self.assertEqual(node.multi_llm_post_process, custom_func) + + @patch("sygra.core.graph.nodes.multi_llm_node.LLMNode") + @patch("sygra.core.graph.nodes.multi_llm_node.utils") + def test_init_with_post_process_class_uses_instance_apply(self, mock_utils, mock_llm_node_cls): + mock_utils.backend_factory = MagicMock() + mock_llm_node_cls.return_value = MagicMock() + + class MyPostProcess: + def apply(self, model_outputs): + return model_outputs + + mock_utils.get_func_from_str.return_value = MyPostProcess + + config = _make_config(extra={"multi_llm_post_process": "some.module.MyPostProcess"}) + node = MultiLLMNode("multi_node", config) + + self.assertEqual(node.multi_llm_post_process.__name__, "apply") + + +class TestDefaultMultiLLMPostProcess(unittest.TestCase): + @patch("sygra.core.graph.nodes.multi_llm_node.LLMNode") + @patch("sygra.core.graph.nodes.multi_llm_node.utils") + def test_default_post_process_aggregates_outputs(self, mock_utils, mock_llm_node_cls): + mock_utils.backend_factory = MagicMock() + mock_llm_node_cls.return_value = MagicMock() + + config = _make_config(models={"m1": {"name": "m1", "model": "gpt", "model_type": "openai"}}) + node = MultiLLMNode("multi_node", config) + node.output_key = "messages" + + result = node._default_multi_llm_post_process({ + "m1": {"messages": ["response from m1"]}, + }) + + self.assertIn("messages", result) + self.assertEqual(result["messages"][0]["m1"], ["response from m1"]) + + +class TestMultiLLMNodeValidate(unittest.TestCase): + @patch("sygra.core.graph.nodes.multi_llm_node.LLMNode") + @patch("sygra.core.graph.nodes.multi_llm_node.utils") + def test_validate_raises_when_models_missing(self, mock_utils, mock_llm_node_cls): + mock_utils.validate_required_keys.side_effect = ValueError("Missing key: models") + mock_utils.backend_factory = MagicMock() + + with self.assertRaises((ValueError, Exception)): + MultiLLMNode("multi_node", {"node_type": "multi_llm", "prompt": []}) + + @patch("sygra.core.graph.nodes.multi_llm_node.LLMNode") + @patch("sygra.core.graph.nodes.multi_llm_node.utils") + def test_validate_raises_when_prompt_missing(self, mock_utils, mock_llm_node_cls): + mock_utils.validate_required_keys.side_effect = ValueError("Missing key: prompt") + mock_utils.backend_factory = MagicMock() + + with self.assertRaises((ValueError, Exception)): + MultiLLMNode("multi_node", {"node_type": "multi_llm", "models": {}}) + + +class TestMultiLLMNodeToBackend(unittest.TestCase): + @patch("sygra.core.graph.nodes.multi_llm_node.LLMNode") + @patch("sygra.core.graph.nodes.multi_llm_node.utils") + def test_to_backend_calls_backend_factory(self, mock_utils, mock_llm_node_cls): + mock_utils.backend_factory = MagicMock() + mock_llm_node_cls.return_value = MagicMock() + expected = MagicMock() + mock_utils.backend_factory.create_multi_llm_runnable.return_value = expected + + node = MultiLLMNode("multi_node", _make_config()) + result = node.to_backend() + + mock_utils.backend_factory.create_multi_llm_runnable.assert_called_once_with( + node.llm_dict, node.multi_llm_post_process + ) + self.assertEqual(result, expected) + + +class TestMultiLLMNodeExecWrapper(unittest.IsolatedAsyncioTestCase): + @patch("sygra.core.graph.nodes.multi_llm_node.LLMNode") + @patch("sygra.core.graph.nodes.multi_llm_node.utils") + async def test_exec_wrapper_calls_llm_nodes_and_post_process(self, mock_utils, mock_llm_node_cls): + mock_utils.backend_factory = MagicMock() + + fake_llm_a = MagicMock() + fake_llm_a._exec_wrapper = MagicMock(return_value={"messages": ["output_a"]}) + fake_llm_b = MagicMock() + fake_llm_b._exec_wrapper = MagicMock(return_value={"messages": ["output_b"]}) + mock_llm_node_cls.side_effect = [fake_llm_a, fake_llm_b] + + node = MultiLLMNode("multi_node", _make_config()) + post_process_mock = MagicMock(return_value={"messages": ["aggregated"]}) + node.multi_llm_post_process = post_process_mock + + with patch.object(node, "_record_execution_metadata"): + result = await node._exec_wrapper({"input": "hello"}) + + post_process_mock.assert_called_once() + self.assertEqual(result, {"messages": ["aggregated"]}) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/graph/nodes/test_node_utils.py b/tests/core/graph/nodes/test_node_utils.py new file mode 100644 index 00000000..9a8b7e8f --- /dev/null +++ b/tests/core/graph/nodes/test_node_utils.py @@ -0,0 +1,90 @@ +import sys +import unittest +from pathlib import Path +from unittest.mock import MagicMock, patch + +sys.path.append(str(Path(__file__).parent.parent.parent.parent.parent)) + +from sygra.core.graph.nodes.node_utils import get_node, get_node_config + + +class TestGetNode(unittest.TestCase): + def test_raises_assertion_when_node_type_missing(self): + with self.assertRaises(AssertionError): + get_node("my_node", {}) + + def test_raises_not_implemented_for_unknown_type(self): + with self.assertRaises(NotImplementedError): + get_node("my_node", {"node_type": "nonexistent_type"}) + + def test_returns_special_node_for_special_type(self): + from sygra.core.graph.nodes.special_node import SpecialNode + node = get_node("START", {"node_type": "special"}) + self.assertIsInstance(node, SpecialNode) + self.assertEqual(node.name, "START") + + @patch("sygra.core.graph.nodes.node_utils.ConnectorNode") + def test_returns_connector_node_for_connector_type(self, mock_cls): + mock_instance = MagicMock() + mock_cls.return_value = mock_instance + result = get_node("conn_node", {"node_type": "connector"}) + mock_cls.assert_called_once_with("conn_node") + self.assertEqual(result, mock_instance) + + @patch("sygra.core.graph.nodes.node_utils.LLMNode") + def test_returns_llm_node_for_llm_type(self, mock_cls): + mock_instance = MagicMock() + mock_cls.return_value = mock_instance + config = {"node_type": "llm", "model": {}, "prompt": []} + result = get_node("llm_node", config) + mock_cls.assert_called_once_with("llm_node", config) + self.assertEqual(result, mock_instance) + + @patch("sygra.core.graph.nodes.node_utils.LambdaNode") + def test_returns_lambda_node_for_lambda_type(self, mock_cls): + mock_instance = MagicMock() + mock_cls.return_value = mock_instance + config = {"node_type": "lambda", "lambda": "some.module.func"} + result = get_node("lambda_node", config) + mock_cls.assert_called_once_with("lambda_node", config) + self.assertEqual(result, mock_instance) + + @patch("sygra.core.graph.nodes.node_utils.MultiLLMNode") + def test_returns_multi_llm_node_for_multi_llm_type(self, mock_cls): + mock_instance = MagicMock() + mock_cls.return_value = mock_instance + config = {"node_type": "multi_llm", "models": {}, "prompt": []} + result = get_node("multi_node", config) + mock_cls.assert_called_once_with("multi_node", config) + self.assertEqual(result, mock_instance) + + +class TestGetNodeConfig(unittest.TestCase): + def test_raises_assertion_when_node_not_in_config(self): + with self.assertRaises(AssertionError): + get_node_config("missing_node", {"nodes": {"other_node": {}}}) + + def test_returns_correct_config(self): + node_cfg = {"node_type": "llm", "prompt": [], "model": {}} + result = get_node_config("my_node", {"nodes": {"my_node": node_cfg}}) + self.assertEqual(result, node_cfg) + + def test_raises_assertion_when_nodes_key_absent(self): + with self.assertRaises(AssertionError): + get_node_config("my_node", {}) + + def test_returns_config_among_multiple_nodes(self): + target_cfg = {"node_type": "lambda", "lambda": "my.func"} + graph_config = { + "nodes": { + "node_a": {"node_type": "llm"}, + "node_b": target_cfg, + "node_c": {"node_type": "special"}, + } + } + result = get_node_config("node_b", graph_config) + self.assertEqual(result, target_cfg) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/graph/nodes/test_weighted_sampler_node.py b/tests/core/graph/nodes/test_weighted_sampler_node.py new file mode 100644 index 00000000..8dbb127d --- /dev/null +++ b/tests/core/graph/nodes/test_weighted_sampler_node.py @@ -0,0 +1,102 @@ +import sys +import unittest +from pathlib import Path +from unittest.mock import MagicMock, patch + +sys.path.append(str(Path(__file__).parent.parent.parent.parent.parent)) + +from sygra.core.graph.nodes.weighted_sampler_node import WeightedSamplerNode + + +def _make_config(attributes=None, extra=None): + config = { + "node_type": "weighted_sampler", + "attributes": attributes or {"color": {"values": ["red", "blue", "green"]}}, + } + if extra: + config.update(extra) + return config + + +class TestWeightedSamplerNodeInit(unittest.TestCase): + @patch("sygra.core.graph.nodes.weighted_sampler_node.utils") + def test_init_adds_attribute_keys_to_state_variables(self, mock_utils): + mock_utils.backend_factory = MagicMock() + config = _make_config(attributes={"color": {"values": ["red"]}, "size": {"values": ["S"]}}) + node = WeightedSamplerNode("sampler", config) + self.assertIn("color", node.state_variables) + self.assertIn("size", node.state_variables) + + @patch("sygra.core.graph.nodes.weighted_sampler_node.utils") + def test_init_raises_when_attributes_not_dict(self, mock_utils): + mock_utils.backend_factory = MagicMock() + config = {"node_type": "weighted_sampler", "attributes": ["not", "a", "dict"]} + with self.assertRaises(ValueError): + WeightedSamplerNode("sampler", config) + + +class TestWeightedSamplerWeightedSampler(unittest.TestCase): + @patch("sygra.core.graph.nodes.weighted_sampler_node.utils") + def test_samples_from_static_list(self, mock_utils): + mock_utils.backend_factory = MagicMock() + config = _make_config(attributes={"color": {"values": ["red", "blue", "green"]}}) + node = WeightedSamplerNode("sampler", config) + result = node._weighted_sampler(config["attributes"], {}) + self.assertIn("color", result) + self.assertIn(result["color"], ["red", "blue", "green"]) + + @patch("sygra.core.graph.nodes.weighted_sampler_node.utils") + def test_respects_weights(self, mock_utils): + mock_utils.backend_factory = MagicMock() + config = _make_config(attributes={"color": {"values": ["red", "blue"], "weights": [0, 1]}}) + node = WeightedSamplerNode("sampler", config) + results = {node._weighted_sampler(config["attributes"], {})["color"] for _ in range(10)} + self.assertEqual(results, {"blue"}) + + @patch("sygra.core.graph.nodes.weighted_sampler_node.utils") + def test_samples_from_datasource(self, mock_utils): + mock_utils.backend_factory = MagicMock() + mock_utils.fetch_next_record.return_value = "persona_val" + datasrc = {"type": "hf", "repo_id": "some/repo", "split": "train"} + config = _make_config(attributes={"role": {"values": {"column": "persona", "source": datasrc}}}) + node = WeightedSamplerNode("sampler", config) + result = node._weighted_sampler(config["attributes"], {}) + mock_utils.fetch_next_record.assert_called_once_with(datasrc, "persona") + self.assertEqual(result["role"], "persona_val") + + +class TestWeightedSamplerNodeValidate(unittest.TestCase): + @patch("sygra.core.graph.nodes.weighted_sampler_node.utils") + def test_validate_raises_when_attributes_key_missing(self, mock_utils): + mock_utils.validate_required_keys.side_effect = ValueError("Missing key: attributes") + mock_utils.backend_factory = MagicMock() + with self.assertRaises((ValueError, Exception)): + WeightedSamplerNode("sampler", {"node_type": "weighted_sampler"}) + + +class TestWeightedSamplerNodeExecWrapper(unittest.IsolatedAsyncioTestCase): + @patch("sygra.core.graph.nodes.weighted_sampler_node.utils") + async def test_exec_wrapper_merges_sampled_values_into_state(self, mock_utils): + mock_utils.backend_factory = MagicMock() + config = _make_config(attributes={"color": {"values": ["red"]}}) + node = WeightedSamplerNode("sampler", config) + state = {"existing": "val"} + with patch.object(node, "_record_execution_metadata"): + result = await node._exec_wrapper(state) + self.assertEqual(result["existing"], "val") + self.assertEqual(result["color"], "red") + + @patch("sygra.core.graph.nodes.weighted_sampler_node.utils") + async def test_exec_wrapper_records_failure_on_exception(self, mock_utils): + mock_utils.backend_factory = MagicMock() + config = _make_config() + node = WeightedSamplerNode("sampler", config) + with patch.object(node, "_weighted_sampler", side_effect=RuntimeError("err")): + with patch.object(node, "_record_execution_metadata") as mock_record: + with self.assertRaises(RuntimeError): + await node._exec_wrapper({}) + self.assertFalse(mock_record.call_args[0][1]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/graph/test_backend_factory.py b/tests/core/graph/test_backend_factory.py new file mode 100644 index 00000000..c9f30aaf --- /dev/null +++ b/tests/core/graph/test_backend_factory.py @@ -0,0 +1,63 @@ +import sys +import unittest +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent.parent.parent)) + +from sygra.core.graph.backend_factory import BackendFactory + + +class TestBackendFactoryIsABC(unittest.TestCase): + def test_cannot_instantiate_directly(self): + with self.assertRaises(TypeError): + BackendFactory() + + def test_has_all_abstract_methods(self): + abstract_methods = { + "create_lambda_runnable", + "create_llm_runnable", + "create_multi_llm_runnable", + "create_weighted_sampler_runnable", + "create_connector_runnable", + "build_workflow", + "get_message_content", + "convert_to_chat_format", + "get_test_message", + } + self.assertEqual(BackendFactory.__abstractmethods__, abstract_methods) + + def test_concrete_subclass_can_be_instantiated(self): + class ConcreteFactory(BackendFactory): + def create_lambda_runnable(self, exec_wrapper, async_func=True): + return None + + def create_llm_runnable(self, exec_wrapper): + return None + + def create_multi_llm_runnable(self, llm_dict, post_process): + return None + + def create_weighted_sampler_runnable(self, exec_wrapper): + return None + + def create_connector_runnable(self): + return None + + def build_workflow(self, graph_config): + return None + + def get_message_content(self, msg): + return "" + + def convert_to_chat_format(self, msgs): + return [] + + def get_test_message(self, model_config): + return None + + factory = ConcreteFactory() + self.assertIsInstance(factory, BackendFactory) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/graph/test_sygra_message.py b/tests/core/graph/test_sygra_message.py new file mode 100644 index 00000000..b321e144 --- /dev/null +++ b/tests/core/graph/test_sygra_message.py @@ -0,0 +1,30 @@ +import sys +import unittest +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent.parent.parent)) + +from sygra.core.graph.sygra_message import SygraMessage + + +class TestSygraMessage(unittest.TestCase): + def test_init_stores_message(self): + msg = SygraMessage("hello") + self.assertEqual(msg._message, "hello") + + def test_message_property_returns_stored_message(self): + msg = SygraMessage({"role": "user", "content": "hi"}) + self.assertEqual(msg.message, {"role": "user", "content": "hi"}) + + def test_stores_none_message(self): + msg = SygraMessage(None) + self.assertIsNone(msg.message) + + def test_stores_list_message(self): + messages = [1, 2, 3] + msg = SygraMessage(messages) + self.assertEqual(msg.message, [1, 2, 3]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/graph/test_sygra_state.py b/tests/core/graph/test_sygra_state.py new file mode 100644 index 00000000..ef65d304 --- /dev/null +++ b/tests/core/graph/test_sygra_state.py @@ -0,0 +1,31 @@ +import sys +import unittest +from pathlib import Path +from typing import TypedDict + +sys.path.append(str(Path(__file__).parent.parent.parent.parent)) + +from sygra.core.graph.sygra_state import SygraState + + +class TestSygraState(unittest.TestCase): + def test_is_typed_dict_subclass(self): + self.assertTrue(issubclass(SygraState, dict)) + + def test_can_be_instantiated_empty(self): + state = SygraState() + self.assertEqual(state, {}) + + def test_can_hold_arbitrary_keys(self): + state = SygraState(messages=["hello"], count=1) + self.assertEqual(state["messages"], ["hello"]) + self.assertEqual(state["count"], 1) + + def test_has_total_false_so_no_required_fields(self): + # total=False means all fields are optional; instantiation with no args should succeed + state = SygraState() + self.assertIsInstance(state, dict) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/data_mapper/__init__.py b/tests/data_mapper/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/data_mapper/test_mapper.py b/tests/data_mapper/test_mapper.py new file mode 100644 index 00000000..ead7e250 --- /dev/null +++ b/tests/data_mapper/test_mapper.py @@ -0,0 +1,150 @@ +import sys +import unittest +from pathlib import Path +from unittest.mock import MagicMock, patch + +sys.path.append(str(Path(__file__).parent.parent.parent)) + +from sygra.data_mapper.mapper import DataMapper + + +class TestDataMapperInit(unittest.TestCase): + def test_init_raises_when_type_missing(self): + with self.assertRaises(ValueError): + DataMapper(config={}) + + def test_init_raises_when_type_is_empty(self): + with self.assertRaises(ValueError): + DataMapper(config={"type": ""}) + + def test_init_creates_sft_pipeline(self): + mapper = DataMapper(config={"type": "sft"}) + self.assertEqual(mapper.transform_type, "sft") + self.assertIsNotNone(mapper.pipeline) + self.assertGreater(len(mapper.pipeline), 0) + + def test_init_creates_dpo_pipeline(self): + mapper = DataMapper(config={"type": "dpo"}) + self.assertEqual(mapper.transform_type, "dpo") + self.assertIsNotNone(mapper.pipeline) + self.assertGreater(len(mapper.pipeline), 0) + + +class TestDataMapperOrderPipeline(unittest.TestCase): + def test_order_pipeline_with_active_false_returns_pipeline_unchanged(self): + mapper = DataMapper(config={"type": "sft"}) + result = mapper.order_pipeline(active=False) + self.assertEqual(result, mapper.pipeline) + + def test_order_pipeline_with_active_true_returns_ordered_list(self): + mapper = DataMapper(config={"type": "sft"}) + result = mapper.order_pipeline(active=True) + self.assertIsInstance(result, list) + self.assertGreater(len(result), 0) + + +class TestDataMapperMapAllItems(unittest.TestCase): + def test_map_all_items_calls_map_single_item_for_each(self): + mapper = DataMapper(config={"type": "sft"}) + items = [{"id": "1"}, {"id": "2"}] + + with patch.object(mapper, "map_single_item", return_value=[{"mapped": True}]) as mock_map: + result = mapper.map_all_items(items) + + self.assertEqual(mock_map.call_count, 2) + self.assertEqual(len(result), 2) + + def test_map_all_items_converts_non_list_to_list(self): + mapper = DataMapper(config={"type": "sft"}) + items = iter([{"id": "1"}]) + + with patch.object(mapper, "map_single_item", return_value=[]) as mock_map: + mapper.map_all_items(items) + + mock_map.assert_called_once() + + def test_map_all_items_continues_on_item_error(self): + mapper = DataMapper(config={"type": "sft"}) + items = [{"id": "1"}, {"id": "2"}] + + call_count = 0 + + def side_effect(item): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise ValueError("item error") + return [{"mapped": True}] + + with patch.object(mapper, "map_single_item", side_effect=side_effect): + result = mapper.map_all_items(items) + + self.assertEqual(len(result), 1) + + def test_map_single_item_returns_empty_list_on_error(self): + mapper = DataMapper(config={"type": "sft"}) + # Item missing required fields will cause pipeline to fail gracefully + result = mapper.map_single_item({"id": "bad_item_no_conversation"}) + self.assertIsInstance(result, list) + + +class TestDataMapperBuildRowsAndValidate(unittest.TestCase): + def test_build_rows_and_validate_builds_rows_from_context(self): + context = { + "conversation_id": "conv-1", + "root_message_id": "msg-1", + "messages": [ + { + "message_id": "msg-1", + "parent_id": None, + "level": 0, + "role": "user", + "content": "Hello", + "instruction_tags": [], + "quality": {}, + "length": {}, + "data_characteristics": {}, + } + ], + } + rows = DataMapper.build_rows_and_validate(context) + self.assertEqual(len(rows), 1) + self.assertEqual(rows[0]["conversation_id"], "conv-1") + self.assertEqual(rows[0]["role"], "user") + self.assertEqual(rows[0]["content"], "Hello") + + def test_build_rows_and_validate_returns_multiple_rows_for_multiple_messages(self): + context = { + "conversation_id": "conv-2", + "root_message_id": "msg-1", + "messages": [ + { + "message_id": "msg-1", + "parent_id": None, + "level": 0, + "role": "user", + "content": "Hi", + "instruction_tags": [], + "quality": {}, + "length": {}, + "data_characteristics": {}, + }, + { + "message_id": "msg-2", + "parent_id": "msg-1", + "level": 1, + "role": "assistant", + "content": "Hello!", + "instruction_tags": [], + "quality": {}, + "length": {}, + "data_characteristics": {}, + }, + ], + } + rows = DataMapper.build_rows_and_validate(context) + self.assertEqual(len(rows), 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils/test_decorators.py b/tests/utils/test_decorators.py new file mode 100644 index 00000000..2e0874de --- /dev/null +++ b/tests/utils/test_decorators.py @@ -0,0 +1,95 @@ +import sys +import unittest +import warnings +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent.parent)) + +from sygra.utils.decorators import future_deprecation + + +class TestFutureDeprecationDecorator(unittest.TestCase): + def test_decorated_function_still_works(self): + @future_deprecation() + def my_func(x): + return x * 2 + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + result = my_func(5) + + self.assertEqual(result, 10) + + def test_emits_deprecation_warning(self): + @future_deprecation() + def deprecated_func(): + return "ok" + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + deprecated_func() + + self.assertEqual(len(w), 1) + self.assertTrue(issubclass(w[0].category, DeprecationWarning)) + + def test_warning_message_contains_function_name(self): + @future_deprecation() + def my_deprecated_func(): + pass + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + my_deprecated_func() + + self.assertIn("my_deprecated_func", str(w[0].message)) + + def test_warning_message_contains_reason_when_provided(self): + @future_deprecation(reason="use new_func instead") + def old_func(): + pass + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + old_func() + + self.assertIn("use new_func instead", str(w[0].message)) + + def test_warning_message_without_reason(self): + @future_deprecation() + def no_reason_func(): + pass + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + no_reason_func() + + self.assertIn("scheduled for future deprecation", str(w[0].message)) + + def test_preserves_function_name_via_functools_wraps(self): + @future_deprecation(reason="test") + def original_name(): + pass + + self.assertEqual(original_name.__name__, "original_name") + + def test_sets_future_deprecation_attribute(self): + @future_deprecation() + def tagged_func(): + pass + + self.assertTrue(tagged_func._future_deprecation) + + def test_decorated_function_passes_args_and_kwargs(self): + @future_deprecation() + def add(a, b, c=0): + return a + b + c + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + result = add(1, 2, c=3) + + self.assertEqual(result, 6) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils/test_graph_utils.py b/tests/utils/test_graph_utils.py new file mode 100644 index 00000000..73e32003 --- /dev/null +++ b/tests/utils/test_graph_utils.py @@ -0,0 +1,96 @@ +import sys +import unittest +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +sys.path.append(str(Path(__file__).parent.parent.parent)) + +from sygra.utils.graph_utils import convert_graph_output_to_records, execute_graph + + +class TestConvertGraphOutputToRecords(unittest.TestCase): + def test_returns_all_records_when_no_generator(self): + results = [{"a": 1}, {"b": 2}] + output = convert_graph_output_to_records(results) + self.assertEqual(output, results) + + def test_applies_output_record_generator_to_each_result(self): + results = [{"raw": "data1"}, {"raw": "data2"}] + generator = lambda r: {**r, "transformed": True} + output = convert_graph_output_to_records(results, output_record_generator=generator) + self.assertTrue(all(r["transformed"] for r in output)) + + def test_skips_none_results_from_generator(self): + results = [{"a": 1}, {"b": 2}, {"c": 3}] + generator = lambda r: None if r.get("b") else r + output = convert_graph_output_to_records(results, output_record_generator=generator) + self.assertEqual(len(output), 2) + + def test_skips_results_where_generator_raises_exception(self): + results = [{"a": 1}, {"bad": "data"}, {"c": 3}] + + def generator(r): + if "bad" in r: + raise ValueError("bad data") + return r + + output = convert_graph_output_to_records(results, output_record_generator=generator) + self.assertEqual(len(output), 2) + self.assertNotIn({"bad": "data"}, output) + + def test_returns_empty_list_for_empty_input(self): + output = convert_graph_output_to_records([]) + self.assertEqual(output, []) + + def test_none_results_skipped_without_generator(self): + results = [{"a": 1}, None, {"b": 2}] + output = convert_graph_output_to_records(results) + # None items are skipped + self.assertEqual(len(output), 2) + + +class TestExecuteGraph(unittest.IsolatedAsyncioTestCase): + async def test_execute_graph_calls_graph_ainvoke(self): + record = {"id": "1", "input": "hello"} + mock_graph = AsyncMock() + mock_graph.ainvoke.return_value = {"output": "world"} + + result = await execute_graph(record, mock_graph) + + mock_graph.ainvoke.assert_called_once() + self.assertEqual(result, {"output": "world"}) + + async def test_execute_graph_applies_input_record_generator(self): + record = {"id": "1"} + mock_graph = AsyncMock() + mock_graph.ainvoke.return_value = {"result": "ok"} + generator = lambda r: {**r, "extra": "added"} + + await execute_graph(record, mock_graph, input_record_generator=generator) + + call_args = mock_graph.ainvoke.call_args[0][0] + self.assertEqual(call_args["extra"], "added") + + async def test_execute_graph_returns_error_dict_on_exception(self): + record = {"id": "1"} + mock_graph = AsyncMock() + mock_graph.ainvoke.side_effect = RuntimeError("graph exploded") + + result = await execute_graph(record, mock_graph) + + self.assertIn("execution_error", result) + self.assertTrue(result["execution_error"]) + + async def test_execute_graph_passes_debug_flag(self): + record = {"id": "1"} + mock_graph = AsyncMock() + mock_graph.ainvoke.return_value = {} + + await execute_graph(record, mock_graph, debug=True) + + call_kwargs = mock_graph.ainvoke.call_args[1] + self.assertTrue(call_kwargs.get("debug")) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils/test_model_utils.py b/tests/utils/test_model_utils.py new file mode 100644 index 00000000..de221ab8 --- /dev/null +++ b/tests/utils/test_model_utils.py @@ -0,0 +1,177 @@ +import sys +import unittest +from pathlib import Path +from unittest.mock import MagicMock + +sys.path.append(str(Path(__file__).parent.parent.parent)) + +from sygra.utils.model_utils import ( + InputType, + OutputType, + detect_input_type, + get_model_capabilities, + get_output_type, + has_audio_input, + has_image_input, + is_gpt4o_audio_model, + should_route_to_image, + should_route_to_speech, + should_route_to_transcription, + validate_input_output_compatibility, +) + + +def _make_chat_prompt(content): + """Helper: create a mock ChatPromptValue with one message.""" + msg = MagicMock() + msg.content = content + prompt = MagicMock() + prompt.messages = [msg] + return prompt + + +class TestHasAudioInput(unittest.TestCase): + def test_detects_audio_data_url_string(self): + prompt = _make_chat_prompt("data:audio/wav;base64,UklGR...") + self.assertTrue(has_audio_input(prompt)) + + def test_detects_audio_url_in_list_content(self): + prompt = _make_chat_prompt([{"type": "audio_url", "audio_url": {"url": "..."}}]) + self.assertTrue(has_audio_input(prompt)) + + def test_returns_false_for_plain_text(self): + prompt = _make_chat_prompt("Hello, world!") + self.assertFalse(has_audio_input(prompt)) + + def test_returns_false_for_image_content(self): + prompt = _make_chat_prompt("data:image/png;base64,abc") + self.assertFalse(has_audio_input(prompt)) + + +class TestHasImageInput(unittest.TestCase): + def test_detects_image_data_url_string(self): + prompt = _make_chat_prompt("data:image/png;base64,iVBOR...") + self.assertTrue(has_image_input(prompt)) + + def test_detects_image_url_in_list_content(self): + prompt = _make_chat_prompt([{"type": "image_url", "image_url": {"url": "..."}}]) + self.assertTrue(has_image_input(prompt)) + + def test_returns_false_for_plain_text(self): + prompt = _make_chat_prompt("Hello!") + self.assertFalse(has_image_input(prompt)) + + def test_returns_false_for_audio_content(self): + prompt = _make_chat_prompt("data:audio/wav;base64,abc") + self.assertFalse(has_image_input(prompt)) + + +class TestDetectInputType(unittest.TestCase): + def test_returns_audio_when_audio_present(self): + prompt = _make_chat_prompt("data:audio/wav;base64,abc") + self.assertEqual(detect_input_type(prompt), InputType.AUDIO) + + def test_returns_image_when_image_present(self): + prompt = _make_chat_prompt("data:image/png;base64,abc") + self.assertEqual(detect_input_type(prompt), InputType.IMAGE) + + def test_returns_text_for_plain_text(self): + prompt = _make_chat_prompt("Hello") + self.assertEqual(detect_input_type(prompt), InputType.TEXT) + + def test_audio_takes_priority_over_image(self): + msg1 = MagicMock() + msg1.content = "data:audio/wav;base64,abc" + msg2 = MagicMock() + msg2.content = "data:image/png;base64,abc" + prompt = MagicMock() + prompt.messages = [msg1, msg2] + self.assertEqual(detect_input_type(prompt), InputType.AUDIO) + + +class TestGetOutputType(unittest.TestCase): + def test_returns_text_by_default(self): + self.assertEqual(get_output_type({}), OutputType.TEXT) + + def test_returns_audio_when_configured(self): + self.assertEqual(get_output_type({"output_type": "audio"}), OutputType.AUDIO) + + def test_returns_image_when_configured(self): + self.assertEqual(get_output_type({"output_type": "image"}), OutputType.IMAGE) + + +class TestIsGpt4oAudioModel(unittest.TestCase): + def test_returns_true_for_gpt4o_audio_model(self): + self.assertTrue(is_gpt4o_audio_model({"model": "gpt-4o-audio-preview"})) + + def test_returns_false_for_whisper(self): + self.assertFalse(is_gpt4o_audio_model({"model": "whisper-1"})) + + def test_returns_false_for_regular_gpt4(self): + self.assertFalse(is_gpt4o_audio_model({"model": "gpt-4"})) + + def test_case_insensitive(self): + self.assertTrue(is_gpt4o_audio_model({"model": "GPT-4O-AUDIO-PREVIEW"})) + + +class TestShouldRouteToSpeech(unittest.TestCase): + def test_returns_true_when_output_type_is_audio(self): + self.assertTrue(should_route_to_speech({"output_type": "audio"})) + + def test_returns_false_for_text_output(self): + self.assertFalse(should_route_to_speech({"output_type": "text"})) + + def test_returns_false_when_output_type_not_set(self): + self.assertFalse(should_route_to_speech({})) + + +class TestShouldRouteToImage(unittest.TestCase): + def test_returns_true_when_output_type_is_image(self): + self.assertTrue(should_route_to_image({"output_type": "image"})) + + def test_returns_false_for_text_output(self): + self.assertFalse(should_route_to_image({"output_type": "text"})) + + +class TestGetModelCapabilities(unittest.TestCase): + def test_all_models_support_text_input(self): + caps = get_model_capabilities({"model": "gpt-3.5"}) + self.assertIn(InputType.TEXT, caps["input_types"]) + + def test_gpt4_supports_image_input(self): + caps = get_model_capabilities({"model": "gpt-4"}) + self.assertIn(InputType.IMAGE, caps["input_types"]) + self.assertTrue(caps["is_multimodal"]) + + def test_gpt4o_audio_supports_audio_input(self): + caps = get_model_capabilities({"model": "gpt-4o-audio-preview"}) + self.assertTrue(caps["is_audio_chat"]) + self.assertIn(InputType.AUDIO, caps["input_types"]) + + def test_audio_output_type_reflected(self): + caps = get_model_capabilities({"model": "tts-1", "output_type": "audio"}) + self.assertEqual(caps["output_type"], "audio") + + +class TestValidateInputOutputCompatibility(unittest.TestCase): + def test_valid_for_text_input_text_output(self): + prompt = _make_chat_prompt("Hello") + valid, error = validate_input_output_compatibility(prompt, {"model": "gpt-4"}) + self.assertTrue(valid) + self.assertIsNone(error) + + def test_returns_error_for_image_input_with_non_vision_model(self): + prompt = _make_chat_prompt("data:image/png;base64,abc") + valid, error = validate_input_output_compatibility(prompt, {"model": "gpt-3.5-turbo"}) + self.assertFalse(valid) + self.assertIsNotNone(error) + + def test_valid_for_image_input_with_gpt4_vision(self): + prompt = _make_chat_prompt("data:image/png;base64,abc") + valid, error = validate_input_output_compatibility(prompt, {"model": "gpt-4-vision"}) + self.assertTrue(valid) + self.assertIsNone(error) + + +if __name__ == "__main__": + unittest.main()