164 lines
5.1 KiB
Python
164 lines
5.1 KiB
Python
import argparse
|
|
import json
|
|
import os
|
|
import subprocess
|
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
|
from pathlib import Path
|
|
from tqdm import tqdm
|
|
import pyarrow.parquet as pq
|
|
|
|
"""
|
|
Convert Pyarrow parquets to megatron format, use jsonl as intermediate format.
|
|
|
|
Takes in parquet schema:
|
|
|
|
text: <string>
|
|
|
|
Usage:
|
|
|
|
python /apps/yi/model_training/scripts/convert_phase_to_megatron.py \
|
|
--input-dir /apps/yi/model_training/data/phase1 \
|
|
--output-dir /ssd/yi/converted_data/megatron_phase1 \
|
|
--tmp-dir /ssd/yi/converted_data/tmp_jsonl \
|
|
--megatron-dir /apps/yi/model_training/Megatron-LM \
|
|
--tokenizer-model /apps/yi/model_training/data/tokenizer \
|
|
--text-key text \
|
|
--num-shards 4 \
|
|
--workers-per-shard 16 \
|
|
--start 100 \
|
|
--end 220 # 1 of total 220 parquets
|
|
"""
|
|
|
|
|
|
|
|
def parquet_to_jsonl(parquet_path: Path, jsonl_path: Path, text_key: str):
|
|
jsonl_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
rows = 0
|
|
with jsonl_path.open("w", encoding="utf-8") as fout:
|
|
pf = pq.ParquetFile(parquet_path)
|
|
for batch in pf.iter_batches(columns=[text_key], batch_size=8192):
|
|
col = batch.column(0).to_pylist()
|
|
for text in col:
|
|
if isinstance(text, str) and text.strip():
|
|
fout.write(json.dumps({text_key: text}, ensure_ascii=False) + "\n")
|
|
rows += 1
|
|
return rows
|
|
|
|
|
|
def run_one(args_tuple):
|
|
(
|
|
parquet_path,
|
|
output_dir,
|
|
tmp_dir,
|
|
text_key,
|
|
megatron_dir,
|
|
tokenizer_type,
|
|
tokenizer_model,
|
|
workers_per_shard,
|
|
keep_jsonl,
|
|
overwrite,
|
|
) = args_tuple
|
|
|
|
parquet_path = Path(parquet_path)
|
|
stem = parquet_path.name.replace(".zstd.parquet", "").replace(".parquet", "")
|
|
jsonl_path = Path(tmp_dir) / f"{stem}.jsonl"
|
|
output_prefix = Path(output_dir) / f"phase1_{stem}"
|
|
|
|
bin_file = Path(str(output_prefix) + f"_{text_key}_document.bin")
|
|
idx_file = Path(str(output_prefix) + f"_{text_key}_document.idx")
|
|
|
|
if not overwrite and bin_file.exists() and idx_file.exists():
|
|
return f"[SKIP] {parquet_path.name}: existing bin/idx"
|
|
|
|
print(f"[START] {parquet_path.name}", flush=True)
|
|
|
|
rows = parquet_to_jsonl(parquet_path, jsonl_path, text_key)
|
|
print(f"[JSONL DONE] {parquet_path.name}: rows={rows}, jsonl={jsonl_path}", flush=True)
|
|
|
|
print(f"[MEGATRON START] {parquet_path.name}", flush=True)
|
|
|
|
cmd = [
|
|
"python",
|
|
str(Path(megatron_dir) / "tools/preprocess_data.py"),
|
|
"--input", str(jsonl_path),
|
|
"--output-prefix", str(output_prefix),
|
|
"--tokenizer-type", tokenizer_type,
|
|
"--tokenizer-model", tokenizer_model,
|
|
"--json-keys", text_key,
|
|
"--workers", str(workers_per_shard),
|
|
"--append-eod",
|
|
]
|
|
|
|
env = os.environ.copy()
|
|
proc = subprocess.run(
|
|
cmd,
|
|
cwd=megatron_dir,
|
|
env=env,
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.STDOUT,
|
|
text=True,
|
|
)
|
|
|
|
if proc.returncode != 0:
|
|
return f"[FAIL] {parquet_path.name}\n{proc.stdout[-4000:]}"
|
|
|
|
if not keep_jsonl:
|
|
jsonl_path.unlink(missing_ok=True)
|
|
|
|
return f"[OK] {parquet_path.name}: rows={rows}, output_prefix={output_prefix}"
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--input-dir", required=True)
|
|
parser.add_argument("--output-dir", required=True)
|
|
parser.add_argument("--tmp-dir", required=True)
|
|
parser.add_argument("--megatron-dir", default="/apps/model_training/Megatron-LM")
|
|
parser.add_argument("--tokenizer-type", default="HuggingFaceTokenizer")
|
|
parser.add_argument("--tokenizer-model", required=True)
|
|
parser.add_argument("--text-key", default="text")
|
|
parser.add_argument("--num-shards", type=int, default=1, help="parallel parquet shards")
|
|
parser.add_argument("--workers-per-shard", type=int, default=8)
|
|
parser.add_argument("--start", type=int, default=0)
|
|
parser.add_argument("--end", type=int, default=None)
|
|
parser.add_argument("--keep-jsonl", action="store_true")
|
|
parser.add_argument("--overwrite", action="store_true")
|
|
args = parser.parse_args()
|
|
|
|
files = sorted(Path(args.input_dir).glob("*.zstd.parquet"))
|
|
if not files:
|
|
files = sorted(Path(args.input_dir).glob("*.parquet"))
|
|
|
|
files = files[args.start:args.end]
|
|
print(f"Converting {len(files)} files")
|
|
print(f"Parallel shards: {args.num_shards}")
|
|
print(f"Workers per shard: {args.workers_per_shard}")
|
|
|
|
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
|
Path(args.tmp_dir).mkdir(parents=True, exist_ok=True)
|
|
|
|
tasks = [
|
|
(
|
|
str(f),
|
|
args.output_dir,
|
|
args.tmp_dir,
|
|
args.text_key,
|
|
args.megatron_dir,
|
|
args.tokenizer_type,
|
|
args.tokenizer_model,
|
|
args.workers_per_shard,
|
|
args.keep_jsonl,
|
|
args.overwrite,
|
|
)
|
|
for f in files
|
|
]
|
|
|
|
with ProcessPoolExecutor(max_workers=args.num_shards) as ex:
|
|
futs = [ex.submit(run_one, t) for t in tasks]
|
|
for fut in tqdm(as_completed(futs), total=len(futs)):
|
|
print(fut.result(), flush=True)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |