Skip to content

Better support for target models on the ensemble attack#135

Open
lotif wants to merge 46 commits intomainfrom
marcelo/support-attack-models
Open

Better support for target models on the ensemble attack#135
lotif wants to merge 46 commits intomainfrom
marcelo/support-attack-models

Conversation

@lotif
Copy link
Collaborator

@lotif lotif commented Mar 19, 2026

PR Type

Feature

Short Description

Clickup Ticket(s): https://app.clickup.com/t/868h6nmfc

Refactoring the ensemble attack to allow for more flexibility when setting the target model.

The way I went about this was to use inheritance, with the introduction of the following abstract classes on the midst_toolkit/attacks/ensemble/models.py:

  • EnsembleAttackModelRunner: is the main class responsible for running the model (training, fine tuning and synthesizing). It will also store the configs.
    • The train_or_fine_tune_and_synthesize method is the main method any models that we run the ensemble attack against need to implement.
  • EnsembleAttackTrainingConfig: Will store the configs for training, synthesizing and fine-tuning. Inherits from midst_toolkit.common.config.TrainingConfig.
  • EnsembleAttackTrainingResult: Stores the result of the ensemble attack model training.

Each one of those classes have their respective TabDDPM and ClavaDDPM counterparts, with their implementations moved from the utils file into the train_or_fine_tune_and_synthesize functions.

Most of the other changes are just general low hanging fruit refactorings.

Tests Added

Just fixing the currently existing tests

lotif and others added 30 commits January 8, 2026 14:51
@coderabbitai
Copy link

coderabbitai bot commented Mar 19, 2026

📝 Walkthrough

Walkthrough

The PR introduces a new abstraction model for ensemble attack training by creating an EnsembleAttackModelRunner interface with implementations for TabDDPM and CTGAN. This replaces previous model-type–driven branching with polymorphic training workflows. Hardcoded filenames are replaced with module-level constants from process_split_data. Shadow model training functions are refactored to accept model_runner parameters instead of passing configuration and model-type enums directly. Legacy training utility functions for TabDDPM and CTGAN are consolidated into the new runner implementations. Supporting functions like save_additional_training_config are simplified to work with the new config type system.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: improving support for target models on the ensemble attack through refactoring.
Description check ✅ Passed The description covers the PR type, includes a ClickUp ticket reference, explains the refactoring approach with abstract classes, documents the key abstractions introduced, and mentions tests were fixed accordingly.
Docstring Coverage ✅ Passed Docstring coverage is 86.21% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch marcelo/support-attack-models
📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Tip

You can make CodeRabbit's review stricter and more nitpicky using the `assertive` profile, if that's what you prefer.

Change the reviews.profile setting to assertive to make CodeRabbit's nitpick more issues in your PRs.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
tests/unit/attacks/ensemble/test_process_data_split.py (1)

41-44: ⚠️ Potential issue | 🟡 Minor

Duplicate assertion detected.

Line 44 duplicates the assertion from line 43 (both check real_test.csv). This appears to be a copy-paste error and doesn't add test coverage.

🔧 Suggested fix
     # Assert that the split real data files are saved in the provided path
     assert (output_dir / "real_train.csv").exists()
     assert (output_dir / "real_val.csv").exists()
     assert (output_dir / "real_test.csv").exists()
