Skip to content

Commit c42f8bb

Browse files
committed
add train and predict interfaces to IExample.
1 parent 09d77e2 commit c42f8bb

25 files changed

+503
-87
lines changed

Diff for: test/TensorFlowNET.Examples/BasicEagerApi.cs

+21-2
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,9 @@ namespace TensorFlowNET.Examples
1111
/// </summary>
1212
public class BasicEagerApi : IExample
1313
{
14-
public int Priority => 100;
1514
public bool Enabled { get; set; } = false;
1615
public string Name => "Basic Eager";
17-
public bool ImportGraph { get; set; } = false;
16+
public bool IsImportingGraph { get; set; } = false;
1817

1918
private Tensor a, b, c, d;
2019

@@ -46,5 +45,25 @@ public bool Run()
4645
public void PrepareData()
4746
{
4847
}
48+
49+
public Graph ImportGraph()
50+
{
51+
throw new NotImplementedException();
52+
}
53+
54+
public Graph BuildGraph()
55+
{
56+
throw new NotImplementedException();
57+
}
58+
59+
public bool Predict()
60+
{
61+
throw new NotImplementedException();
62+
}
63+
64+
public bool Train()
65+
{
66+
throw new NotImplementedException();
67+
}
4968
}
5069
}

Diff for: test/TensorFlowNET.Examples/BasicModels/KMeansClustering.cs

+21-2
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,9 @@ namespace TensorFlowNET.Examples
1818
/// </summary>
1919
public class KMeansClustering : IExample
2020
{
21-
public int Priority => 8;
2221
public bool Enabled { get; set; } = true;
2322
public string Name => "K-means Clustering";
24-
public bool ImportGraph { get; set; } = true;
23+
public bool IsImportingGraph { get; set; } = true;
2524

2625
public int? train_size = null;
2726
public int validation_size = 5000;
@@ -127,5 +126,25 @@ public void PrepareData()
127126
string url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/kmeans.meta";
128127
Web.Download(url, "graph", "kmeans.meta");
129128
}
129+
130+
public Graph ImportGraph()
131+
{
132+
throw new NotImplementedException();
133+
}
134+
135+
public Graph BuildGraph()
136+
{
137+
throw new NotImplementedException();
138+
}
139+
140+
public bool Train()
141+
{
142+
throw new NotImplementedException();
143+
}
144+
145+
public bool Predict()
146+
{
147+
throw new NotImplementedException();
148+
}
130149
}
131150
}

Diff for: test/TensorFlowNET.Examples/BasicModels/LinearRegression.cs

+21-3
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,9 @@ namespace TensorFlowNET.Examples
1313
/// </summary>
1414
public class LinearRegression : IExample
1515
{
16-
public int Priority => 3;
1716
public bool Enabled { get; set; } = true;
1817
public string Name => "Linear Regression";
19-
public bool ImportGraph { get; set; } = false;
20-
18+
public bool IsImportingGraph { get; set; } = false;
2119

2220
public int training_epochs = 1000;
2321

@@ -113,5 +111,25 @@ public void PrepareData()
113111
2.827f, 3.465f, 1.65f, 2.904f, 2.42f, 2.94f, 1.3f);
114112
n_samples = train_X.shape[0];
115113
}
114+
115+
public Graph ImportGraph()
116+
{
117+
throw new NotImplementedException();
118+
}
119+
120+
public Graph BuildGraph()
121+
{
122+
throw new NotImplementedException();
123+
}
124+
125+
public bool Train()
126+
{
127+
throw new NotImplementedException();
128+
}
129+
130+
public bool Predict()
131+
{
132+
throw new NotImplementedException();
133+
}
116134
}
117135
}

Diff for: test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs

+21-2
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,9 @@ namespace TensorFlowNET.Examples
1818
/// </summary>
1919
public class LogisticRegression : IExample
2020
{
21-
public int Priority => 4;
2221
public bool Enabled { get; set; } = true;
2322
public string Name => "Logistic Regression";
24-
public bool ImportGraph { get; set; } = false;
23+
public bool IsImportingGraph { get; set; } = false;
2524

2625

2726
public int training_epochs = 10;
@@ -158,5 +157,25 @@ public void Predict()
158157
throw new ValueError("predict error, should be 90% accuracy");
159158
});
160159
}
160+
161+
public Graph ImportGraph()
162+
{
163+
throw new NotImplementedException();
164+
}
165+
166+
public Graph BuildGraph()
167+
{
168+
throw new NotImplementedException();
169+
}
170+
171+
public bool Train()
172+
{
173+
throw new NotImplementedException();
174+
}
175+
176+
bool IExample.Predict()
177+
{
178+
throw new NotImplementedException();
179+
}
161180
}
162181
}

Diff for: test/TensorFlowNET.Examples/BasicModels/NaiveBayesClassifier.cs

+22-3
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,9 @@ namespace TensorFlowNET.Examples
1313
/// </summary>
1414
public class NaiveBayesClassifier : IExample
1515
{
16-
public int Priority => 6;
1716
public bool Enabled { get; set; } = true;
1817
public string Name => "Naive Bayes Classifier";
19-
public bool ImportGraph { get; set; } = false;
18+
public bool IsImportingGraph { get; set; } = false;
2019

2120
public NDArray X, y;
2221
public Normal dist { get; set; }
@@ -96,7 +95,7 @@ public void fit(NDArray X, NDArray y)
9695
this.dist = dist;
9796
}
9897

