feat: optimized dataset convertion efficiency, add on-demand training start/stop script

This commit is contained in:
2026-05-06 22:32:18 +08:00
parent 056df3b6ca
commit 0008288964
6 changed files with 607 additions and 114 deletions

View File

@@ -1,130 +1,240 @@
import argparse
import json
import os
import subprocess
import sys
import time
from concurrent.futures import ProcessPoolExecutor, as_completed
from multiprocessing import Pool
from pathlib import Path
from tqdm import tqdm
from types import SimpleNamespace
import pyarrow.parquet as pq
from tqdm import tqdm
"""
Convert Pyarrow parquets to megatron format, use jsonl as intermediate format.
Convert Kaiyuan parquet files directly to Megatron indexed dataset format.
Takes in parquet schema:
Expected parquet schema:
text: <string>
The previous implementation used parquet -> JSONL -> Megatron preprocess_data.py.
This implementation removes the JSONL intermediate file and writes .bin/.idx with
Megatron's IndexedDatasetBuilder directly.
Usage:
python /apps/yi/model_training/scripts/convert_phase_to_megatron.py \
--input-dir /apps/yi/model_training/data/phase1 \
--output-dir /ssd/yi/converted_data/megatron_phase1 \
--tmp-dir /ssd/yi/converted_data/tmp_jsonl \
--megatron-dir /apps/yi/model_training/Megatron-LM \
--tokenizer-model /apps/yi/model_training/data/tokenizer \
--text-key text \
--num-shards 4 \
--workers-per-shard 16 \
--start 100 \
--end 220 # 1 of total 220 parquets
--end 220
"""
_TOKENIZER = None
_APPEND_EOD = True
def parquet_to_jsonl(parquet_path: Path, jsonl_path: Path, text_key: str):
jsonl_path.parent.mkdir(parents=True, exist_ok=True)
rows = 0
with jsonl_path.open("w", encoding="utf-8") as fout:
pf = pq.ParquetFile(parquet_path)
for batch in pf.iter_batches(columns=[text_key], batch_size=8192):
col = batch.column(0).to_pylist()
for text in col:
if isinstance(text, str) and text.strip():
fout.write(json.dumps({text_key: text}, ensure_ascii=False) + "\n")
rows += 1
return rows
def run_one(args_tuple):
(
parquet_path,
output_dir,
tmp_dir,
text_key,
megatron_dir,
tokenizer_type,
tokenizer_model,
workers_per_shard,
keep_jsonl,
overwrite,
) = args_tuple
parquet_path = Path(parquet_path)
stem = parquet_path.name.replace(".zstd.parquet", "").replace(".parquet", "")
jsonl_path = Path(tmp_dir) / f"{stem}.jsonl"
output_prefix = Path(output_dir) / f"phase1_{stem}"
bin_file = Path(str(output_prefix) + f"_{text_key}_document.bin")
idx_file = Path(str(output_prefix) + f"_{text_key}_document.idx")
if not overwrite and bin_file.exists() and idx_file.exists():
return f"[SKIP] {parquet_path.name}: existing bin/idx"
print(f"[START] {parquet_path.name}", flush=True)
rows = parquet_to_jsonl(parquet_path, jsonl_path, text_key)
print(f"[JSONL DONE] {parquet_path.name}: rows={rows}, jsonl={jsonl_path}", flush=True)
print(f"[MEGATRON START] {parquet_path.name}", flush=True)
cmd = [
"python",
str(Path(megatron_dir) / "tools/preprocess_data.py"),
"--input", str(jsonl_path),
"--output-prefix", str(output_prefix),
"--tokenizer-type", tokenizer_type,
"--tokenizer-model", tokenizer_model,
"--json-keys", text_key,
"--workers", str(workers_per_shard),
"--append-eod",
]
env = os.environ.copy()
proc = subprocess.run(
cmd,
cwd=megatron_dir,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
def make_tokenizer_args(args):
return SimpleNamespace(
rank=0,
make_vocab_size_divisible_by=128,
tensor_model_parallel_size=1,
padded_vocab_size=None,
vocab_size=args.vocab_size,
vocab_file=args.vocab_file,
merge_file=args.merge_file,
vocab_extra_ids=0,
tokenizer_type=args.tokenizer_type,
tokenizer_model=args.tokenizer_model,
metadata_path=args.tokenizer_metadata,
special_tokens=args.tokenizer_special_tokens,
tokenizer_sentencepiece_legacy=args.tokenizer_sentencepiece_legacy,
tokenizer_hf_no_use_fast=args.tokenizer_hf_no_use_fast,
tokenizer_hf_no_include_special_tokens=args.tokenizer_hf_no_include_special_tokens,
trust_remote_code=args.trust_remote_code,
tiktoken_pattern=args.tiktoken_pattern,
tiktoken_num_special_tokens=args.tiktoken_num_special_tokens,
null_tokenizer_eod_id=args.null_tokenizer_eod_id,
null_tokenizer_pad_id=args.null_tokenizer_pad_id,
tokenizer_prompt_format=None,
image_tag_type=None,
force_system_message=False,
sft_tokenizer_prompt_format=None,
)
if proc.returncode != 0:
return f"[FAIL] {parquet_path.name}\n{proc.stdout[-4000:]}"
if not keep_jsonl:
jsonl_path.unlink(missing_ok=True)
return f"[OK] {parquet_path.name}: rows={rows}, output_prefix={output_prefix}"
def add_megatron_to_path(megatron_dir):
megatron_dir = str(Path(megatron_dir).resolve())
if megatron_dir not in sys.path:
sys.path.insert(0, megatron_dir)
def main():
def build_megatron_tokenizer(args):
add_megatron_to_path(args.megatron_dir)
from megatron.core.tokenizers.utils.build_tokenizer import build_tokenizer
return build_tokenizer(make_tokenizer_args(args))
def init_worker(args):
global _TOKENIZER, _APPEND_EOD
_APPEND_EOD = args.append_eod
_TOKENIZER = build_megatron_tokenizer(args)
if _APPEND_EOD and _TOKENIZER.eod is None:
raise ValueError("Tokenizer has no EOD/EOS token, but --append-eod is enabled.")
def encode_text(text):
if not isinstance(text, str):
return None
text = text.strip()
if not text:
return None
token_ids = _TOKENIZER.tokenize(text)
if not token_ids:
return None
sentence_lens = [len(token_ids)]
if _APPEND_EOD:
token_ids.append(_TOKENIZER.eod)
sentence_lens[-1] += 1
return token_ids, sentence_lens
def output_paths(output_prefix, text_key):
prefix = Path(output_prefix)
return (
Path(str(prefix) + f"_{text_key}_document.bin"),
Path(str(prefix) + f"_{text_key}_document.idx"),
)
def remove_partial_outputs(output_prefix, text_key):
bin_file, idx_file = output_paths(output_prefix, text_key)
bin_file.unlink(missing_ok=True)
idx_file.unlink(missing_ok=True)
def convert_one_parquet(args_tuple):
parquet_path, args = args_tuple
parquet_path = Path(parquet_path)
stem = parquet_path.name.replace(".zstd.parquet", "").replace(".parquet", "")
output_prefix = Path(args.output_dir) / f"{args.output_prefix_prefix}_{stem}"
bin_file, idx_file = output_paths(output_prefix, args.text_key)
if not args.overwrite and bin_file.exists() and idx_file.exists():
return f"[SKIP] {parquet_path.name}: existing bin/idx"
remove_partial_outputs(output_prefix, args.text_key)
add_megatron_to_path(args.megatron_dir)
from megatron.core.datasets import indexed_dataset
tokenizer = build_megatron_tokenizer(args)
dtype = indexed_dataset.DType.optimal_dtype(tokenizer.vocab_size)
builder = indexed_dataset.IndexedDatasetBuilder(str(bin_file), dtype=dtype)
start_time = time.time()
rows = 0
docs = 0
tokens = 0
def consume_encoded(encoded):
nonlocal docs, tokens
if encoded is None:
return
token_ids, sentence_lens = encoded
builder.add_document(token_ids, sentence_lens)
docs += 1
tokens += len(token_ids)
if args.log_interval and docs % args.log_interval == 0:
elapsed = max(time.time() - start_time, 1e-6)
print(
f"[{parquet_path.name}] docs={docs} "
f"tokens={tokens} docs/s={docs / elapsed:.2f}",
flush=True,
)
pf = pq.ParquetFile(parquet_path)
if args.workers_per_shard == 1:
init_worker(args)
for batch in pf.iter_batches(columns=[args.text_key], batch_size=args.batch_size):
texts = batch.column(0).to_pylist()
rows += len(texts)
for text in texts:
consume_encoded(encode_text(text))
else:
with Pool(processes=args.workers_per_shard, initializer=init_worker, initargs=(args,)) as pool:
for batch in pf.iter_batches(columns=[args.text_key], batch_size=args.batch_size):
texts = batch.column(0).to_pylist()
rows += len(texts)
for encoded in pool.imap(encode_text, texts, chunksize=args.chunksize):
consume_encoded(encoded)
builder.finalize(str(idx_file))
elapsed = max(time.time() - start_time, 1e-6)
return (
f"[OK] {parquet_path.name}: rows={rows}, docs={docs}, tokens={tokens}, "
f"elapsed={elapsed:.1f}s, docs/s={docs / elapsed:.2f}, output_prefix={output_prefix}"
)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--input-dir", required=True)
parser.add_argument("--output-dir", required=True)
parser.add_argument("--tmp-dir", required=True)
parser.add_argument("--tmp-dir", default=None, help="Deprecated; kept for CLI compatibility.")
parser.add_argument("--megatron-dir", default="/apps/model_training/Megatron-LM")
parser.add_argument("--tokenizer-type", default="HuggingFaceTokenizer")
parser.add_argument("--tokenizer-model", required=True)
parser.add_argument("--tokenizer-metadata", default=None)
parser.add_argument("--tokenizer-special-tokens", nargs="*", default=None)
parser.add_argument("--tokenizer-sentencepiece-legacy", action="store_true")
parser.add_argument("--tokenizer-hf-no-use-fast", action="store_true")
parser.add_argument("--tokenizer-hf-no-include-special-tokens", action="store_true")
parser.add_argument("--trust-remote-code", action="store_true")
parser.add_argument("--vocab-file", default=None)
parser.add_argument("--merge-file", default=None)
parser.add_argument("--vocab-size", type=int, default=None)
parser.add_argument("--tiktoken-pattern", default=None)
parser.add_argument("--tiktoken-num-special-tokens", type=int, default=1000)
parser.add_argument("--null-tokenizer-eod-id", type=int, default=None)
parser.add_argument("--null-tokenizer-pad-id", type=int, default=-1)
parser.add_argument("--text-key", default="text")
parser.add_argument("--num-shards", type=int, default=1, help="parallel parquet shards")
parser.add_argument("--workers-per-shard", type=int, default=8)
parser.add_argument("--output-prefix-prefix", default="phase1")
parser.add_argument("--num-shards", type=int, default=1, help="Parallel parquet files.")
parser.add_argument("--workers-per-shard", type=int, default=max((os.cpu_count() or 8) // 2, 1))
parser.add_argument("--batch-size", type=int, default=8192, help="Parquet record batch size.")
parser.add_argument("--chunksize", type=int, default=64, help="Tokenizer pool imap chunk size.")
parser.add_argument("--log-interval", type=int, default=10000)
parser.add_argument("--start", type=int, default=0)
parser.add_argument("--end", type=int, default=None)
parser.add_argument("--keep-jsonl", action="store_true")
parser.add_argument("--append-eod", action=argparse.BooleanOptionalAction, default=True)
parser.add_argument("--keep-jsonl", action="store_true", help="Deprecated; no JSONL is written.")
parser.add_argument("--overwrite", action="store_true")
args = parser.parse_args()
return parser.parse_args()
def main():
args = parse_args()
if args.num_shards < 1:
raise ValueError("--num-shards must be >= 1")
if args.workers_per_shard < 1:
raise ValueError("--workers-per-shard must be >= 1")
if args.batch_size < 1:
raise ValueError("--batch-size must be >= 1")
if args.chunksize < 1:
raise ValueError("--chunksize must be >= 1")
files = sorted(Path(args.input_dir).glob("*.zstd.parquet"))
if not files:
@@ -132,33 +242,24 @@ def main():
files = files[args.start:args.end]
print(f"Converting {len(files)} files")
print(f"Parallel shards: {args.num_shards}")
print(f"Workers per shard: {args.workers_per_shard}")
print(f"Parallel parquet files: {args.num_shards}")
print(f"Tokenizer workers per parquet: {args.workers_per_shard}")
print(f"Total tokenizer workers: {args.num_shards * args.workers_per_shard}")
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
Path(args.tmp_dir).mkdir(parents=True, exist_ok=True)
if args.tmp_dir:
Path(args.tmp_dir).mkdir(parents=True, exist_ok=True)
tasks = [
(
str(f),
args.output_dir,
args.tmp_dir,
args.text_key,
args.megatron_dir,
args.tokenizer_type,
args.tokenizer_model,
args.workers_per_shard,
args.keep_jsonl,
args.overwrite,
)
for f in files
]
with ProcessPoolExecutor(max_workers=args.num_shards) as ex:
futs = [ex.submit(run_one, t) for t in tasks]
for fut in tqdm(as_completed(futs), total=len(futs)):
print(fut.result(), flush=True)
tasks = [(str(f), args) for f in files]
if args.num_shards == 1:
for task in tqdm(tasks):
print(convert_one_parquet(task), flush=True)
else:
with ProcessPoolExecutor(max_workers=args.num_shards) as ex:
futs = [ex.submit(convert_one_parquet, task) for task in tasks]
for fut in tqdm(as_completed(futs), total=len(futs)):
print(fut.result(), flush=True)
if __name__ == "__main__":
main()
main()

View File

@@ -0,0 +1,92 @@
#!/usr/bin/env bash
set -euo pipefail
SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
ARTIFACT_ROOT=${ARTIFACT_ROOT:-/apps/yi/model_training/artifacts}
RUN_STATE_DIR="${ARTIFACT_ROOT}/run_state"
LOG_DIR="${ARTIFACT_ROOT}/logs"
usage() {
cat <<'EOF'
Usage:
bash start_training.sh <model> [mode] [train_name]
Models:
gpt_smoke
qwen3_1p7b
Examples:
bash start_training.sh gpt_smoke smoke smoke_gpt
bash start_training.sh qwen3_1p7b qwen3_1p7b_smoke_yi qwen3_1p7b_smoke_yi
Environment overrides:
CHECKPOINT_KEEP_RECENT=3
CHECKPOINT_CLEANUP_INTERVAL_SECONDS=300
EXTRA_ARGS="--exit-duration-in-mins 120"
EOF
}
model=${1:-}
mode=${2:-}
train_name=${3:-}
if [ -z "$model" ] || [ "$model" = "-h" ] || [ "$model" = "--help" ]; then
usage
exit 0
fi
case "$model" in
gpt_smoke)
train_script="${SCRIPT_DIR}/training_smoke_gpt2.sh"
mode=${mode:-smoke}
train_name=${train_name:-smoke_gpt}
;;
qwen3_1p7b)
train_script="${SCRIPT_DIR}/training_smoke_qwen3_1p7b.sh"
mode=${mode:-qwen3_1p7b_smoke_yi}
train_name=${train_name:-qwen3_1p7b_smoke_yi}
;;
*)
echo "Unknown model: $model" >&2
usage >&2
exit 1
;;
esac
mkdir -p "$RUN_STATE_DIR" "$LOG_DIR"
pid_file="${RUN_STATE_DIR}/${train_name}.pid"
meta_file="${RUN_STATE_DIR}/${train_name}.env"
log_file="${LOG_DIR}/${train_name}.log"
if [ -f "$pid_file" ]; then
old_pid=$(cat "$pid_file")
if [ -n "$old_pid" ] && kill -0 "$old_pid" 2>/dev/null; then
echo "Training already appears to be running: train_name=${train_name}, pid=${old_pid}" >&2
exit 1
fi
fi
combined_extra_args="--exit-signal-handler ${EXTRA_ARGS:-}"
cd "$SCRIPT_DIR"
EXTRA_ARGS="$combined_extra_args" setsid bash "$train_script" "$mode" "$train_name" > "$log_file" 2>&1 &
pid=$!
pgid=$(ps -o pgid= -p "$pid" | tr -d ' ' || true)
printf '%s\n' "$pid" > "$pid_file"
cat > "$meta_file" <<EOF
MODEL=${model}
MODE=${mode}
TRAIN_NAME=${train_name}
PID=${pid}
PGID=${pgid}
LOG_FILE=${log_file}
TRAIN_SCRIPT=${train_script}
CHECKPOINT_KEEP_RECENT=${CHECKPOINT_KEEP_RECENT:-3}
CHECKPOINT_CLEANUP_INTERVAL_SECONDS=${CHECKPOINT_CLEANUP_INTERVAL_SECONDS:-300}
EOF
echo "Started training: model=${model}, mode=${mode}, train_name=${train_name}, pid=${pid}, pgid=${pgid:-unknown}"
echo "Log: ${log_file}"
echo "Stop: bash ${SCRIPT_DIR}/stop_training.sh ${train_name}"

View File

@@ -0,0 +1,68 @@
#!/usr/bin/env bash
set -euo pipefail
SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
ARTIFACT_ROOT=${ARTIFACT_ROOT:-/apps/yi/model_training/artifacts}
RUN_STATE_DIR="${ARTIFACT_ROOT}/run_state"
GRACE_SECONDS=${GRACE_SECONDS:-300}
usage() {
cat <<'EOF'
Usage:
bash stop_training.sh <train_name>
Environment overrides:
GRACE_SECONDS=300
EOF
}
train_name=${1:-}
if [ -z "$train_name" ] || [ "$train_name" = "-h" ] || [ "$train_name" = "--help" ]; then
usage
exit 0
fi
pid_file="${RUN_STATE_DIR}/${train_name}.pid"
meta_file="${RUN_STATE_DIR}/${train_name}.env"
if [ ! -f "$pid_file" ]; then
echo "PID file not found: ${pid_file}" >&2
exit 1
fi
pid=$(cat "$pid_file")
if [ -z "$pid" ] || ! kill -0 "$pid" 2>/dev/null; then
echo "Training is not running for train_name=${train_name}; cleaning stale state."
rm -f "$pid_file" "$meta_file"
exit 0
fi
pgid=$(ps -o pgid= -p "$pid" | tr -d ' ' || true)
if [ -z "$pgid" ] && [ -f "$meta_file" ]; then
pgid=$(grep '^PGID=' "$meta_file" | cut -d= -f2- || true)
fi
echo "Sending SIGTERM to training process group: train_name=${train_name}, pid=${pid}, pgid=${pgid:-unknown}"
if [ -n "$pgid" ]; then
kill -TERM "-${pgid}" 2>/dev/null || kill -TERM "$pid" 2>/dev/null || true
else
kill -TERM "$pid" 2>/dev/null || true
fi
deadline=$((SECONDS + GRACE_SECONDS))
while kill -0 "$pid" 2>/dev/null; do
if [ "$SECONDS" -ge "$deadline" ]; then
echo "Training did not exit within ${GRACE_SECONDS}s." >&2
echo "If checkpoint saving is still running, wait and inspect logs before forcing termination." >&2
if [ -n "$pgid" ]; then
echo "Force kill manually if needed: kill -KILL -${pgid}" >&2
else
echo "Force kill manually if needed: kill -KILL ${pid}" >&2
fi
exit 2
fi
sleep 5
done
rm -f "$pid_file" "$meta_file"
echo "Stopped training: train_name=${train_name}"

