softmax is not enough (for sharp out-of-distribution)
Developments The authors show that for robust circuits to exist, they must generalize well to arbitrary and valid inputs. They show that as the the number of out-of-distribution tokens increase, the sharpness decresses and the attention coefficients decrease even if they were appropriately sharp for in-distribution examples. To solve this they propose and adaptive temperature that modifies \theta based on the entropy of the input coefficient, motivated by the fact that decreaseing the temperature must decrease the entropy.
Their ad-hoc algorithm is presented as an exxample below, using [JAX].
def adaptive_temperature_softmax(logits):
original_probs = jax.nn.softmax(logits)
poly_fit = jnp.array([-0.037, 0.481, -2.3, 4.917, -1.791]) # see Figure 5
entropy = jnp.sum(-original_probs * jnp.log(original_probs + 1e-9),
axis=-1, keepdims=True) # compute the Shannon entropy
beta = jnp.where( # beta = 1 / theta
entropy > 0.5, # don’t overcorrect low-entropy heads
jnp.maximum(jnp.polyval(poly_fit, entropy), 1.0), # never increase entropy
1.0)
return jax.nn.softmax(logits * beta)
An example collab note book is here