diff --git a/tests/integration/gradient_accumulation_test.py b/tests/integration/gradient_accumulation_test.py index 468c7aced8..4a43fdc610 100644 --- a/tests/integration/gradient_accumulation_test.py +++ b/tests/integration/gradient_accumulation_test.py @@ -65,10 +65,11 @@ def test_grad_accumulate_same_loss(self): "gradient_clipping_threshold=0", # Ensures we are testing raw scales of gradients (clipping off) "enable_checkpointing=False", "enable_goodput_recording=False", + "decoder_block=simple", "base_emb_dim=256", "base_num_decoder_layers=4", rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}", - "steps=20", + "steps=2", ] # Run with gradient accumulation with accumulate_steps=10, per_device_batch=1 --> simulating per_device_batch=10 train_main(