A Mathematical Explanation of Gradient Descent

Example of a close-to-optimal Gradient descent for a quadratic function. Credit: Dylan

Introduction

Although I wrote an article recently explaining Gradient Descent, I feel like a mathematical explanation of how it works would be beneficial not only to understand how it works, but also how you derive the functions. 

Explanation

First, we need a way to calculate how far off the model is. For our example, we will use the L2 Loss Function, 


Let N represent the total amount of data points we have, i the iteration we are on, yᵢ represent the expected output, and ŷ represent the output predicted by the model. 

Let's look more into ŷ. Since ŷ is the predicted value, this is the function that needs to have their parameters changed. Our goal is to change these parameters enough to minimize the loss of ŷ. For our example, let's set ŷ to be a linear line. 

The values in "a" and "b" are our parameters. These will be the values we will have to change within f(x) to change the output of f(x). Our goal is to minimize the value f(x) (which is ŷ) outputs relative to yᵢ. We can compute the rate of change of the Loss Function to the rate of change of ŷ. Since there are two parameters in ŷ, we would compute the rate of change of the Loss Function with respect to the rate of change of a, and compute the rate of change of the Loss Function with respect to the rate of change of b. 
The next thing to tackle is updating the model parameters. Although we calculated the rate of change, this doesn't mean anything until we subtract our value from the initial a and b value by a small amount, defined as the "learning rate." This helps prevent models from making too large steps perhaps overshooting,  or maybe even oscillating near the global minima (which is the value that is the minimum that would "solve" the equation), but never actually get it. This means we would not get the most optimal solution. 

Difference between large vs small learning rate. Credit: Mishra on Medium

Let 𝛼 represent the learning rate. Now, we just repeat this process several times (or, in other words, for several epochs). We could make N the "batch size" instead, allowing for the model to fit for a smaller amount of data points, and allowing the parameters to update quicker, effectively allowing the parameters to move closer and closer towards the global minima instead of "shooting a shot in the dark" and hoping it works. 

Pros and Cons

Pros

1. Widely Used Algorithm in ML/AI. 
2. Can be used for a wide variety of datasets and different types of values (vector databases, images, etc.). 
3. Can be scaled to larger and larger datasets (the equation stays the same, and you can still find the global minima).
4. Can be scaled to more and more parameters. 

Gradient Descent for a complex function. Credit: Ahmed Fawzy Gad (Digital Ocean). 

Cons

1. It can be very time consuming. Depending on the learning rate, the model can/will learn very slowly. 
2. Cost. Like mentioned above, since you move the parameters by a very small amount slowly, it'll take lots of runs for it to come to the optimal solution (global minima). This can cost a lot to run. 
3. Local minima. Instead of finding the global minima, gradient descent can be stuck inside of the local minima, because all other values, once you nudge the model, actually take it further away from the desired output. 

Important Note

Although there are a lot of cons, the variants of Gradient Descent (such as Adam Optimizer, which creates different learning rates for each parameter, Stochastic Gradient Descent, etc.) can cover some if not most of these cons. 

Works Referenced

These are works I used that were not already listed above. 

"What is learning rate in machine learning?" by Belcic and Stryker (IBM). 
"What is gradient descent?" by IBM.
"Gradient Descent Algorithm in Machine Learning" by GeeksForGeeks (last updated January 23, 2025).







Comments

Popular posts from this blog

What is a Multimodal LLM?

Top 3 Breakthroughs in Computer Vision in 2024