Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -724,8 +724,6 @@ def test_golden_values_thd(self, te_attn_backend):

if te_attn_backend == "fused_attn" and torch.cuda.get_device_capability()[0] == 8:
pytest.xfail("On Ada and Ampere, no THD implementation is available for fused attn.")
elif te_attn_backend == "fused_attn" and torch.cuda.get_device_capability()[0] == 12:
pytest.xfail("BIONEMO-2840: On sm120, the THD implementation is not available for fused attn.")

input_data_bshd = self.get_test_input_data(format="bshd")
input_data_thd = self.get_test_input_data(format="thd")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,6 @@ def test_golden_values_thd(self, te_attn_backend):

if te_attn_backend == "fused_attn" and torch.cuda.get_device_capability()[0] == 8:
pytest.xfail("On Ada and Ampere, no THD implementation is available for fused attn.")
elif te_attn_backend == "fused_attn" and torch.cuda.get_device_capability()[0] == 12:
pytest.xfail("BIONEMO-2840: On sm120, the THD implementation is not available for fused attn.")

golden_dir = Path(__file__).parent
golden_sd_path = golden_dir / "golden_state_dict.safetensors"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -718,8 +718,6 @@ def test_golden_values_thd(self, te_attn_backend):

if te_attn_backend == "fused_attn" and torch.cuda.get_device_capability()[0] == 8:
pytest.xfail("On Ada and Ampere, no THD implementation is available for fused attn.")
elif te_attn_backend == "fused_attn" and torch.cuda.get_device_capability()[0] == 12:
pytest.xfail("BIONEMO-2840: On sm120, the THD implementation is not available for fused attn.")

input_data_bshd = self.get_test_input_data(format="bshd")
input_data_thd = self.get_test_input_data(format="thd")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -724,8 +724,6 @@ def test_golden_values_thd(self, te_attn_backend):

if te_attn_backend == "fused_attn" and torch.cuda.get_device_capability()[0] == 8:
pytest.xfail("On Ada and Ampere, no THD implementation is available for fused attn.")
elif te_attn_backend == "fused_attn" and torch.cuda.get_device_capability()[0] == 12:
pytest.xfail("BIONEMO-2840: On sm120, the THD implementation is not available for fused attn.")

input_data_bshd = self.get_test_input_data(format="bshd")
input_data_thd = self.get_test_input_data(format="thd")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -724,8 +724,6 @@ def test_golden_values_thd(self, te_attn_backend):

if te_attn_backend == "fused_attn" and torch.cuda.get_device_capability()[0] == 8:
pytest.xfail("On Ada and Ampere, no THD implementation is available for fused attn.")
elif te_attn_backend == "fused_attn" and torch.cuda.get_device_capability()[0] == 12:
pytest.xfail("BIONEMO-2840: On sm120, the THD implementation is not available for fused attn.")

input_data_bshd = self.get_test_input_data(format="bshd")
input_data_thd = self.get_test_input_data(format="thd")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -724,8 +724,6 @@ def test_golden_values_thd(self, te_attn_backend):

if te_attn_backend == "fused_attn" and torch.cuda.get_device_capability()[0] == 8:
pytest.xfail("On Ada and Ampere, no THD implementation is available for fused attn.")
elif te_attn_backend == "fused_attn" and torch.cuda.get_device_capability()[0] == 12:
pytest.xfail("BIONEMO-2840: On sm120, the THD implementation is not available for fused attn.")

input_data_bshd = self.get_test_input_data(format="bshd")
input_data_thd = self.get_test_input_data(format="thd")
Expand Down
61 changes: 9 additions & 52 deletions bionemo-recipes/recipes/esm2_native_te/tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,12 +364,8 @@ def test_sanity_convergence_fsdp2_fp8_and_model_init(tmp_path, recipe_path):
assert final_loss < 3.0, f"Final loss {final_loss} is too high"


def test_sanity_convergence_fsdp2_thd(tmp_path, monkeypatch, recipe_path):
def test_sanity_convergence_fsdp2_thd(tmp_path, recipe_path):
"""For FSDP2, we check that the script can run successfully with FP8 and check convergence."""
if torch.cuda.get_device_capability() == (12, 0):
# TODO(BIONEMO-2840): On sm120, we need to set NVTE_FUSED_ATTN to 0 since TE will choose fused attn by default,
# but it's missing this THD implementation.
monkeypatch.setenv("NVTE_FUSED_ATTN", "0")

