Files
pretrain_kaiyuan2b/scripts/convert_phase_to_megatron.py
2026-05-06 15:06:07 +08:00

164 lines
5.1 KiB
Python

import argparse
import json
import os
import subprocess
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path
from tqdm import tqdm
import pyarrow.parquet as pq
"""
Convert Pyarrow parquets to megatron format, use jsonl as intermediate format.
Takes in parquet schema:
text: <string>
Usage:
python /apps/yi/model_training/scripts/convert_phase_to_megatron.py \
--input-dir /apps/yi/model_training/data/phase1 \
--output-dir /ssd/yi/converted_data/megatron_phase1 \
--tmp-dir /ssd/yi/converted_data/tmp_jsonl \
--megatron-dir /apps/yi/model_training/Megatron-LM \
--tokenizer-model /apps/yi/model_training/data/tokenizer \
--text-key text \
--num-shards 4 \
--workers-per-shard 16 \
--start 100 \
--end 220 # 1 of total 220 parquets
"""
def parquet_to_jsonl(parquet_path: Path, jsonl_path: Path, text_key: str):
jsonl_path.parent.mkdir(parents=True, exist_ok=True)
rows = 0
with jsonl_path.open("w", encoding="utf-8") as fout:
pf = pq.ParquetFile(parquet_path)
for batch in pf.iter_batches(columns=[text_key], batch_size=8192):
col = batch.column(0).to_pylist()
for text in col:
if isinstance(text, str) and text.strip():
fout.write(json.dumps({text_key: text}, ensure_ascii=False) + "\n")
rows += 1
return rows
def run_one(args_tuple):
(
parquet_path,
output_dir,
tmp_dir,
text_key,
megatron_dir,
tokenizer_type,
tokenizer_model,
workers_per_shard,
keep_jsonl,
overwrite,
) = args_tuple
parquet_path = Path(parquet_path)
stem = parquet_path.name.replace(".zstd.parquet", "").replace(".parquet", "")
jsonl_path = Path(tmp_dir) / f"{stem}.jsonl"
output_prefix = Path(output_dir) / f"phase1_{stem}"
bin_file = Path(str(output_prefix) + f"_{text_key}_document.bin")
idx_file = Path(str(output_prefix) + f"_{text_key}_document.idx")
if not overwrite and bin_file.exists() and idx_file.exists():
return f"[SKIP] {parquet_path.name}: existing bin/idx"
print(f"[START] {parquet_path.name}", flush=True)
rows = parquet_to_jsonl(parquet_path, jsonl_path, text_key)
print(f"[JSONL DONE] {parquet_path.name}: rows={rows}, jsonl={jsonl_path}", flush=True)
print(f"[MEGATRON START] {parquet_path.name}", flush=True)
cmd = [
"python",
str(Path(megatron_dir) / "tools/preprocess_data.py"),
"--input", str(jsonl_path),
"--output-prefix", str(output_prefix),
"--tokenizer-type", tokenizer_type,
"--tokenizer-model", tokenizer_model,
"--json-keys", text_key,
"--workers", str(workers_per_shard),
"--append-eod",
]
env = os.environ.copy()
proc = subprocess.run(
cmd,
cwd=megatron_dir,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
if proc.returncode != 0:
return f"[FAIL] {parquet_path.name}\n{proc.stdout[-4000:]}"
if not keep_jsonl:
jsonl_path.unlink(missing_ok=True)
return f"[OK] {parquet_path.name}: rows={rows}, output_prefix={output_prefix}"
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input-dir", required=True)
parser.add_argument("--output-dir", required=True)
parser.add_argument("--tmp-dir", required=True)
parser.add_argument("--megatron-dir", default="/apps/model_training/Megatron-LM")
parser.add_argument("--tokenizer-type", default="HuggingFaceTokenizer")
parser.add_argument("--tokenizer-model", required=True)
parser.add_argument("--text-key", default="text")
parser.add_argument("--num-shards", type=int, default=1, help="parallel parquet shards")
parser.add_argument("--workers-per-shard", type=int, default=8)
parser.add_argument("--start", type=int, default=0)
parser.add_argument("--end", type=int, default=None)
parser.add_argument("--keep-jsonl", action="store_true")
parser.add_argument("--overwrite", action="store_true")
args = parser.parse_args()
files = sorted(Path(args.input_dir).glob("*.zstd.parquet"))
if not files:
files = sorted(Path(args.input_dir).glob("*.parquet"))
files = files[args.start:args.end]
print(f"Converting {len(files)} files")
print(f"Parallel shards: {args.num_shards}")
print(f"Workers per shard: {args.workers_per_shard}")
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
Path(args.tmp_dir).mkdir(parents=True, exist_ok=True)
tasks = [
(
str(f),
args.output_dir,
args.tmp_dir,
args.text_key,
args.megatron_dir,
args.tokenizer_type,
args.tokenizer_model,
args.workers_per_shard,
args.keep_jsonl,
args.overwrite,
)
for f in files
]
with ProcessPoolExecutor(max_workers=args.num_shards) as ex:
futs = [ex.submit(run_one, t) for t in tasks]
for fut in tqdm(as_completed(futs), total=len(futs)):
print(fut.result(), flush=True)
if __name__ == "__main__":
main()