Safeguarded Learned Convex Optimization

Howard Heaton, Xiaohan Chen, Zhangyang Wang, and Wotao Yin



Many applications require repeatedly solving a certain type of optimization problem, each time with new but similar data. Data-driven algorithms can “learn to optimize” (L2O) with much fewer iterations and with a similar cost per iteration as general-purpose optimization algorithms.

L2O algorithms are often derived from general-purpose algorithms, but with the inclusion of (possibly many) tunable parameters. Exceptional performance has been demonstrated when the parameters are optimized for a particular distribution of data. Unfortunately, it is impossible to ensure all L2O algorithms always converge to a solution, especially on new unseen data.

However, we present a framework that uses L2O updates together with a safeguard to guarantee convergence for convex problems with proximal and/or gradient oracles. The safeguard is simple and computationally cheap to implement, and it should be activated only when the current L2O updates would perform poorly or appear to diverge. This approach yields the numerical benefits of employing machine learning methods to create rapid L2O algorithms while still rigorously guaranteeing convergence.

Our numerical examples demonstrate the efficacy of this approach for several existing and new L2O schemes.

For example, we compared

  • ISTA: an analytic method known as iteratie soft-thresholding algorithm for recovery sparse signals;

  • LISTA: similar to ISTA, but fixed to have 20 iterations with step sizes and certain matrices learned from instances taken from a distribution;

  • Safe-L2O LISTA: our safeguard framework that wraps LISTA and ISTA.

In the left plot below, ISTA and LISTA are compared on instances from the same (seen) distribution that was used to generate LISTA training instances. The plot shows that LISTA converges much faster.

In the right polot below, ISTA, LISTA, and Safe-L2O LISTA were compared on data from the similar but different (unseen) distribution from was used to train LISTA. LISTA (red curve), though reduces the relative error in the first 10 itearions, starts to diverge afterward. ISTA (dashed black curve) converges though slowly. Safe-L2O LISTA (blue curve) follows LISTA until it shows sign of divergence and then falls back to ISTA, which further decrease the relative error. The blue cuver has a solid segment and a dashed segment. In the solid blue segment, the choice between LISTA and ISTA is decided by a rule. After than, only ISTA is applied since LISTA has 20 iterations only. See our report for more details.



H. Heaton, X. Chen, Z. Wang, and W. Yin, Safeguarded learned convex optimization, arXiv:2003.01880, 2020.

« Back