with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
sanity_config = compose(
Expand All @@ -386,12 +382,8 @@ def test_sanity_convergence_fsdp2_thd(tmp_path, monkeypatch, recipe_path):


@requires_fp8
def test_sanity_convergence_fsdp2_thd_fp8(tmp_path, monkeypatch, recipe_path):
def test_sanity_convergence_fsdp2_thd_fp8(tmp_path, recipe_path):
"""For FSDP2, we check that the script can run successfully with THD + FP8 and check convergence."""
if torch.cuda.get_device_capability() == (12, 0):
# TODO(BIONEMO-2840): On sm120, we need to set NVTE_FUSED_ATTN to 0 since TE will choose fused attn by default,
# but it's missing this THD implementation.
monkeypatch.setenv("NVTE_FUSED_ATTN", "0")

with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
sanity_config = compose(
Expand All @@ -408,12 +400,7 @@ def test_sanity_convergence_fsdp2_thd_fp8(tmp_path, monkeypatch, recipe_path):
assert final_loss < 3.0, f"Final loss {final_loss} is too high"


def test_sanity_ddp_thd(tmp_path, monkeypatch, recipe_path):
if torch.cuda.get_device_capability() == (12, 0):
# TODO(BIONEMO-2840): On sm120, we need to set NVTE_FUSED_ATTN to 0 since TE will choose fused attn by default,
# but it's missing this THD implementation.
monkeypatch.setenv("NVTE_FUSED_ATTN", "0")

def test_sanity_ddp_thd(tmp_path, recipe_path):
# For DDP, we only check that the script can run successfully with THD, not convergence.
with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
sanity_config = compose(
Expand All @@ -429,12 +416,7 @@ def test_sanity_ddp_thd(tmp_path, monkeypatch, recipe_path):
main_ddp(sanity_config)


def test_sanity_mfsdp_thd(tmp_path, monkeypatch, recipe_path):
if torch.cuda.get_device_capability() == (12, 0):
# TODO(BIONEMO-2840): On sm120, we need to set NVTE_FUSED_ATTN to 0 since TE will choose fused attn by default,
# but it's missing this THD implementation.
monkeypatch.setenv("NVTE_FUSED_ATTN", "0")

def test_sanity_mfsdp_thd(tmp_path, recipe_path):
# For MFSDP, we only check that the script can run successfully with THD, not convergence.
with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
sanity_config = compose(
Expand All @@ -451,12 +433,7 @@ def test_sanity_mfsdp_thd(tmp_path, monkeypatch, recipe_path):


@requires_fp8
def test_sanity_ddp_thd_fp8(tmp_path, monkeypatch, recipe_path):
if torch.cuda.get_device_capability() == (12, 0):
# TODO(BIONEMO-2840): On sm120, we need to set NVTE_FUSED_ATTN to 0 since TE will choose fused attn by default,
# but it's missing this THD implementation.
monkeypatch.setenv("NVTE_FUSED_ATTN", "0")

def test_sanity_ddp_thd_fp8(tmp_path, recipe_path):
# For DDP, we only check that the script can run successfully with THD, not convergence.
with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
sanity_config = compose(
Expand All @@ -474,12 +451,7 @@ def test_sanity_ddp_thd_fp8(tmp_path, monkeypatch, recipe_path):


@requires_fp8
def test_sanity_mfsdp_thd_fp8(tmp_path, monkeypatch, recipe_path):
if torch.cuda.get_device_capability() == (12, 0):
# TODO(BIONEMO-2840): On sm120, we need to set NVTE_FUSED_ATTN to 0 since TE will choose fused attn by default,
# but it's missing this THD implementation.
monkeypatch.setenv("NVTE_FUSED_ATTN", "0")

def test_sanity_mfsdp_thd_fp8(tmp_path, recipe_path):
# For MFSDP, we only check that the script can run successfully with THD, not convergence.
with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
sanity_config = compose(
Expand Down Expand Up @@ -571,12 +543,7 @@ def test_sanity_convergence_fsdp2_huggingface_model(tmp_path, recipe_path):
assert final_loss < 3.0, f"Final loss {final_loss} is too high"


def test_sanity_ddp_thd_token_packing(tmp_path, monkeypatch, recipe_path):
if torch.cuda.get_device_capability() == (12, 0):
# TODO(BIONEMO-2840): On sm120, we need to set NVTE_FUSED_ATTN to 0 since TE will choose fused attn by default,
# but it's missing this THD implementation.
monkeypatch.setenv("NVTE_FUSED_ATTN", "0")

def test_sanity_ddp_thd_token_packing(tmp_path, recipe_path):
# For DDP, we only check that the script can run successfully with THD, not convergence.
with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
sanity_config = compose(
Expand All @@ -592,12 +559,7 @@ def test_sanity_ddp_thd_token_packing(tmp_path, monkeypatch, recipe_path):
main_ddp(sanity_config)


def test_sanity_mfsdp_thd_token_packing(tmp_path, monkeypatch, recipe_path):
if torch.cuda.get_device_capability() == (12, 0):
# TODO(BIONEMO-2840): On sm120, we need to set NVTE_FUSED_ATTN to 0 since TE will choose fused attn by default,
# but it's missing this THD implementation.
monkeypatch.setenv("NVTE_FUSED_ATTN", "0")

def test_sanity_mfsdp_thd_token_packing(tmp_path, recipe_path):
with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
sanity_config = compose(
config_name="L0_sanity",
Expand All @@ -612,12 +574,7 @@ def test_sanity_mfsdp_thd_token_packing(tmp_path, monkeypatch, recipe_path):
main_mfsdp(sanity_config)


def test_sanity_fsdp2_thd_token_packing(tmp_path, monkeypatch, recipe_path):
if torch.cuda.get_device_capability() == (12, 0):
# TODO(BIONEMO-2840): On sm120, we need to set NVTE_FUSED_ATTN to 0 since TE will choose fused attn by default,
# but it's missing this THD implementation.
monkeypatch.setenv("NVTE_FUSED_ATTN", "0")

def test_sanity_fsdp2_thd_token_packing(tmp_path, recipe_path):
with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
sanity_config = compose(
config_name="L0_sanity",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from hydra import compose, initialize_config_dir

from train_lora_ddp import main as main_ddp
Expand Down Expand Up @@ -54,12 +53,7 @@ def test_sanity_convergence_ddp_non_streaming_dataset(tmp_path, recipe_path):
assert final_loss < 3.0, f"Final loss {final_loss} is too high"


def test_sanity_ddp_thd(tmp_path, monkeypatch, recipe_path):
if torch.cuda.get_device_capability() == (12, 0):
# TODO(BIONEMO-2840): On sm120, we need to set NVTE_FUSED_ATTN to 0 since TE will choose fused attn by default,
# but it's missing this THD implementation.
monkeypatch.setenv("NVTE_FUSED_ATTN", "0")

def test_sanity_ddp_thd(tmp_path, recipe_path):
# For DDP, we only check that the script can run successfully with THD, not convergence.
with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
sanity_config = compose(
Expand Down
Loading