[feat][training-script][add 3b fineweb-edu training
script][feat][tokenizer][add MythosTokenizer class with encode decode][improvement][deps][add transformers and datasets dependencies][docs][readme-training][add training section with run commands][improvement][pyproject][pin torch and add new deps]
This commit is contained in:
parent
97bc414977
commit
5ffb897dcf
4 changed files with 33 additions and 1 deletions
28
README.md
28
README.md
|
|
@ -106,6 +106,34 @@ print(f"Parameters: {total:,}")
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## Training
|
||||||
|
|
||||||
|
The training script for the 3B model on FineWeb-Edu is at [`training/3b_fine_web_edu.py`](training/3b_fine_web_edu.py).
|
||||||
|
|
||||||
|
**Single GPU:**
|
||||||
|
```bash
|
||||||
|
python training/3b_fine_web_edu.py
|
||||||
|
```
|
||||||
|
|
||||||
|
**Multi-GPU (auto-detects GPU count):**
|
||||||
|
```bash
|
||||||
|
torchrun --nproc_per_node=$(python -c "import torch; print(torch.cuda.device_count())") training/3b_fine_web_edu.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Key design choices:
|
||||||
|
|
||||||
|
| Feature | Detail |
|
||||||
|
|---|---|
|
||||||
|
| Optimizer | Muon for 2D weight matrices, AdamW for embeddings/norms |
|
||||||
|
| Dataset | `HuggingFaceFW/fineweb-edu` (`sample-10BT` by default, swap to `sample-100BT` or `default` for full run) |
|
||||||
|
| Tokenizer | `openai/gpt-oss-20b` via `MythosTokenizer` |
|
||||||
|
| Parallelism | PyTorch DDP via `torchrun`, sharded streaming dataset |
|
||||||
|
| Precision | bfloat16 on H100/A100, float16 + GradScaler on older GPUs |
|
||||||
|
| Schedule | Linear warmup (2000 steps) → cosine decay |
|
||||||
|
| Target | 30B tokens (~Chinchilla-adjusted for looped architecture) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## Documentation
|
## Documentation
|
||||||
|
|
||||||
| Page | Description |
|
| Page | Description |
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,9 @@ classifiers = [
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.10,<4.0"
|
python = ">=3.10,<4.0"
|
||||||
torch = "*"
|
torch = "2.11.0"
|
||||||
|
transformers = ">=4.40.0"
|
||||||
|
datasets = ">=2.18.0"
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry.group.lint.dependencies]
|
[tool.poetry.group.lint.dependencies]
|
||||||
|
|
|
||||||
|
|
@ -1,2 +1,4 @@
|
||||||
torch>=2.1.0
|
torch>=2.1.0
|
||||||
|
transformers>=4.40.0
|
||||||
|
datasets>=2.18.0
|
||||||
pytest>=7.0.0
|
pytest>=7.0.0
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue