diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a160922..a428d21 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -90,18 +90,24 @@ jobs: - run: mix credo --strict test: - name: Tests (OTP ${{ matrix.otp }} / Elixir ${{ matrix.elixir }}) - runs-on: ubuntu-latest + name: Tests (${{ matrix.os }} / OTP ${{ matrix.otp }} / Elixir ${{ matrix.elixir }}) + runs-on: ${{ matrix.os }} strategy: fail-fast: false matrix: include: - - elixir: "1.17.3" + - os: ubuntu-latest + elixir: "1.17.3" otp: "27.2" - - elixir: "1.18.3" + - os: ubuntu-latest + elixir: "1.18.3" otp: "27.2" - - elixir: "1.20.0-rc.1" + - os: ubuntu-latest + elixir: "1.20.0-rc.1" otp: "28.3.3" + - os: macos-latest + elixir: "1.17.3" + otp: "27.2" steps: - uses: actions/checkout@v4 diff --git a/Makefile b/Makefile index f8ac582..3c58910 100644 --- a/Makefile +++ b/Makefile @@ -16,16 +16,18 @@ ERL_INTERFACE_LIB_DIR ?= $(shell erl -noshell -eval "io:format(\"~ts\", [code:li UNAME_S := $(shell uname -s) CC ?= cc -CFLAGS_BASE = -O2 -Wall -Wextra -Werror -std=c99 +CFLAGS_BASE = -O2 -Wall -Wextra -Werror -std=c99 -fstack-protector-strong -D_FORTIFY_SOURCE=2 ifeq ($(UNAME_S),Darwin) # macOS needs _DARWIN_C_SOURCE for SCM_RIGHTS, CMSG_SPACE, etc. CFLAGS = $(CFLAGS_BASE) -D_DARWIN_C_SOURCE NIF_LDFLAGS = -dynamiclib -undefined dynamic_lookup + SHEPHERD_LDFLAGS = -fPIE NIF_EXT = .so else CFLAGS = $(CFLAGS_BASE) -D_GNU_SOURCE - NIF_LDFLAGS = -shared + NIF_LDFLAGS = -shared -Wl,-z,relro,-z,now -Wl,-z,noexecstack + SHEPHERD_LDFLAGS = -fPIE -pie -Wl,-z,relro,-z,now -Wl,-z,noexecstack NIF_EXT = .so endif @@ -52,7 +54,7 @@ $(PRIV_DIR): # Shepherd binary $(SHEPHERD): $(SHEPHERD_OBJ) - $(CC) -o $@ $< + $(CC) $(SHEPHERD_LDFLAGS) -o $@ $< $(SHEPHERD_OBJ): $(SHEPHERD_SRC) $(HEADERS) $(CC) $(CFLAGS) -I$(C_SRC_DIR) -c -o $@ $< diff --git a/c_src/net_runner_nif.c b/c_src/net_runner_nif.c index c2e8780..d860379 100644 --- a/c_src/net_runner_nif.c +++ b/c_src/net_runner_nif.c @@ -324,11 +324,17 @@ static ERL_NIF_TERM nif_close(ErlNifEnv *env, int argc, res->monitor_active = 0; } + /* Close FD inside critical section to prevent TOCTOU race: + * a concurrent nif_read/nif_write on a dirty scheduler could copy the FD + * under lock then use it after we release the lock but before close(). */ + int close_ret = close(fd); + int close_errno = errno; + enif_mutex_unlock(res->lock); - if (close(fd) != 0) { + if (close_ret != 0 && close_errno != EINTR) { return enif_make_tuple2(env, atom_error, - MAKE_ATOM(env, errno_to_atom(errno))); + MAKE_ATOM(env, errno_to_atom(close_errno))); } return atom_ok; diff --git a/c_src/shepherd.c b/c_src/shepherd.c index 4416938..2ea0d83 100644 --- a/c_src/shepherd.c +++ b/c_src/shepherd.c @@ -84,6 +84,10 @@ static int send_fds(int uds_fd, int *fds, int nfds) { msg.msg_controllen = cmsg_space; struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg); + if (!cmsg) { + free(cmsg_buf); + return -1; + } cmsg->cmsg_level = SOL_SOCKET; cmsg->cmsg_type = SCM_RIGHTS; cmsg->cmsg_len = CMSG_LEN((size_t)nfds * sizeof(int)); @@ -234,7 +238,7 @@ static void kill_child(pid_t child_pid) { /* Wait for graceful exit (configurable, default 5s) */ int poll_interval_us = 100000; /* 100ms */ - int iterations = kill_timeout_ms * 1000 / poll_interval_us; + int iterations = (int)((long)kill_timeout_ms * 1000 / poll_interval_us); if (iterations < 1) iterations = 1; for (int i = 0; i < iterations; i++) { @@ -264,6 +268,8 @@ static void handle_command(int uds_fd, pid_t child_pid, int stdin_w, case CMD_KILL: if (len >= 2 && child_pid > 0) { int sig = buf[1]; + /* Validate signal is in POSIX range */ + if (sig < 1 || sig > 31) break; /* Kill the process group (catches grandchildren). * Fall back to direct kill if group doesn't exist. */ if (kill(-child_pid, sig) != 0) { @@ -398,14 +404,25 @@ int main(int argc, char *argv[]) { /* Parse optional flags */ while (cmd_idx < argc && argv[cmd_idx][0] == '-') { if (strcmp(argv[cmd_idx], "--kill-timeout") == 0 && cmd_idx + 1 < argc) { - kill_timeout_ms = atoi(argv[cmd_idx + 1]); - if (kill_timeout_ms < 0) kill_timeout_ms = DEFAULT_KILL_TIMEOUT_MS; + char *endptr; + long val = strtol(argv[cmd_idx + 1], &endptr, 10); + if (*endptr != '\0' || endptr == argv[cmd_idx + 1] || val <= 0 || val > 60000) { + fprintf(stderr, "error: --kill-timeout must be 1-60000 ms\n"); + return 1; + } + kill_timeout_ms = (int)val; cmd_idx += 2; } else if (strcmp(argv[cmd_idx], "--pty") == 0) { pty_mode = MODE_PTY; cmd_idx += 1; } else if (strcmp(argv[cmd_idx], "--cgroup-path") == 0 && cmd_idx + 1 < argc) { - strncpy(cgroup_path, argv[cmd_idx + 1], CGROUP_PATH_MAX - 1); + const char *path = argv[cmd_idx + 1]; + /* Reject path traversal: no ".." components, no leading "/" */ + if (path[0] == '/' || strstr(path, "..") != NULL) { + fprintf(stderr, "error: invalid cgroup path (must be relative, no '..')\n"); + return 1; + } + strncpy(cgroup_path, path, CGROUP_PATH_MAX - 1); cgroup_path[CGROUP_PATH_MAX - 1] = '\0'; cmd_idx += 2; } else { @@ -436,7 +453,10 @@ int main(int argc, char *argv[]) { memset(&sa, 0, sizeof(sa)); sa.sa_handler = sigchld_handler; sa.sa_flags = SA_RESTART | SA_NOCLDSTOP; - sigaction(SIGCHLD, &sa, NULL); + if (sigaction(SIGCHLD, &sa, NULL) != 0) { + perror("sigaction"); + return 1; + } /* Connect to BEAM's UDS listener */ int uds_fd = socket(AF_UNIX, SOCK_STREAM, 0); @@ -473,6 +493,8 @@ int main(int argc, char *argv[]) { child_pid = fork(); if (child_pid < 0) { send_error(uds_fd, "fork failed"); + close(master_fd); + close(slave_fd); close(uds_fd); return 1; } @@ -511,12 +533,14 @@ int main(int argc, char *argv[]) { if (send_fds(uds_fd, fds_to_send, 1) != 0) { send_error(uds_fd, "failed to send PTY FD"); kill_child(child_pid); + close(master_fd); close(uds_fd); return 1; } if (send_child_started(uds_fd, child_pid) != 0) { kill_child(child_pid); + close(master_fd); close(uds_fd); return 1; } @@ -526,20 +550,54 @@ int main(int argc, char *argv[]) { } else { /* === Pipe mode (default) === */ - int stdin_pipe[2]; /* [0]=read (child), [1]=write (beam) */ - int stdout_pipe[2]; /* [0]=read (beam), [1]=write (child) */ - int stderr_pipe[2]; /* [0]=read (beam), [1]=write (child) */ + int stdin_pipe[2] = {-1, -1}; + int stdout_pipe[2] = {-1, -1}; + int stderr_pipe[2] = {-1, -1}; - if (pipe(stdin_pipe) != 0 || pipe(stdout_pipe) != 0 || - pipe(stderr_pipe) != 0) { +#ifdef __linux__ + /* Use pipe2 with O_CLOEXEC to atomically set close-on-exec */ + if (pipe2(stdin_pipe, O_CLOEXEC) != 0) { +#else + if (pipe(stdin_pipe) != 0) { +#endif send_error(uds_fd, "failed to create pipes"); close(uds_fd); return 1; } +#ifdef __linux__ + if (pipe2(stdout_pipe, O_CLOEXEC) != 0) { +#else + if (pipe(stdout_pipe) != 0) { +#endif + send_error(uds_fd, "failed to create pipes"); + close(stdin_pipe[0]); close(stdin_pipe[1]); + close(uds_fd); + return 1; + } +#ifdef __linux__ + if (pipe2(stderr_pipe, O_CLOEXEC) != 0) { +#else + if (pipe(stderr_pipe) != 0) { +#endif + send_error(uds_fd, "failed to create pipes"); + close(stdin_pipe[0]); close(stdin_pipe[1]); + close(stdout_pipe[0]); close(stdout_pipe[1]); + close(uds_fd); + return 1; + } +#ifndef __linux__ + /* On macOS, set close-on-exec manually */ + set_cloexec(stdin_pipe[0]); set_cloexec(stdin_pipe[1]); + set_cloexec(stdout_pipe[0]); set_cloexec(stdout_pipe[1]); + set_cloexec(stderr_pipe[0]); set_cloexec(stderr_pipe[1]); +#endif child_pid = fork(); if (child_pid < 0) { send_error(uds_fd, "fork failed"); + close(stdin_pipe[0]); close(stdin_pipe[1]); + close(stdout_pipe[0]); close(stdout_pipe[1]); + close(stderr_pipe[0]); close(stderr_pipe[1]); close(uds_fd); return 1; } @@ -579,12 +637,18 @@ int main(int argc, char *argv[]) { if (send_fds(uds_fd, fds_to_send, 3) != 0) { send_error(uds_fd, "failed to send FDs"); kill_child(child_pid); + close(stdin_pipe[1]); + close(stdout_pipe[0]); + close(stderr_pipe[0]); close(uds_fd); return 1; } if (send_child_started(uds_fd, child_pid) != 0) { kill_child(child_pid); + close(stdin_pipe[1]); + close(stdout_pipe[0]); + close(stderr_pipe[0]); close(uds_fd); return 1; } diff --git a/lib/net_runner/process.ex b/lib/net_runner/process.ex index e02a5ba..5c17daa 100644 --- a/lib/net_runner/process.ex +++ b/lib/net_runner/process.ex @@ -222,9 +222,23 @@ defmodule NetRunner.Process do when port == state.shepherd_port do # Shepherd died. Read exit status from UDS if we haven't already. state = maybe_read_exit_status(state) + + # If we still haven't received exit status, schedule a forced timeout + if state.status != :exited do + Process.send_after(self(), :force_exit_timeout, 5_000) + end + {:noreply, state} end + def handle_info(:force_exit_timeout, state) do + if state.status != :exited do + {:noreply, finish_exit(state, 137)} + else + {:noreply, state} + end + end + # UDS message from shepherd (via active socket) def handle_info({:"$socket", socket, :select, _info}, state) when socket == state.uds_socket do diff --git a/lib/net_runner/process/exec.ex b/lib/net_runner/process/exec.ex index 9c0b77a..1c7cd5d 100644 --- a/lib/net_runner/process/exec.ex +++ b/lib/net_runner/process/exec.ex @@ -25,11 +25,20 @@ defmodule NetRunner.Process.Exec do uds_path = uds_socket_path() pty_mode = Keyword.get(opts, :pty, false) - with {:ok, listen_socket} <- create_uds_listener(uds_path), + with :ok <- validate_cgroup_path(Keyword.get(opts, :cgroup_path, nil)), + {:ok, listen_socket} <- create_uds_listener(uds_path), shepherd_port <- open_shepherd(uds_path, cmd, args, opts), {:ok, conn_socket} <- accept_connection(listen_socket), - :ok <- cleanup_listener(listen_socket, uds_path), - {:ok, fds, iov_rest} <- receive_fds(conn_socket, pty_mode), + :ok <- cleanup_listener(listen_socket, uds_path) do + # conn_socket and shepherd_port are now live — clean up on any failure + setup_after_connection(conn_socket, shepherd_port, owner, cmd, args, opts, pty_mode) + else + {:error, reason} -> {:error, reason} + end + end + + defp setup_after_connection(conn_socket, shepherd_port, owner, cmd, args, opts, pty_mode) do + with {:ok, fds, iov_rest} <- receive_fds(conn_socket, pty_mode), {:ok, os_pid} <- extract_child_started(conn_socket, iov_rest), {:ok, pipes} <- wrap_fds(fds, owner, pty_mode) do stderr_mode = if pty_mode, do: :disabled, else: Keyword.get(opts, :stderr, :consume) @@ -48,7 +57,41 @@ defmodule NetRunner.Process.Exec do status: :running }} else - {:error, reason} -> {:error, reason} + {:error, reason} -> + safe_close_socket(conn_socket) + safe_port_close(shepherd_port) + {:error, reason} + end + end + + defp safe_port_close(port) when is_port(port) do + Port.close(port) + catch + _, _ -> :ok + end + + defp safe_port_close(_), do: :ok + + defp safe_close_socket(socket) do + :socket.close(socket) + catch + _, _ -> :ok + end + + defp validate_cgroup_path(nil), do: :ok + + defp validate_cgroup_path(path) do + path_str = to_string(path) + + cond do + String.starts_with?(path_str, "/") -> + {:error, {:invalid_cgroup_path, "must be relative, got: #{path_str}"}} + + String.contains?(path_str, "..") -> + {:error, {:invalid_cgroup_path, "cannot contain '..', got: #{path_str}"}} + + true -> + :ok end end diff --git a/lib/net_runner/stream.ex b/lib/net_runner/stream.ex index e35cb46..d39415b 100644 --- a/lib/net_runner/stream.ex +++ b/lib/net_runner/stream.ex @@ -53,7 +53,10 @@ defmodule NetRunner.Stream do Stream.resource( fn -> start_writer(pid, input) end, fn acc -> read_next(pid, acc) end, - fn _acc -> cleanup(pid) end + fn + {:error, _pid, _reason} = err -> cleanup(err) + _acc -> cleanup(pid) + end ) end @@ -98,7 +101,7 @@ defmodule NetRunner.Stream do # Check if writer is done, but don't block case Task.yield(writer, 0) do {:ok, _} -> read_next(pid, :reading) - {:exit, reason} -> raise "writer task crashed: #{inspect(reason)}" + {:exit, reason} -> {:halt, {:error, pid, reason}} nil -> do_read(pid, acc) end end @@ -123,7 +126,25 @@ defmodule NetRunner.Stream do end end - defp cleanup(pid) do + defp cleanup({:error, pid, reason}) do + # Writer task crashed — clean up process first, then re-raise + cleanup_process(pid) + raise "writer task crashed: #{inspect(reason)}" + end + + defp cleanup(pid) when is_pid(pid) do + cleanup_process(pid) + end + + defp cleanup({:writing, _writer}) do + :ok + end + + defp cleanup(:reading) do + :ok + end + + defp cleanup_process(pid) do if Process.alive?(pid) do Proc.close_stdin(pid) diff --git a/mix.exs b/mix.exs index da7c22c..e76b438 100644 --- a/mix.exs +++ b/mix.exs @@ -1,7 +1,7 @@ defmodule NetRunner.MixProject do use Mix.Project - @version "1.0.0" + @version "1.0.1" @source_url "https://github.com/nyo16/net_runner" def project do diff --git a/test/leak_test.exs b/test/leak_test.exs new file mode 100644 index 0000000..fe167e1 --- /dev/null +++ b/test/leak_test.exs @@ -0,0 +1,156 @@ +defmodule NetRunner.LeakTest do + use ExUnit.Case, async: false + + alias NetRunner.Process, as: Proc + + describe "FD leak prevention" do + @tag :linux_only + test "rapid spawn/kill cycle does not leak FDs" do + # Warm-up run to stabilize FD baseline + for _ <- 1..5 do + {:ok, pid} = Proc.start("true", []) + Proc.await_exit(pid) + end + + :erlang.garbage_collect() + Process.sleep(500) + + initial_fd_count = count_open_fds() + + for _ <- 1..20 do + {:ok, pid} = Proc.start("sleep", ["100"]) + Proc.kill(pid, :sigkill) + Proc.await_exit(pid) + GenServer.stop(pid, :normal) + end + + # Allow GC and cleanup time + :erlang.garbage_collect() + Process.sleep(1_000) + :erlang.garbage_collect() + Process.sleep(500) + + final_fd_count = count_open_fds() + + # Allow margin for BEAM-internal FD activity + assert final_fd_count <= initial_fd_count + 30, + "FD leak detected: started with #{initial_fd_count}, ended with #{final_fd_count}" + end + + test "stream abort cleans up process" do + # Start a long-running stream and abort mid-read + stream = NetRunner.stream!(["sh", "-c", "while true; do echo line; sleep 0.01; done"]) + + # Take only a few elements then halt + result = Enum.take(stream, 3) + assert length(result) == 3 + + # Give cleanup time to run + Process.sleep(500) + end + + test "process exit before read gives clean error" do + {:ok, pid} = Proc.start("true", []) + {:ok, 0} = Proc.await_exit(pid) + + # Read after exit should return eof or error, not crash + result = Proc.read(pid) + assert result in [:eof, {:error, :process_exited}, {:error, :closed}] + end + end + + describe "concurrent close+read" do + test "concurrent close and read does not crash" do + {:ok, pid} = Proc.start("cat", []) + + tasks = + for _ <- 1..5 do + Task.async(fn -> + try do + Proc.read(pid, 1024) + catch + :exit, _ -> :exited + end + end) + end + + # Close stdin and kill to trigger cleanup + Process.sleep(50) + Proc.kill(pid, :sigkill) + + results = + Enum.map(tasks, fn task -> + case Task.yield(task, 5_000) do + {:ok, result} -> result + nil -> Task.shutdown(task, :brutal_kill) + end + end) + + # All tasks should have completed without crashing the BEAM + assert length(results) == 5 + end + end + + describe "nif_close idempotency" do + test "closing an already-closed pipe returns :ok" do + {:ok, pid} = Proc.start("echo", ["test"]) + :ok = Proc.close_stdin(pid) + # Second close should be idempotent + :ok = Proc.close_stdin(pid) + Proc.await_exit(pid) + end + end + + describe "write to closed stdin" do + test "write after close_stdin returns error" do + {:ok, pid} = Proc.start("cat", []) + :ok = Proc.close_stdin(pid) + + result = Proc.write(pid, "should fail") + assert {:error, :closed} = result + + Proc.kill(pid, :sigkill) + Proc.await_exit(pid) + end + end + + describe "cgroup path validation" do + test "rejects path traversal with .." do + assert {:error, {:invalid_cgroup_path, msg}} = + Proc.start("echo", ["test"], cgroup_path: "../../etc/evil") + + assert msg =~ ".." + end + + test "rejects absolute cgroup path" do + assert {:error, {:invalid_cgroup_path, msg}} = + Proc.start("echo", ["test"], cgroup_path: "/sys/fs/cgroup/evil") + + assert msg =~ "relative" + end + end + + describe "multiple concurrent await_exit" do + test "all callers receive exit status" do + {:ok, pid} = Proc.start("echo", ["hello"]) + + tasks = + for _ <- 1..3 do + Task.async(fn -> + Proc.await_exit(pid, 5_000) + end) + end + + results = Task.await_many(tasks, 10_000) + assert Enum.all?(results, &match?({:ok, 0}, &1)) + end + end + + # Helper to count open FDs via /proc/self/fd + defp count_open_fds do + case File.ls("/proc/self/fd") do + {:ok, entries} -> length(entries) + {:error, _} -> 0 + end + end +end diff --git a/test/test_helper.exs b/test/test_helper.exs index 869559e..d1d11e6 100644 --- a/test/test_helper.exs +++ b/test/test_helper.exs @@ -1 +1,7 @@ -ExUnit.start() +exclude = + case :os.type() do + {:unix, :linux} -> [] + _ -> [:linux_only] + end + +ExUnit.start(exclude: exclude)