import argparse import json import os import subprocess from concurrent.futures import ProcessPoolExecutor, as_completed from pathlib import Path from tqdm import tqdm import pyarrow.parquet as pq """ Convert Pyarrow parquets to megatron format, use jsonl as intermediate format. Takes in parquet schema: text: Usage: python /apps/yi/model_training/scripts/convert_phase_to_megatron.py \ --input-dir /apps/yi/model_training/data/phase1 \ --output-dir /ssd/yi/converted_data/megatron_phase1 \ --tmp-dir /ssd/yi/converted_data/tmp_jsonl \ --megatron-dir /apps/yi/model_training/Megatron-LM \ --tokenizer-model /apps/yi/model_training/data/tokenizer \ --text-key text \ --num-shards 4 \ --workers-per-shard 16 \ --start 100 \ --end 220 # 1 of total 220 parquets """ def parquet_to_jsonl(parquet_path: Path, jsonl_path: Path, text_key: str): jsonl_path.parent.mkdir(parents=True, exist_ok=True) rows = 0 with jsonl_path.open("w", encoding="utf-8") as fout: pf = pq.ParquetFile(parquet_path) for batch in pf.iter_batches(columns=[text_key], batch_size=8192): col = batch.column(0).to_pylist() for text in col: if isinstance(text, str) and text.strip(): fout.write(json.dumps({text_key: text}, ensure_ascii=False) + "\n") rows += 1 return rows def run_one(args_tuple): ( parquet_path, output_dir, tmp_dir, text_key, megatron_dir, tokenizer_type, tokenizer_model, workers_per_shard, keep_jsonl, overwrite, ) = args_tuple parquet_path = Path(parquet_path) stem = parquet_path.name.replace(".zstd.parquet", "").replace(".parquet", "") jsonl_path = Path(tmp_dir) / f"{stem}.jsonl" output_prefix = Path(output_dir) / f"phase1_{stem}" bin_file = Path(str(output_prefix) + f"_{text_key}_document.bin") idx_file = Path(str(output_prefix) + f"_{text_key}_document.idx") if not overwrite and bin_file.exists() and idx_file.exists(): return f"[SKIP] {parquet_path.name}: existing bin/idx" print(f"[START] {parquet_path.name}", flush=True) rows = parquet_to_jsonl(parquet_path, jsonl_path, text_key) print(f"[JSONL DONE] {parquet_path.name}: rows={rows}, jsonl={jsonl_path}", flush=True) print(f"[MEGATRON START] {parquet_path.name}", flush=True) cmd = [ "python", str(Path(megatron_dir) / "tools/preprocess_data.py"), "--input", str(jsonl_path), "--output-prefix", str(output_prefix), "--tokenizer-type", tokenizer_type, "--tokenizer-model", tokenizer_model, "--json-keys", text_key, "--workers", str(workers_per_shard), "--append-eod", ] env = os.environ.copy() proc = subprocess.run( cmd, cwd=megatron_dir, env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, ) if proc.returncode != 0: return f"[FAIL] {parquet_path.name}\n{proc.stdout[-4000:]}" if not keep_jsonl: jsonl_path.unlink(missing_ok=True) return f"[OK] {parquet_path.name}: rows={rows}, output_prefix={output_prefix}" def main(): parser = argparse.ArgumentParser() parser.add_argument("--input-dir", required=True) parser.add_argument("--output-dir", required=True) parser.add_argument("--tmp-dir", required=True) parser.add_argument("--megatron-dir", default="/apps/model_training/Megatron-LM") parser.add_argument("--tokenizer-type", default="HuggingFaceTokenizer") parser.add_argument("--tokenizer-model", required=True) parser.add_argument("--text-key", default="text") parser.add_argument("--num-shards", type=int, default=1, help="parallel parquet shards") parser.add_argument("--workers-per-shard", type=int, default=8) parser.add_argument("--start", type=int, default=0) parser.add_argument("--end", type=int, default=None) parser.add_argument("--keep-jsonl", action="store_true") parser.add_argument("--overwrite", action="store_true") args = parser.parse_args() files = sorted(Path(args.input_dir).glob("*.zstd.parquet")) if not files: files = sorted(Path(args.input_dir).glob("*.parquet")) files = files[args.start:args.end] print(f"Converting {len(files)} files") print(f"Parallel shards: {args.num_shards}") print(f"Workers per shard: {args.workers_per_shard}") Path(args.output_dir).mkdir(parents=True, exist_ok=True) Path(args.tmp_dir).mkdir(parents=True, exist_ok=True) tasks = [ ( str(f), args.output_dir, args.tmp_dir, args.text_key, args.megatron_dir, args.tokenizer_type, args.tokenizer_model, args.workers_per_shard, args.keep_jsonl, args.overwrite, ) for f in files ] with ProcessPoolExecutor(max_workers=args.num_shards) as ex: futs = [ex.submit(run_one, t) for t in tasks] for fut in tqdm(as_completed(futs), total=len(futs)): print(fut.result(), flush=True) if __name__ == "__main__": main()