diff --git a/babs/cli.py b/babs/cli.py index c06fbcb6..f7dee487 100644 --- a/babs/cli.py +++ b/babs/cli.py @@ -464,6 +464,18 @@ def _parse_status(): default=Path.cwd(), type=PathExists, ) + parser.add_argument( + '--wait', + action='store_true', + default=False, + help='Poll until all submitted jobs complete or fail.', + ) + parser.add_argument( + '--wait-interval', + type=int, + default=300, + help='Seconds between status checks when using --wait.', + ) return parser @@ -486,6 +498,8 @@ def _enter_status(argv=None): def babs_status_main( project_root: str, + wait: bool = False, + wait_interval: int = 300, ): """ This is the core function of `babs status`. @@ -494,11 +508,18 @@ def babs_status_main( ---------- project_root: str absolute path to the directory of BABS project + wait: bool + whether to poll until all submitted jobs complete or fail + wait_interval: int + seconds between status checks when using --wait """ from babs import BABSInteraction babs_proj = BABSInteraction(project_root) - babs_proj.babs_status() + if wait: + babs_proj.babs_status_wait(interval=wait_interval) + else: + babs_proj.babs_status() def _parse_merge(): diff --git a/babs/interaction.py b/babs/interaction.py index 004c294b..aa567906 100644 --- a/babs/interaction.py +++ b/babs/interaction.py @@ -1,5 +1,8 @@ """This is the main module.""" +import sys +import time + import numpy as np from babs.base import BABS @@ -149,3 +152,42 @@ def babs_status(self): self.ensure_shared_group_runtime_ready() statuses = self._update_results_status() report_job_status(statuses, self.analysis_path) + + def babs_status_wait(self, interval=300): + """Poll job status until all submitted jobs complete or fail. + + Exits 0 if nothing has been submitted or all submitted jobs + succeeded; exits 1 only if a submitted job failed; exits 130 + on Ctrl-C. + + Parameters + ---------- + interval: int + Seconds between status checks. + """ + try: + while True: + statuses = self._update_results_status() + report_job_status(statuses, self.analysis_path) + sys.stdout.flush() + + submitted = [j for j in statuses.values() if j.submitted] + if not submitted: + print('No jobs have been submitted; nothing to wait on.') + return + + done = all(j.has_results or j.is_failed for j in submitted) + if done: + n_results = sum(1 for j in submitted if j.has_results) + n_failed = sum(1 for j in submitted if j.is_failed) + print( + f'\nAll submitted jobs finished: {n_results} succeeded, {n_failed} failed.' + ) + if n_failed > 0: + sys.exit(1) + return + + time.sleep(interval) + except KeyboardInterrupt: + print('\nInterrupted by user.') + sys.exit(130) diff --git a/docker/environment.yml b/docker/environment.yml index 2a76413d..f1944360 100644 --- a/docker/environment.yml +++ b/docker/environment.yml @@ -25,7 +25,6 @@ dependencies: - pytest - pytest-cov==5.0.0 - pytest-env==1.1.3 - - pytest-timeout>=2.2.0 - pytest-xdist - python=3.11 - shellcheck diff --git a/environment_hpc.yml b/environment_hpc.yml index 314bfa63..1ef30a32 100644 --- a/environment_hpc.yml +++ b/environment_hpc.yml @@ -31,5 +31,4 @@ dependencies: - qstat>=0.0.5 - pytest-cov>=5.0.0 - pytest-env>=1.1.0 - - pytest-timeout>=2.2.0 - pytest-xdist diff --git a/pyproject.toml b/pyproject.toml index 7a2402ac..68479b9c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ dev = ["ruff ~= 0.4.3", "pre-commit"] tests = [ "coverage", "pytest", + "pytest-timeout>=2.2.0", "pytest-xdist", # for running pytest in parallel "pytest-cov", # for ordering test execution "datalad-osf", diff --git a/tests/e2e-slurm/container/walkthrough-tests.sh b/tests/e2e-slurm/container/walkthrough-tests.sh index c8947af8..ba1bf8bc 100755 --- a/tests/e2e-slurm/container/walkthrough-tests.sh +++ b/tests/e2e-slurm/container/walkthrough-tests.sh @@ -73,44 +73,26 @@ echo "Job submitted: Check setup, with job" babs submit -# # Wait for all running jobs to finish -while [[ -n $(squeue -u "$USER" -t RUNNING,PENDING --noheader) ]]; do - echo "squeue -u \"$USER\" -t RUNNING,PENDING" - squeue -u "$USER" -t RUNNING,PENDING - echo "Waiting for running jobs to finish..." - sleep 5 # Wait for 60 seconds before checking again -done - -echo "=========================================================================" -echo "babs status:" -babs status -echo "=========================================================================" - -# Check for failed jobs TODO see above -# if sacct -u $USER --state=FAILED --noheader | grep -q "FAILED"; then -sacct -u "$USER" -if sacct -u "$USER" --noheader | grep -q "FAILED"; then - echo "=========================================================================" - echo "There are failed jobs." - LOGS_DIR="analysis/logs" - if [ -d "$LOGS_DIR" ]; then - echo "=========================================================================" - echo "Failed job / task logs from $LOGS_DIR:" - for f in "$LOGS_DIR"/*; do - if [ -f "$f" ]; then - echo "---------- $f ----------" - cat "$f" - echo "" - fi - done - fi - exit 1 # Exit with failure status -else - echo "=========================================================================" - echo "PASSED: No failed jobs." -fi +babs status --wait --wait-interval 5 +echo "PASSED: No failed jobs." babs merge + +echo "Checking job_status.csv after merge..." +cat analysis/code/job_status.csv +python -c " +import csv, sys +with open('analysis/code/job_status.csv') as f: + for row in csv.DictReader(f): + if row['submitted'].strip().lower() == 'true': + if row['has_results'].strip().lower() != 'true': + print(f'FAIL: {row[\"sub_id\"]} submitted but has_results={row[\"has_results\"]}') + sys.exit(1) + if row['is_failed'].strip().lower() == 'true': + print(f'FAIL: {row[\"sub_id\"]} has_results=True but is_failed=True') + sys.exit(1) +print('PASSED: job_status.csv is consistent') +" echo "PASSED: e2e walkthrough successful!" popd @@ -134,14 +116,22 @@ pushd "${PWD}/${TEST2_NAME}" babs check-setup babs submit -# # Wait for all running jobs to finish -while [[ -n $(squeue -u "$USER" -t RUNNING,PENDING --noheader) ]]; do - echo "squeue -u \"$USER\" -t RUNNING,PENDING" - squeue -u "$USER" -t RUNNING,PENDING - echo "Waiting for running jobs to finish..." - sleep 5 # Wait for 60 seconds before checking again -done - -babs status +babs status --wait --wait-interval 5 babs merge + +echo "Checking job_status.csv after merge (multiinput)..." +cat analysis/code/job_status.csv +python -c " +import csv, sys +with open('analysis/code/job_status.csv') as f: + for row in csv.DictReader(f): + if row['submitted'].strip().lower() == 'true': + if row['has_results'].strip().lower() != 'true': + print(f'FAIL: {row[\"sub_id\"]} submitted but has_results={row[\"has_results\"]}') + sys.exit(1) + if row['is_failed'].strip().lower() == 'true': + print(f'FAIL: {row[\"sub_id\"]} has_results=True but is_failed=True') + sys.exit(1) +print('PASSED: job_status.csv is consistent (multiinput)') +" diff --git a/tests/test_interaction.py b/tests/test_interaction.py index abef84e5..6b086546 100644 --- a/tests/test_interaction.py +++ b/tests/test_interaction.py @@ -4,6 +4,7 @@ import pytest from babs.interaction import BABSInteraction +from babs.status import JobStatus, SchedulerState from babs.utils import scheduler_status_columns @@ -209,3 +210,219 @@ def test_get_latest_submitted_jobs_df_missing_job_id_column(babs_project_subject assert latest_df['sub_id'].tolist() == ['sub-01'] assert latest_df['task_id'].tolist() == [1] assert latest_df['job_id'].isna().all() + + +# -- babs_status_wait tests -- + + +def _make_statuses(submitted, has_results): + """Build a statuses dict from parallel lists of booleans.""" + statuses = {} + for i, (sub, res) in enumerate(zip(submitted, has_results, strict=True)): + sub_id = f'sub-{i + 1:02d}' + if sub and not res: + state = SchedulerState.DONE + elif sub: + state = SchedulerState.DONE + else: + state = SchedulerState.NOT_SUBMITTED + job = JobStatus( + sub_id=sub_id, + ses_id=None, + scheduler_state=state, + has_results=res, + job_id=100 + i if sub else None, + task_id=i + 1 if sub else None, + time_used='', + time_limit='', + nodes=0, + cpus=0, + partition='', + name='', + ) + statuses[job.key] = job + return statuses + + +def _patch_wait(monkeypatch, babs_proj, statuses_list): + """Patch a BABSInteraction for babs_status_wait testing. + + Parameters + ---------- + statuses_list : list[dict] + Sequence of statuses dicts returned by successive _update_results_status calls. + """ + call_count = {'n': 0} + + def _update(): + idx = min(call_count['n'], len(statuses_list) - 1) + call_count['n'] += 1 + return statuses_list[idx] + + monkeypatch.setattr(babs_proj, '_update_results_status', _update) + monkeypatch.setattr('babs.interaction.report_job_status', lambda *a, **kw: None) + monkeypatch.setattr('babs.interaction.time.sleep', lambda s: None) + + return call_count + + +def test_status_wait_all_succeeded(babs_project_subjectlevel, monkeypatch, capsys): + """All submitted jobs already have results — should exit immediately.""" + babs_proj = BABSInteraction(project_root=babs_project_subjectlevel) + statuses = _make_statuses(submitted=[True, True], has_results=[True, True]) + call_count = _patch_wait(monkeypatch, babs_proj, [statuses]) + + babs_proj.babs_status_wait(interval=1) + + assert call_count['n'] == 1 + captured = capsys.readouterr() + assert '2 succeeded' in captured.out + assert '0 failed' in captured.out + + +def test_status_wait_all_failed(babs_project_subjectlevel, monkeypatch): + """All submitted jobs failed — should exit with sys.exit(1).""" + babs_proj = BABSInteraction(project_root=babs_project_subjectlevel) + statuses = _make_statuses(submitted=[True, True], has_results=[False, False]) + _patch_wait(monkeypatch, babs_proj, [statuses]) + + with pytest.raises(SystemExit, match='1'): + babs_proj.babs_status_wait(interval=1) + + +def test_status_wait_mixed_results(babs_project_subjectlevel, monkeypatch, capsys): + """Some succeeded, some failed — should exit(1).""" + babs_proj = BABSInteraction(project_root=babs_project_subjectlevel) + statuses = _make_statuses(submitted=[True, True], has_results=[True, False]) + _patch_wait(monkeypatch, babs_proj, [statuses]) + + with pytest.raises(SystemExit, match='1'): + babs_proj.babs_status_wait(interval=1) + + +def test_status_wait_loops_until_done(babs_project_subjectlevel, monkeypatch, capsys): + """Jobs still running on first check, done on second — should loop once.""" + babs_proj = BABSInteraction(project_root=babs_project_subjectlevel) + + # First poll: running + running = {} + for i in range(2): + sub_id = f'sub-{i + 1:02d}' + job = JobStatus( + sub_id=sub_id, + ses_id=None, + scheduler_state=SchedulerState.RUNNING, + has_results=False, + job_id=100 + i, + task_id=i + 1, + time_used='', + time_limit='', + nodes=0, + cpus=0, + partition='', + name='', + ) + running[job.key] = job + + # Second poll: done + done = _make_statuses(submitted=[True, True], has_results=[True, True]) + + call_count = _patch_wait(monkeypatch, babs_proj, [running, done]) + + babs_proj.babs_status_wait(interval=1) + + assert call_count['n'] == 2 + captured = capsys.readouterr() + assert '2 succeeded' in captured.out + + +def test_status_wait_no_submitted_jobs(babs_project_subjectlevel, monkeypatch, capsys): + """No jobs submitted — should return cleanly (exit 0).""" + babs_proj = BABSInteraction(project_root=babs_project_subjectlevel) + statuses = _make_statuses(submitted=[False, False], has_results=[False, False]) + _patch_wait(monkeypatch, babs_proj, [statuses]) + + babs_proj.babs_status_wait(interval=1) + captured = capsys.readouterr() + assert 'No jobs have been submitted' in captured.out + + +def test_status_wait_report_called_each_iteration(babs_project_subjectlevel, monkeypatch): + """report_job_status should be called on every iteration.""" + babs_proj = BABSInteraction(project_root=babs_project_subjectlevel) + + running = {} + sub_id = 'sub-01' + job = JobStatus( + sub_id=sub_id, + ses_id=None, + scheduler_state=SchedulerState.RUNNING, + has_results=False, + job_id=100, + task_id=1, + time_used='', + time_limit='', + nodes=0, + cpus=0, + partition='', + name='', + ) + running[job.key] = job + + done = _make_statuses(submitted=[True], has_results=[True]) + + report_calls = [] + + call_count = {'n': 0} + + def _update(): + idx = min(call_count['n'], 1) + call_count['n'] += 1 + return [running, done][idx] + + monkeypatch.setattr(babs_proj, '_update_results_status', _update) + monkeypatch.setattr( + 'babs.interaction.report_job_status', + lambda *a, **kw: report_calls.append(a), + ) + monkeypatch.setattr('babs.interaction.time.sleep', lambda s: None) + + babs_proj.babs_status_wait(interval=1) + + assert len(report_calls) == 2 + + +def test_status_wait_keyboard_interrupt(babs_project_subjectlevel, monkeypatch, capsys): + """Ctrl-C should print a message and exit(130).""" + babs_proj = BABSInteraction(project_root=babs_project_subjectlevel) + + running = {} + job = JobStatus( + sub_id='sub-01', + ses_id=None, + scheduler_state=SchedulerState.RUNNING, + has_results=False, + job_id=100, + task_id=1, + time_used='', + time_limit='', + nodes=0, + cpus=0, + partition='', + name='', + ) + running[job.key] = job + + monkeypatch.setattr(babs_proj, '_update_results_status', lambda: running) + monkeypatch.setattr('babs.interaction.report_job_status', lambda *a, **kw: None) + monkeypatch.setattr( + 'babs.interaction.time.sleep', + lambda s: (_ for _ in ()).throw(KeyboardInterrupt), + ) + + with pytest.raises(SystemExit) as exc_info: + babs_proj.babs_status_wait(interval=1) + + assert exc_info.value.code == 130 + captured = capsys.readouterr() + assert 'Interrupted' in captured.out