Skip to content

Commit d16eb96

Browse files
author
Chetanfs
committed
implemented PLA initalized by linear regression
1 parent 8ea5fd9 commit d16eb96

File tree

9 files changed

+154
-32
lines changed

9 files changed

+154
-32
lines changed

HW1 - PLA Implementation/.idea/workspace.xml

+4-4
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

HW1 - PLA Implementation/src/PLA.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ public static void main(String[] args) {
8080
/estimGError.size());
8181
}
8282
System.out.println(counters);
83-
System.out.println(counters.stream().mapToInt(Integer::intValue).average().getAsDouble());
84-
System.out.println(propDiffs.stream().mapToDouble(Double::doubleValue).average().getAsDouble());
83+
System.out.println("Average iterations to convergence: " + counters.stream().mapToInt(Integer::intValue).average().getAsDouble());
84+
System.out.println("Average probability that convergence hypothesis and real function differ: " + propDiffs.stream().mapToDouble(Double::doubleValue).average().getAsDouble());
8585
}
8686

8787
public static Point genRanPoint() {

The Learning Problem HW2.docx

84.8 KB
Binary file not shown.

hw2Programs/.idea/workspace.xml

+63-20
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Binary file not shown.
3.97 KB
Binary file not shown.

hw2Programs/src/LinRegImplementation.java

+13-6
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ public static void main(String[] args) {
3434
Matrix output = Matrix.constructWithCopy(realClasses);
3535
Matrix weights = inputs.inverse().times(output);
3636

37+
// How to estimate parameters without .inverse() pseudoinverse shortcut
38+
// Matrix inputsT = inputs.transpose();
39+
// Matrix weights = inputsT.times(inputs).inverse().times(inputsT).times(output);
40+
3741
// Estimating error in training/test set
3842
propIncorrectTrain.add(evaluate(inputs, realClasses, weights));
3943
Matrix inputTest = Matrix.constructWithCopy(testData);
@@ -47,10 +51,9 @@ public static void main(String[] args) {
4751
public static double[][] genRandomPoints(int n, int cols) {
4852
double[][] inputs = new double [n][cols+1];
4953
for (int i = 0; i < n; i++) {
50-
double[] newPoint = {Math.random() * 2 - 1, Math.random() * 2 - 1};
51-
inputs[i][0] = 1;
52-
inputs[i][1] = newPoint[0];
53-
inputs[i][2] = newPoint[1];
54+
inputs[i][0] = 1; // added so we can treat threshold as w_0
55+
inputs[i][1] = Math.random() * 2 - 1;
56+
inputs[i][2] = Math.random() * 2 - 1;
5457
}
5558
return inputs;
5659
}
@@ -70,9 +73,13 @@ public static double[] genRandomLine() {
7073
return new double [] {m, b};
7174
}
7275

73-
public static double evaluate(Matrix input, double[][] realClasses, Matrix weights) {
76+
public static double[] predict(Matrix input, Matrix weights) {
7477
double scores[][] = input.times(weights).getArray();
75-
double estimClass[] = Arrays.stream(scores).map(n -> n[0] > 0 ? 1.0 : 0.0).mapToDouble(Double::doubleValue).toArray();
78+
return Arrays.stream(scores).map(n -> n[0] > 0 ? 1.0 : 0.0).mapToDouble(Double::doubleValue).toArray();
79+
}
80+
81+
public static double evaluate(Matrix input, double[][] realClasses, Matrix weights) {
82+
double estimClass[] = predict(input, weights);
7683
int numIncorrect = 0;
7784
for (int i = 0; i < estimClass.length; i++)
7885
if (estimClass[i] != realClasses[i][0])

hw2Programs/src/PLA.java

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import Jama.Matrix;
2+
3+
import java.util.ArrayList;
4+
5+
/**
6+
* Created by cheta_000 on 5/22/2015.
7+
* Implements the PLA algorithm with weights initialized by linear regression
8+
*/
9+
10+
public class PLA {
11+
12+
public static void main(String[] args) {
13+
int N = 10;
14+
ArrayList<Double> errorRatesInitially = new ArrayList<>();
15+
ArrayList<Integer> iterations = new ArrayList<>();
16+
for (int j = 0; j < 1000; j++) {
17+
// Generate f(x)
18+
double[] realLine = LinRegImplementation.genRandomLine();
19+
20+
// Generate training data
21+
double[][] training = LinRegImplementation.genRandomPoints(N, 2);
22+
double[][] trainingClass = LinRegImplementation.getPointClasses(training, realLine);
23+
24+
// Initialize weights
25+
Matrix inputs = Matrix.constructWithCopy(training);
26+
Matrix output = Matrix.constructWithCopy(trainingClass);
27+
Matrix weights = inputs.inverse().times(output);
28+
29+
// System.out.println("Initial error: " + LinRegImplementation.evaluate(inputs, trainingClass, weights) + ", ");
30+
errorRatesInitially.add(LinRegImplementation.evaluate(inputs, trainingClass, weights));
31+
32+
// Get error and start iterations
33+
double[] error = getErrorArray(weights, inputs, output);
34+
int wrongIndex = findWrongIndex(error);
35+
int counter = 0;
36+
while(wrongIndex != -1) {
37+
counter++;
38+
weights.getArray()[0][0] += error[wrongIndex]*inputs.getArray()[wrongIndex][0];
39+
weights.getArray()[1][0] += error[wrongIndex]*inputs.getArray()[wrongIndex][1];
40+
weights.getArray()[2][0] += error[wrongIndex]*inputs.getArray()[wrongIndex][2];
41+
error = getErrorArray(weights, inputs, output);
42+
wrongIndex = findWrongIndex(error);
43+
}
44+
iterations.add(counter);
45+
// System.out.println("Iterations for PLA to perfect: " + counter);
46+
}
47+
System.out.println("Average initial error: " + errorRatesInitially.stream().mapToDouble(Double::doubleValue).average().getAsDouble());
48+
System.out.println("Average num of iterations: " + iterations.stream().mapToInt(Integer::intValue).average().getAsDouble());
49+
50+
}
51+
52+
public static double[] getErrorArray(Matrix weights, Matrix inputs, Matrix outputs) {
53+
double[][] realClass = outputs.getArray();
54+
double[] estimOutput = LinRegImplementation.predict(inputs, weights);
55+
double[] error = new double[realClass.length];
56+
for (int i = 0; i < realClass.length; i++) {
57+
error[i] = realClass[i][0] - estimOutput[i];
58+
}
59+
return error;
60+
}
61+
62+
public static int findWrongIndex(double[] error) {
63+
for (int i = 0; i < error.length; i++) {
64+
if (error[i] != 0) {
65+
return i;
66+
}
67+
}
68+
return -1;
69+
}
70+
71+
72+
}

~WRL2006.tmp

274 KB
Binary file not shown.

0 commit comments

Comments
 (0)