Non classé

How to Write a Gradient Descent Algorithm in Python

Gradient descent algorithm is a first-order iterative optimization algorithm used to find the parameters of a given function and minimize the function.

In this tutorial, I will teach you the steps involved in a gradient descent algorithm and how to write a gradient descent algorithm using Python.

Table of Contents

You can skip to any specific section of this tutorial using the table of contents below.

· What is a Gradient Descent Algorithm?

· Steps involved in a Gradient Descent Algorithm

· Writing a Gradient Descent Algorithm in Python

What is a Gradient Descent Algorithm?

As discussed above, gradient descent algorithm is an optimization used to find the local minimum of a function. The minimum of a function is calculated by moving in the direction of steepest descent iteratively. Gradient descent algorithm is used commonly in cases where it is not possible to find the coefficients of a function using linear algebra.

Steps Involved in a Gradient Descent Algorithm

Let us now look at the standard procedure used to perform a gradient descent algorithm.

1. Set an initial value for the coefficients of the function. You can either set the initial value as zero or set it to any random number.

If y = f(x), set an initial value for 2.

2. Calculate the derivative of the given function.

Derivative refers to the slope of a given function. It is important to calculate the slope of the function to understand the direction the coefficient values must be moved to get the local minimum value of the function.

∆ = d/dx(y)

Note that the direction of the derivative is downhill.

3. Specify a rate at which you want to move the gradient

4. Specify a precision that controls how much the coefficients differ on each iteration.

5. Perform the iterations

The iteration should continue till the difference between xi+1 and x1 is less than the precision that we specified.

Writing a Gradient Descent Algorithm in Python

Let’s take the function f(x) = y = (x+3)2

Step 1: Initialize the value of x. In our example, take x = 2

x = 2 # The algorithm starts at x=2
r = 0.01 # Learning rate
p = 0.000001 # The precision at which the algorithm should be stopped
before_step_size = 1 #
i_max = 1000 # Maximum number of iterations
i = 0 #iteration counter
df = lambda x: 2*(x+3) #Gradient (First derivative)  of our function 

Step 2: Run the gradient descent algorithm in a loop

while before_step_size > p and i < i_max:
    x_old = x # Stores the current value of x in x_old
    x = x - r * df(x_old) # Gradient descent
    before_step_size = abs(x - x_old) #Difference between consecutive iterations
    i = i+1 #Iteration increasing with every loop
    print("Iteration",i,"\nX value is",x) 
print("The local minimum of the given function occurs at", x)

The above loop terminates when difference between the x and x_old is less than 0.000001 or when the total number of iterations exceeds 1000.


From the below output, we can see that the algorithm runs for 571 iterations. The local minimum for the function f(x) = y = (x+3)2 is at x = -2.999951128099859

On validating the value of the x for the first three iterations with the manual calculation new performed above, it is clear that our values match with the output obtained.

A snippet of the output is shown below.

Iteration 1 
X value is 1.9
Iteration 2 
X value is 1.8019999999999998
Iteration 3 
X value is 1.70596
Iteration 4 
X value is 1.6118408
Iteration 5 
X value is 1.519603984
Iteration 565
X value is -2.999944830027105
Iteration 566
X value is -2.999945933426563
Iteration 567
X value is -2.999947014758032
Iteration 568
X value is -2.9999480744628713
Iteration 569
X value is -2.999949112973614
Iteration 570
X value is -2.999950130714142
Iteration 571
X value is -2.999951128099859
The local minimum of the given function occurs at -2.999951128099859

This was a guest contribution by Nick McCullum, who teaches JavaScript and Python development on his website.

Leave a Reply

This site uses Akismet to reduce spam. Learn how your comment data is processed.