Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 50 additions & 27 deletions babs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import os.path as op
import subprocess
from dataclasses import replace
from pathlib import Path
from urllib.parse import urlparse

Expand All @@ -12,6 +13,13 @@
from babs.input_datasets import InputDatasets, OutputDatasets
from babs.scheduler import (
request_all_job_status,
run_squeue,
)
from babs.status import (
read_job_status_csv,
update_from_branches,
update_from_scheduler,
write_job_status_csv,
)
from babs.system import validate_queue
from babs.utils import (
Expand All @@ -20,12 +28,9 @@
get_results_branches,
identify_running_jobs,
read_yaml,
results_branch_dataframe,
results_status_columns,
scheduler_status_columns,
status_dtypes,
update_job_batch_status,
update_results_status,
validate_processing_level,
)

Expand Down Expand Up @@ -381,32 +386,50 @@ def _get_results_branches(self) -> list[str]:
"""Get the results branch names from the output RIA in a list."""
return get_results_branches(self.output_ria_data_dir)

def _update_results_status(self) -> None:
"""
Update the status of jobs based on results in the output RIA and zip files.
"""

previous_job_completion_df = self.get_job_status_df()

# Step 1: get a list of branches in the output ria to update the status
list_branches = self._get_results_branches()
completed_branches_df = results_branch_dataframe(list_branches, self.processing_level)
def _update_results_status(self) -> dict:
"""Update job statuses from external sources and write to CSV.

# Get any completed merged zip files
merged_zip_completion_df = self._get_merged_results_from_analysis_dir()

# Update the results status
current_status_df = update_results_status(
previous_job_completion_df, completed_branches_df, merged_zip_completion_df
)

# Part 2: Update which jobs are running
currently_running_df = self.get_currently_running_jobs_df()
current_status_df = update_job_batch_status(current_status_df, currently_running_df)
current_status_df['has_results'] = (
current_status_df['has_results'].astype('boolean').fillna(False)
Returns
-------
dict[tuple, JobStatus]
Updated statuses keyed by (sub_id,) or (sub_id, ses_id).
"""
# Read current state from CSV
if op.exists(self.job_status_path_abs):
statuses = read_job_status_csv(self.job_status_path_abs)
else:
statuses = {}

# Update from results branches in output RIA
branches = self._get_results_branches()
statuses = update_from_branches(statuses, branches)

# Update from merged zip files in analysis dir
merged_zip_df = self._get_merged_results_from_analysis_dir()
if not merged_zip_df.empty:
for _, row in merged_zip_df.iterrows():
ses_id = row.get('ses_id') if 'ses_id' in merged_zip_df.columns else None
key = (row['sub_id'], ses_id) if ses_id else (row['sub_id'],)
if key in statuses:
statuses[key] = replace(statuses[key], has_results=True)

# Update from scheduler (squeue)
job_ids = sorted(
{
job.job_id
for job in statuses.values()
if job.submitted and not job.has_results and job.job_id is not None
}
)
current_status_df.to_csv(self.job_status_path_abs, index=False)
raw_squeue_parts = []
for jid in job_ids:
raw_squeue_parts.append(run_squeue(self.queue, jid))
raw_squeue = ''.join(raw_squeue_parts)
statuses = update_from_scheduler(statuses, raw_squeue)

# Write updated statuses
write_job_status_csv(self.job_status_path_abs, statuses)
return statuses

