feat: incorporate multi-device training scripts and README
This commit is contained in:
@@ -14,6 +14,12 @@ CKPT_DIR="${ARTIFACT_ROOT}/checkpoints/${TRAIN_NAME}"
|
||||
CHECKPOINT_KEEP_RECENT=${CHECKPOINT_KEEP_RECENT:-3}
|
||||
CHECKPOINT_CLEANUP_INTERVAL_SECONDS=${CHECKPOINT_CLEANUP_INTERVAL_SECONDS:-300}
|
||||
EXTRA_ARGS=${EXTRA_ARGS:-}
|
||||
NPROC_PER_NODE=${NPROC_PER_NODE:-8}
|
||||
NNODES=${NNODES:-1}
|
||||
NODE_RANK=${NODE_RANK:-0}
|
||||
MASTER_ADDR=${MASTER_ADDR:-localhost}
|
||||
MASTER_PORT=${MASTER_PORT:-6000}
|
||||
ZERO_STAGE=${ZERO_STAGE:-0}
|
||||
|
||||
source params/optim_common.sh
|
||||
source params/gpt_smoke/model.sh
|
||||
@@ -46,6 +52,28 @@ PARALLEL_ARGS="
|
||||
# --sequence-parallel
|
||||
# "
|
||||
|
||||
case "$ZERO_STAGE" in
|
||||
0)
|
||||
ZERO_ARGS=""
|
||||
;;
|
||||
1)
|
||||
ZERO_ARGS="
|
||||
--use-distributed-optimizer
|
||||
--data-parallel-sharding-strategy optim
|
||||
"
|
||||
;;
|
||||
2)
|
||||
ZERO_ARGS="
|
||||
--use-distributed-optimizer
|
||||
--data-parallel-sharding-strategy optim_grads
|
||||
"
|
||||
;;
|
||||
*)
|
||||
echo "Unsupported ZERO_STAGE=${ZERO_STAGE}; expected 0, 1, or 2" >&2
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
mkdir -p "$CKPT_DIR" "$TB_DIR"
|
||||
|
||||
cleanup_old_checkpoints_once() {
|
||||
@@ -107,17 +135,18 @@ CHECKPOINT_CLEANUP_PID=$!
|
||||
trap 'kill "$CHECKPOINT_CLEANUP_PID" 2>/dev/null || true; cleanup_old_checkpoints_once "$CKPT_DIR" "$CHECKPOINT_KEEP_RECENT"' EXIT
|
||||
|
||||
DISTRIBUTED_ARGS="
|
||||
--nproc_per_node 8
|
||||
--nnodes 1
|
||||
--node_rank 0
|
||||
--master_addr localhost
|
||||
--master_port 6000
|
||||
--nproc_per_node ${NPROC_PER_NODE}
|
||||
--nnodes ${NNODES}
|
||||
--node_rank ${NODE_RANK}
|
||||
--master_addr ${MASTER_ADDR}
|
||||
--master_port ${MASTER_PORT}
|
||||
"
|
||||
|
||||
torchrun $DISTRIBUTED_ARGS \
|
||||
$MEGATRON_PATH/pretrain_gpt.py \
|
||||
$MODEL_ARGS \
|
||||
$OPTIM_ARGS \
|
||||
$ZERO_ARGS \
|
||||
$PRECISION_ARGS \
|
||||
$PARALLEL_ARGS \
|
||||
$DATA_ARGS \
|
||||
|
||||
Reference in New Issue
Block a user