Files
pretrain_kaiyuan2b/scripts/convert_phase_to_megatron.py

268 lines
9.8 KiB
Python

import argparse
import os
import sys
import time
from concurrent.futures import ProcessPoolExecutor, as_completed
from multiprocessing import Pool
from pathlib import Path
from types import SimpleNamespace
import pyarrow.parquet as pq
from tqdm import tqdm
"""
Convert Kaiyuan parquet files directly to Megatron indexed dataset format.
Expected parquet schema:
text: <string>
The previous implementation used parquet -> JSONL -> Megatron preprocess_data.py.
This implementation removes the JSONL intermediate file and writes .bin/.idx with
Megatron's IndexedDatasetBuilder directly.
Usage:
python /ssd1/yi/pretrain_kaiyuan2b/scripts/convert_phase_to_megatron.py \
--input-dir /ssd1/yi/data/phase1 \
--output-dir /ssd1/yi/converted_data/phase1 \
--megatron-dir /ssd1/yi/pretrain_kaiyuan2b/Megatron-LM \
--tokenizer-model /ssd1/yi/data/tokenizer \
--text-key text \
--num-shards 16 \
--workers-per-shard 12 \
--batch-size 16384 \
--chunksize 128 \
--start 0 \
--end 210
"""
_TOKENIZER = None
_APPEND_EOD = True
def make_tokenizer_args(args):
return SimpleNamespace(
rank=0,
make_vocab_size_divisible_by=128,
tensor_model_parallel_size=1,
padded_vocab_size=None,
vocab_size=args.vocab_size,
vocab_file=args.vocab_file,
merge_file=args.merge_file,
vocab_extra_ids=0,
tokenizer_type=args.tokenizer_type,
tokenizer_model=args.tokenizer_model,
metadata_path=args.tokenizer_metadata,
special_tokens=args.tokenizer_special_tokens,
tokenizer_sentencepiece_legacy=args.tokenizer_sentencepiece_legacy,
tokenizer_hf_no_use_fast=args.tokenizer_hf_no_use_fast,
tokenizer_hf_no_include_special_tokens=args.tokenizer_hf_no_include_special_tokens,
trust_remote_code=args.trust_remote_code,
tiktoken_pattern=args.tiktoken_pattern,
tiktoken_num_special_tokens=args.tiktoken_num_special_tokens,
null_tokenizer_eod_id=args.null_tokenizer_eod_id,
null_tokenizer_pad_id=args.null_tokenizer_pad_id,
tokenizer_prompt_format=None,
image_tag_type=None,
force_system_message=False,
sft_tokenizer_prompt_format=None,
)
def add_megatron_to_path(megatron_dir):
megatron_dir = str(Path(megatron_dir).resolve())
if megatron_dir not in sys.path:
sys.path.insert(0, megatron_dir)
def build_megatron_tokenizer(args):
add_megatron_to_path(args.megatron_dir)
from megatron.core.tokenizers.utils.build_tokenizer import build_tokenizer
return build_tokenizer(make_tokenizer_args(args))
def init_worker(args):
global _TOKENIZER, _APPEND_EOD
_APPEND_EOD = args.append_eod
_TOKENIZER = build_megatron_tokenizer(args)
if _APPEND_EOD and _TOKENIZER.eod is None:
raise ValueError("Tokenizer has no EOD/EOS token, but --append-eod is enabled.")
def encode_text(text):
if not isinstance(text, str):
return None
text = text.strip()
if not text:
return None
token_ids = _TOKENIZER.tokenize(text)
if not token_ids:
return None
sentence_lens = [len(token_ids)]
if _APPEND_EOD:
token_ids.append(_TOKENIZER.eod)
sentence_lens[-1] += 1
return token_ids, sentence_lens
def output_paths(output_prefix, text_key):
prefix = Path(output_prefix)
return (
Path(str(prefix) + f"_{text_key}_document.bin"),
Path(str(prefix) + f"_{text_key}_document.idx"),
)
def remove_partial_outputs(output_prefix, text_key):
bin_file, idx_file = output_paths(output_prefix, text_key)
bin_file.unlink(missing_ok=True)
idx_file.unlink(missing_ok=True)
def convert_one_parquet(args_tuple):
parquet_path, args = args_tuple
parquet_path = Path(parquet_path)
stem = parquet_path.name.replace(".zstd.parquet", "").replace(".parquet", "")
output_prefix = Path(args.output_dir) / f"{args.output_prefix_prefix}_{stem}"
bin_file, idx_file = output_paths(output_prefix, args.text_key)
if not args.overwrite and bin_file.exists() and idx_file.exists():
return f"[SKIP] {parquet_path.name}: existing bin/idx"
remove_partial_outputs(output_prefix, args.text_key)
add_megatron_to_path(args.megatron_dir)
from megatron.core.datasets import indexed_dataset
tokenizer = build_megatron_tokenizer(args)
dtype = indexed_dataset.DType.optimal_dtype(tokenizer.vocab_size)
builder = indexed_dataset.IndexedDatasetBuilder(str(bin_file), dtype=dtype)
start_time = time.time()
rows = 0
docs = 0
tokens = 0
def consume_encoded(encoded):
nonlocal docs, tokens
if encoded is None:
return
token_ids, sentence_lens = encoded
builder.add_document(token_ids, sentence_lens)
docs += 1
tokens += len(token_ids)
if args.log_interval and docs % args.log_interval == 0:
elapsed = max(time.time() - start_time, 1e-6)
print(
f"[{parquet_path.name}] docs={docs} "
f"tokens={tokens} docs/s={docs / elapsed:.2f}",
flush=True,
)
pf = pq.ParquetFile(parquet_path)
if args.workers_per_shard == 1:
init_worker(args)
for batch in pf.iter_batches(columns=[args.text_key], batch_size=args.batch_size):
texts = batch.column(0).to_pylist()
rows += len(texts)
for text in texts:
consume_encoded(encode_text(text))
else:
with Pool(processes=args.workers_per_shard, initializer=init_worker, initargs=(args,)) as pool:
for batch in pf.iter_batches(columns=[args.text_key], batch_size=args.batch_size):
texts = batch.column(0).to_pylist()
rows += len(texts)
for encoded in pool.imap(encode_text, texts, chunksize=args.chunksize):
consume_encoded(encoded)
builder.finalize(str(idx_file))
elapsed = max(time.time() - start_time, 1e-6)
return (
f"[OK] {parquet_path.name}: rows={rows}, docs={docs}, tokens={tokens}, "
f"elapsed={elapsed:.1f}s, docs/s={docs / elapsed:.2f}, output_prefix={output_prefix}"
)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--input-dir", required=True)
parser.add_argument("--output-dir", required=True)
parser.add_argument("--tmp-dir", default=None, help="Deprecated; kept for CLI compatibility.")
parser.add_argument("--megatron-dir", default="/apps/model_training/Megatron-LM")
parser.add_argument("--tokenizer-type", default="HuggingFaceTokenizer")
parser.add_argument("--tokenizer-model", required=True)
parser.add_argument("--tokenizer-metadata", default=None)
parser.add_argument("--tokenizer-special-tokens", nargs="*", default=None)
parser.add_argument("--tokenizer-sentencepiece-legacy", action="store_true")
parser.add_argument("--tokenizer-hf-no-use-fast", action="store_true")
parser.add_argument("--tokenizer-hf-no-include-special-tokens", action="store_true")
parser.add_argument("--trust-remote-code", action="store_true")
parser.add_argument("--vocab-file", default=None)
parser.add_argument("--merge-file", default=None)
parser.add_argument("--vocab-size", type=int, default=None)
parser.add_argument("--tiktoken-pattern", default=None)
parser.add_argument("--tiktoken-num-special-tokens", type=int, default=1000)
parser.add_argument("--null-tokenizer-eod-id", type=int, default=None)
parser.add_argument("--null-tokenizer-pad-id", type=int, default=-1)
parser.add_argument("--text-key", default="text")
parser.add_argument("--output-prefix-prefix", default="phase1")
parser.add_argument("--num-shards", type=int, default=1, help="Parallel parquet files.")
parser.add_argument("--workers-per-shard", type=int, default=max((os.cpu_count() or 8) // 2, 1))
parser.add_argument("--batch-size", type=int, default=16384, help="Parquet record batch size.")
parser.add_argument("--chunksize", type=int, default=128, help="Tokenizer pool imap chunk size.")
parser.add_argument("--log-interval", type=int, default=10000)
parser.add_argument("--start", type=int, default=0)
parser.add_argument("--end", type=int, default=None)
parser.add_argument("--append-eod", action=argparse.BooleanOptionalAction, default=True)
parser.add_argument("--keep-jsonl", action="store_true", help="Deprecated; no JSONL is written.")
parser.add_argument("--overwrite", action="store_true")
return parser.parse_args()
def main():
args = parse_args()
if args.num_shards < 1:
raise ValueError("--num-shards must be >= 1")
if args.workers_per_shard < 1:
raise ValueError("--workers-per-shard must be >= 1")
if args.batch_size < 1:
raise ValueError("--batch-size must be >= 1")
if args.chunksize < 1:
raise ValueError("--chunksize must be >= 1")
files = sorted(Path(args.input_dir).glob("*.zstd.parquet"))
if not files:
files = sorted(Path(args.input_dir).glob("*.parquet"))
files = files[args.start:args.end]
print(f"Converting {len(files)} files")
print(f"Parallel parquet files: {args.num_shards}")
print(f"Tokenizer workers per parquet: {args.workers_per_shard}")
print(f"Total tokenizer workers: {args.num_shards * args.workers_per_shard}")
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
if args.tmp_dir:
Path(args.tmp_dir).mkdir(parents=True, exist_ok=True)
tasks = [(str(f), args) for f in files]
if args.num_shards == 1:
for task in tqdm(tasks):
print(convert_one_parquet(task), flush=True)
else:
with ProcessPoolExecutor(max_workers=args.num_shards) as ex:
futs = [ex.submit(convert_one_parquet, task) for task in tasks]
for fut in tqdm(as_completed(futs), total=len(futs)):
print(fut.result(), flush=True)
if __name__ == "__main__":
main()