From 7ba690797b4949db9da0238cfcf1223052acae2f Mon Sep 17 00:00:00 2001 From: Kye Gomez Date: Mon, 20 Apr 2026 08:25:00 -0400 Subject: [PATCH] [improvement][loguru-logging][replace print with loguru in training script][feat][ckpt-logging][add checkpoint start and success log events][docs][readme-optimizer][remove muon optimizer reference][feat][train-requirements][add requirements txt to training folder] --- README.md | 2 +- training/3b_fine_web_edu.py | 18 ++++++++++-------- training/requirements.txt | 4 ++++ 3 files changed, 15 insertions(+), 9 deletions(-) create mode 100644 training/requirements.txt diff --git a/README.md b/README.md index b1d9984..afc5517 100644 --- a/README.md +++ b/README.md @@ -151,7 +151,7 @@ Key design choices: | Feature | Detail | |---|---| -| Optimizer | Muon for 2D weight matrices, AdamW for embeddings/norms | +| Optimizer | AdamW | | Dataset | `HuggingFaceFW/fineweb-edu` (`sample-10BT` by default, swap to `sample-100BT` or `default` for full run) | | Tokenizer | `openai/gpt-oss-20b` via `MythosTokenizer` | | Parallelism | PyTorch DDP via `torchrun`, sharded streaming dataset | diff --git a/training/3b_fine_web_edu.py b/training/3b_fine_web_edu.py index f9f20b8..92603f3 100644 --- a/training/3b_fine_web_edu.py +++ b/training/3b_fine_web_edu.py @@ -15,6 +15,7 @@ import time import torch import torch.nn as nn import torch.distributed as dist +from loguru import logger from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, ShardingStrategy, @@ -113,7 +114,7 @@ def main(): master = rank == 0 if master: - print( + logger.info( f"GPUs: {torch.cuda.device_count()} | World size: {world_size} | Device: {device}" ) @@ -124,7 +125,7 @@ def main(): vocab_size = encoding.vocab_size if master: - print(f"Tokenizer: gpt-oss-20b | Vocab size: {vocab_size:,}") + logger.info(f"Tokenizer: gpt-oss-20b | Vocab size: {vocab_size:,}") # ------------------------------------------------------------------ # Hyperparameters @@ -144,8 +145,8 @@ def main(): dataset_subset = "sample-10BT" # → sample-100BT or "default" for full run if master: - print( - f"seq_len={seq_len} | micro_batch={micro_batch} | grad_accum={grad_accum}\n" + logger.info( + f"seq_len={seq_len} | micro_batch={micro_batch} | grad_accum={grad_accum} | " f"global_batch_tokens={global_batch_tok:,} | total_steps={total_steps:,}" ) @@ -188,7 +189,7 @@ def main(): if master: n_params = sum(p.numel() for p in model.parameters()) - print(f"Parameters: {n_params:,} | AMP dtype: {amp_dtype}") + logger.info(f"Parameters: {n_params:,} | AMP dtype: {amp_dtype}") # ------------------------------------------------------------------ # Optimizer @@ -255,7 +256,7 @@ def main(): dt = time.perf_counter() - t0 tok_per_sec = global_batch_tok * log_every / dt tokens_seen = step * global_batch_tok - print( + logger.info( f"step {step:6d}/{total_steps} | loss {loss_accum:.4f} " f"| lr {cur_lr:.2e} | {tok_per_sec / 1e6:.2f}M tok/s " f"| {tokens_seen / 1e9:.1f}B tokens seen" @@ -264,6 +265,7 @@ def main(): if master and step % ckpt_every == 0: path = os.path.join(ckpt_dir, f"step_{step:07d}.pt") + logger.info(f"Saving checkpoint at step {step} → {path}") if ddp: with FSDP.state_dict_type( model, @@ -283,13 +285,13 @@ def main(): }, path, ) - print(f"Checkpoint saved → {path}") + logger.success(f"Checkpoint saved → {path}") if ddp: dist.destroy_process_group() if master: - print("Training complete.") + logger.success("Training complete.") if __name__ == "__main__": diff --git a/training/requirements.txt b/training/requirements.txt new file mode 100644 index 0000000..e3348c5 --- /dev/null +++ b/training/requirements.txt @@ -0,0 +1,4 @@ +torch>=2.11.0 +datasets>=3.6.0 +loguru>=0.7.3 +open-mythos \ No newline at end of file