99-
public Tensor predict (NDArray X)
98+
public Tensor predict(NDArray X)
10099
{
101100
if (dist == null)
102101
{
@@ -170,5 +169,25 @@ public void PrepareData()
170169
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2);
171170
#endregion
172171
}
172+
173+
public Graph ImportGraph()
174+
{
175+
throw new NotImplementedException();
176+
}
177+
178+
public Graph BuildGraph()
179+
{
180+
throw new NotImplementedException();
181+
}
182+
183+
public bool Train()
184+
{
185+
throw new NotImplementedException();
186+
}
187+
188+
public bool Predict()
189+
{
190+
throw new NotImplementedException();
191+
}
173192
}
174193
}

Diff for: test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs

+21-2
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,14 @@ namespace TensorFlowNET.Examples
1515
/// </summary>
1616
public class NearestNeighbor : IExample
1717
{
18-
public int Priority => 5;
1918
public bool Enabled { get; set; } = true;
2019
public string Name => "Nearest Neighbor";
2120
Datasets mnist;
2221
NDArray Xtr, Ytr, Xte, Yte;
2322
public int? TrainSize = null;
2423
public int ValidationSize = 5000;
2524
public int? TestSize = null;
26-
public bool ImportGraph { get; set; } = false;
25+
public bool IsImportingGraph { get; set; } = false;
2726

2827

2928
public bool Run()
@@ -76,5 +75,25 @@ public void PrepareData()
7675
(Xtr, Ytr) = mnist.train.next_batch(TrainSize==null ? 5000 : TrainSize.Value / 100); // 5000 for training (nn candidates)
7776
(Xte, Yte) = mnist.test.next_batch(TestSize==null ? 200 : TestSize.Value / 100); // 200 for testing
7877
}
78+
79+
public Graph ImportGraph()
80+
{
81+
throw new NotImplementedException();
82+
}
83+
84+
public Graph BuildGraph()
85+
{
86+
throw new NotImplementedException();
87+
}
88+
89+
public bool Train()
90+
{
91+
throw new NotImplementedException();
92+
}
93+
94+
public bool Predict()
95+
{
96+
throw new NotImplementedException();
97+
}
7998
}
8099
}

Diff for: test/TensorFlowNET.Examples/BasicModels/NeuralNetXor.cs

+23-4
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,9 @@ namespace TensorFlowNET.Examples
1414
/// </summary>
1515
public class NeuralNetXor : IExample
1616
{
17-
public int Priority => 10;
1817
public bool Enabled { get; set; } = true;
1918
public string Name => "NN XOR";
20-
public bool ImportGraph { get; set; } = false;
19+
public bool IsImportingGraph { get; set; } = false;
2120

2221
public int num_steps = 10000;
2322

@@ -54,7 +53,7 @@ public bool Run()
5453
{
5554
PrepareData();
5655
float loss_value = 0;
57-
if (ImportGraph)
56+
if (IsImportingGraph)
5857
loss_value = RunWithImportedGraph();
5958
else
6059
loss_value = RunWithBuiltGraph();
@@ -145,12 +144,32 @@ public void PrepareData()
145144
{0, 1 }
146145
};
147146

148-
if (ImportGraph)
147+
if (IsImportingGraph)
149148
{
150149
// download graph meta data
151150
string url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/xor.meta";
152151
Web.Download(url, "graph", "xor.meta");
153152
}
154153
}
154+
155+
public Graph ImportGraph()
156+
{
157+
throw new NotImplementedException();
158+
}
159+
160+
public Graph BuildGraph()
161+
{
162+
throw new NotImplementedException();
163+
}
164+
165+
public bool Train()
166+
{
167+
throw new NotImplementedException();
168+
}
169+
170+
public bool Predict()
171+
{
172+
throw new NotImplementedException();
173+
}
155174
}
156175
}

Diff for: test/TensorFlowNET.Examples/BasicOperations.cs

+21-3
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,8 @@ namespace TensorFlowNET.Examples
1414
public class BasicOperations : IExample
1515
{
1616
public bool Enabled { get; set; } = true;
17-
public int Priority => 2;
1817
public string Name => "Basic Operations";
19-
public bool ImportGraph { get; set; } = false;
20-
18+
public bool IsImportingGraph { get; set; } = false;
2119

2220
private Session sess;
2321

@@ -104,5 +102,25 @@ public bool Run()
104102
public void PrepareData()
105103
{
106104
}
105+
106+
public Graph ImportGraph()
107+
{
108+
throw new NotImplementedException();
109+
}
110+
111+
public Graph BuildGraph()
112+
{
113+
throw new NotImplementedException();
114+
}
115+
116+
public bool Train()
117+
{
118+
throw new NotImplementedException();
119+
}
120+
121+
public bool Predict()
122+
{
123+
throw new NotImplementedException();
124+
}
107125
}
108126
}

Diff for: test/TensorFlowNET.Examples/HelloWorld.cs

+21-3
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,9 @@ namespace TensorFlowNET.Examples
1212
/// </summary>
1313
public class HelloWorld : IExample
1414
{
15-
public int Priority => 1;
1615
public bool Enabled { get; set; } = true;
1716
public string Name => "Hello World";
18-
public bool ImportGraph { get; set; } = false;
19-
17+
public bool IsImportingGraph { get; set; } = false;
2018

2119
public bool Run()
2220
{
@@ -41,5 +39,25 @@ of the Constant op. */
4139
public void PrepareData()
4240
{
4341
}
42+
43+
public Graph ImportGraph()
44+
{
45+
throw new NotImplementedException();
46+
}
47+
48+
public Graph BuildGraph()
49+
{
50+
throw new NotImplementedException();
51+
}
52+
53+
public bool Train()
54+
{
55+
throw new NotImplementedException();
56+
}
57+
58+
public bool Predict()
59+
{
60+
throw new NotImplementedException();
61+
}
4462
}
4563
}

0 commit comments

Comments
 (0)