def get_latest_submitted_jobs_df(self):
"""
Expand Down
26 changes: 7 additions & 19 deletions babs/bootstrap.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,23 @@
"""This is the main module."""

import csv
import os
import os.path as op
import subprocess
import tempfile
from pathlib import Path

import datalad.api as dlapi
import pandas as pd
import yaml
from jinja2 import Environment, PackageLoader, StrictUndefined

from babs.base import BABS
from babs.container import Container
from babs.input_datasets import InputDatasets
from babs.status import create_initial_statuses, write_job_status_csv
from babs.system import System, validate_queue
from babs.utils import (
get_datalad_version,
results_status_columns,
results_status_default_values,
status_dtypes,
validate_processing_level,
)

Expand Down Expand Up @@ -611,22 +609,12 @@ def clean_up(self):
print('\nCreated BABS project has been cleaned up.')

def _create_initial_job_status_csv(self):
"""
Create the initial job status csv file.
"""
"""Create the initial job status csv file."""
if op.exists(self.job_status_path_abs):
return

# Load the complete list of subjects and optionally sessions
df_sub = pd.read_csv(self.list_sub_path_abs)
df_job = df_sub.copy()

# Fill the columns that should get default values
for column_name, default_value in results_status_default_values.items():
df_job[column_name] = default_value

# ensure dtypes for all the columns
for column_name in results_status_columns:
df_job[column_name] = df_job[column_name].astype(status_dtypes[column_name])
with open(self.list_sub_path_abs, newline='') as f:
sub_ses_list = list(csv.DictReader(f))

df_job.to_csv(self.job_status_path_abs, index=False)
statuses = create_initial_statuses(sub_ses_list)
write_job_status_csv(self.job_status_path_abs, statuses)
6 changes: 2 additions & 4 deletions babs/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,5 @@ def babs_status(self):
"""
Check job status and makes a nice report.
"""
self._update_results_status()
currently_running_df = self.get_currently_running_jobs_df()
current_results_df = self.get_job_status_df()
report_job_status(current_results_df, currently_running_df, self.analysis_path)
statuses = self._update_results_status()
report_job_status(statuses, self.analysis_path)
76 changes: 58 additions & 18 deletions babs/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,53 @@
import pandas as pd
import yaml

from babs.status import SchedulerState
from babs.utils import get_username, scheduler_status_columns, status_dtypes


def run_squeue(queue, job_id: int) -> str:
"""Run squeue and return raw pipe-delimited output.

Parameters
----------
queue : str
Job scheduling system type (only 'slurm' supported).
job_id : int
The job array ID to query.

Returns
-------
str
Raw squeue stdout (pipe-delimited lines), or empty string if
no jobs found.
"""
if queue != 'slurm':
raise NotImplementedError(f'Queue {queue!r} is not supported.')
if not check_slurm_available():
raise RuntimeError('Slurm commands are not available on this system.')

username = get_username()
cmd = [
'squeue',
'-u',
username,
'-r',
'--noheader',
'--format=%i|%t|%M|%l|%D|%C|%P|%j',
f'-j{job_id}',
]

result = subprocess.run(cmd, capture_output=True, text=True)

if result.returncode == 1 and 'Invalid job id specified' in result.stderr:
return ''
if result.returncode != 0:
raise RuntimeError(
f'squeue failed with return code {result.returncode}\nstderr: {result.stderr}'
)
return result.stdout


def check_slurm_available() -> bool:
"""Check if Slurm commands are available on the system.

Expand Down Expand Up @@ -264,21 +308,16 @@ def submit_one_test_job(analysis_path, queue):
return job_id


def report_job_status(current_results_df, currently_running_df, analysis_path):
def report_job_status(statuses, analysis_path):
"""
Print a report that summarizes the overall status of a BABS project.

This will show how many of the jobs have been completed,
how many are still running, and how many have failed.

Parameters:
-------------
current_results_df: pd.DataFrame
dataframe the accurately reflects which tasks have finished
currently_running_df: pd.DataFrame
dataframe of currently running tasks
analysis_path: str
path to the `analysis` folder of a `BABS` project
Parameters
----------
statuses : dict[tuple, JobStatus]
Current job statuses keyed by (sub_id,) or (sub_id, ses_id).
analysis_path : str
Path to the ``analysis`` folder of a BABS project.
"""
from jinja2 import Environment, PackageLoader, StrictUndefined

Expand All @@ -291,12 +330,13 @@ def report_job_status(current_results_df, currently_running_df, analysis_path):
)
template = env.get_template('job_status_report.jinja')

total_jobs = current_results_df.shape[0]
total_submitted = int(current_results_df['submitted'].sum())
total_is_done = int(current_results_df['has_results'].sum())
total_pending = int((currently_running_df['state'] == 'PD').sum())
total_running = int((currently_running_df['state'] == 'R').sum())
total_failed = int(current_results_df['is_failed'].sum())
jobs = list(statuses.values())
total_jobs = len(jobs)
total_submitted = sum(1 for j in jobs if j.submitted)
total_is_done = sum(1 for j in jobs if j.has_results)
total_pending = sum(1 for j in jobs if j.scheduler_state == SchedulerState.PENDING)
total_running = sum(1 for j in jobs if j.scheduler_state == SchedulerState.RUNNING)
total_failed = sum(1 for j in jobs if j.is_failed)

print(
template.render(
Expand Down
Loading
Loading