A Mathematical Explanation of Gradient Descent
Example of a close-to-optimal Gradient descent for a quadratic function. Credit: Dylan
Introduction
Although I wrote an article recently explaining Gradient Descent, I feel like a mathematical explanation of how it works would be beneficial not only to understand how it works, but also how you derive the functions.
Explanation
First, we need a way to calculate how far off the model is. For our example, we will use the L2 Loss Function,
Let N represent the total amount of data points we have, i the iteration we are on, yᵢ represent the expected output, and ŷ represent the output predicted by the model.
Let's look more into ŷ. Since ŷ is the predicted value, this is the function that needs to have their parameters changed. Our goal is to change these parameters enough to minimize the loss of ŷ. For our example, let's set ŷ to be a linear line.
The values in "a" and "b" are our parameters. These will be the values we will have to change within f(x) to change the output of f(x). Our goal is to minimize the value f(x) (which is ŷ) outputs relative to yᵢ. We can compute the rate of change of the Loss Function to the rate of change of ŷ. Since there are two parameters in ŷ, we would compute the rate of change of the Loss Function with respect to the rate of change of a, and compute the rate of change of the Loss Function with respect to the rate of change of b.
The next thing to tackle is updating the model parameters. Although we calculated the rate of change, this doesn't mean anything until we subtract our value from the initial a and b value by a small amount, defined as the "learning rate." This helps prevent models from making too large steps perhaps overshooting, or maybe even oscillating near the global minima (which is the value that is the minimum that would "solve" the equation), but never actually get it. This means we would not get the most optimal solution.
Pros and Cons
Pros
1. Widely Used Algorithm in ML/AI.
2. Can be used for a wide variety of datasets and different types of values (vector databases, images, etc.).
3. Can be scaled to larger and larger datasets (the equation stays the same, and you can still find the global minima).
4. Can be scaled to more and more parameters.
Cons
1. It can be very time consuming. Depending on the learning rate, the model can/will learn very slowly.
2. Cost. Like mentioned above, since you move the parameters by a very small amount slowly, it'll take lots of runs for it to come to the optimal solution (global minima). This can cost a lot to run.
3. Local minima. Instead of finding the global minima, gradient descent can be stuck inside of the local minima, because all other values, once you nudge the model, actually take it further away from the desired output.
Important Note
Although there are a lot of cons, the variants of Gradient Descent (such as Adam Optimizer, which creates different learning rates for each parameter, Stochastic Gradient Descent, etc.) can cover some if not most of these cons.
Works Referenced
These are works I used that were not already listed above.
"What is learning rate in machine learning?" by Belcic and Stryker (IBM).
"What is gradient descent?" by IBM.
"Gradient Descent Algorithm in Machine Learning" by GeeksForGeeks (last updated January 23, 2025).
Comments
Post a Comment