-    assert (output_dir / "real_test.csv").exists()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit/attacks/ensemble/test_process_data_split.py` around lines 41 - 44,
Remove the duplicate assertion that repeats checking (output_dir /
"real_test.csv").exists(); locate the repeated line asserting (output_dir /
"real_test.csv").exists() and delete it so the test only asserts real_train.csv,
real_val.csv, and real_test.csv once each (or, if the intent was to verify a
different file, replace the duplicate with the correct filename instead of
duplicating the real_test.csv assertion).
examples/ensemble_attack/test_attack_model.py (1)

345-353: ⚠️ Potential issue | 🟡 Minor

Typo in variable names: "mataclassifier" should be "metaclassifier".

Lines 345-351 have consistent typos: mataclassifier_path and trained_mataclassifier_model. While this doesn't affect functionality, it reduces code readability.

📝 Suggested fix
-    mataclassifier_path = Path(config.metaclassifier.metaclassifier_model_path) / f"{metaclassifier_model_name}.pkl"
-    assert mataclassifier_path.exists(), (
-        f"No metaclassifier model found at {mataclassifier_path}. Make sure to run the training script first."
+    metaclassifier_path = Path(config.metaclassifier.metaclassifier_model_path) / f"{metaclassifier_model_name}.pkl"
+    assert metaclassifier_path.exists(), (
+        f"No metaclassifier model found at {metaclassifier_path}. Make sure to run the training script first."
     )

-    with open(mataclassifier_path, "rb") as f:
-        trained_mataclassifier_model = pickle.load(f)
+    with open(metaclassifier_path, "rb") as f:
+        trained_metaclassifier_model = pickle.load(f)

-    log(INFO, f"Metaclassifier model loaded from {mataclassifier_path}, starting the test...")
+    log(INFO, f"Metaclassifier model loaded from {metaclassifier_path}, starting the test...")

Also update line 422:

-    blending_attacker.trained_model = trained_mataclassifier_model
+    blending_attacker.trained_model = trained_metaclassifier_model
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/ensemble_attack/test_attack_model.py` around lines 345 - 353, Rename
the misspelled variables in the test to use "metaclassifier" consistently:
change mataclassifier_path to metaclassifier_path (the Path construction and the
existence assertion) and change trained_mataclassifier_model to
trained_metaclassifier_model (the pickle.load assignment and any subsequent
uses, including the referenced occurrence around line 422); update all
references so variable names match (e.g., in the open(...) block and the log
message) to improve readability and avoid typos.
src/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.py (1)

53-80: ⚠️ Potential issue | 🟡 Minor

Docstring references removed parameters.

Lines 67 and 73-74 document parameters fine_tuning_config and number_of_points_to_synthesize that no longer exist in the function signature after the refactor.

📝 Suggested fix
         training_json_config_paths: Configuration dictionary containing paths to the data JSON config files.
             An example of this config is provided in ``examples/ensemble_attack/config.yaml``. Required keys are:
             - table_domain_file_path (str): Path to the table domain json file.
             - dataset_meta_file_path (str): Path to dataset meta json file.
             - training_config_path (str): Path to table's training config json file.
-        fine_tuning_config: Configuration dictionary containing shadow model fine-tuning specific information.
         init_model_id: An ID to assign to the pre-trained initial models. This can be used to save multiple
             pre-trained models with different IDs.
         table_name: Name of the main table to be used for training the TabDDPM model.
         id_column_name: Name of the ID column in the data.
         pre_training_data_size: Size of the initial training set, defaults to 60,000.
