java - Gradient descent for linear regression is not working -


i tried making program linear regression using gradient descent sample data. theta values not give best fit data. have normalized data.

public class onevariableregression {      public static void main(string[] args) {          double x1[] = {-1.605793084, -1.436762233, -1.267731382, -1.098700531, -0.92966968, -0.760638829, -0.591607978, -0.422577127, -0.253546276, -0.084515425, 0.084515425, 0.253546276, 0.422577127, 0.591607978, 0.760638829, 0.92966968, 1.098700531, 1.267731382, 1.436762233, 1.605793084};           double y[] = {0.3, 0.2, 0.24, 0.33, 0.35, 0.28, 0.61, 0.38, 0.38, 0.42, 0.51, 0.6, 0.55, 0.56, 0.53, 0.61, 0.65, 0.68, 0.74, 0.87};         double theta0 = 0.5;         double theta1 = 0.5;         double temp0;         double temp1;         double alpha = 1.5;         double m = x1.length;         system.out.println(m);         double derivative0 = 0;         double derivative1 = 0;         {                     (int = 0; < x1.length; i++) {             derivative0 = (derivative0 + (theta0 + (theta1 * x1[i]) - y[i])) * (1/m);             derivative1 = (derivative1 + (theta0 + (theta1 * x1[i]) - y[i])) * (1/m) * x1[i];         }           temp0 = theta0 - (alpha * derivative0);           temp1 = theta1 - (alpha * derivative1);           theta0 = temp0;           theta1 = temp1;           //system.out.println("derivative0 = " + derivative0);           //system.out.println("derivative1 = " + derivative1);         }         while (derivative0 > 0.0001 || derivative1 > 0.0001);         system.out.println();         system.out.println("theta 0 = " + theta0);         system.out.println("theta 1 = " + theta1);     } } 

yes, it's convex.

the derivative you're using comes squared error function, convex, hence accepts no local minimums other 1 global minimum. (in fact, type of problem can accepts closed-form solution called normal equation, it's not numerically tractable large problems, hence use of gradient descent)

and correct answer around theta0 = 0.4895 , theta1 = 0.1652, trivial check on statistical computing environment. (see bottom of answer if you're skeptical)

below point out mistakes in code, after fixing mistakes, you'll correct answer above within 4 decimals places.

problems implementation:

so right expect converge global minimum, have problems in implementation

each time recalculate derivative_i, forgot reset 0 (what doing accumulating derivative across iterations in do{}while()

you need in while loop

do {                                                                        derivative0 = 0;                                                          derivative1 = 0;     ... }   

next this

derivative0 = (derivative0 + (theta0 + (theta1 * x1[i]) - y[i])) * (1/m); derivative1 = (derivative1 + (theta0 + (theta1 * x1[i]) - y[i])) * (1/m) * x1[i]; 

the x1[i] factor should applied (theta0 + (theta1 * x1[i]) - y[i])) alone.

your attempt confusing, let's write in clearer manner below, lot closer mathematical equation (1/m)sum(y_hat_i - y_i)x_i:

// need fresh vars, don't accumulate derivatives across gradient descent iterations  derivative0 = 0;                                                       derivative1 = 0;  (int = 0; < m; i++) {                            derivative0 += (1/m) * (theta0 + (theta1 * x1[i]) - y[i]);               derivative1 += (1/m) * (theta0 + (theta1 * x1[i]) - y[i])*x1[i];     }   

that should close enough, however, find learning rate alpha tad big. when it's too big, gradient descent have trouble zeroing in no global optimum, hang around there, won't quite there.

double alpha = 0.5;                                                  

confirm results

run , compare answer statistics software

here's gist on github of .java file.

➜  ~ javac onevariableregression.java && java onevariableregression                                                                                                                            20.0  theta 0 = 0.48950064086914064 theta 1 = 0.16520139788757973 

i compared r

> x  [1] -1.60579308 -1.43676223 -1.26773138 -1.09870053 -0.92966968 -0.76063883  [7] -0.59160798 -0.42257713 -0.25354628 -0.08451543  0.08451543  0.25354628 [13]  0.42257713  0.59160798  0.76063883  0.92966968  1.09870053  1.26773138 [19]  1.43676223  1.60579308 > y  [1] 0.30 0.20 0.24 0.33 0.35 0.28 0.61 0.38 0.38 0.42 0.51 0.60 0.55 0.56 0.53 [16] 0.61 0.65 0.68 0.74 0.87 > lm(y ~ x)  call: lm(formula = y ~ x)  coefficients: (intercept)            x        0.4895       0.1652   

now code gives correct answer @ least 4 decimals.


Comments

Popular posts from this blog

Failed to execute goal org.apache.maven.plugins:maven-surefire-plugin:2.12:test (default-test) on project.Error occurred in starting fork -

windows - Debug iNetMgr.exe unhandle exception System.Management.Automation.CmdletInvocationException -

configurationsection - activeMq-5.13.3 setup configurations for wildfly 10.0.0 -