GradientDescent - Two Variables

 

 

Followings are the code that I wrote in Octave to creates all the plots shown in this page. You may copy these code and play with these codes. Change variables and try yourself until you get your own intuitive understanding.

 

 

< Code 1 >

 

function main

  

w1 = -2:.025:2;

w2 = -2:.025:2;

[X,Y] = meshgrid(w1,w2);

 

Z = (X .^2 + 0.5 .* Y .^2);

 

vx = 1.0;

vy = 1.0;

 

Lf = 0.75; % Learning Factor

 

hFig = figure(1,'Position',[300 300 600 500]);

  

contour(X,Y,Z,50);

hold on;

  

for i = 1:40

  

  scale = 0.1;

  [gx1,gx2,gy1,gy2] = GetGradientLineAt_2Var(X,Y,Z,vx,vy,scale)

  plot(vx,vy,'ro','MarkerFaceColor',[1 0 0]);

  

  if i > 1

    line([gx1 gx2],[gy1 gy2]);

    

    [gnx,gny] = GetGradientNextAt_2Var(X,Y,Z,vx,vy,Lf)

    plot(gnx,gny,'bo','MarkerFaceColor',[0 0 1]);

    vx = gnx;

    vy = gny;

  end

  

end;

 

hold off;

 

 

end

 

function [gnx,gny] = GetGradientNextAt_2Var(x,y,z,vx,vy,Lf)

  

  [ix,iy] = GetLowerMaxIndex_2Var(x,y,vx,vy);

  [sx,sy] = GetSlopAt_2Var(x,y,z,vx,vy);

 

  gnx = vx - Lf*sx;

  gny = vy - Lf*sy;

 

endfunction  

 

function [gx1,gx2,gy1,gy2] = GetGradientLineAt_2Var(x,y,z,vx,vy,scale)

  

  [ix,iy] = GetLowerMaxIndex_2Var(x,y,vx,vy);

  [sx,sy] = GetSlopAt_2Var(x,y,z,vx,vy);

 

  %scale = 0.1;

  gx1 = vx;

  gx2 = vx - scale*sx;

  gy1 = vy;

  gy2 = vy - scale*sy;

 

endfunction  

 

function [sx,sy] = GetSlopAt_2Var(x,y,z,vx,vy)

  

   [i,j] = GetLowerMaxIndex_2Var(x,y,vx,vy);

   

   x = x(1,:);

   y = y(:,1);

   dx = x(i+1)-x(i);

   dy = y(j+1)-y(j);

   dzx = z(j,i+1)-z(j,i);

   dzy = z(j+1,i)-z(j,i);

   

   sx = dzx / dx;

   sy = dzy / dy;

  

endfunction  

 

 

function [px,py,pz] = GetPoint3At_2Var(x,y,z,vx,vy)

  

   [i,j] = GetLowerMaxIndex_2Var(x,y,vx,vy);

   

   x = x(1,:);

   y = y(:,1);

   px = x(i);

   py = y(i);

   pz = z(j,i);

   

endfunction  

 

 

function [idx,idy] = GetLowerMaxIndex_2Var(x,y,vx,vy)

  

  idx = 1;

  idy = 1;

  xr = x(1,:);

  yr = y(:,1);

  yr = yr';

  

  for i = 1:length(xr)

    if xr(i) > vx

       idx = i-1;

       break;

    end;     

  end;

  

  for j = 1:length(yr)

    if yr(j) > vy

       idy = j-1;

       break;

    end;     

  end;

  

  %return [idx,idy];

  

endfunction