Better support for target models on the ensemble attack#135
Better support for target models on the ensemble attack#135
Conversation
…celo/ensamble-ctgan
…n optional parameter to the config
📝 WalkthroughWalkthroughThe PR introduces a new abstraction model for ensemble attack training by creating an Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes 🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment Tip You can make CodeRabbit's review stricter and more nitpicky using the `assertive` profile, if that's what you prefer.Change the |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
tests/unit/attacks/ensemble/test_process_data_split.py (1)
41-44:⚠️ Potential issue | 🟡 MinorDuplicate assertion detected.
Line 44 duplicates the assertion from line 43 (both check
real_test.csv). This appears to be a copy-paste error and doesn't add test coverage.🔧 Suggested fix
# Assert that the split real data files are saved in the provided path assert (output_dir / "real_train.csv").exists() assert (output_dir / "real_val.csv").exists() assert (output_dir / "real_test.csv").exists() - assert (output_dir / "real_test.csv").exists()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unit/attacks/ensemble/test_process_data_split.py` around lines 41 - 44, Remove the duplicate assertion that repeats checking (output_dir / "real_test.csv").exists(); locate the repeated line asserting (output_dir / "real_test.csv").exists() and delete it so the test only asserts real_train.csv, real_val.csv, and real_test.csv once each (or, if the intent was to verify a different file, replace the duplicate with the correct filename instead of duplicating the real_test.csv assertion).examples/ensemble_attack/test_attack_model.py (1)
345-353:⚠️ Potential issue | 🟡 MinorTypo in variable names: "mataclassifier" should be "metaclassifier".
Lines 345-351 have consistent typos:
mataclassifier_pathandtrained_mataclassifier_model. While this doesn't affect functionality, it reduces code readability.📝 Suggested fix
- mataclassifier_path = Path(config.metaclassifier.metaclassifier_model_path) / f"{metaclassifier_model_name}.pkl" - assert mataclassifier_path.exists(), ( - f"No metaclassifier model found at {mataclassifier_path}. Make sure to run the training script first." + metaclassifier_path = Path(config.metaclassifier.metaclassifier_model_path) / f"{metaclassifier_model_name}.pkl" + assert metaclassifier_path.exists(), ( + f"No metaclassifier model found at {metaclassifier_path}. Make sure to run the training script first." ) - with open(mataclassifier_path, "rb") as f: - trained_mataclassifier_model = pickle.load(f) + with open(metaclassifier_path, "rb") as f: + trained_metaclassifier_model = pickle.load(f) - log(INFO, f"Metaclassifier model loaded from {mataclassifier_path}, starting the test...") + log(INFO, f"Metaclassifier model loaded from {metaclassifier_path}, starting the test...")Also update line 422:
- blending_attacker.trained_model = trained_mataclassifier_model + blending_attacker.trained_model = trained_metaclassifier_model🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/ensemble_attack/test_attack_model.py` around lines 345 - 353, Rename the misspelled variables in the test to use "metaclassifier" consistently: change mataclassifier_path to metaclassifier_path (the Path construction and the existence assertion) and change trained_mataclassifier_model to trained_metaclassifier_model (the pickle.load assignment and any subsequent uses, including the referenced occurrence around line 422); update all references so variable names match (e.g., in the open(...) block and the log message) to improve readability and avoid typos.src/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.py (1)
53-80:⚠️ Potential issue | 🟡 MinorDocstring references removed parameters.
Lines 67 and 73-74 document parameters
fine_tuning_configandnumber_of_points_to_synthesizethat no longer exist in the function signature after the refactor.📝 Suggested fix
training_json_config_paths: Configuration dictionary containing paths to the data JSON config files. An example of this config is provided in ``examples/ensemble_attack/config.yaml``. Required keys are: - table_domain_file_path (str): Path to the table domain json file. - dataset_meta_file_path (str): Path to dataset meta json file. - training_config_path (str): Path to table's training config json file. - fine_tuning_config: Configuration dictionary containing shadow model fine-tuning specific information. init_model_id: An ID to assign to the pre-trained initial models. This can be used to save multiple pre-trained models with different IDs. table_name: Name of the main table to be used for training the TabDDPM model. id_column_name: Name of the ID column in the data. pre_training_data_size: Size of the initial training set, defaults to 60,000. - number_of_points_to_synthesize: Size of the synthetic data to be generated by each shadow model, - defaults to 20,000. init_data_seed: Random seed for the initial training set. random_seed: Random seed used for reproducibility, defaults to None.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.py` around lines 53 - 80, The docstring for the shadow model training function in src/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.py still documents removed parameters `fine_tuning_config` and `number_of_points_to_synthesize`; update the function's docstring to match the current signature by removing any references to those parameters (or replacing them with the correct current parameter names), ensure the Args section lists only existing parameters such as `model_runner`, `n_models`, `n_reps`, `population_data`, `master_challenge_data`, `shadow_models_output_path`, `training_json_config_paths`, `init_model_id`, `table_name`, `id_column_name`, `pre_training_data_size`, `init_data_seed`, and `random_seed`, and adjust the description of Returns if needed so the docstring accurately reflects the function's current behavior.
🧹 Nitpick comments (3)
examples/ensemble_attack/run_attack.py (1)
87-96: Consider adding error handling for JSON config loading.The JSON is loaded directly into
EnsembleAttackTabDDPMTrainingConfigwithout validation. If the JSON structure doesn't match the expected schema, the error message may be unclear.💡 Optional: Add a try-except for clearer error messages
with open(config.shadow_training.training_json_config_paths.training_config_path, "r") as file: - training_config = EnsembleAttackTabDDPMTrainingConfig(**json.load(file)) + try: + training_config = EnsembleAttackTabDDPMTrainingConfig(**json.load(file)) + except (TypeError, ValueError) as e: + raise ValueError( + f"Failed to parse training config from " + f"{config.shadow_training.training_json_config_paths.training_config_path}: {e}" + ) from e🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/ensemble_attack/run_attack.py` around lines 87 - 96, The JSON loading for EnsembleAttackTabDDPMTrainingConfig is unprotected, so malformed or schema-mismatched JSON will raise unclear errors; wrap the open/json.load/EnsembleAttackTabDDPMTrainingConfig(...) call in a try-except that catches json.JSONDecodeError and TypeError/ValueError (or a generic Exception) and re-raise or log a clear message including the config path and the original exception, then only proceed to set fine_tuning_* fields and instantiate EnsembleAttackTabDDPMModelRunner when parsing succeeds.src/midst_toolkit/attacks/ensemble/models.py (1)
115-117: Hardcoded table name "trans" should be parameterized.The
load_tablescall uses a hardcoded"trans"key in thetrain_datadictionary. This assumes the table is always named "trans", but other datasets may use different table names.Consider either:
- Adding a
table_namefield toEnsembleAttackTabDDPMTrainingConfig(similar to howEnsembleAttackCTGANTrainingConfighas it)- Deriving the table name from the dataset metadata
♻️ Suggested approach
+class EnsembleAttackTabDDPMTrainingConfig(ClavaDDPMTrainingConfig, EnsembleAttackTrainingConfig): + fine_tuning_diffusion_iterations: int = 100 + fine_tuning_classifier_iterations: int = 10 + table_name: str = "trans" # Default for backward compatibilityThen in
train_or_fine_tune_and_synthesize:- tables, relation_order, _ = load_tables(self.training_config.general.data_dir, train_data={"trans": dataset}) + tables, relation_order, _ = load_tables( + self.training_config.general.data_dir, + train_data={self.training_config.table_name: dataset} + )And update synthesis line 182:
- result.synthetic_data = cleaned_tables["trans"] + result.synthetic_data = cleaned_tables[self.training_config.table_name]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/midst_toolkit/attacks/ensemble/models.py` around lines 115 - 117, The load_tables call in train_or_fine_tune_and_synthesize uses a hardcoded "trans" key which breaks datasets with different table names; add a table_name field to EnsembleAttackTabDDPMTrainingConfig (mirroring EnsembleAttackCTGANTrainingConfig) or derive the table name from the provided dataset metadata, then replace the hardcoded "trans" with that value (e.g., use self.training_config.table_name or dataset.metadata.table_name) in the load_tables call and anywhere later (including the synthesis usage around the previous line 182) so the code uses the configured/derived table name instead of "trans".examples/ensemble_attack/run_shadow_model_training.py (1)
61-72: Sharedmodel_runnerstate is mutated by multiple callers.Based on the context snippets, the same
model_runnerinstance flows throughrun_shadow_model_training→train_three_sets_of_shadow_models→ multiple shadow training functions, each of which overwritesmodel_runner.training_config.general.*fields. Then the same instance is passed torun_target_model_training, which resets these fields here (lines 68-70).While this appears to work correctly because the reset happens before the target model uses the config, this pattern is fragile. If the call order changes or a caller forgets to reset, stale paths could cause models to be saved to unexpected locations.
Consider either:
- Documenting this behavior explicitly in the function docstring
- Creating a fresh config copy for each training phase instead of mutating the shared instance
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/ensemble_attack/run_shadow_model_training.py` around lines 61 - 72, The code mutates the shared model_runner.training_config.general for different training phases (see model_runner, training_config.general, save_additional_training_config, train_or_fine_tune_and_synthesize) which is fragile; fix by creating and using a fresh config copy for each phase instead of mutating the shared instance—e.g., deep-copy model_runner.training_config (or construct a new TrainingConfig from save_additional_training_config's return) and assign that copy to a separate ModelRunner or pass it into train_or_fine_tune_and_synthesize, leaving the original model_runner unchanged; update run_shadow_model_training, train_three_sets_of_shadow_models, and run_target_model_training to operate on their own config copies.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/gan/ensemble_attack/utils.py`:
- Around line 32-43: The docstring for function make_training_config contains a
typo "attacktraining"; update the docstring text to read "attack training" (in
the description line and any other occurrences within the function's docstring)
so it clearly states "Make the ensemble attack training config for the CTGAN
model..." and preserves existing punctuation and formatting.
In `@src/midst_toolkit/attacks/ensemble/shadow_model_utils.py`:
- Around line 19-34: Update the docstring for the function that modifies and
loads training configurations (the function taking config_type, data_dir,
training_config_json_path, final_config_json_path, experiment_name,
workspace_name) to remove the specific "TabDDPM" mention and describe it as a
generic modifier/loader for any EnsembleAttackTrainingConfig subclass;
explicitly state that config_type accepts an EnsembleAttackTrainingConfig
subclass, and ensure the Args and Returns sections reflect the generic behavior
and returned configs/save_dir values instead of TabDDPM-specific language.
---
Outside diff comments:
In `@examples/ensemble_attack/test_attack_model.py`:
- Around line 345-353: Rename the misspelled variables in the test to use
"metaclassifier" consistently: change mataclassifier_path to metaclassifier_path
(the Path construction and the existence assertion) and change
trained_mataclassifier_model to trained_metaclassifier_model (the pickle.load
assignment and any subsequent uses, including the referenced occurrence around
line 422); update all references so variable names match (e.g., in the open(...)
block and the log message) to improve readability and avoid typos.
In `@src/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.py`:
- Around line 53-80: The docstring for the shadow model training function in
src/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.py still documents
removed parameters `fine_tuning_config` and `number_of_points_to_synthesize`;
update the function's docstring to match the current signature by removing any
references to those parameters (or replacing them with the correct current
parameter names), ensure the Args section lists only existing parameters such as
`model_runner`, `n_models`, `n_reps`, `population_data`,
`master_challenge_data`, `shadow_models_output_path`,
`training_json_config_paths`, `init_model_id`, `table_name`, `id_column_name`,
`pre_training_data_size`, `init_data_seed`, and `random_seed`, and adjust the
description of Returns if needed so the docstring accurately reflects the
function's current behavior.
In `@tests/unit/attacks/ensemble/test_process_data_split.py`:
- Around line 41-44: Remove the duplicate assertion that repeats checking
(output_dir / "real_test.csv").exists(); locate the repeated line asserting
(output_dir / "real_test.csv").exists() and delete it so the test only asserts
real_train.csv, real_val.csv, and real_test.csv once each (or, if the intent was
to verify a different file, replace the duplicate with the correct filename
instead of duplicating the real_test.csv assertion).
---
Nitpick comments:
In `@examples/ensemble_attack/run_attack.py`:
- Around line 87-96: The JSON loading for EnsembleAttackTabDDPMTrainingConfig is
unprotected, so malformed or schema-mismatched JSON will raise unclear errors;
wrap the open/json.load/EnsembleAttackTabDDPMTrainingConfig(...) call in a
try-except that catches json.JSONDecodeError and TypeError/ValueError (or a
generic Exception) and re-raise or log a clear message including the config path
and the original exception, then only proceed to set fine_tuning_* fields and
instantiate EnsembleAttackTabDDPMModelRunner when parsing succeeds.
In `@examples/ensemble_attack/run_shadow_model_training.py`:
- Around line 61-72: The code mutates the shared
model_runner.training_config.general for different training phases (see
model_runner, training_config.general, save_additional_training_config,
train_or_fine_tune_and_synthesize) which is fragile; fix by creating and using a
fresh config copy for each phase instead of mutating the shared instance—e.g.,
deep-copy model_runner.training_config (or construct a new TrainingConfig from
save_additional_training_config's return) and assign that copy to a separate
ModelRunner or pass it into train_or_fine_tune_and_synthesize, leaving the
original model_runner unchanged; update run_shadow_model_training,
train_three_sets_of_shadow_models, and run_target_model_training to operate on
their own config copies.
In `@src/midst_toolkit/attacks/ensemble/models.py`:
- Around line 115-117: The load_tables call in train_or_fine_tune_and_synthesize
uses a hardcoded "trans" key which breaks datasets with different table names;
add a table_name field to EnsembleAttackTabDDPMTrainingConfig (mirroring
EnsembleAttackCTGANTrainingConfig) or derive the table name from the provided
dataset metadata, then replace the hardcoded "trans" with that value (e.g., use
self.training_config.table_name or dataset.metadata.table_name) in the
load_tables call and anywhere later (including the synthesis usage around the
previous line 182) so the code uses the configured/derived table name instead of
"trans".
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 8fdb8e7f-6866-45c8-bf8d-cb1f088f2238
📒 Files selected for processing (15)
.gitignoreexamples/ensemble_attack/run_attack.pyexamples/ensemble_attack/run_metaclassifier_training.pyexamples/ensemble_attack/run_shadow_model_training.pyexamples/ensemble_attack/test_attack_model.pyexamples/gan/ensemble_attack/test_attack_model.pyexamples/gan/ensemble_attack/train_attack_model.pyexamples/gan/ensemble_attack/utils.pysrc/midst_toolkit/attacks/ensemble/models.pysrc/midst_toolkit/attacks/ensemble/process_split_data.pysrc/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.pysrc/midst_toolkit/attacks/ensemble/shadow_model_utils.pytests/integration/attacks/ensemble/assets/data_configs/trans.jsontests/integration/attacks/ensemble/test_shadow_model_training.pytests/unit/attacks/ensemble/test_process_data_split.py
emersodb
left a comment
There was a problem hiding this comment.
A really nice step in the right direction! I like the design you chose.
| config.shadow_training.fine_tuning_config.fine_tune_classifier_iterations | ||
| ) | ||
|
|
||
| model_runner = EnsembleAttackTabDDPMModelRunner(training_config=training_config) |
There was a problem hiding this comment.
Perhaps you've already thought of this, but should the code above be part of the base for the ModelRunner? That is, should lines 87-94 actually happen inside that class rather than in the attack script here?
There was a problem hiding this comment.
This would also slightly simplify the process of subbing out the model, since you would just need to sub the runner class instead of both the running and the config class? I might be missing a complexity though.
There was a problem hiding this comment.
Not sure if I understood your idea, but I thought maybe if I pass the config dictionary to the init of the model runner class we would be able to skip making the config. Is that it?
There was a problem hiding this comment.
Sort of. My thought was that you could simply have the EnsembleAttackTabDDPMModelRunner init take a path to the configuration file. Then you could load the file and do all of the steps to properly construct EnsembleAttackTabDDPMTrainingConfig object within the runner class? That way a user doesn't have to do that themselves.
It's possible I'm missing something where that would be a bad idea though 🙂
There was a problem hiding this comment.
Let me know if my explanation of what I was trying to suggest isn't clear. We can talk about it together.
| config.shadow_training.fine_tuning_config.fine_tune_classifier_iterations | ||
| ) | ||
|
|
||
| model_runner = EnsembleAttackTabDDPMModelRunner(training_config=training_config) |
There was a problem hiding this comment.
Similar comment here about config processing.
| } | ||
| json.dump(training_config, f) | ||
|
|
||
| ctgan_training_config = EnsembleAttackCTGANTrainingConfig(**training_config) # type: ignore[arg-type] |
There was a problem hiding this comment.
Is the type ignore here just because of the ** magic?
There was a problem hiding this comment.
No, this is because the training_config dictionary comes from the OmegaConf.to_container statement on line 49, which types it weirdly, and it doesn't go well with the typing pydantic is expecting. See error below.
examples/gan/ensemble_attack/utils.py:60: error: Argument after ** must be a
mapping, not
"dict[str | bytes | int | Enum | float | bool, Any] | dict[Any, Any]"
[arg-type]
ctgan_training_config = EnsembleAttackCTGANTrainingConfig(**traini...
^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~...
Found 1 error in 1 file (checked 111 source files)
dict[str | bytes | int | Enum | float | bool, Any] is the OmegaConf typing and dict[Any, Any] is pydantic's. It's a bit of a silly error, I felt it was easier to ignore it than to solve it. I can try to do a cast if you think it's better.
There was a problem hiding this comment.
Yeah I see why you punted here. Hmmm...Cast is also the same as an ignore, but maybe it will help preserve downstream typing. So perhaps we can do the cast?
|
|
||
| import pandas as pd | ||
| from pydantic import BaseModel, ConfigDict, Field | ||
| from sdv.metadata import SingleTableMetadata # type: ignore[import-untyped] |
There was a problem hiding this comment.
I think I already asked you this, but I forget the answer (maybe not I don't know 😂). Why do we need these type ignores?
There was a problem hiding this comment.
The SDV framework does not have a py.typed file or an sdv-stubs package on pypi and mypy does not like that. Nothing we can do, really, maybe open a bug report with them.
There was a problem hiding this comment.
Do you think we should leverage the "nuclear" option of adding it to the mypy.ini similar to, for example, catboost as:
[mypy-catboost.*]
ignore_missing_imports = True
| raise NotImplementedError("Subclasses must implement this method.") | ||
|
|
||
|
|
||
| # TabDDPM/ClavaDDPM implementation |
There was a problem hiding this comment.
I know why it's a bit convolved (because the ensemble method only works for single table, i.e. TabDDPM), but it's a bit weird to have a ClavaDDPM training config inside a TabDDPM Ensemble attack config, since technically TabDDPM is less general than ClavaDDPM. Should we just call everything a ClavaDDPM config as a catch all with the understanding that the attacks we implement are only single table?
There was a problem hiding this comment.
Not sure about that. Looks like the attack only supports single table, so I think it would be even more confusing to have ClavaDDPM here, no? Or we can have it on the docstrings saying that even thought we are naming it ClavaDDPM, the attack only supports single table?
There was a problem hiding this comment.
Yeah I know what you mean. If I were voting, I guess I would say call everything ClavaDDPM and have a docstring for the attack as being single table to avoid confusion. From a user perspective, it seems more clear to have everything be Clava but the attack is restricted, but I'm open to you disagreeing 🙂
There was a problem hiding this comment.
I agree with David as well, mostly because the function names for training/fine-tuning the shadow models use ClavaDDPM. A consistent naming with an explanation makes it less confusing.
emersodb
left a comment
There was a problem hiding this comment.
The changes you made and most of the responses you gave make sense to me! Just a few final pieces I think.
sarakodeiri
left a comment
There was a problem hiding this comment.
Minor suggestions and questions for my own learning. Thanks Marcelo!
|
|
||
|
|
||
| # Base Classes | ||
| class EnsembleAttackTrainingConfig(TrainingConfig): |
There was a problem hiding this comment.
I don't understand why this base class is needed and what is adds to the pipeline (flexibility?) that using the config file doesn't.
| save_dir=save_dir, | ||
| synthesize=True, | ||
| ) | ||
| train_result = model_runner.train_or_fine_tune_and_synthesize(dataset=df_real_data, synthesize=True) |
There was a problem hiding this comment.
This is a much needed change!
| training_config.number_of_points_to_synthesize = number_of_points_to_synthesize | ||
| model_runner = EnsembleAttackCTGANModelRunner(training_config=training_config) | ||
|
|
||
| master_challenge_train = get_master_challenge_train_data(config) |
There was a problem hiding this comment.
I'm assuming you don't have a flag in the pipeline to control whether data collection should be done. Do you want to run it every time you're training a new meta classifier?
| raise NotImplementedError("Subclasses must implement this method.") | ||
|
|
||
|
|
||
| # TabDDPM/ClavaDDPM implementation |
There was a problem hiding this comment.
I agree with David as well, mostly because the function names for training/fine-tuning the shadow models use ClavaDDPM. A consistent naming with an explanation makes it less confusing.
There was a problem hiding this comment.
Fantastic refactoring, thank you!
| # TODO: The following function is directly copied from the midst reference code since | ||
| # I need it to run the attack code, but, it should probably be moved to somewhere else | ||
| # as it is an essential part of a working TabDDPM training pipeline. | ||
| def setup_save_dir(configs: TrainingConfig) -> Path: |
There was a problem hiding this comment.
This could also be moved to models.py imo
PR Type
Feature
Short Description
Clickup Ticket(s): https://app.clickup.com/t/868h6nmfc
Refactoring the ensemble attack to allow for more flexibility when setting the target model.
The way I went about this was to use inheritance, with the introduction of the following abstract classes on the
midst_toolkit/attacks/ensemble/models.py:EnsembleAttackModelRunner: is the main class responsible for running the model (training, fine tuning and synthesizing). It will also store the configs.train_or_fine_tune_and_synthesizemethod is the main method any models that we run the ensemble attack against need to implement.EnsembleAttackTrainingConfig: Will store the configs for training, synthesizing and fine-tuning. Inherits frommidst_toolkit.common.config.TrainingConfig.EnsembleAttackTrainingResult: Stores the result of the ensemble attack model training.Each one of those classes have their respective TabDDPM and ClavaDDPM counterparts, with their implementations moved from the utils file into the
train_or_fine_tune_and_synthesizefunctions.Most of the other changes are just general low hanging fruit refactorings.
Tests Added
Just fixing the currently existing tests