Hacker Newsnew | past | comments | ask | show | jobs | submitlogin

We have a few technical issues that we still need to address:

1) This entire fine-tuning run was done in JAX eager mode. I kept running out of memory (OOM) when trying to `jax.jit` the entire training step. Even gradual `jax.jit` didn't work.

2) The current version doesn't have gradient accumulation, and with a batch size of just 16, that’s not ideal. I'm working on implementing gradient accumulation next.

3) We still haven't found a good way to load large sequence-length data (like 32k sequence length). Currently, before sharding the training batch across GPUs, it ends up loading the entire batch onto a single GPU’s VRAM and causes OOM issues.



> I kept running out of memory (OOM) when trying to `jax.jit` the entire training step. Even gradual `jax.jit` didn't work.

Were you using activation checkpointing? https://jax.readthedocs.io/en/latest/_autosummary/jax.checkp... is very important for keeping memory usage reasonable when training large models.




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: