feat: incorporate multi-device training scripts and README

This commit is contained in:
2026-05-09 21:35:42 +08:00
parent 02868ec01a
commit 75eacf00c2
6 changed files with 1082 additions and 10 deletions

View File

@@ -14,6 +14,12 @@ CKPT_DIR="${ARTIFACT_ROOT}/checkpoints/${TRAIN_NAME}"
CHECKPOINT_KEEP_RECENT=${CHECKPOINT_KEEP_RECENT:-3}
CHECKPOINT_CLEANUP_INTERVAL_SECONDS=${CHECKPOINT_CLEANUP_INTERVAL_SECONDS:-300}
EXTRA_ARGS=${EXTRA_ARGS:-}
NPROC_PER_NODE=${NPROC_PER_NODE:-8}
NNODES=${NNODES:-1}
NODE_RANK=${NODE_RANK:-0}
MASTER_ADDR=${MASTER_ADDR:-localhost}
MASTER_PORT=${MASTER_PORT:-6000}
ZERO_STAGE=${ZERO_STAGE:-0}
source params/optim_common.sh
source params/gpt_smoke/model.sh
@@ -46,6 +52,28 @@ PARALLEL_ARGS="
# --sequence-parallel
# "
case "$ZERO_STAGE" in
0)
ZERO_ARGS=""
;;
1)
ZERO_ARGS="
--use-distributed-optimizer
--data-parallel-sharding-strategy optim
"
;;
2)
ZERO_ARGS="
--use-distributed-optimizer
--data-parallel-sharding-strategy optim_grads
"
;;
*)
echo "Unsupported ZERO_STAGE=${ZERO_STAGE}; expected 0, 1, or 2" >&2
exit 1
;;
esac
mkdir -p "$CKPT_DIR" "$TB_DIR"
cleanup_old_checkpoints_once() {
@@ -107,17 +135,18 @@ CHECKPOINT_CLEANUP_PID=$!
trap 'kill "$CHECKPOINT_CLEANUP_PID" 2>/dev/null || true; cleanup_old_checkpoints_once "$CKPT_DIR" "$CHECKPOINT_KEEP_RECENT"' EXIT
DISTRIBUTED_ARGS="
--nproc_per_node 8
--nnodes 1
--node_rank 0
--master_addr localhost
--master_port 6000
--nproc_per_node ${NPROC_PER_NODE}
--nnodes ${NNODES}
--node_rank ${NODE_RANK}
--master_addr ${MASTER_ADDR}
--master_port ${MASTER_PORT}
"
torchrun $DISTRIBUTED_ARGS \
$MEGATRON_PATH/pretrain_gpt.py \
$MODEL_ARGS \
$OPTIM_ARGS \
$ZERO_ARGS \
$PRECISION_ARGS \
$PARALLEL_ARGS \
$DATA_ARGS \