Skip to content

Commit f87e51f

Browse files
committed
Update binary_logistic_regression_tf.ipynb
1 parent 3fe8a23 commit f87e51f

File tree

1 file changed

+61
-16
lines changed

1 file changed

+61
-16
lines changed

binary_logistic_regression_tf.ipynb

+61-16
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
"from sklearn.datasets import make_classification\n",
4646
"from sklearn.model_selection import train_test_split\n",
4747
"import tensorflow as tf\n",
48-
"import tensorflow.keras as keras\n",
48+
"from tensorflow import keras\n",
4949
"\n",
5050
"\n",
5151
"print(\n",
@@ -98,12 +98,23 @@
9898
{
9999
"cell_type": "code",
100100
"execution_count": null,
101-
"id": "e99a940f",
101+
"id": "c795c19c-aa94-48b9-9385-390fd831a1fa",
102102
"metadata": {},
103103
"outputs": [],
104104
"source": [
105-
"def predict_class(y):\n",
106-
" y[y < 0.5], y[y >= 0.5] = 0, 1"
105+
"def predict_class_tf(y):\n",
106+
" y[y[:, 0] < 0.5, :], y[y[:, 0] >= 0.5, :] = 0, 1"
107+
]
108+
},
109+
{
110+
"cell_type": "code",
111+
"execution_count": null,
112+
"id": "8388f542-910c-408f-9df7-b9ef4ed6f02a",
113+
"metadata": {},
114+
"outputs": [],
115+
"source": [
116+
"def predict_class_my(y):\n",
117+
" y[:, y[0, :] < 0.5], y[:, y[0, :] >= 0.5] = 0, 1"
107118
]
108119
},
109120
{
@@ -116,7 +127,7 @@
116127
"def evaluate(y_true, y_pred):\n",
117128
" y_true_tmp = np.copy(y_true)\n",
118129
" y_pred_tmp = np.copy(y_pred)\n",
119-
" predict_class(y_pred_tmp)\n",
130+
" predict_class_my(y_pred_tmp)\n",
120131
"\n",
121132
" # https://www.tensorflow.org/api_docs/python/tf/math/confusion_matrix\n",
122133
" # The matrix columns represent the prediction labels.\n",
@@ -193,7 +204,7 @@
193204
")\n",
194205
"X_train, X_test, Y_train, Y_test = train_test_split(\n",
195206
" X, Y, train_size=train_size, random_state=None\n",
196-
")"
207+
")\n"
197208
]
198209
},
199210
{
@@ -221,6 +232,26 @@
221232
"print(\"X train dim\", X_train_our.shape, \"Y train dim\", Y_train_our.shape)"
222233
]
223234
},
235+
{
236+
"cell_type": "markdown",
237+
"id": "a3aedc53-87b0-4e3c-91a7-73be1df13b95",
238+
"metadata": {},
239+
"source": [
240+
"- prep for TF"
241+
]
242+
},
243+
{
244+
"cell_type": "code",
245+
"execution_count": null,
246+
"id": "3d4f778d-786a-4146-a662-4bbafa8577e0",
247+
"metadata": {},
248+
"outputs": [],
249+
"source": [
250+
"Y_train = Y_train[:, None] # newer TF needs (x,1) instead of (x) arrays\n",
251+
"Y_test = Y_test[:, None]\n",
252+
"X_train.shape, X_test.shape, Y_train.shape, Y_test.shape"
253+
]
254+
},
224255
{
225256
"cell_type": "markdown",
226257
"id": "f2b9c6e6",
@@ -439,13 +470,17 @@
439470
"source": [
440471
"# prediction after training finished\n",
441472
"Y_train_pred_tf = model.predict(X_train)\n",
442-
"predict_class(Y_train_pred_tf)\n",
473+
"predict_class_tf(Y_train_pred_tf)\n",
474+
"\n",
475+
"print(Y_train_pred_tf.shape, Y_train.shape)\n",
443476
"\n",
444477
"# confusion matrix\n",
445478
"cm_train_tf = tf.math.confusion_matrix(\n",
446-
" labels=Y_train, predictions=Y_train_pred_tf, num_classes=2\n",
479+
" labels=np.squeeze(Y_train), predictions=np.squeeze(Y_train_pred_tf), num_classes=2\n",
447480
")\n",
448481
"\n",
482+
"\n",
483+
"\n",
449484
"# get technical measures for the trained model on the training data set\n",
450485
"results_train_tf = model.evaluate(\n",
451486
" X_train, Y_train, batch_size=M_train, verbose=verbose\n",
@@ -552,7 +587,7 @@
552587
"print(\"\\nm_test\", M_test)\n",
553588
"# our implementation needs transposed data\n",
554589
"X_test_our = X_test.T\n",
555-
"Y_test_our = Y_test[None, :]\n",
590+
"Y_test_our = Y_test.T\n",
556591
"print(\"X test dim\", X_test_our.shape, \"Y test dim\", Y_test_our.shape)"
557592
]
558593
},
@@ -601,11 +636,11 @@
601636
"source": [
602637
"# prediction\n",
603638
"Y_test_pred_tf = model.predict(X_test)\n",
604-
"predict_class(Y_test_pred_tf)\n",
639+
"predict_class_tf(Y_test_pred_tf)\n",
605640
"\n",
606641
"# confusion matrix\n",
607642
"cm_test_tf = tf.math.confusion_matrix(\n",
608-
" labels=Y_test, predictions=Y_test_pred_tf, num_classes=2\n",
643+
" labels=np.squeeze(Y_test), predictions=np.squeeze(Y_test_pred_tf), num_classes=2\n",
609644
")\n",
610645
"\n",
611646
"# get technical measures for the trained model on the training data set\n",
@@ -668,6 +703,16 @@
668703
"print(\"TF confusion matrix in %\\n\", cm_test_tf / M_test * 100.0)"
669704
]
670705
},
706+
{
707+
"cell_type": "code",
708+
"execution_count": null,
709+
"id": "e37d308d-9be7-4394-a7fd-d08a1feb279b",
710+
"metadata": {},
711+
"outputs": [],
712+
"source": [
713+
"X_train.shape, Y_train.shape, X_test.shape, Y_test.shape"
714+
]
715+
},
671716
{
672717
"cell_type": "code",
673718
"execution_count": null,
@@ -684,8 +729,8 @@
684729
"\n",
685730
" plt.figure(figsize=(10, 10))\n",
686731
" plt.subplot(2, 1, 1)\n",
687-
" plt.plot(X_train[Y_train == 0, 0], X_train[Y_train == 0, 1], \"C0o\", ms=1)\n",
688-
" plt.plot(X_train[Y_train == 1, 0], X_train[Y_train == 1, 1], \"C1o\", ms=1)\n",
732+
" plt.plot(X_train[Y_train[:, 0] == 0, 0], X_train[Y_train[:, 0] == 0, 1], \"C0o\", ms=1)\n",
733+
" plt.plot(X_train[Y_train[:, 0] == 1, 0], X_train[Y_train[:, 0] == 1, 1], \"C1o\", ms=1)\n",
689734
" plt.contourf(f1, f2, tmp, cmap=\"RdBu_r\")\n",
690735
" plt.axis(\"equal\")\n",
691736
" plt.colorbar()\n",
@@ -694,8 +739,8 @@
694739
" plt.ylabel(\"feature 2\")\n",
695740
"\n",
696741
" plt.subplot(2, 1, 2)\n",
697-
" plt.plot(X_test[Y_test == 0, 0], X_test[Y_test == 0, 1], \"C0o\", ms=1)\n",
698-
" plt.plot(X_test[Y_test == 1, 0], X_test[Y_test == 1, 1], \"C1o\", ms=1)\n",
742+
" plt.plot(X_test[Y_test[:, 0] == 0, 0], X_test[Y_test[:, 0] == 0, 1], \"C0o\", ms=1)\n",
743+
" plt.plot(X_test[Y_test[:, 0] == 1, 0], X_test[Y_test[:, 0] == 1, 1], \"C1o\", ms=1)\n",
699744
" plt.contourf(f1, f2, tmp, cmap=\"RdBu_r\")\n",
700745
" plt.axis(\"equal\")\n",
701746
" plt.colorbar()\n",
@@ -735,7 +780,7 @@
735780
"name": "python",
736781
"nbconvert_exporter": "python",
737782
"pygments_lexer": "ipython3",
738-
"version": "3.10.6"
783+
"version": "3.12.3"
739784
}
740785
},
741786
"nbformat": 4,

0 commit comments

Comments
 (0)