-        number_of_points_to_synthesize: Size of the synthetic data to be generated by each shadow model,
-            defaults to 20,000.
         init_data_seed: Random seed for the initial training set.
         random_seed: Random seed used for reproducibility, defaults to None.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.py` around
lines 53 - 80, The docstring for the shadow model training function in
src/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.py still documents
removed parameters `fine_tuning_config` and `number_of_points_to_synthesize`;
update the function's docstring to match the current signature by removing any
references to those parameters (or replacing them with the correct current
parameter names), ensure the Args section lists only existing parameters such as
`model_runner`, `n_models`, `n_reps`, `population_data`,
`master_challenge_data`, `shadow_models_output_path`,
`training_json_config_paths`, `init_model_id`, `table_name`, `id_column_name`,
`pre_training_data_size`, `init_data_seed`, and `random_seed`, and adjust the
description of Returns if needed so the docstring accurately reflects the
function's current behavior.
🧹 Nitpick comments (3)
examples/ensemble_attack/run_attack.py (1)

87-96: Consider adding error handling for JSON config loading.

The JSON is loaded directly into EnsembleAttackTabDDPMTrainingConfig without validation. If the JSON structure doesn't match the expected schema, the error message may be unclear.

💡 Optional: Add a try-except for clearer error messages
     with open(config.shadow_training.training_json_config_paths.training_config_path, "r") as file:
-        training_config = EnsembleAttackTabDDPMTrainingConfig(**json.load(file))
+        try:
+            training_config = EnsembleAttackTabDDPMTrainingConfig(**json.load(file))
+        except (TypeError, ValueError) as e:
+            raise ValueError(
+                f"Failed to parse training config from "
+                f"{config.shadow_training.training_json_config_paths.training_config_path}: {e}"
+            ) from e
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/ensemble_attack/run_attack.py` around lines 87 - 96, The JSON
loading for EnsembleAttackTabDDPMTrainingConfig is unprotected, so malformed or
schema-mismatched JSON will raise unclear errors; wrap the
open/json.load/EnsembleAttackTabDDPMTrainingConfig(...) call in a try-except
that catches json.JSONDecodeError and TypeError/ValueError (or a generic
Exception) and re-raise or log a clear message including the config path and the
original exception, then only proceed to set fine_tuning_* fields and
instantiate EnsembleAttackTabDDPMModelRunner when parsing succeeds.
src/midst_toolkit/attacks/ensemble/models.py (1)

115-117: Hardcoded table name "trans" should be parameterized.

The load_tables call uses a hardcoded "trans" key in the train_data dictionary. This assumes the table is always named "trans", but other datasets may use different table names.

Consider either:

  1. Adding a table_name field to EnsembleAttackTabDDPMTrainingConfig (similar to how EnsembleAttackCTGANTrainingConfig has it)
  2. Deriving the table name from the dataset metadata
♻️ Suggested approach
+class EnsembleAttackTabDDPMTrainingConfig(ClavaDDPMTrainingConfig, EnsembleAttackTrainingConfig):
+    fine_tuning_diffusion_iterations: int = 100
+    fine_tuning_classifier_iterations: int = 10
+    table_name: str = "trans"  # Default for backward compatibility

Then in train_or_fine_tune_and_synthesize:

-        tables, relation_order, _ = load_tables(self.training_config.general.data_dir, train_data={"trans": dataset})
+        tables, relation_order, _ = load_tables(
+            self.training_config.general.data_dir, 
+            train_data={self.training_config.table_name: dataset}
+        )

And update synthesis line 182:

