Try the demo here

I was feeling ambitious after completing the MNIST training from scratch in Rust. It was addictive seeing it work - I can see why people get into machine learning.

My next experiment was to evolve my codebase to be able to train a language model in WASM entirely. I wanted to be able to have a demo I could show off in public, hence running it in WASM, so it could be easily shared on this blog similar to the MNIST training demo.

This was hard and taught me a lot about language model training.

At a high level, language model training looks like this:

graph LR
  data[Data Fetching] --> dataset[(Dataset)]
  dataset --> tokenizer[Train Tokenizer]
  tokenizer --> tokenize[Tokenize Dataset]
  tokenize --> train[Train Model]
  train --> inference[Inference]
  inference --> eval[Evaluation]

The first issue was fetching training data in the browser - originally I tried scraping Wikipedia, this was slow, then I tried Project Gutenberg’s collection of books but ran into CORS issues, finally I was able to work around CORS and get a lot of high quality text with Project Gitenberg.

Once I was able to get UI fetching a dataset, then I had to actually implement the transformer training architecture. This was done collaboratively, with the help of Claude 4.5 Opus - I avoided any machine learning libraries and wanted to implement it from scratch with Claude. It would be better to implement it completely from scratch myself, but I didn’t have the time for that and I learned a lot doing it with Claude from scratch. Much more than just using a pre-built library like Candles.

I used WebGPU to do the training. Originally I implemented a basic training loop with a CPU, but was only able to get 2 GFLOPS which was far too slow. This is obvious but was a learning experience for me - ML training is a lot of independent matrix multiplication. GPUs are much faster at this!

The key with iterating on the WebGPU implementation was: 1) Efficiently load memory (weights and biases) into the GPU 2) Avoid interrupting the GPU, keep everything on the GPU running and don’t synchronize back to the CPU

This took a while and the majority of the difficulty of the project, but eventually worked well - I was able to get 50 GFLOPS in the browser on my M1 Macbook Air. This is still lower than the 1 - 2 TFLOPS theoretical limit, but much better than 2 GFLOPS on the CPU!

Side note - I’m used to writing things in Rust for performance reasons. Rust programs tend to just work better, since most other high level languages require a garbage collector and are thus less performant. This isn’t true for model training - since the critical part is the GPU shaders, the actual CPU code doesn’t really matter much (other than data loading and manipulation).

I also had Claude write me a subword/BPE tokenizer - originally the tokenizer training was incredibly slow, but I benchmarked and iterated to make it work within 1-2 minutes in the browser, both training the tokenizer on a 10 MB text signal as well as then tokenizing the training data.

Finally I had it working! I ran a training run overnight, loss went down and I played with the completion results and they were…awful. Only occasionally made any sense for a couple characters at most.

At this point I started to create an eval framework and ran it on my model. The results were low.

You can try it out here if you are curious - LLM Training Visualizer.

Need more compute!

I realized the core problem was my training compute was too small. I really was looking for GPT-2 level coherence. GPT-2 compute estimates vary, but let’s assume it took 10^19 FLOPS (source). 1 GFLOP is 10^9, so to train a GPT-2 level model at 50 GFLOPS would take 2x10^8 seconds, 6 years. Even if I was able to get 1 TFLOPS out of my macbook, it would take 100 days to train. Far too long for a browser demo.

So GPT-2 level quality is out of the question, at least on consumer macbook hardware for a fun demo.

Pivoting

To train my own GPT-2, I need to do it on an actual NVIDIA GPU. Doing the math - GPT-2 is roughly 10^19 FLOPS required. A NVIDIA A100 can do 93 TFLOPS, and an 8 node A100 on lambda labs can thus do 744 TFLOPS (7 × 10^14). That would take 4 hours to train, roughly 50$ on Lambda Labs. Very doable!

So this is the goal - improve the local performance of my rust training framework, and rent a node on Lambda Labs to train my own GPT-2 class model.

References