Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions babs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import datalad.api as dlapi
import pandas as pd
import yaml

from babs.input_datasets import InputDatasets, OutputDatasets
from babs.scheduler import (
Expand Down Expand Up @@ -45,7 +46,7 @@
class BABS:
"""The BABS base class holds common attributes and methods for all BABS classes."""

def __init__(self, project_root):
def __init__(self, project_root, container_config=None):
"""The BABS class is for babs projects of BIDS Apps.

The constructor only initializes the attributes.
Expand Down Expand Up @@ -108,13 +109,28 @@ def __init__(self, project_root):
# attributes:
self.project_root = str(project_root)

self.analysis_path = op.join(self.project_root, 'analysis')
if container_config is not None:
with open(container_config) as f:
cfg = yaml.safe_load(f)
else:
root_config_path = op.join(self.project_root, '.babs', 'babs_init_config.yaml')
cfg = {}
if op.exists(root_config_path):
with open(root_config_path) as f:
cfg = yaml.safe_load(f) or {}

analysis_dir = cfg.get('analysis_path', 'analysis')
self.analysis_path = op.normpath(op.join(self.project_root, analysis_dir))
self._analysis_datalad_handle = None

self.config_path = op.join(self.analysis_path, 'code/babs_proj_config.yaml')

self.input_ria_path = op.join(self.project_root, 'input_ria')
self.output_ria_path = op.join(self.project_root, 'output_ria')
self.input_ria_path = op.normpath(
op.join(self.project_root, cfg.get('input_ria_path', 'input_ria'))
)
self.output_ria_path = op.normpath(
op.join(self.project_root, cfg.get('output_ria_path', 'output_ria'))
)

self.input_ria_url = 'ria+file://' + self.input_ria_path
self.output_ria_url = 'ria+file://' + self.output_ria_path
Expand All @@ -128,6 +144,7 @@ def __init__(self, project_root):
self.job_status_path_rel = 'code/job_status.csv'
self.job_status_path_abs = op.join(self.analysis_path, self.job_status_path_rel)
self.job_submit_path_abs = op.join(self.analysis_path, 'code/job_submit.csv')
self.analysis_root = op.dirname(self.analysis_path)
self._shared_group_enabled_cache = None
self._apply_config()

Expand Down Expand Up @@ -178,7 +195,7 @@ def _apply_config(self) -> None:
self.wtf_key_info(flag_output_ria_only=True)

self.input_datasets = InputDatasets(self.processing_level, config_yaml['input_datasets'])
self.input_datasets.update_abs_paths(Path(self.project_root) / 'analysis')
self.input_datasets.update_abs_paths(Path(self.analysis_path))
self.ensure_shared_group_git_safe_directories()

def _validate_pipeline_config(self) -> None:
Expand Down
28 changes: 25 additions & 3 deletions babs/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import csv
import os
import os.path as op
import shutil
import subprocess
import tempfile
from pathlib import Path
Expand All @@ -25,6 +26,9 @@
class BABSBootstrap(BABS):
"""A BABS subclass that implements the bootstrap process."""

def __init__(self, project_root, container_config=None):
super().__init__(project_root, container_config=container_config)

def _apply_config(self):
pass

Expand Down Expand Up @@ -114,14 +118,24 @@ def babs_bootstrap(
self.queue = validate_queue(queue)
system = System(self.queue)

# Create `analysis` folder: -----------------------------
# Create analysis folder: -----------------------------
print('DataLad version: ' + get_datalad_version())
print('\nCreating `analysis` folder (also a datalad dataset)...')
print(f'\nCreating `{self.analysis_path}` folder (also a datalad dataset)...')
create_kwargs = {'cfg_proc': 'yoda', 'annex': True}
if self.shared_group is not None:
create_kwargs['initopts'] = ['--shared=group']
self._analysis_datalad_handle = dlapi.create(self.analysis_path, **create_kwargs)
self.input_datasets.update_abs_paths(Path(self.analysis_path))

# Persist original config so other BABS commands can find it:
babs_dir = op.join(self.project_root, '.babs')
os.makedirs(babs_dir, exist_ok=True)
shutil.copy2(container_config, op.join(babs_dir, 'babs_init_config.yaml'))
if op.normpath(self.analysis_path) == op.normpath(self.project_root):
self.datalad_save(
path='.babs/babs_init_config.yaml',
message='Save babs init config',
)
self.input_datasets.set_inclusion_dataframe(initial_inclusion_df, processing_level)

# Prepare `.gitignore` ------------------------------
Expand All @@ -132,6 +146,9 @@ def babs_bootstrap(
os.remove(gitignore_path)
gitignore_file = open(gitignore_path, 'a') # open in append mode

# not to track input/output RIA stores:
gitignore_file.write('\n' + op.basename(self.input_ria_path))
gitignore_file.write('\n' + op.basename(self.output_ria_path))
# not to track `logs` folder:
gitignore_file.write('\nlogs')
# not to track `.*_datalad_lock`:
Expand All @@ -153,7 +170,7 @@ def babs_bootstrap(

# Create `babs_proj_config.yaml` file: ----------------------
print('Save BABS project configurations in a YAML file ...')
print("Path to this yaml file will be: 'analysis/code/babs_proj_config.yaml'")
print(f"Path to this yaml file will be: '{self.config_path}'")

env = Environment(
loader=PackageLoader('babs', 'templates'),
Expand All @@ -165,6 +182,7 @@ def babs_bootstrap(
with open(self.config_path, 'w') as f:
f.write(
template.render(
analysis_dir=op.basename(self.analysis_path),
processing_level=self.processing_level,
queue=self.queue,
input_ds=self.input_datasets,
Expand Down Expand Up @@ -443,6 +461,7 @@ def _bootstrap_single_app_scripts(
self.processing_level,
system,
project_root=op.dirname(self.analysis_path),
analysis_dir=op.basename(self.analysis_path),
shared_group_mode=shared_group_mode,
)

Expand Down Expand Up @@ -518,6 +537,7 @@ def _bootstrap_pipeline_scripts(self, container_ds, container_config, system):
container_images=container_images,
datalad_run_message='pipeline',
project_root=op.dirname(self.analysis_path),
analysis_dir=op.basename(self.analysis_path),
)

with open(bash_path, 'w') as f:
Expand Down Expand Up @@ -599,6 +619,8 @@ def clean_up(self):
if op.exists(self.analysis_path): # analysis folder is created by datalad
print('Removing input dataset(s) if cloned...')
for in_ds in self.input_datasets:
if in_ds._babs_project_analysis_path is None:
continue
if op.exists(in_ds.babs_project_analysis_path):
# use `datalad remove` to remove:
_ = self.analysis_datalad_handle.remove(
Expand Down
9 changes: 5 additions & 4 deletions babs/check_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ def babs_check_setup(self, submit_a_test_job):
print('Did not request `--job-test`; will not submit a test job.')

# Print out the saved configuration info: ----------------
analysis_dir = op.basename(self.analysis_path)
print(
'Below is the configuration information saved during `babs init`'
" in file 'analysis/code/babs_proj_config.yaml':\n"
f" in file '{analysis_dir}/code/babs_proj_config.yaml':\n"
)
with open(op.join(self.analysis_path, 'code/babs_proj_config.yaml')) as f:
file_contents = f.read()
Expand All @@ -52,13 +53,13 @@ def babs_check_setup(self, submit_a_test_job):
print('Checking the BABS project itself...')
if not op.exists(self.analysis_path):
raise FileNotFoundError(
"Folder 'analysis' does not exist in this BABS project!"
f"Folder '{analysis_dir}' does not exist in this BABS project!"
' Current path to analysis folder: ' + self.analysis_path
)
print(CHECK_MARK + ' All good!')

# Check `analysis` datalad dataset: ----------------------
print("\nCheck status of 'analysis' DataLad dataset...")
# Check analysis datalad dataset: ----------------------
print(f"\nCheck status of '{analysis_dir}' DataLad dataset...")
# Are there anything unsaved? ref: CuBIDS function
analysis_statuses = {
status['state']
Expand Down
2 changes: 1 addition & 1 deletion babs/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def babs_init_main(

from babs import BABSBootstrap

babs_proj = BABSBootstrap(project_root)
babs_proj = BABSBootstrap(project_root, container_config=container_config)
try:
babs_proj.babs_bootstrap(
processing_level,
Expand Down
2 changes: 2 additions & 0 deletions babs/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def generate_bash_participant_job(
processing_level,
system,
project_root=None,
analysis_dir='analysis',
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new arg but docstring wasn't updated

shared_group_mode=False,
):
"""Generate bash script for participant job.
Expand Down Expand Up @@ -188,6 +189,7 @@ def generate_bash_participant_job(
container_name=self.container_name,
zip_foldernames=self.config['zip_foldernames'],
project_root=project_root,
analysis_dir=analysis_dir,
)

with open(bash_path, 'w') as f:
Expand Down
2 changes: 2 additions & 0 deletions babs/generate_submit_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def generate_submit_script(
container_images=None,
datalad_run_message=None,
project_root=None,
analysis_dir='analysis',
):
"""
Generate a bash script that runs the BIDS App singularity image.
Expand Down Expand Up @@ -129,6 +130,7 @@ def generate_submit_script(
container_image_paths=container_image_paths,
datalad_run_message=datalad_run_message,
project_root=project_root,
analysis_dir=analysis_dir,
)


Expand Down
2 changes: 1 addition & 1 deletion babs/templates/job_submit.yaml.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
# '${max_array}' is a placeholder.
{% endif %}

cmd_template: '{{ submit_head }} {{ env_flags }} {{ name_flag_str }}{{ job_name }} {{ eo_args }} {{ array_args }} {% if test %}{{ babs.analysis_path }}/code/check_setup/call_test_job.sh{% else %}{{ babs.analysis_path }}/code/participant_job.sh {{ dssource }} {{ pushgitremote }} {{ babs.job_submit_path_abs }} {{ babs.project_root }}{% endif %}'
cmd_template: '{{ submit_head }} {{ env_flags }} {{ name_flag_str }}{{ job_name }} {{ eo_args }} {{ array_args }} {% if test %}{{ babs.analysis_path }}/code/check_setup/call_test_job.sh{% else %}{{ babs.analysis_path }}/code/participant_job.sh {{ dssource }} {{ pushgitremote }} {{ babs.job_submit_path_abs }} {{ babs.analysis_root }}{% endif %}'
job_name_template: '{{ job_name }}'
2 changes: 1 addition & 1 deletion babs/templates/participant_job.sh.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ CONTAINER_IMAGE_PATHS=(
)

for CONTAINER_JOB in "${CONTAINER_IMAGE_PATHS[@]}"; do
CONTAINER_SHARED="${PROJECT_ROOT}/analysis/${CONTAINER_JOB}"
CONTAINER_SHARED="${PROJECT_ROOT}/{{ analysis_dir }}/${CONTAINER_JOB}"

if [ ! -e "${CONTAINER_SHARED}" ] && [ ! -L "${CONTAINER_SHARED}" ]; then
echo "ERROR: shared container image not found at ${CONTAINER_SHARED}" >&2
Expand Down
5 changes: 4 additions & 1 deletion tests/test_babs_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,10 @@ def test_init_forwards_shared_group(tmp_path):
with mock.patch('babs.BABSBootstrap') as mock_bootstrap_cls:
_enter_init()

mock_bootstrap_cls.assert_called_once_with(options.project_root)
mock_bootstrap_cls.assert_called_once_with(
options.project_root,
container_config=options.container_config,
)
mock_bootstrap_cls.return_value.babs_bootstrap.assert_called_once_with(
options.processing_level,
options.queue,
Expand Down
Loading