From 3de9092442cdf1826ad1490ac836d549c74fc567 Mon Sep 17 00:00:00 2001 From: jli10004 Date: Fri, 20 Mar 2026 03:30:40 +0000 Subject: [PATCH 1/2] [Fix] fix make_coord to require tuple arg and add crd2idx tests Change make_coord(*coord) to make_coord(coord) so that passing multiple args like make_coord(r, c) raises a clear TypeError instead of silently creating a wrong-rank coordinate. Update the one caller in kernels/layout_utils.py and add comprehensive pytest-compatible tests covering col-major, row-major, 1D, 3D layouts and error cases. --- kernels/layout_utils.py | 2 +- python/flydsl/expr/primitive.py | 4 +- tests/unit/Layout/test_crd2idx.py | 174 ++++++++++++++++++++++++++++++ 3 files changed, 178 insertions(+), 2 deletions(-) create mode 100644 tests/unit/Layout/test_crd2idx.py diff --git a/kernels/layout_utils.py b/kernels/layout_utils.py index 2a50cac1..3f78e4cf 100644 --- a/kernels/layout_utils.py +++ b/kernels/layout_utils.py @@ -109,7 +109,7 @@ def crd2idx(crd, layout): if isinstance(cv, ir.Value) and isinstance(cv.type, ir.IndexType): cv = arith.index_cast(T.i32, cv) crd_i32.append(cv) - coord_val = fx.make_coord(*crd_i32) + coord_val = fx.make_coord(tuple(crd_i32)) result = fx.crd2idx(coord_val, layout) scalar = fx.get_scalar(result) if isinstance(scalar, ir.Value) and not isinstance(scalar.type, ir.IndexType): diff --git a/python/flydsl/expr/primitive.py b/python/flydsl/expr/primitive.py index 02a60025..1c47d7ae 100644 --- a/python/flydsl/expr/primitive.py +++ b/python/flydsl/expr/primitive.py @@ -191,7 +191,9 @@ def make_stride(*stride, loc=None, ip=None): @traced_op -def make_coord(*coord, loc=None, ip=None): +def make_coord(coord, loc=None, ip=None): + if not isinstance(coord, tuple): + raise TypeError(f"make_coord expects a tuple, e.g. make_coord((r, c)), got {type(coord).__name__}") IntTupleTy, dyncElems = fly.infer_int_tuple_type(coord) return fly.make_coord(IntTupleTy, dyncElems, loc=loc, ip=ip) diff --git a/tests/unit/Layout/test_crd2idx.py b/tests/unit/Layout/test_crd2idx.py new file mode 100644 index 00000000..9d0e7b95 --- /dev/null +++ b/tests/unit/Layout/test_crd2idx.py @@ -0,0 +1,174 @@ +import os +import subprocess +import sys + +import pytest + +import flydsl.compiler as flyc +import flydsl.expr as fx + + +# --- jit kernels (not collected by pytest due to _run_ prefix) --- + + +@flyc.jit +def _run_crd2idx_col_major(): + """(4,8) col-major: idx = r + 4*c""" + layout = fx.make_layout((4, 8), (1, 4)) + for r in range_constexpr(4): + for c in range_constexpr(8): + idx = fx.crd2idx(fx.make_coord((r, c)), layout) + fx.printf("{}", idx) + + +@flyc.jit +def _run_crd2idx_row_major(): + """(4,8) row-major: idx = r*8 + c""" + layout = fx.make_layout((4, 8), (8, 1)) + for r in range_constexpr(4): + for c in range_constexpr(8): + idx = fx.crd2idx(fx.make_coord((r, c)), layout) + fx.printf("{}", idx) + + +@flyc.jit +def _run_crd2idx_1d(): + """1D layout: shape=(8,), stride=(2,)""" + layout = fx.make_layout((8,), (2,)) + for c in range_constexpr(8): + idx = fx.crd2idx(fx.make_coord((c,)), layout) + fx.printf("{}", idx) + + +@flyc.jit +def _run_crd2idx_3d(): + """3D layout: (2,3,4) row-major strides (12,4,1)""" + layout = fx.make_layout((2, 3, 4), (12, 4, 1)) + for i in range_constexpr(2): + for j in range_constexpr(3): + for k in range_constexpr(4): + idx = fx.crd2idx(fx.make_coord((i, j, k)), layout) + fx.printf("{}", idx) + + +# --- subprocess helper to capture C-level printf --- + +JIT_KERNELS = { + "col_major": _run_crd2idx_col_major, + "row_major": _run_crd2idx_row_major, + "1d": _run_crd2idx_1d, + "3d": _run_crd2idx_3d, +} + +EXPECTED = { + "col_major": [r + 4 * c for r in range(4) for c in range(8)], + "row_major": [r * 8 + c for r in range(4) for c in range(8)], + "1d": [c * 2 for c in range(8)], + "3d": [i * 12 + j * 4 + k for i in range(2) for j in range(3) for k in range(4)], +} + + +def _run_jit_and_capture(test_name): + """Run a jit kernel in a subprocess and return parsed int output.""" + env = os.environ.copy() + env["FLYDSL_RUNTIME_ENABLE_CACHE"] = "0" + result = subprocess.run( + [sys.executable, __file__, "--run", test_name], + capture_output=True, + text=True, + env=env, + ) + assert result.returncode == 0, f"subprocess failed:\n{result.stderr}" + lines = [l for l in result.stdout.strip().split("\n") if l.strip()] + return [int(x) for x in lines] + + +def _run_error_test(snippet_name): + """Run an error-test snippet in a subprocess, return (returncode, stderr).""" + env = os.environ.copy() + env["FLYDSL_RUNTIME_ENABLE_CACHE"] = "0" + result = subprocess.run( + [sys.executable, __file__, "--error", snippet_name], + capture_output=True, + text=True, + env=env, + ) + return result.returncode, result.stdout.strip(), result.stderr.strip() + + +# --- pytest test cases: correctness --- + + +def test_crd2idx_col_major(): + """(4,8) col-major layout: idx = r + 4*c""" + actual = _run_jit_and_capture("col_major") + assert actual == EXPECTED["col_major"] + + +def test_crd2idx_row_major(): + """(4,8) row-major layout: idx = r*8 + c""" + actual = _run_jit_and_capture("row_major") + assert actual == EXPECTED["row_major"] + + +def test_crd2idx_1d(): + """1D layout: shape=(8,), stride=(2,)""" + actual = _run_jit_and_capture("1d") + assert actual == EXPECTED["1d"] + + +def test_crd2idx_3d(): + """3D layout: (2,3,4) row-major strides (12,4,1)""" + actual = _run_jit_and_capture("3d") + assert actual == EXPECTED["3d"] + + +# --- pytest test cases: error handling (via subprocess to avoid conftest path issues) --- + + +def test_make_coord_rejects_varargs(): + """make_coord(r, c) must raise TypeError.""" + rc, stdout, stderr = _run_error_test("varargs") + assert rc != 0, f"Expected failure but succeeded.\nstdout: {stdout}" + assert "make_coord expects a tuple" in stderr, f"Wrong error message:\n{stderr}" + + +def test_make_coord_rejects_int(): + """make_coord(42) must raise TypeError.""" + rc, stdout, stderr = _run_error_test("int_arg") + assert rc != 0, f"Expected failure but succeeded.\nstdout: {stdout}" + assert "make_coord expects a tuple" in stderr, f"Wrong error message:\n{stderr}" + + +# --- subprocess entry point --- + +if __name__ == "__main__": + if len(sys.argv) >= 3 and sys.argv[1] == "--run": + JIT_KERNELS[sys.argv[2]]() + + elif len(sys.argv) >= 3 and sys.argv[1] == "--error": + # These are intentionally broken snippets that should raise TypeError. + # Import fresh (no conftest interference). + import flydsl.compiler as flyc_fresh + import flydsl.expr as fx_fresh + + if sys.argv[2] == "varargs": + + @flyc_fresh.jit + def _bad(): + layout = fx_fresh.make_layout((4, 8), (1, 4)) + idx = fx_fresh.crd2idx(fx_fresh.make_coord(0, 1), layout) + + _bad() + + elif sys.argv[2] == "int_arg": + + @flyc_fresh.jit + def _bad(): + layout = fx_fresh.make_layout((4,), (1,)) + idx = fx_fresh.crd2idx(fx_fresh.make_coord(0), layout) + + _bad() + + else: + pytest.main([__file__, "-v"]) From 01e66c3e8e19a4b162fe0ddbcc710790e4216742 Mon Sep 17 00:00:00 2001 From: jli10004 Date: Fri, 20 Mar 2026 06:29:25 +0000 Subject: [PATCH 2/2] keep make_coord *args and support both tuple and varargs forms Revert make_coord to *args signature (consistent with make_shape and make_stride), but auto-unwrap a single tuple argument so both make_coord(r, c) and make_coord((r, c)) produce the same result. Update tests to verify both calling conventions. Co-Authored-By: Claude Opus 4.6 --- kernels/layout_utils.py | 2 +- python/flydsl/expr/primitive.py | 7 ++- tests/unit/Layout/test_crd2idx.py | 94 +++++++++++-------------------- 3 files changed, 37 insertions(+), 66 deletions(-) diff --git a/kernels/layout_utils.py b/kernels/layout_utils.py index 3f78e4cf..2a50cac1 100644 --- a/kernels/layout_utils.py +++ b/kernels/layout_utils.py @@ -109,7 +109,7 @@ def crd2idx(crd, layout): if isinstance(cv, ir.Value) and isinstance(cv.type, ir.IndexType): cv = arith.index_cast(T.i32, cv) crd_i32.append(cv) - coord_val = fx.make_coord(tuple(crd_i32)) + coord_val = fx.make_coord(*crd_i32) result = fx.crd2idx(coord_val, layout) scalar = fx.get_scalar(result) if isinstance(scalar, ir.Value) and not isinstance(scalar.type, ir.IndexType): diff --git a/python/flydsl/expr/primitive.py b/python/flydsl/expr/primitive.py index 1c47d7ae..20673833 100644 --- a/python/flydsl/expr/primitive.py +++ b/python/flydsl/expr/primitive.py @@ -191,9 +191,10 @@ def make_stride(*stride, loc=None, ip=None): @traced_op -def make_coord(coord, loc=None, ip=None): - if not isinstance(coord, tuple): - raise TypeError(f"make_coord expects a tuple, e.g. make_coord((r, c)), got {type(coord).__name__}") +def make_coord(*coord, loc=None, ip=None): + # Support both make_coord(r, c) and make_coord((r, c)) + if len(coord) == 1 and isinstance(coord[0], tuple): + coord = coord[0] IntTupleTy, dyncElems = fly.infer_int_tuple_type(coord) return fly.make_coord(IntTupleTy, dyncElems, loc=loc, ip=ip) diff --git a/tests/unit/Layout/test_crd2idx.py b/tests/unit/Layout/test_crd2idx.py index 9d0e7b95..23edbef3 100644 --- a/tests/unit/Layout/test_crd2idx.py +++ b/tests/unit/Layout/test_crd2idx.py @@ -12,8 +12,8 @@ @flyc.jit -def _run_crd2idx_col_major(): - """(4,8) col-major: idx = r + 4*c""" +def _run_crd2idx_col_major_tuple(): + """(4,8) col-major with make_coord((r, c)) tuple form""" layout = fx.make_layout((4, 8), (1, 4)) for r in range_constexpr(4): for c in range_constexpr(8): @@ -21,6 +21,16 @@ def _run_crd2idx_col_major(): fx.printf("{}", idx) +@flyc.jit +def _run_crd2idx_col_major_varargs(): + """(4,8) col-major with make_coord(r, c) varargs form""" + layout = fx.make_layout((4, 8), (1, 4)) + for r in range_constexpr(4): + for c in range_constexpr(8): + idx = fx.crd2idx(fx.make_coord(r, c), layout) + fx.printf("{}", idx) + + @flyc.jit def _run_crd2idx_row_major(): """(4,8) row-major: idx = r*8 + c""" @@ -54,14 +64,16 @@ def _run_crd2idx_3d(): # --- subprocess helper to capture C-level printf --- JIT_KERNELS = { - "col_major": _run_crd2idx_col_major, + "col_major_tuple": _run_crd2idx_col_major_tuple, + "col_major_varargs": _run_crd2idx_col_major_varargs, "row_major": _run_crd2idx_row_major, "1d": _run_crd2idx_1d, "3d": _run_crd2idx_3d, } EXPECTED = { - "col_major": [r + 4 * c for r in range(4) for c in range(8)], + "col_major_tuple": [r + 4 * c for r in range(4) for c in range(8)], + "col_major_varargs": [r + 4 * c for r in range(4) for c in range(8)], "row_major": [r * 8 + c for r in range(4) for c in range(8)], "1d": [c * 2 for c in range(8)], "3d": [i * 12 + j * 4 + k for i in range(2) for j in range(3) for k in range(4)], @@ -83,26 +95,26 @@ def _run_jit_and_capture(test_name): return [int(x) for x in lines] -def _run_error_test(snippet_name): - """Run an error-test snippet in a subprocess, return (returncode, stderr).""" - env = os.environ.copy() - env["FLYDSL_RUNTIME_ENABLE_CACHE"] = "0" - result = subprocess.run( - [sys.executable, __file__, "--error", snippet_name], - capture_output=True, - text=True, - env=env, - ) - return result.returncode, result.stdout.strip(), result.stderr.strip() +# --- pytest test cases --- + +def test_crd2idx_col_major_tuple(): + """make_coord((r, c)) tuple form works""" + actual = _run_jit_and_capture("col_major_tuple") + assert actual == EXPECTED["col_major_tuple"] -# --- pytest test cases: correctness --- +def test_crd2idx_col_major_varargs(): + """make_coord(r, c) varargs form works""" + actual = _run_jit_and_capture("col_major_varargs") + assert actual == EXPECTED["col_major_varargs"] -def test_crd2idx_col_major(): - """(4,8) col-major layout: idx = r + 4*c""" - actual = _run_jit_and_capture("col_major") - assert actual == EXPECTED["col_major"] + +def test_crd2idx_col_major_both_forms_equal(): + """make_coord((r, c)) and make_coord(r, c) produce identical results""" + tuple_result = _run_jit_and_capture("col_major_tuple") + varargs_result = _run_jit_and_capture("col_major_varargs") + assert tuple_result == varargs_result def test_crd2idx_row_major(): @@ -123,52 +135,10 @@ def test_crd2idx_3d(): assert actual == EXPECTED["3d"] -# --- pytest test cases: error handling (via subprocess to avoid conftest path issues) --- - - -def test_make_coord_rejects_varargs(): - """make_coord(r, c) must raise TypeError.""" - rc, stdout, stderr = _run_error_test("varargs") - assert rc != 0, f"Expected failure but succeeded.\nstdout: {stdout}" - assert "make_coord expects a tuple" in stderr, f"Wrong error message:\n{stderr}" - - -def test_make_coord_rejects_int(): - """make_coord(42) must raise TypeError.""" - rc, stdout, stderr = _run_error_test("int_arg") - assert rc != 0, f"Expected failure but succeeded.\nstdout: {stdout}" - assert "make_coord expects a tuple" in stderr, f"Wrong error message:\n{stderr}" - - # --- subprocess entry point --- if __name__ == "__main__": if len(sys.argv) >= 3 and sys.argv[1] == "--run": JIT_KERNELS[sys.argv[2]]() - - elif len(sys.argv) >= 3 and sys.argv[1] == "--error": - # These are intentionally broken snippets that should raise TypeError. - # Import fresh (no conftest interference). - import flydsl.compiler as flyc_fresh - import flydsl.expr as fx_fresh - - if sys.argv[2] == "varargs": - - @flyc_fresh.jit - def _bad(): - layout = fx_fresh.make_layout((4, 8), (1, 4)) - idx = fx_fresh.crd2idx(fx_fresh.make_coord(0, 1), layout) - - _bad() - - elif sys.argv[2] == "int_arg": - - @flyc_fresh.jit - def _bad(): - layout = fx_fresh.make_layout((4,), (1,)) - idx = fx_fresh.crd2idx(fx_fresh.make_coord(0), layout) - - _bad() - else: pytest.main([__file__, "-v"])