I've been looking into the Monet/PEER sparse expert papers. I think there's a lot of potential in these ideas for interpretability-by-design.
Some of what I've done so far:
-
Quantization experiments: PEER can be losslessly distilled to int8 and distilled to int4 with only minor degradation. From int4, you can train PEER by having second int4 tensor that works as a gradient accumulation buffer (allowing for incremental steps between two int4 values), with some stochastic rounding on the accumulation steps. You can pack these int4 values into int8 tensors and hold a heck of a lot of them in VRAM. This enables efficient training of large PEER models on relatively limited VRAM. The minor loss of going from int8 to int4 representation is more than made up for by the increase in experts per gb of VRAM.
-
Monet to PEER: Monet trains better than PEER because it distributes the gradient less sparsely. There are some interpretability benefits to PEER, so I was curious if I could turn a trained Monet model into a PEER model. So far this seems to work quite well, although there is some distillation compute overhead required. Not yet sure how the overhead will scale.
-
PEER to Logic/Math functions: you can distill each PEER expert into a mix of logical statements and mathematical functions. To do this I've drawn on KAN 2.0 and Differentiable Logic Gates Why would I do this? Well, I'm hopeful it might be useful for interpretability, and also it seems potentially possible to make the distilled model run inference quite efficiently on CPU.
-
Attention sink: neither PEER nor Monet seem to handle the attention sink phenomena particularly gracefully. I'm going to look into this more, but one hacky workaround that works well but sacrifices some interpretability is to pair the PEER/Monet layers with a small FF MLP. I've been using d=128 and tieing the MLP weights across all model layers, with the intent that the capacity be restricted to minimize the loss to interpretability from shifting learning from the sparse experts to the MLP. I will also experiment with other things, including an even smaller MLP. I tried having a small pool of 4 - 16 always-on experts for each layer, and this also sorta worked but less well.
-
Metarouter: Mixture of Experts where each Expert is itself a pool of sparse experts. I see clear topic differentiation among these metaexperts, but also seem to suffer some small performance hit at the small scales I've been testing. I think this is potentially promising at larger scales (e.g. >100B params). I have experimented with forcing the model to choose 1 (or k) metaexpert(s) for an entire short sequence (1024 tokens), with the thought that this might further push topic-specialization (e.g. code vs medical vs creative writing). Seems to work ok, albeit with a small additional performance hit.
-
JumpReLU gates: using JumpReLU gates with PEER experts allows for graceful adaptivity in the number of experts active per layer per token. This so far appears to be a win both for performance and for interpretability.
-
Testing for superposition among experts: Logan pointed out that I should keep a watchful eye out for sparse experts that consistently activate together under certain circumstances, since they might be hiding extra functionality in their co-activation that wouldn't be measurable by testing each alone. This seems very important, and I'm open to ideas about how best to measure this, and potentially for how to reduce the impact of this effect.
Discuss