Escaping the MAB prison: fast KL-projection for linearly constrained continuous actions
I am tired of MAB. I want something fancy.
This post comes from a random thought I had on a so-and-so scenario with a linear constraint. Doesn't sound like a big deal. Linear constraints are mathematically nice. But when the constraint is tied directly to a massive budget in production, you sometimes just fail at fear management and want to take the coward's way out.
That's exactly what we did. We created discrete presets of actions that safely satisfied the constraint. The agent (or algorithm, or model, whatever you want to call this probabilistic decision engine) just picked the best one based on context.
If you're an RL practitioner, this probably sounds familiar. It's an old friend from 1952 called the multi-armed bandit (MAB). Staring down this old friend, a natural curiosity would be: is there a fancy way out? Could we earn the model a bit more freedom? Can we do things safely in the continuum?
1. Formulation of the constrained continuous action space
To grant the policy network full expressiveness, we have to let it output a raw, continuous probability distribution over a high-dimensional action space. But this output has to strictly adhere to the linear budget constraints.
Suppose we have strict resource limits. In matrix form, the realized distribution must satisfy . Here is our feature or cost matrix, and is the vector of target expected values.
The challenge is constructing a projection that maps the network's unconstrained output onto the hyperplane.
2. Think twice before you scream quadratic programming (QP)
When projecting a point onto a linear subspace, the standard move is to minimize Euclidean distance. This gives a quadratic programming (QP) problem.
This is the Pavlovian Response of anyone working with optimization. It’s very straightforward.
However, applying Euclidean projection to a probability distribution relies on a fundamentally flawed topological assumption: it treats the probability simplex as a flat Euclidean space.
The Euclidean metric is insensitive to the boundaries of the simplex. When a strict constraint hyperplane forces the unconstrained prior outside the feasible positive orthant, the Euclidean projection will naturally yield negative values. To restore validity (), the QP solver's active-set method forcibly projects the solution onto the boundary of the simplex. It drives smaller and long-tail probabilities to absolute zero (). In optimization literature, these are known as "corner solutions."
While acceptable in standard operations research, in Reinforcement Learning, gradient sparsity is bad bad news. The policy gradient estimator relies fundamentally on the log-probability of an action, . If the QP projection layer deterministically forces an action's probability to , it triggers the following downward spiral:
- Zero Probability: Some actions may never get sampled.
- Absence of Reward Signal: Because the action is never selected, the agent collects no environmental transitions for that specific state-action pair.
- Vanishing Gradients: Without a reward signal, and because disrupts the gradient calculation, no meaningful gradient flows backward through the network for that dimension.
The policy network becomes permanently blind to that region of the action space.
3. Information geometry and the analytical proof of the Esscher transform
To resolve the vanishing gradient problem inherent to Euclidean projections, we must analyze the problem on the statistical manifold of the exponential family. The natural geometric distance metric between probability distributions is not a straight line, but the Kullback-Leibler (KL) divergence.
Information geometry (Amari, 2016) demonstrates that the squared Euclidean distance is merely a localized, second-order Taylor expansion of the KL divergence, assuming an isotropic Fisher Information Matrix (). This approximation catastrophically breaks down near the boundaries of the probability simplex, which is precisely why QP solvers force probabilities to zero.
To preserve the topological structure of the prior distribution , we formulate the primal optimization problem as minimizing the KL divergence subject to our generalized linear constraints:
Proof of the primal solution
We construct the Lagrangian. We introduce a multiplier vector for the strict budget constraints and a scalar for the simplex constraint.
Taking the partial derivative with respect to each probability component and setting it to zero yields:
Solving for , we obtain:
We can extract the terms independent of into a normalization constant. Let us define the partition function . The resulting optimal distribution is the Generalized Esscher Transform, also known as exponential tilting:
Or, expressed in vectorized form:
This proof reveals a profound geometric property. Because the transformation relies strictly on an exponential scaling of the prior, any candidate action with a strictly positive prior probability () is mathematically guaranteed to retain a strictly positive projected probability (). The projection function never intersects the axes of the positive orthant. Zero-probability clipping is entirely eliminated, and policy gradients survive globally.
4. The dual proof: strict convexity and global convergence
While we have derived the analytical functional form of , we must still compute the unique Lagrange multiplier vector that ensures the tilted distribution exactly satisfies . We achieve this by optimizing the unconstrained Lagrangian dual function.
Let us define the log-partition function (the cumulant-generating function) . The dual objective function we wish to minimize is:
Proof of the gradient (first moment)
By the fundamental properties of the exponential family, the derivative of the log-partition function yields the expected value of the sufficient statistics. Let's prove this explicitly by taking the gradient of with respect to the -th multiplier :
Therefore, the full gradient of our dual objective represents the exact residual error between the current expected values and the target budgets:
To minimize , we must find the root where the gradient is zero, i.e., .
Proof of the Hessian and strict convexity (second moment)
To construct a Newton-Raphson solver, we require the Hessian matrix. Taking the second partial derivative of with respect to and :
Applying the quotient rule to reveals its derivative with respect to :
Since we established in the gradient proof that , this simplifies to:
Substituting this back into our Hessian equation:
Lifting this from element-wise to matrix form:
This is the exact definition of covariance. The Hessian of the dual objective is the Covariance Matrix of the constraint features under the tilted distribution.
The convergence theorem
Because a covariance matrix is inherently positive semi-definite (), and assuming the constraint features are linearly independent, the Hessian is strictly positive definite ().
This mathematically guarantees that the dual objective is strictly convex globally. Consequently, a multivariate Newton-Raphson optimization step:
...is mathematically guaranteed to converge monotonically to the unique global optimum from any arbitrary initial vector , achieving a quadratic rate of convergence in the neighborhood of the root.
Here is the algorithmic loop in Python:
def esscher_projection(prior_p, A, target_b, max_iters=10, tol=1e-7):
"""
prior_p: (N,) Raw prior probabilities from the policy network
A: (K, N) Constraint feature matrix
target_b: (K,) Target budgets/expected values
"""
lam = zeros(K) # Initialize K-dimensional Lagrangian multipliers
for _ in range(max_iters):
# 1. Exponential tilting
tilted_logits = log(prior_p) + (A.T @ lam)
p_star = softmax(tilted_logits)
# 2. Compute Gradient (Expected Values & Error)
ev = A @ p_star
error = ev - target_b
if norm(error) < tol:
break
# 3. Compute Hessian (Covariance Matrix)
cov = (A * p_star) @ A.T - outer(ev, ev)
# 4. Multi-dimensional Newton update
lam = lam - inv(cov) @ error
return p_star
The convergence guarantee from the last section already does most of the work. Quadratic convergence means a handful of iterations. The iteration count comes bounded for free. You get to the global minimum starting from any initial condition. You can now put Adam, RMSprop, SGD, and everything for shit objective functions back on the shelf. There is nothing to tune.
5. Making it faster if you really need to
A couple ways to push this further if you are one of those runtime freaks:
- JIT it. Wrap
esscher_projectionwith@jit— Numba, JAX, ortorch.compile, whichever your stack already uses. Tight numerical kernel, no Python overhead. Easiest path. - Drop to C++/CUDA. If JIT isn't enough, push the whole projection layer down to a raw extension. Nuclear option, but then you probably have to write the entire thing in cpp too, because a python wrapper can make things slow.
- FP32, not FP64. 32-bit floats hold 7 decimal digits of precision, more than any physical constraint cares about. Half the memory bandwidth.
- Single-pass covariance. Keep precomputed second-moment buffers around. Skip the redundant matmuls.
The projection ends up cheap enough to live inside a high-frequency on-policy training loop. Nobody notices.
6. System architecture and agent integration
In production, this is a standard contextual bandit with a strict mathematical projection layer bolted on.
The pipeline begins with the Context. The system observes the environment state. User features, session history, market state, whatever fits.
The decision maker is a standard policy network. It takes the context and outputs unconstrained logits. It dreams of absolute freedom and maximum reward. Out comes the raw prior distribution .
Before the environment receives the action, gets intercepted by the Esscher transform layer. Constrained-RL algorithms typically assume access to an Oracle — a procedure that returns the constraint-compliant projection of any proposed policy. Here, the Oracle is not assumed. It is computed. The Newton-Raphson solver wakes up. It computes the exact , tilts the distribution, and outputs the strictly compliant .
The system samples the final continuous action from and executes it.
Then the environment returns a reward. Gradients flow backward through the differentiable Oracle to update the Brain.
7. Practical considerations and architectural adaptations
A few practical approaches/considerations when you actually use this thing.
Dimensionality reduction. If the action space has tens of thousands of items, inverting that big a covariance matrix is infeasible. So we compress the raw actions into a lower-dimensional basis registry ( or so). The Brain outputs weights for these basis functions. The Oracle tilts the basis functions, not the raw items.
Discretization bias. Evaluating continuous functions on discrete lookup tables introduces numerical drift. We fix this by computing the effective empirical means of the basis functions against the actual discrete payload. Those calibrated values feed into the Newton solver. Adherence stays exact.
0 times anything is zero. The Esscher transform is multiplicative scaling (). It cannot project UP from zero. . If the Brain outputs an exact zero prior for an action, it stays zero after projection. Even when resurrecting that probability is necessary to satisfy the constraint. So no ReLU. Nothing that produces hard zeros. Route all raw network outputs through a softmax. Strict positivity before the Oracle sees them.
Gradient alignment. The projection may greatly change the direction of the gradient. We can potentially fix it by adding a KL penalty to the loss (projection_kl) that compares pre and post-projection distributions. The alignment loss forces the Brain to feel the pain of the projection. It learns to output distributions that naturally sit closer to the constraint hyperplane.
8. What can this be used for?
Come on, it’s a mechanism for continuous decision-making under strict linear constraint. It applies broadly.
The below is suggested by AI because I don’t really care, but I think these make sense.
In computational advertising, a policy network can dynamically bid to maximize click-through rate. The Esscher projection guarantees the expected cost per action (CPA) across the campaign matches the advertiser's budget exactly.
In quant finance, an agent can allocate portfolio weights to maximize predicted alpha. The projection layer keeps the expected market beta within risk compliance limits.
In LLM alignment (RLHF), the policy can tune a model to maximize a reward function. Multidimensional constraints keep metrics like toxicity or verbosity below safety thresholds in expectation.
References
- Amari, S. I. (2016). Information Geometry and Its Applications (Vol. 194). Springer.