Skip to content

Commit d16eb96

Browse files
author
Chetanfs
committed
implemented PLA initalized by linear regression
1 parent 8ea5fd9 commit d16eb96

File tree

9 files changed

+154
-32
lines changed

9 files changed

+154
-32
lines changed

HW1 - PLA Implementation/.idea/workspace.xml

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

HW1 - PLA Implementation/src/PLA.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ public static void main(String[] args) {
8080
/estimGError.size());
8181
}
8282
System.out.println(counters);
83-
System.out.println(counters.stream().mapToInt(Integer::intValue).average().getAsDouble());
84-
System.out.println(propDiffs.stream().mapToDouble(Double::doubleValue).average().getAsDouble());
83+
System.out.println("Average iterations to convergence: " + counters.stream().mapToInt(Integer::intValue).average().getAsDouble());
84+
System.out.println("Average probability that convergence hypothesis and real function differ: " + propDiffs.stream().mapToDouble(Double::doubleValue).average().getAsDouble());
8585
}
8686

8787
public static Point genRanPoint() {

The Learning Problem HW2.docx

84.8 KB
Binary file not shown.

hw2Programs/.idea/workspace.xml

Lines changed: 63 additions & 20 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
57 Bytes
Binary file not shown.
3.97 KB
Binary file not shown.

hw2Programs/src/LinRegImplementation.java

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ public static void main(String[] args) {
3434
Matrix output = Matrix.constructWithCopy(realClasses);
3535
Matrix weights = inputs.inverse().times(output);
3636

37+
// How to estimate parameters without .inverse() pseudoinverse shortcut
38+
// Matrix inputsT = inputs.transpose();
39+
// Matrix weights = inputsT.times(inputs).inverse().times(inputsT).times(output);
40+
3741
// Estimating error in training/test set
3842
propIncorrectTrain.add(evaluate(inputs, realClasses, weights));
3943
Matrix inputTest = Matrix.constructWithCopy(testData);
@@ -47,10 +51,9 @@ public static void main(String[] args) {
4751
public static double[][] genRandomPoints(int n, int cols) {
4852
double[][] inputs = new double [n][cols+1];
4953
for (int i = 0; i < n; i++) {
50-
double[] newPoint = {Math.random() * 2 - 1, Math.random() * 2 - 1};
51-
inputs[i][0] = 1;
52-
inputs[i][1] = newPoint[0];
53-
inputs[i][2] = newPoint[1];
54+
inputs[i][0] = 1; // added so we can treat threshold as w_0
55+
inputs[i][1] = Math.random() * 2 - 1;
56+
inputs[i][2] = Math.random() * 2 - 1;
5457
}
5558
return inputs;
5659
}
@@ -70,9 +73,13 @@ public static double[] genRandomLine() {
7073
return new double [] {m, b};
7174
}
7275

73-
public static double evaluate(Matrix input, double[][] realClasses, Matrix weights) {
76+
public static double[] predict(Matrix input, Matrix weights) {
7477
double scores[][] = input.times(weights).getArray();
75-
double estimClass[] = Arrays.stream(scores).map(n -> n[0] > 0 ? 1.0 : 0.0).mapToDouble(Double::doubleValue).toArray();
78+
return Arrays.stream(scores).map(n -> n[0] > 0 ? 1.0 : 0.0).mapToDouble(Double::doubleValue).toArray();
79+
}
80+
81+
public static double evaluate(Matrix input, double[][] realClasses, Matrix weights) {
82+
double estimClass[] = predict(input, weights);
7683
int numIncorrect = 0;
7784
for (int i = 0; i < estimClass.length; i++)
7885
if (estimClass[i] != realClasses[i][0])

hw2Programs/src/PLA.java

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import Jama.Matrix;
2+
3+
import java.util.ArrayList;
4+
5+
/**
6+
* Created by cheta_000 on 5/22/2015.
7+
* Implements the PLA algorithm with weights initialized by linear regression
8+
*/
9+
10+
public class PLA {
11+
12+
public static void main(String[] args) {
13+
int N = 10;
14+
ArrayList<Double> errorRatesInitially = new ArrayList<>();
15+
ArrayList<Integer> iterations = new ArrayList<>();
16+
for (int j = 0; j < 1000; j++) {
17+
// Generate f(x)
18+
double[] realLine = LinRegImplementation.genRandomLine();
19+
20+
// Generate training data
21+
double[][] training = LinRegImplementation.genRandomPoints(N, 2);
22+
double[][] trainingClass = LinRegImplementation.getPointClasses(training, realLine);
23+
24+
// Initialize weights
25+
Matrix inputs = Matrix.constructWithCopy(training);
26+
Matrix output = Matrix.constructWithCopy(trainingClass);
27+
Matrix weights = inputs.inverse().times(output);
28+
29+
// System.out.println("Initial error: " + LinRegImplementation.evaluate(inputs, trainingClass, weights) + ", ");
30+
errorRatesInitially.add(LinRegImplementation.evaluate(inputs, trainingClass, weights));
31+
32+
// Get error and start iterations
33+
double[] error = getErrorArray(weights, inputs, output);
34+
int wrongIndex = findWrongIndex(error);
35+
int counter = 0;
36+
while(wrongIndex != -1) {
37+
counter++;
38+
weights.getArray()[0][0] += error[wrongIndex]*inputs.getArray()[wrongIndex][0];
39+
weights.getArray()[1][0] += error[wrongIndex]*inputs.getArray()[wrongIndex][1];
40+
weights.getArray()[2][0] += error[wrongIndex]*inputs.getArray()[wrongIndex][2];
41+
error = getErrorArray(weights, inputs, output);
42+
wrongIndex = findWrongIndex(error);
43+
}
44+
iterations.add(counter);
45+
// System.out.println("Iterations for PLA to perfect: " + counter);
46+
}
47+
System.out.println("Average initial error: " + errorRatesInitially.stream().mapToDouble(Double::doubleValue).average().getAsDouble());
48+
System.out.println("Average num of iterations: " + iterations.stream().mapToInt(Integer::intValue).average().getAsDouble());
49+
50+
}
51+
52+
public static double[] getErrorArray(Matrix weights, Matrix inputs, Matrix outputs) {
53+
double[][] realClass = outputs.getArray();
54+
double[] estimOutput = LinRegImplementation.predict(inputs, weights);
55+
double[] error = new double[realClass.length];
56+
for (int i = 0; i < realClass.length; i++) {
57+
error[i] = realClass[i][0] - estimOutput[i];
58+
}
59+
return error;
60+
}
61+
62+
public static int findWrongIndex(double[] error) {
63+
for (int i = 0; i < error.length; i++) {
64+
if (error[i] != 0) {
65+
return i;
66+
}
67+
}
68+
return -1;
69+
}
70+
71+
72+
}

~WRL2006.tmp

274 KB
Binary file not shown.

0 commit comments

Comments
 (0)