-            result.synthetic_data = cleaned_tables["trans"]
+            result.synthetic_data = cleaned_tables[self.training_config.table_name]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/midst_toolkit/attacks/ensemble/models.py` around lines 115 - 117, The
load_tables call in train_or_fine_tune_and_synthesize uses a hardcoded "trans"
key which breaks datasets with different table names; add a table_name field to
EnsembleAttackTabDDPMTrainingConfig (mirroring
EnsembleAttackCTGANTrainingConfig) or derive the table name from the provided
dataset metadata, then replace the hardcoded "trans" with that value (e.g., use
self.training_config.table_name or dataset.metadata.table_name) in the
load_tables call and anywhere later (including the synthesis usage around the
previous line 182) so the code uses the configured/derived table name instead of
"trans".
examples/ensemble_attack/run_shadow_model_training.py (1)

61-72: Shared model_runner state is mutated by multiple callers.

Based on the context snippets, the same model_runner instance flows through run_shadow_model_trainingtrain_three_sets_of_shadow_models → multiple shadow training functions, each of which overwrites model_runner.training_config.general.* fields. Then the same instance is passed to run_target_model_training, which resets these fields here (lines 68-70).

While this appears to work correctly because the reset happens before the target model uses the config, this pattern is fragile. If the call order changes or a caller forgets to reset, stale paths could cause models to be saved to unexpected locations.

Consider either:

  1. Documenting this behavior explicitly in the function docstring
  2. Creating a fresh config copy for each training phase instead of mutating the shared instance
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/ensemble_attack/run_shadow_model_training.py` around lines 61 - 72,
The code mutates the shared model_runner.training_config.general for different
training phases (see model_runner, training_config.general,
save_additional_training_config, train_or_fine_tune_and_synthesize) which is
fragile; fix by creating and using a fresh config copy for each phase instead of
mutating the shared instance—e.g., deep-copy model_runner.training_config (or
construct a new TrainingConfig from save_additional_training_config's return)
and assign that copy to a separate ModelRunner or pass it into
train_or_fine_tune_and_synthesize, leaving the original model_runner unchanged;
update run_shadow_model_training, train_three_sets_of_shadow_models, and
run_target_model_training to operate on their own config copies.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/gan/ensemble_attack/utils.py`:
- Around line 32-43: The docstring for function make_training_config contains a
typo "attacktraining"; update the docstring text to read "attack training" (in
the description line and any other occurrences within the function's docstring)
so it clearly states "Make the ensemble attack training config for the CTGAN
model..." and preserves existing punctuation and formatting.

In `@src/midst_toolkit/attacks/ensemble/shadow_model_utils.py`:
- Around line 19-34: Update the docstring for the function that modifies and
loads training configurations (the function taking config_type, data_dir,
training_config_json_path, final_config_json_path, experiment_name,
workspace_name) to remove the specific "TabDDPM" mention and describe it as a
generic modifier/loader for any EnsembleAttackTrainingConfig subclass;
explicitly state that config_type accepts an EnsembleAttackTrainingConfig
subclass, and ensure the Args and Returns sections reflect the generic behavior
and returned configs/save_dir values instead of TabDDPM-specific language.

---

Outside diff comments:
In `@examples/ensemble_attack/test_attack_model.py`:
- Around line 345-353: Rename the misspelled variables in the test to use
"metaclassifier" consistently: change mataclassifier_path to metaclassifier_path
(the Path construction and the existence assertion) and change
trained_mataclassifier_model to trained_metaclassifier_model (the pickle.load
assignment and any subsequent uses, including the referenced occurrence around
line 422); update all references so variable names match (e.g., in the open(...)
block and the log message) to improve readability and avoid typos.

In `@src/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.py`:
- Around line 53-80: The docstring for the shadow model training function in
src/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.py still documents
removed parameters `fine_tuning_config` and `number_of_points_to_synthesize`;
update the function's docstring to match the current signature by removing any
references to those parameters (or replacing them with the correct current
parameter names), ensure the Args section lists only existing parameters such as
`model_runner`, `n_models`, `n_reps`, `population_data`,
`master_challenge_data`, `shadow_models_output_path`,
`training_json_config_paths`, `init_model_id`, `table_name`, `id_column_name`,
`pre_training_data_size`, `init_data_seed`, and `random_seed`, and adjust the
description of Returns if needed so the docstring accurately reflects the
function's current behavior.

In `@tests/unit/attacks/ensemble/test_process_data_split.py`:
- Around line 41-44: Remove the duplicate assertion that repeats checking
(output_dir / "real_test.csv").exists(); locate the repeated line asserting
(output_dir / "real_test.csv").exists() and delete it so the test only asserts
real_train.csv, real_val.csv, and real_test.csv once each (or, if the intent was
to verify a different file, replace the duplicate with the correct filename
instead of duplicating the real_test.csv assertion).

---

Nitpick comments:
In `@examples/ensemble_attack/run_attack.py`:
- Around line 87-96: The JSON loading for EnsembleAttackTabDDPMTrainingConfig is
unprotected, so malformed or schema-mismatched JSON will raise unclear errors;
wrap the open/json.load/EnsembleAttackTabDDPMTrainingConfig(...) call in a
try-except that catches json.JSONDecodeError and TypeError/ValueError (or a
generic Exception) and re-raise or log a clear message including the config path
and the original exception, then only proceed to set fine_tuning_* fields and
instantiate EnsembleAttackTabDDPMModelRunner when parsing succeeds.

In `@examples/ensemble_attack/run_shadow_model_training.py`:
- Around line 61-72: The code mutates the shared
model_runner.training_config.general for different training phases (see
model_runner, training_config.general, save_additional_training_config,
train_or_fine_tune_and_synthesize) which is fragile; fix by creating and using a
fresh config copy for each phase instead of mutating the shared instance—e.g.,
deep-copy model_runner.training_config (or construct a new TrainingConfig from
save_additional_training_config's return) and assign that copy to a separate
ModelRunner or pass it into train_or_fine_tune_and_synthesize, leaving the
original model_runner unchanged; update run_shadow_model_training,
train_three_sets_of_shadow_models, and run_target_model_training to operate on
their own config copies.

In `@src/midst_toolkit/attacks/ensemble/models.py`:
- Around line 115-117: The load_tables call in train_or_fine_tune_and_synthesize
uses a hardcoded "trans" key which breaks datasets with different table names;
add a table_name field to EnsembleAttackTabDDPMTrainingConfig (mirroring
EnsembleAttackCTGANTrainingConfig) or derive the table name from the provided
dataset metadata, then replace the hardcoded "trans" with that value (e.g., use
self.training_config.table_name or dataset.metadata.table_name) in the
load_tables call and anywhere later (including the synthesis usage around the
previous line 182) so the code uses the configured/derived table name instead of
"trans".

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 8fdb8e7f-6866-45c8-bf8d-cb1f088f2238

📥 Commits

Reviewing files that changed from the base of the PR and between c6718eb and 94da62e.

📒 Files selected for processing (15)
  • .gitignore
  • examples/ensemble_attack/run_attack.py
  • examples/ensemble_attack/run_metaclassifier_training.py
  • examples/ensemble_attack/run_shadow_model_training.py
  • examples/ensemble_attack/test_attack_model.py
  • examples/gan/ensemble_attack/test_attack_model.py
  • examples/gan/ensemble_attack/train_attack_model.py
  • examples/gan/ensemble_attack/utils.py
  • src/midst_toolkit/attacks/ensemble/models.py
  • src/midst_toolkit/attacks/ensemble/process_split_data.py
  • src/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.py
  • src/midst_toolkit/attacks/ensemble/shadow_model_utils.py
  • tests/integration/attacks/ensemble/assets/data_configs/trans.json
  • tests/integration/attacks/ensemble/test_shadow_model_training.py
  • tests/unit/attacks/ensemble/test_process_data_split.py

Copy link
Collaborator

@emersodb emersodb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A really nice step in the right direction! I like the design you chose.

config.shadow_training.fine_tuning_config.fine_tune_classifier_iterations
)

model_runner = EnsembleAttackTabDDPMModelRunner(training_config=training_config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps you've already thought of this, but should the code above be part of the base for the ModelRunner? That is, should lines 87-94 actually happen inside that class rather than in the attack script here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would also slightly simplify the process of subbing out the model, since you would just need to sub the runner class instead of both the running and the config class? I might be missing a complexity though.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if I understood your idea, but I thought maybe if I pass the config dictionary to the init of the model runner class we would be able to skip making the config. Is that it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sort of. My thought was that you could simply have the EnsembleAttackTabDDPMModelRunner init take a path to the configuration file. Then you could load the file and do all of the steps to properly construct EnsembleAttackTabDDPMTrainingConfig object within the runner class? That way a user doesn't have to do that themselves.

It's possible I'm missing something where that would be a bad idea though 🙂

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know if my explanation of what I was trying to suggest isn't clear. We can talk about it together.

config.shadow_training.fine_tuning_config.fine_tune_classifier_iterations
)

model_runner = EnsembleAttackTabDDPMModelRunner(training_config=training_config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment here about config processing.

}
json.dump(training_config, f)

ctgan_training_config = EnsembleAttackCTGANTrainingConfig(**training_config) # type: ignore[arg-type]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the type ignore here just because of the ** magic?

Copy link
Collaborator Author

@lotif lotif Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this is because the training_config dictionary comes from the OmegaConf.to_container statement on line 49, which types it weirdly, and it doesn't go well with the typing pydantic is expecting. See error below.

examples/gan/ensemble_attack/utils.py:60: error: Argument after ** must be a
mapping, not
"dict[str | bytes | int | Enum | float | bool, Any] | dict[Any, Any]" 
[arg-type]
        ctgan_training_config = EnsembleAttackCTGANTrainingConfig(**traini...
                                ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~...
Found 1 error in 1 file (checked 111 source files)

dict[str | bytes | int | Enum | float | bool, Any] is the OmegaConf typing and dict[Any, Any] is pydantic's. It's a bit of a silly error, I felt it was easier to ignore it than to solve it. I can try to do a cast if you think it's better.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I see why you punted here. Hmmm...Cast is also the same as an ignore, but maybe it will help preserve downstream typing. So perhaps we can do the cast?


import pandas as pd
from pydantic import BaseModel, ConfigDict, Field
from sdv.metadata import SingleTableMetadata # type: ignore[import-untyped]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I already asked you this, but I forget the answer (maybe not I don't know 😂). Why do we need these type ignores?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The SDV framework does not have a py.typed file or an sdv-stubs package on pypi and mypy does not like that. Nothing we can do, really, maybe open a bug report with them.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we should leverage the "nuclear" option of adding it to the mypy.ini similar to, for example, catboost as:

[mypy-catboost.*]
ignore_missing_imports = True

raise NotImplementedError("Subclasses must implement this method.")


# TabDDPM/ClavaDDPM implementation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know why it's a bit convolved (because the ensemble method only works for single table, i.e. TabDDPM), but it's a bit weird to have a ClavaDDPM training config inside a TabDDPM Ensemble attack config, since technically TabDDPM is less general than ClavaDDPM. Should we just call everything a ClavaDDPM config as a catch all with the understanding that the attacks we implement are only single table?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about that. Looks like the attack only supports single table, so I think it would be even more confusing to have ClavaDDPM here, no? Or we can have it on the docstrings saying that even thought we are naming it ClavaDDPM, the attack only supports single table?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I know what you mean. If I were voting, I guess I would say call everything ClavaDDPM and have a docstring for the attack as being single table to avoid confusion. From a user perspective, it seems more clear to have everything be Clava but the attack is restricted, but I'm open to you disagreeing 🙂

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with David as well, mostly because the function names for training/fine-tuning the shadow models use ClavaDDPM. A consistent naming with an explanation makes it less confusing.

@lotif lotif requested a review from emersodb March 23, 2026 15:26
Copy link
Collaborator

@emersodb emersodb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes you made and most of the responses you gave make sense to me! Just a few final pieces I think.

Copy link
Collaborator

@sarakodeiri sarakodeiri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor suggestions and questions for my own learning. Thanks Marcelo!



# Base Classes
class EnsembleAttackTrainingConfig(TrainingConfig):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why this base class is needed and what is adds to the pipeline (flexibility?) that using the config file doesn't.

save_dir=save_dir,
synthesize=True,
)
train_result = model_runner.train_or_fine_tune_and_synthesize(dataset=df_real_data, synthesize=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a much needed change!

training_config.number_of_points_to_synthesize = number_of_points_to_synthesize
model_runner = EnsembleAttackCTGANModelRunner(training_config=training_config)

master_challenge_train = get_master_challenge_train_data(config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming you don't have a flag in the pipeline to control whether data collection should be done. Do you want to run it every time you're training a new meta classifier?

raise NotImplementedError("Subclasses must implement this method.")


# TabDDPM/ClavaDDPM implementation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with David as well, mostly because the function names for training/fine-tuning the shadow models use ClavaDDPM. A consistent naming with an explanation makes it less confusing.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fantastic refactoring, thank you!

# TODO: The following function is directly copied from the midst reference code since
# I need it to run the attack code, but, it should probably be moved to somewhere else
# as it is an essential part of a working TabDDPM training pipeline.
def setup_save_dir(configs: TrainingConfig) -> Path:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could also be moved to models.py imo

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants