Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 183 additions & 0 deletions ci/lepton/model_convergence/configs/recipes/llama3_native_te_1b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# @package _global_
defaults:
- /base
- _self_

############################################################
# lepton job info
############################################################
node_group: yo-bom-lepton-001
mount_from: node-nfs:fs1
num_nodes: 1
device_type: gpu
num_devices: 8
gpu_type: h100-sxm
resource_shape: "${device_type}.${num_devices}x${gpu_type}"

############################################################
# kratos info: where to log data
############################################################
kratos_subject: "convergence_tests_v0.0.3"

############################################################
# recipe identifiers
# mostly used for logging and observability
############################################################
recipe_subdir: llama3_native_te
model_type: llama3
variant: train

# Core identifiers for filtering
framework: native
precision: bf16
te_enabled: true
fp8_enabled: false
fp8_recipe: ""
fp8_format: ""
cp_enabled: false
thd_enabled: false

# Catchall for additional features/configs
extras: []

############################################################
# wandb info (total_gpus used for group name)
############################################################
total_gpus: ${multiply:${num_devices},${num_nodes}}

wandb_init_args:
project: "test_convergence__recipes__${sanitize:${branch}}"
group: "${model_type}__${task_cmd}__${total_gpus}gpus__${sanitize:${gpu_type}}"
job_type: "${recipe_subdir}"
name: null

############################################################
# task commands
# shared across all products (if not explicitly overridden)
# Matches L2_lingua_1b.yaml defaults
############################################################
config: L2_lingua_1b
task_cmd: train_fsdp2

# Training parameters
num_train_steps: 10_000
use_torch_compile: false
use_meta_device: true
use_sequence_packing: true

# Dataset parameters (from L2_lingua_1b)
micro_batch_size: 4
max_seq_length: 4096
num_workers: 8
stride: 512
buffer_size: 50_000

# Optimizer (from L2_lingua_1b)
lr: 0.003
weight_decay: 0.033

# LR scheduler
num_warmup_steps: 1_000
num_decay_steps: 9_000

# Checkpoint controls
ckpt_dir: ""
save_final_model: false
resume_from_checkpoint: false
save_every_n_steps: 10_000

# Context parallelism
cp_size: 1

############################################################
# Each product is a different config to run, alongside
# config-specific arguments. Must have a `wandb_name`.
############################################################
products:
# Lingua 1B baseline - FSDP2 with THD (sequence packing)
- config: L2_lingua_1b
task_cmd: train_fsdp2
thd_enabled: true
use_sequence_packing: true
fp8_enabled: false
cp_enabled: false
wandb_name: "llama3_lingua_1b__fsdp2__thd__${now:%Y%m%d-%H%M%S}__${gitsha:}"
job_name: "llama3-lingua-1b-fsdp2-thd"

# Lingua 1B - FSDP2 with FP8 + THD
- config: L2_lingua_1b
task_cmd: train_fsdp2
thd_enabled: true
use_sequence_packing: true
fp8_enabled: true
fp8_recipe: transformer_engine.common.recipe.DelayedScaling
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use Float8BlockScaling

fp8_format: HYBRID
cp_enabled: false
wandb_name: "llama3_lingua_1b__fsdp2__thd__fp8__${now:%Y%m%d-%H%M%S}__${gitsha:}"
job_name: "llama3-lingua-1b-fsdp2-thd-fp8"

# Lingua 1B - FSDP2 with Context Parallelism
- config: L2_lingua_1b
task_cmd: train_fsdp2_cp
thd_enabled: false
use_sequence_packing: false
fp8_enabled: false
cp_enabled: true
cp_size: 2
wandb_name: "llama3_lingua_1b__fsdp2__cp__${now:%Y%m%d-%H%M%S}__${gitsha:}"
job_name: "llama3-lingua-1b-fsdp2-cp"

# Lingua 1B - FSDP2 with Context Parallelism + FP8
- config: L2_lingua_1b
task_cmd: train_fsdp2_cp
thd_enabled: false
use_sequence_packing: false
fp8_enabled: true
fp8_recipe: transformer_engine.common.recipe.DelayedScaling
fp8_format: HYBRID
cp_enabled: true
cp_size: 2
wandb_name: "llama3_lingua_1b__fsdp2__cp__fp8__${now:%Y%m%d-%H%M%S}__${gitsha:}"
job_name: "llama3-lingua-1b-fsdp2-cp-fp8"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cp runs should probably use sequence packing, but we need to be careful to use cudnn>=9.18


############################################################
# run script
# This gets called right after `checkout_script` in the base config.
############################################################
run_script: |
wget -O init.sh https://raw.githubusercontent.com/leptonai/scripts/main/lepton_env_to_pytorch.sh;
chmod +x init.sh;
source init.sh;

HYDRA_FULL_ERROR=1 torchrun \
--nnodes=$NNODES \
--nproc_per_node=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader | wc -l) \
--node_rank=$NODE_RANK \
--master_addr=$MASTER_ADDR \
--master_port=$MASTER_PORT \
${task_cmd}.py \
--config-name ${config}.yaml \
+wandb.mode=${wandb_init_args.mode} \
+wandb.project=${wandb_init_args.project} \
+wandb.name=${wandb_name} \
num_train_steps=${num_train_steps} \
use_torch_compile=${use_torch_compile} \
use_meta_device=${use_meta_device} \
use_sequence_packing=${use_sequence_packing} \
cp_size=${cp_size} \
dataset.micro_batch_size=${micro_batch_size} \
dataset.max_seq_length=${max_seq_length} \
dataset.num_workers=${num_workers} \
dataset.stride=${stride} \
dataset.buffer_size=${buffer_size} \
adamw_kwargs.lr=${lr} \
adamw_kwargs.weight_decay=${weight_decay} \
lr_scheduler_kwargs.num_warmup_steps=${num_warmup_steps} \
lr_scheduler_kwargs.num_decay_steps=${num_decay_steps} \
checkpoint.ckpt_dir=${ckpt_dir} \
checkpoint.save_final_model=${save_final_model} \
checkpoint.resume_from_checkpoint=${resume_from_checkpoint} \
checkpoint.save_every_n_steps=${save_every_n_steps} \
fp8_config.enabled=${fp8_enabled} \
fp8_config.fp8_recipe=${fp8_recipe} \
fp8_config.fp8_format=${fp8_format}