New Optimizer 🌹 Rose: low VRAM, easy to use, great results, Apache 2.0

Hello, World! I have finally publicly released a new PyTorch optimizer I've been researching and developing for the last couple of years. It's named "Rose" in memory of my mother, who loved to hear about my discoveries and progress with AI.

Without going into the technical details (which you can read about in the GitHub repo), here are some of its benefits:

- It's stateless, which means it uses less memory than even AdamW8bit. If it weren't for working memory, its memory use would be as low as plain vanilla SGD (without momentum).
- Fast convergence, low VRAM, and excellent generalization, along with overfitting resistance. Yeah, I know... sounds too good to be true. Try it for yourself and tell me what you think, I'd really love to hear everyone's experiences, good or bad.
- Apache 2.0 license

You can find the code and more information at: https://github.com/MatthewK78/Rose

Benchmarks can sometimes be misleading, which is why I haven't included any. For example, sometimes training loss is higher in Rose than in Adam but validation loss is lower in Rose. The actual output of the trained model is what really matters in the end, and even that can be subjective. I'd prefer to let the community decide.

Here's some quickstart help for getting it up and running in ostris/ai-toolkit.

Install with:

pip install git+https://github.com/MatthewK78/Rose


Add this alongside other optimizers in the toolkit/optimizer.py file:

elif lower_type.startswith("rose"):
from rose import Rose
print(f"Using Rose optimizer, lr: {learning_rate:.2e}")
optimizer = Rose(params, lr=learning_rate, **optimizer_params)


Here's a config file example:

optimizer: Rose
lr: 8e-4

lr_scheduler: cosine
lr_scheduler_params:
eta_min: 1e-4

# all are default settings except `wd_schedule`
optimizer_params:
weight_decay: 1e-4 # adamw-style decoupled weight decay
wd_schedule: true # helps when using wd + lr_scheduler
centralize: true # gradient centralization
stabilize: true # disable for more aggressive training
bf16_sr: true # bf16 stochastic rounding
compute_dtype: fp64 # use fp32 only if you really need it

max_grad_norm: 65504 # effectively disables gradient clipping
ema_config:
use_ema: false
timestep_type: weighted


It may also initially be helpful to assess what it's doing by setting sample_every to something low like 128 steps.

If you try it, please let me know your thoughts and share your results. 😊


https://redd.it/1sokmqw
@rStableDiffusion