java - Gradient descent for linear regression is not working -
i tried making program linear regression using gradient descent sample data. theta values not give best fit data. have normalized data.
public class onevariableregression { public static void main(string[] args) { double x1[] = {-1.605793084, -1.436762233, -1.267731382, -1.098700531, -0.92966968, -0.760638829, -0.591607978, -0.422577127, -0.253546276, -0.084515425, 0.084515425, 0.253546276, 0.422577127, 0.591607978, 0.760638829, 0.92966968, 1.098700531, 1.267731382, 1.436762233, 1.605793084}; double y[] = {0.3, 0.2, 0.24, 0.33, 0.35, 0.28, 0.61, 0.38, 0.38, 0.42, 0.51, 0.6, 0.55, 0.56, 0.53, 0.61, 0.65, 0.68, 0.74, 0.87}; double theta0 = 0.5; double theta1 = 0.5; double temp0; double temp1; double alpha = 1.5; double m = x1.length; system.out.println(m); double derivative0 = 0; double derivative1 = 0; { (int = 0; < x1.length; i++) { derivative0 = (derivative0 + (theta0 + (theta1 * x1[i]) - y[i])) * (1/m); derivative1 = (derivative1 + (theta0 + (theta1 * x1[i]) - y[i])) * (1/m) * x1[i]; } temp0 = theta0 - (alpha * derivative0); temp1 = theta1 - (alpha * derivative1); theta0 = temp0; theta1 = temp1; //system.out.println("derivative0 = " + derivative0); //system.out.println("derivative1 = " + derivative1); } while (derivative0 > 0.0001 || derivative1 > 0.0001); system.out.println(); system.out.println("theta 0 = " + theta0); system.out.println("theta 1 = " + theta1); } }
yes, it's convex.
the derivative you're using comes squared error function, convex, hence accepts no local minimums other 1 global minimum. (in fact, type of problem can accepts closed-form solution called normal equation, it's not numerically tractable large problems, hence use of gradient descent)
and correct answer around theta0 = 0.4895
, theta1 = 0.1652
, trivial check on statistical computing environment. (see bottom of answer if you're skeptical)
below point out mistakes in code, after fixing mistakes, you'll correct answer above within 4 decimals places.
problems implementation:
so right expect converge global minimum, have problems in implementation
each time recalculate derivative_i
, forgot reset 0 (what doing accumulating derivative across iterations in do{}while()
you need in while loop
do { derivative0 = 0; derivative1 = 0; ... }
next this
derivative0 = (derivative0 + (theta0 + (theta1 * x1[i]) - y[i])) * (1/m); derivative1 = (derivative1 + (theta0 + (theta1 * x1[i]) - y[i])) * (1/m) * x1[i];
the x1[i]
factor should applied (theta0 + (theta1 * x1[i]) - y[i]))
alone.
your attempt confusing, let's write in clearer manner below, lot closer mathematical equation (1/m)sum(y_hat_i - y_i)x_i
:
// need fresh vars, don't accumulate derivatives across gradient descent iterations derivative0 = 0; derivative1 = 0; (int = 0; < m; i++) { derivative0 += (1/m) * (theta0 + (theta1 * x1[i]) - y[i]); derivative1 += (1/m) * (theta0 + (theta1 * x1[i]) - y[i])*x1[i]; }
that should close enough, however, find learning rate alpha tad big. when it's too big, gradient descent have trouble zeroing in no global optimum, hang around there, won't quite there.
double alpha = 0.5;
confirm results
run , compare answer statistics software
here's gist on github of .java file.
➜ ~ javac onevariableregression.java && java onevariableregression 20.0 theta 0 = 0.48950064086914064 theta 1 = 0.16520139788757973
i compared r
> x [1] -1.60579308 -1.43676223 -1.26773138 -1.09870053 -0.92966968 -0.76063883 [7] -0.59160798 -0.42257713 -0.25354628 -0.08451543 0.08451543 0.25354628 [13] 0.42257713 0.59160798 0.76063883 0.92966968 1.09870053 1.26773138 [19] 1.43676223 1.60579308 > y [1] 0.30 0.20 0.24 0.33 0.35 0.28 0.61 0.38 0.38 0.42 0.51 0.60 0.55 0.56 0.53 [16] 0.61 0.65 0.68 0.74 0.87 > lm(y ~ x) call: lm(formula = y ~ x) coefficients: (intercept) x 0.4895 0.1652
now code gives correct answer @ least 4 decimals.
Comments
Post a Comment