|
45 | 45 | "from sklearn.datasets import make_classification\n",
|
46 | 46 | "from sklearn.model_selection import train_test_split\n",
|
47 | 47 | "import tensorflow as tf\n",
|
48 |
| - "import tensorflow.keras as keras\n", |
| 48 | + "from tensorflow import keras\n", |
49 | 49 | "\n",
|
50 | 50 | "\n",
|
51 | 51 | "print(\n",
|
|
98 | 98 | {
|
99 | 99 | "cell_type": "code",
|
100 | 100 | "execution_count": null,
|
101 |
| - "id": "e99a940f", |
| 101 | + "id": "c795c19c-aa94-48b9-9385-390fd831a1fa", |
102 | 102 | "metadata": {},
|
103 | 103 | "outputs": [],
|
104 | 104 | "source": [
|
105 |
| - "def predict_class(y):\n", |
106 |
| - " y[y < 0.5], y[y >= 0.5] = 0, 1" |
| 105 | + "def predict_class_tf(y):\n", |
| 106 | + " y[y[:, 0] < 0.5, :], y[y[:, 0] >= 0.5, :] = 0, 1" |
| 107 | + ] |
| 108 | + }, |
| 109 | + { |
| 110 | + "cell_type": "code", |
| 111 | + "execution_count": null, |
| 112 | + "id": "8388f542-910c-408f-9df7-b9ef4ed6f02a", |
| 113 | + "metadata": {}, |
| 114 | + "outputs": [], |
| 115 | + "source": [ |
| 116 | + "def predict_class_my(y):\n", |
| 117 | + " y[:, y[0, :] < 0.5], y[:, y[0, :] >= 0.5] = 0, 1" |
107 | 118 | ]
|
108 | 119 | },
|
109 | 120 | {
|
|
116 | 127 | "def evaluate(y_true, y_pred):\n",
|
117 | 128 | " y_true_tmp = np.copy(y_true)\n",
|
118 | 129 | " y_pred_tmp = np.copy(y_pred)\n",
|
119 |
| - " predict_class(y_pred_tmp)\n", |
| 130 | + " predict_class_my(y_pred_tmp)\n", |
120 | 131 | "\n",
|
121 | 132 | " # https://www.tensorflow.org/api_docs/python/tf/math/confusion_matrix\n",
|
122 | 133 | " # The matrix columns represent the prediction labels.\n",
|
|
193 | 204 | ")\n",
|
194 | 205 | "X_train, X_test, Y_train, Y_test = train_test_split(\n",
|
195 | 206 | " X, Y, train_size=train_size, random_state=None\n",
|
196 |
| - ")" |
| 207 | + ")\n" |
197 | 208 | ]
|
198 | 209 | },
|
199 | 210 | {
|
|
221 | 232 | "print(\"X train dim\", X_train_our.shape, \"Y train dim\", Y_train_our.shape)"
|
222 | 233 | ]
|
223 | 234 | },
|
| 235 | + { |
| 236 | + "cell_type": "markdown", |
| 237 | + "id": "a3aedc53-87b0-4e3c-91a7-73be1df13b95", |
| 238 | + "metadata": {}, |
| 239 | + "source": [ |
| 240 | + "- prep for TF" |
| 241 | + ] |
| 242 | + }, |
| 243 | + { |
| 244 | + "cell_type": "code", |
| 245 | + "execution_count": null, |
| 246 | + "id": "3d4f778d-786a-4146-a662-4bbafa8577e0", |
| 247 | + "metadata": {}, |
| 248 | + "outputs": [], |
| 249 | + "source": [ |
| 250 | + "Y_train = Y_train[:, None] # newer TF needs (x,1) instead of (x) arrays\n", |
| 251 | + "Y_test = Y_test[:, None]\n", |
| 252 | + "X_train.shape, X_test.shape, Y_train.shape, Y_test.shape" |
| 253 | + ] |
| 254 | + }, |
224 | 255 | {
|
225 | 256 | "cell_type": "markdown",
|
226 | 257 | "id": "f2b9c6e6",
|
|
439 | 470 | "source": [
|
440 | 471 | "# prediction after training finished\n",
|
441 | 472 | "Y_train_pred_tf = model.predict(X_train)\n",
|
442 |
| - "predict_class(Y_train_pred_tf)\n", |
| 473 | + "predict_class_tf(Y_train_pred_tf)\n", |
| 474 | + "\n", |
| 475 | + "print(Y_train_pred_tf.shape, Y_train.shape)\n", |
443 | 476 | "\n",
|
444 | 477 | "# confusion matrix\n",
|
445 | 478 | "cm_train_tf = tf.math.confusion_matrix(\n",
|
446 |
| - " labels=Y_train, predictions=Y_train_pred_tf, num_classes=2\n", |
| 479 | + " labels=np.squeeze(Y_train), predictions=np.squeeze(Y_train_pred_tf), num_classes=2\n", |
447 | 480 | ")\n",
|
448 | 481 | "\n",
|
| 482 | + "\n", |
| 483 | + "\n", |
449 | 484 | "# get technical measures for the trained model on the training data set\n",
|
450 | 485 | "results_train_tf = model.evaluate(\n",
|
451 | 486 | " X_train, Y_train, batch_size=M_train, verbose=verbose\n",
|
|
552 | 587 | "print(\"\\nm_test\", M_test)\n",
|
553 | 588 | "# our implementation needs transposed data\n",
|
554 | 589 | "X_test_our = X_test.T\n",
|
555 |
| - "Y_test_our = Y_test[None, :]\n", |
| 590 | + "Y_test_our = Y_test.T\n", |
556 | 591 | "print(\"X test dim\", X_test_our.shape, \"Y test dim\", Y_test_our.shape)"
|
557 | 592 | ]
|
558 | 593 | },
|
|
601 | 636 | "source": [
|
602 | 637 | "# prediction\n",
|
603 | 638 | "Y_test_pred_tf = model.predict(X_test)\n",
|
604 |
| - "predict_class(Y_test_pred_tf)\n", |
| 639 | + "predict_class_tf(Y_test_pred_tf)\n", |
605 | 640 | "\n",
|
606 | 641 | "# confusion matrix\n",
|
607 | 642 | "cm_test_tf = tf.math.confusion_matrix(\n",
|
608 |
| - " labels=Y_test, predictions=Y_test_pred_tf, num_classes=2\n", |
| 643 | + " labels=np.squeeze(Y_test), predictions=np.squeeze(Y_test_pred_tf), num_classes=2\n", |
609 | 644 | ")\n",
|
610 | 645 | "\n",
|
611 | 646 | "# get technical measures for the trained model on the training data set\n",
|
|
668 | 703 | "print(\"TF confusion matrix in %\\n\", cm_test_tf / M_test * 100.0)"
|
669 | 704 | ]
|
670 | 705 | },
|
| 706 | + { |
| 707 | + "cell_type": "code", |
| 708 | + "execution_count": null, |
| 709 | + "id": "e37d308d-9be7-4394-a7fd-d08a1feb279b", |
| 710 | + "metadata": {}, |
| 711 | + "outputs": [], |
| 712 | + "source": [ |
| 713 | + "X_train.shape, Y_train.shape, X_test.shape, Y_test.shape" |
| 714 | + ] |
| 715 | + }, |
671 | 716 | {
|
672 | 717 | "cell_type": "code",
|
673 | 718 | "execution_count": null,
|
|
684 | 729 | "\n",
|
685 | 730 | " plt.figure(figsize=(10, 10))\n",
|
686 | 731 | " plt.subplot(2, 1, 1)\n",
|
687 |
| - " plt.plot(X_train[Y_train == 0, 0], X_train[Y_train == 0, 1], \"C0o\", ms=1)\n", |
688 |
| - " plt.plot(X_train[Y_train == 1, 0], X_train[Y_train == 1, 1], \"C1o\", ms=1)\n", |
| 732 | + " plt.plot(X_train[Y_train[:, 0] == 0, 0], X_train[Y_train[:, 0] == 0, 1], \"C0o\", ms=1)\n", |
| 733 | + " plt.plot(X_train[Y_train[:, 0] == 1, 0], X_train[Y_train[:, 0] == 1, 1], \"C1o\", ms=1)\n", |
689 | 734 | " plt.contourf(f1, f2, tmp, cmap=\"RdBu_r\")\n",
|
690 | 735 | " plt.axis(\"equal\")\n",
|
691 | 736 | " plt.colorbar()\n",
|
|
694 | 739 | " plt.ylabel(\"feature 2\")\n",
|
695 | 740 | "\n",
|
696 | 741 | " plt.subplot(2, 1, 2)\n",
|
697 |
| - " plt.plot(X_test[Y_test == 0, 0], X_test[Y_test == 0, 1], \"C0o\", ms=1)\n", |
698 |
| - " plt.plot(X_test[Y_test == 1, 0], X_test[Y_test == 1, 1], \"C1o\", ms=1)\n", |
| 742 | + " plt.plot(X_test[Y_test[:, 0] == 0, 0], X_test[Y_test[:, 0] == 0, 1], \"C0o\", ms=1)\n", |
| 743 | + " plt.plot(X_test[Y_test[:, 0] == 1, 0], X_test[Y_test[:, 0] == 1, 1], \"C1o\", ms=1)\n", |
699 | 744 | " plt.contourf(f1, f2, tmp, cmap=\"RdBu_r\")\n",
|
700 | 745 | " plt.axis(\"equal\")\n",
|
701 | 746 | " plt.colorbar()\n",
|
|
735 | 780 | "name": "python",
|
736 | 781 | "nbconvert_exporter": "python",
|
737 | 782 | "pygments_lexer": "ipython3",
|
738 |
| - "version": "3.10.6" |
| 783 | + "version": "3.12.3" |
739 | 784 | }
|
740 | 785 | },
|
741 | 786 | "nbformat": 4,
|
|
0 commit comments