
Machine Learning Optimization - Stochastic Descent, Variance Reduction Techniques, and the Bias-Variance Tradeoff
Introduction
The intent of this post is to explore some examples of variance reduction techniques in gradient descent optimization methods for machine learning, with some focus on their underlying statistical concepts. We will use the ADAM optimizer, a popular algorithm in neural network optimization and deep learning, and the less prolific SARAH optimizer (a variant of SVRG) as examples of algorithms in practice, with additional background on the methodologies from which these stochastic optimizers are derived including SGD, SAG, SAGA, and SVRG. In addition to exploring the techniques themselves, we will lend some focused attention to some bias-variance tradeoffs unique to variance reduction techniques, which are unrelated to the typical overfitting/overparameterization bias-variance tradeoff traditionally examined in machine learning. There are many resources available online for all of these topics, and I will provide some links to my favorites, but my aim is to cover the basic underlying concepts needed to build some more common and contemporary algorithms.
This post will cover variance reduction techniques with some technical statistical details, and assumes the reader has a basic understanding of machine learning principles and gradient descent (here's a 3Blue1Brown video if you need a refresher, but I also dedicated a section of this post to my own refresher). But, my intent is to make these topics accessible for an audience looking for some enhanced understanding of techniques in machine learning optimization. I will include links and additional details wherever I feel it's warranted, but I am open to any feedback you may have.
Information for this post is largely drawn from my final project for an optimization course I took in the Spring of 2024, for which the paper and presentation are avialable below.
Please refer to these documents for additional formal details, the post may reference portions of them for clarity.
Background
Machine Learning Optimization and Stochastic Optimizers
Optimization as a field has become critical to recent progress in machine learning as the field expands its applications into ever greater swathes of data. We are using machine learning to approach increasingly large datasets, and to do so effectively we need model optimization techniques that are both fast and guarantee convergence. Stochastic optimizers are a class of algorithms that have become prevalent in recent years due to their theoretical convergence properties with the desirable efficiency of stochastic processes. Being stochastic methods, they require special techniques to address and control variance in their computations to accomplish any advantageous theoretical convergence over other common methods.
When optimizing machine learning model parameters, first-order gradient descent algorithms utilizing stochastic gradients (individual gradient selected uniformly at random), or minibatch gradients (subsets of a full gradient whose elements selected uniformly at random), have become core staples in marchine learning frameworks due to their advantageous computational efficiency for applications to datasets of both large sample size and dimensionality. But, the randomness (variance) in the steps that is introduced to these gradient descent methods with stochasticity tends to slow down or inhibit convergence of the algorithms, so it is in our interest to implement techniques to curb this variance.
This is a visual example of the topics and algorithms we intend to cover. The graph depicts model parameters for a two parameter model (training to correlated noise) as they iteratively update toward the true parameters using different variants of stochastic gradient descent (SGD, yellow line) that each employ different techniques to reduce variance in the parameter updates. Variance, as we will discuss in this post, is the quantity describing the magnitude of the randomness in the algorithms' steps to the true solution. SGD clearly has the highest variance, and the other variance reduction techniques have obviously smaller variance in their steps.
This example is using some contrived data, but the settings for each algorithm are approximately analagous (we'll see what "approximately" means as the algorithms are different in methodology), and they are running for the same number of iterations. We can see that the variants of SGD can progress closer to the solution, each with qualitatively smaller variance than SGD, which encapsulates our general motivation for utilizing variance reduction techniques in stochastic gradient descent style algorithms. If we reduce variance in the steps, we can be more efficient about converging to the true solution. We will dive into some of these techniques and their underlying theory to better understand how we can balance variance reduction with convergence properties, algorithm computational efficiency, and unintended bias that can be a byproduct of these methodologies.
NOTE
Updated January 2025, this post has been split into some smaller posts, use the Table of Contents links on each page to navigate.
Table of Contents
1: Crash Course: Gradient Descent
2: Stochastic Gradient Descent
3: Variance Reduction Techniques
4: SARAH and ADAM Algorithms
Supplemental Resources
- Gradient descent, how neural networks learn | Chapter 2, Deep learning
- 3Blue1Brown video covering the basics of gradient descent in deep learning out of his Deep Learning series. Of course, I strongly recommend any of his excellent videos.
- Stochastic Gradient Descent, Clearly Explained!!!
- A fun StatQuest video about stochastic gradient descent.
- Adam Optimizer Explained in Detail | Deep Learning - Coding Lane
- A good simplified breakdown of ADAM optimizaiton
- Goh, "Why Momentum Really Works", Distill, 2017. http://doi.org/10.23915/distill.00006
- A really nice (deep dive) post about Momentum, a stochastic variance reduction technique
- Lipschitz Functions: Intro and Simple Explanation for Usefulness in Machine Learning
- Learn about Lipschitz continuous functions and their uses in machine learning
- L20.4 On the Mean Squared Error of an Estimator
- MIT Open Courseware video about MSE
- Notes on "Big O" notation
- "Big O" is standard notation to describe convergence for machine learning algorithms
Bibliography
Reproducibility
All code for diagrams and accompanying paper is available here.
Appendix
Some supplemental information for the topics we have covered.
Convexity
Convex Function
A function is convex if it satisfies the following.
- is convex if and only if for all , ,
- is strictly convex if and only if for all , ,
- is -strongly convex if and only if there exists a such that for all , ,
- If the objective function is convex, then any local minimum is a global minimum.
- If the objective function is strictly convex, then the minimum is unique.
- If the objective function is strongly convex, then the minimum is unique.
Return to
Lipschitz Continuity
For the following convex optimization problem,
- has a Lipschitz-continuous gradient if there exists such that
- If is twice differentiable such that for all and some , then has a Lipschitz continuous gradient with constant .
Return to
Consistency
Informally, a consistent estimator just means that if we use more data to estimate a parameter, it is guaranteed to converge to the true value of the parameter.
Gradient Calculation
For simplicity, we are using the least squares loss objective function .
Recall that our data has features and total samples, so our matrix has dimensions . Our labels (the "response" vector) have a corresponding sample for each data sample, so its dimensions are . Lastly, our parameters vector has a parameter for each of the features, so its dimensions are .
Full gradient calculation
Return to
Stochastic gradient calculation
Select a random index from .
Compute the stochastic gradient (approximation of full gradient).
Return to
Batch gradient calculation
We first select a subset of indices where . In this example we select , , .
We compute the batch gradient.
Which results in the average of the selected gradients.
Return to