View File

@@ -11,6 +11,9 @@ MEGATRON_PATH=/apps/yi/model_training/Megatron-LM
ARTIFACT_ROOT=/apps/yi/model_training/artifacts
TB_DIR="${ARTIFACT_ROOT}/tb_logs/${TRAIN_NAME}"
CKPT_DIR="${ARTIFACT_ROOT}/checkpoints/${TRAIN_NAME}"
CHECKPOINT_KEEP_RECENT=${CHECKPOINT_KEEP_RECENT:-3}
CHECKPOINT_CLEANUP_INTERVAL_SECONDS=${CHECKPOINT_CLEANUP_INTERVAL_SECONDS:-300}
EXTRA_ARGS=${EXTRA_ARGS:-}
source params/optim_common.sh
source params/gpt_smoke/model.sh
@@ -45,6 +48,64 @@ PARALLEL_ARGS="
mkdir -p "$CKPT_DIR" "$TB_DIR"
cleanup_old_checkpoints_once() {
local ckpt_dir=$1
local keep=$2
if ! [[ "$keep" =~ ^[0-9]+$ ]] || [ "$keep" -le 0 ] || [ ! -d "$ckpt_dir" ]; then
return 0
fi
local latest=""
if [ -f "${ckpt_dir}/latest_checkpointed_iteration.txt" ]; then
read -r latest < "${ckpt_dir}/latest_checkpointed_iteration.txt" || latest=""
if [[ "$latest" =~ ^[0-9]+$ ]]; then
latest=$(printf "iter_%07d" "$latest")
else
latest=""
fi
fi
local checkpoints=()
while IFS= read -r path; do
checkpoints+=("$path")
done < <(find "$ckpt_dir" -maxdepth 1 -type d -name 'iter_[0-9][0-9][0-9][0-9][0-9][0-9][0-9]' -print | sort)
local delete_count=$((${#checkpoints[@]} - keep))
if [ "$delete_count" -le 0 ]; then
return 0
fi
local i base
for ((i = 0; i < delete_count; i++)); do
base=$(basename "${checkpoints[$i]}")
if [ "$base" = "$latest" ]; then
continue
fi
echo "[checkpoint-cleanup] deleting old checkpoint: ${checkpoints[$i]}"
rm -rf -- "${checkpoints[$i]}"
done
}
checkpoint_cleanup_loop() {
local ckpt_dir=$1
local keep=$2
local interval=$3
if ! [[ "$interval" =~ ^[0-9]+$ ]] || [ "$interval" -le 0 ]; then
return 0
fi
while true; do
sleep "$interval"
cleanup_old_checkpoints_once "$ckpt_dir" "$keep"
done
}
checkpoint_cleanup_loop "$CKPT_DIR" "$CHECKPOINT_KEEP_RECENT" "$CHECKPOINT_CLEANUP_INTERVAL_SECONDS" &
CHECKPOINT_CLEANUP_PID=$!
trap 'kill "$CHECKPOINT_CLEANUP_PID" 2>/dev/null || true; cleanup_old_checkpoints_once "$CKPT_DIR" "$CHECKPOINT_KEEP_RECENT"' EXIT
DISTRIBUTED_ARGS="
--nproc_per_node 8
--nnodes 1
@@ -63,4 +124,5 @@ torchrun $DISTRIBUTED_ARGS \
$RUN_ARGS \
$LOGGING_ARGS\
--save "$CKPT_DIR" \
--load "$CKPT_DIR" \
--load "$CKPT_DIR" \
$EXTRA_ARGS

View File

@@ -13,6 +13,9 @@ SCRIPT_DIR=/apps/yi/model_training/scripts/kaiyuan2b-training
PARAMS_DIR="${SCRIPT_DIR}/params"
TB_DIR="${ARTIFACT_ROOT}/tb_logs/${TRAIN_NAME}"
CKPT_DIR="${ARTIFACT_ROOT}/checkpoints/${TRAIN_NAME}"
CHECKPOINT_KEEP_RECENT=${CHECKPOINT_KEEP_RECENT:-3}
CHECKPOINT_CLEANUP_INTERVAL_SECONDS=${CHECKPOINT_CLEANUP_INTERVAL_SECONDS:-300}
EXTRA_ARGS=${EXTRA_ARGS:-}
source "${PARAMS_DIR}/optim_common.sh"
source "${PARAMS_DIR}/qwen3_1p7b/model.sh"
@@ -56,6 +59,64 @@ fi
mkdir -p "$CKPT_DIR" "$TB_DIR"
cleanup_old_checkpoints_once() {
local ckpt_dir=$1
local keep=$2
if ! [[ "$keep" =~ ^[0-9]+$ ]] || [ "$keep" -le 0 ] || [ ! -d "$ckpt_dir" ]; then
return 0
fi
local latest=""
if [ -f "${ckpt_dir}/latest_checkpointed_iteration.txt" ]; then
read -r latest < "${ckpt_dir}/latest_checkpointed_iteration.txt" || latest=""
if [[ "$latest" =~ ^[0-9]+$ ]]; then
latest=$(printf "iter_%07d" "$latest")
else
latest=""
fi
fi
local checkpoints=()
while IFS= read -r path; do
checkpoints+=("$path")
done < <(find "$ckpt_dir" -maxdepth 1 -type d -name 'iter_[0-9][0-9][0-9][0-9][0-9][0-9][0-9]' -print | sort)
local delete_count=$((${#checkpoints[@]} - keep))
if [ "$delete_count" -le 0 ]; then
return 0
fi
local i base
for ((i = 0; i < delete_count; i++)); do
base=$(basename "${checkpoints[$i]}")
if [ "$base" = "$latest" ]; then
continue
fi
echo "[checkpoint-cleanup] deleting old checkpoint: ${checkpoints[$i]}"
rm -rf -- "${checkpoints[$i]}"
done
}
checkpoint_cleanup_loop() {
local ckpt_dir=$1
local keep=$2
local interval=$3
if ! [[ "$interval" =~ ^[0-9]+$ ]] || [ "$interval" -le 0 ]; then
return 0
fi
while true; do
sleep "$interval"
cleanup_old_checkpoints_once "$ckpt_dir" "$keep"
done
}
checkpoint_cleanup_loop "$CKPT_DIR" "$CHECKPOINT_KEEP_RECENT" "$CHECKPOINT_CLEANUP_INTERVAL_SECONDS" &
CHECKPOINT_CLEANUP_PID=$!
trap 'kill "$CHECKPOINT_CLEANUP_PID" 2>/dev/null || true; cleanup_old_checkpoints_once "$CKPT_DIR" "$CHECKPOINT_KEEP_RECENT"' EXIT
DISTRIBUTED_ARGS="
--nproc_per_node 8
--nnodes 1
@@ -79,5 +140,5 @@ torchrun $DISTRIBUTED_ARGS \
--cuda-graph-warmup-steps 3 \
--transformer-impl transformer_engine \
--cross-entropy-loss-fusion \
--cross-entropy-fusion-impl te
--cross-entropy-fusion-impl te \
$EXTRA_ARGS