From 72d861074393d7880150674a3ec0f1856e9e7d7b Mon Sep 17 00:00:00 2001 From: oskar Date: Tue, 4 Nov 2025 23:06:08 +0100 Subject: [PATCH] my version --- work-sc.ipynb | 322 ++++++++++++++++++++++++++++---------------------- 1 file changed, 181 insertions(+), 141 deletions(-) diff --git a/work-sc.ipynb b/work-sc.ipynb index 95b123a..f67362f 100644 --- a/work-sc.ipynb +++ b/work-sc.ipynb @@ -5,19 +5,19 @@ "id": "initial_id", "metadata": { "ExecuteTime": { - "end_time": "2025-11-04T21:43:32.795850Z", - "start_time": "2025-11-04T21:43:32.794457Z" + "end_time": "2025-11-04T22:00:00.816816Z", + "start_time": "2025-11-04T22:00:00.813630Z" } }, "source": "import numpy as np", "outputs": [], - "execution_count": 45 + "execution_count": 84 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-04T21:43:32.811210Z", - "start_time": "2025-11-04T21:43:32.809638Z" + "end_time": "2025-11-04T22:00:00.830847Z", + "start_time": "2025-11-04T22:00:00.829329Z" } }, "cell_type": "code", @@ -32,13 +32,13 @@ ], "id": "48cafaf4b64967bb", "outputs": [], - "execution_count": 46 + "execution_count": 85 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-04T21:43:32.862226Z", - "start_time": "2025-11-04T21:43:32.860368Z" + "end_time": "2025-11-04T22:00:00.885348Z", + "start_time": "2025-11-04T22:00:00.880961Z" } }, "cell_type": "code", @@ -62,13 +62,13 @@ ], "id": "d13137630b41b756", "outputs": [], - "execution_count": 47 + "execution_count": 86 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-04T21:46:21.643740Z", - "start_time": "2025-11-04T21:46:21.639693Z" + "end_time": "2025-11-04T22:00:00.944688Z", + "start_time": "2025-11-04T22:00:00.939752Z" } }, "cell_type": "code", @@ -78,13 +78,13 @@ ], "id": "31f205147667dea6", "outputs": [], - "execution_count": 64 + "execution_count": 87 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-04T21:43:32.957461Z", - "start_time": "2025-11-04T21:43:32.955675Z" + "end_time": "2025-11-04T22:00:00.994063Z", + "start_time": "2025-11-04T22:00:00.990969Z" } }, "cell_type": "code", @@ -106,13 +106,13 @@ ], "id": "c1b960e7dcf09d91", "outputs": [], - "execution_count": 49 + "execution_count": 88 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-04T21:44:22.909895Z", - "start_time": "2025-11-04T21:44:22.906363Z" + "end_time": "2025-11-04T22:00:01.051837Z", + "start_time": "2025-11-04T22:00:01.046197Z" } }, "cell_type": "code", @@ -131,13 +131,13 @@ ], "id": "efae2e184daf2fce", "outputs": [], - "execution_count": 61 + "execution_count": 89 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-04T21:43:33.055558Z", - "start_time": "2025-11-04T21:43:33.053594Z" + "end_time": "2025-11-04T22:00:01.101365Z", + "start_time": "2025-11-04T22:00:01.097608Z" } }, "cell_type": "code", @@ -162,13 +162,13 @@ ], "id": "c3cd9e8f51dbe967", "outputs": [], - "execution_count": 51 + "execution_count": 90 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-04T21:43:33.103372Z", - "start_time": "2025-11-04T21:43:33.101510Z" + "end_time": "2025-11-04T22:00:01.147862Z", + "start_time": "2025-11-04T22:00:01.146127Z" } }, "cell_type": "code", @@ -191,13 +191,13 @@ ], "id": "121416e7bbab57bb", "outputs": [], - "execution_count": 52 + "execution_count": 91 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-04T21:43:33.176375Z", - "start_time": "2025-11-04T21:43:33.169411Z" + "end_time": "2025-11-04T22:00:01.200653Z", + "start_time": "2025-11-04T22:00:01.198951Z" } }, "cell_type": "code", @@ -205,9 +205,9 @@ "def single_layer_backward_propagation(dA_curr, W_curr, b_curr, Z_curr, A_prev, activation=\"relu\"):\n", " m = A_prev.shape[1]\n", "\n", - " if activation is \"relu\":\n", + " if activation == \"relu\":\n", " backward_activation_func = relu_backward\n", - " elif activation is \"sigmoid\":\n", + " elif activation == \"sigmoid\":\n", " backward_activation_func = sigmoid_backward\n", " else:\n", " raise Exception('Non-supported activation function')\n", @@ -221,13 +221,13 @@ ], "id": "92e4b87664f18a63", "outputs": [], - "execution_count": 53 + "execution_count": 92 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-04T21:43:33.243823Z", - "start_time": "2025-11-04T21:43:33.234283Z" + "end_time": "2025-11-04T22:00:01.259385Z", + "start_time": "2025-11-04T22:00:01.253050Z" } }, "cell_type": "code", @@ -260,13 +260,13 @@ ], "id": "2c8e4eed1846f003", "outputs": [], - "execution_count": 54 + "execution_count": 93 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-04T21:47:33.615104Z", - "start_time": "2025-11-04T21:47:33.610483Z" + "end_time": "2025-11-04T22:00:01.319868Z", + "start_time": "2025-11-04T22:00:01.312729Z" } }, "cell_type": "code", @@ -281,13 +281,13 @@ ], "id": "16320b953a183511", "outputs": [], - "execution_count": 66 + "execution_count": 94 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-04T21:47:37.239308Z", - "start_time": "2025-11-04T21:47:37.236527Z" + "end_time": "2025-11-04T22:00:01.380430Z", + "start_time": "2025-11-04T22:00:01.373966Z" } }, "cell_type": "code", @@ -326,13 +326,13 @@ ], "id": "fce33f70bba3898", "outputs": [], - "execution_count": 67 + "execution_count": 95 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-04T21:43:33.422252Z", - "start_time": "2025-11-04T21:43:33.417262Z" + "end_time": "2025-11-04T22:00:01.444163Z", + "start_time": "2025-11-04T22:00:01.436199Z" } }, "cell_type": "code", @@ -359,13 +359,13 @@ ], "id": "cccd73b5018799d4", "outputs": [], - "execution_count": 57 + "execution_count": 96 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-04T21:43:33.472509Z", - "start_time": "2025-11-04T21:43:33.470657Z" + "end_time": "2025-11-04T22:00:01.500700Z", + "start_time": "2025-11-04T22:00:01.497537Z" } }, "cell_type": "code", @@ -377,13 +377,13 @@ ], "id": "4f66ffa878f01c02", "outputs": [], - "execution_count": 58 + "execution_count": 97 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-04T21:43:33.520603Z", - "start_time": "2025-11-04T21:43:33.518562Z" + "end_time": "2025-11-04T22:00:01.560294Z", + "start_time": "2025-11-04T22:00:01.553505Z" } }, "cell_type": "code", @@ -393,13 +393,13 @@ ], "id": "bebe0ed00a2d514", "outputs": [], - "execution_count": 59 + "execution_count": 98 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-04T21:51:26.021417Z", - "start_time": "2025-11-04T21:51:23.520284Z" + "end_time": "2025-11-04T22:00:04.165839Z", + "start_time": "2025-11-04T22:00:01.614181Z" } }, "cell_type": "code", @@ -409,13 +409,13 @@ ], "id": "ce04892d496c5147", "outputs": [], - "execution_count": 77 + "execution_count": 99 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-04T21:51:27.733451Z", - "start_time": "2025-11-04T21:51:27.727264Z" + "end_time": "2025-11-04T22:00:11.428146Z", + "start_time": "2025-11-04T22:00:11.422370Z" } }, "cell_type": "code", @@ -435,83 +435,13 @@ ] } ], - "execution_count": 78 + "execution_count": 105 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-04T21:43:33.666607121Z", - "start_time": "2025-11-04T20:21:26.059140Z" - } - }, - "cell_type": "code", - "source": [ - "startA = np.random.randn(nn_architecture[0][\"input_dim\"],1) * 0.1\n", - "full_forward_propagation(startA, params, nn_architecture)" - ], - "id": "8b672c5fd5832cc", - "outputs": [ - { - "data": { - "text/plain": [ - "(array([[0.51608074]]),\n", - " {'A0': array([[-0.10166672],\n", - " [ 0.14706683]]),\n", - " 'Z1': array([[ 0.0203953 ],\n", - " [-0.22010647],\n", - " [-0.01614817],\n", - " [ 0.07300465]]),\n", - " 'A1': array([[0.0203953 ],\n", - " [0. ],\n", - " [0. ],\n", - " [0.07300465]]),\n", - " 'Z2': array([[-0.18085747],\n", - " [-0.01827604],\n", - " [-0.21683156],\n", - " [ 0.08504111],\n", - " [ 0.17066065],\n", - " [-0.04521306]]),\n", - " 'A2': array([[0. ],\n", - " [0. ],\n", - " [0. ],\n", - " [0.08504111],\n", - " [0.17066065],\n", - " [0. ]]),\n", - " 'Z3': array([[-0.17707529],\n", - " [ 0.0237745 ],\n", - " [-0.07487052],\n", - " [-0.02497606],\n", - " [ 0.12622027],\n", - " [ 0.02613133]]),\n", - " 'A3': array([[0. ],\n", - " [0.0237745 ],\n", - " [0. ],\n", - " [0. ],\n", - " [0.12622027],\n", - " [0.02613133]]),\n", - " 'Z4': array([[-0.09066425],\n", - " [ 0.05792425],\n", - " [ 0.07822296],\n", - " [ 0.07317913]]),\n", - " 'A4': array([[0. ],\n", - " [0.05792425],\n", - " [0.07822296],\n", - " [0.07317913]]),\n", - " 'Z5': array([[0.06434517]])})" - ] - }, - "execution_count": 24, - "metadata": {}, - "output_type": "execute_result" - } - ], - "execution_count": 24 - }, - { - "metadata": { - "ExecuteTime": { - "end_time": "2025-11-04T21:52:07.296371Z", - "start_time": "2025-11-04T21:52:01.384867Z" + "end_time": "2025-11-04T22:00:29.176357Z", + "start_time": "2025-11-04T22:00:23.282276Z" } }, "cell_type": "code", @@ -535,47 +465,157 @@ "output_type": "stream", "text": [ "/home/oskar/projects/nn-from-scratch/.venv/lib/python3.13/site-packages/keras/src/layers/core/dense.py:95: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.\n", - " super().__init__(activity_regularizer=activity_regularizer, **kwargs)\n", - "2025-11-04 22:52:01.409083: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: CUDA_ERROR_COMPAT_NOT_SUPPORTED_ON_DEVICE: forward compatibility was attempted on non supported HW\n", - "2025-11-04 22:52:01.409097: I external/local_xla/xla/stream_executor/cuda/cuda_diagnostics.cc:171] verbose logging is disabled. Rerun with verbose logging (usually --v=1 or --vmodule=cuda_diagnostics=1) to get more diagnostic output from this module\n", - "2025-11-04 22:52:01.409099: I external/local_xla/xla/stream_executor/cuda/cuda_diagnostics.cc:176] retrieving CUDA diagnostic information for host: solaria\n", - "2025-11-04 22:52:01.409101: I external/local_xla/xla/stream_executor/cuda/cuda_diagnostics.cc:183] hostname: solaria\n", - "2025-11-04 22:52:01.409176: I external/local_xla/xla/stream_executor/cuda/cuda_diagnostics.cc:190] libcuda reported version is: 580.95.5\n", - "2025-11-04 22:52:01.409184: I external/local_xla/xla/stream_executor/cuda/cuda_diagnostics.cc:194] kernel reported version is: 570.195.3\n", - "2025-11-04 22:52:01.409185: E external/local_xla/xla/stream_executor/cuda/cuda_diagnostics.cc:287] kernel version 570.195.3 does not match DSO version 580.95.5 -- cannot find working devices in this configuration\n" + " super().__init__(activity_regularizer=activity_regularizer, **kwargs)\n" ] } ], - "execution_count": 79 + "execution_count": 106 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-04T21:53:11.479872Z", - "start_time": "2025-11-04T21:53:11.455625Z" + "end_time": "2025-11-04T22:00:33.380478Z", + "start_time": "2025-11-04T22:00:33.309269Z" } }, "cell_type": "code", "source": [ - "Y_test_hat = model.predict_classes(X_test)\n", + "Y_test_prob = model.predict(X_test)\n", + "Y_test_hat = (Y_test_prob > 0.5).astype(int).ravel()\n", "acc_test = accuracy_score(y_test, Y_test_hat)\n", "print(\"Test set accuracy: {:.2f} - Goliath\".format(acc_test))" ], "id": "ef52bee9c93081d3", "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WARNING:tensorflow:6 out of the last 10 calls to .one_step_on_data_distributed at 0x7e21f476c900> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", + "\u001B[1m4/4\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m0s\u001B[0m 8ms/step \n", + "Test set accuracy: 0.99 - Goliath\n" + ] + } + ], + "execution_count": 107 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-11-04T22:03:33.972219Z", + "start_time": "2025-11-04T22:03:33.966407Z" + } + }, + "cell_type": "code", + "source": [ + "def make_plot(X, y, plot_name, file_name=None, XX=None, YY=None, preds=None, dark=False):\n", + " if (dark):\n", + " plt.style.use('dark_background')\n", + " else:\n", + " sns.set_style(\"whitegrid\")\n", + " plt.figure(figsize=(16,12))\n", + " axes = plt.gca()\n", + " axes.set(xlabel=\"$X_1$\", ylabel=\"$X_2$\")\n", + " plt.title(plot_name, fontsize=30)\n", + " plt.subplots_adjust(left=0.20)\n", + " plt.subplots_adjust(right=0.80)\n", + " if(XX is not None and YY is not None and preds is not None):\n", + " plt.contourf(XX, YY, preds.reshape(XX.shape), 25, alpha = 1, cmap=cm.Spectral)\n", + " plt.contour(XX, YY, preds.reshape(XX.shape), levels=[.5], cmap=\"Greys\", vmin=0, vmax=.6)\n", + " plt.scatter(X[:, 0], X[:, 1], c=y.ravel(), s=40, cmap=plt.cm.Spectral, edgecolors='black')\n", + " if(file_name):\n", + " plt.savefig(file_name)\n", + " plt.close()" + ], + "id": "9535365d1da72395", + "outputs": [], + "execution_count": 109 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-11-04T22:02:51.938430Z", + "start_time": "2025-11-04T22:02:51.934316Z" + } + }, + "cell_type": "code", + "source": [ + "# boundary of the graph\n", + "GRID_X_START = -1.5\n", + "GRID_X_END = 2.5\n", + "GRID_Y_START = -1.0\n", + "GRID_Y_END = 2\n", + "# output directory (the folder must be created on the drive)\n", + "OUTPUT_DIR = \"./binary_classification_vizualizations/\"\n", + "### Definition of grid boundaries\n", + "grid = np.mgrid[GRID_X_START:GRID_X_END:100j, GRID_X_START:GRID_Y_END:100j]\n", + "grid_2d = grid.reshape(2, -1).T\n", + "XX, YY = grid" + ], + "id": "b070f03d55981894", + "outputs": [], + "execution_count": 108 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-11-04T22:05:02.290039Z", + "start_time": "2025-11-04T22:05:02.042691Z" + } + }, + "cell_type": "code", + "source": [ + "def callback_keras_plot(epoch, logs):\n", + " plot_title = \"Keras Model - It: {:05}\".format(epoch)\n", + " file_name = \"keras_model_{:05}.png\".format(epoch)\n", + " file_path = os.path.join(OUTPUT_DIR, file_name)\n", + " prediction_probs = model.predict_proba(grid_2d, batch_size=32, verbose=0)\n", + " make_plot(X_test, y_test, plot_title, file_name=file_path, XX=XX, YY=YY, preds=prediction_probs)\n", + "\n", + "\n", + "# Adding callback functions that they will run in every epoch\n", + "testmodelcb = keras.callbacks.LambdaCallback(on_epoch_end=callback_keras_plot)\n", + "\n", + "# Building a model\n", + "model = Sequential()\n", + "model.add(Dense(25, input_dim=2, activation='relu'))\n", + "model.add(Dense(50, activation='relu'))\n", + "model.add(Dense(50, activation='relu'))\n", + "model.add(Dense(25, activation='relu'))\n", + "model.add(Dense(1, activation='sigmoid'))\n", + "\n", + "model.compile(loss='binary_crossentropy', optimizer=\"sgd\", metrics=['accuracy'])\n", + "\n", + "# Training\n", + "history = model.fit(X_train, y_train, epochs=200, verbose=0, callbacks=[testmodelcb])\n", + "rediction_probs = model.predict_proba(grid_2d, batch_size=32, verbose=0)\n", + "make_plot(X_test, y_test, \"Keras Model\", file_name=None, XX=XX, YY=YY, preds=prediction_probs)" + ], + "id": "6feab7da06e7a828", + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/oskar/projects/nn-from-scratch/.venv/lib/python3.13/site-packages/keras/src/layers/core/dense.py:95: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.\n", + " super().__init__(activity_regularizer=activity_regularizer, **kwargs)\n" + ] + }, { "ename": "AttributeError", - "evalue": "'Sequential' object has no attribute 'predict_classes'", + "evalue": "'Sequential' object has no attribute 'predict_proba'", "output_type": "error", "traceback": [ "\u001B[31m---------------------------------------------------------------------------\u001B[39m", "\u001B[31mAttributeError\u001B[39m Traceback (most recent call last)", - "\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[83]\u001B[39m\u001B[32m, line 1\u001B[39m\n\u001B[32m----> \u001B[39m\u001B[32m1\u001B[39m Y_test_hat = \u001B[43mmodel\u001B[49m\u001B[43m.\u001B[49m\u001B[43mpredict_classes\u001B[49m(X_test)\n\u001B[32m 2\u001B[39m acc_test = accuracy_score(y_test, Y_test_hat)\n\u001B[32m 3\u001B[39m \u001B[38;5;28mprint\u001B[39m(\u001B[33m\"\u001B[39m\u001B[33mTest set accuracy: \u001B[39m\u001B[38;5;132;01m{:.2f}\u001B[39;00m\u001B[33m - Goliath\u001B[39m\u001B[33m\"\u001B[39m.format(acc_test))\n", - "\u001B[31mAttributeError\u001B[39m: 'Sequential' object has no attribute 'predict_classes'" + "\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[110]\u001B[39m\u001B[32m, line 23\u001B[39m\n\u001B[32m 20\u001B[39m model.compile(loss=\u001B[33m'\u001B[39m\u001B[33mbinary_crossentropy\u001B[39m\u001B[33m'\u001B[39m, optimizer=\u001B[33m\"\u001B[39m\u001B[33msgd\u001B[39m\u001B[33m\"\u001B[39m, metrics=[\u001B[33m'\u001B[39m\u001B[33maccuracy\u001B[39m\u001B[33m'\u001B[39m])\n\u001B[32m 22\u001B[39m \u001B[38;5;66;03m# Training\u001B[39;00m\n\u001B[32m---> \u001B[39m\u001B[32m23\u001B[39m history = \u001B[43mmodel\u001B[49m\u001B[43m.\u001B[49m\u001B[43mfit\u001B[49m\u001B[43m(\u001B[49m\u001B[43mX_train\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43my_train\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mepochs\u001B[49m\u001B[43m=\u001B[49m\u001B[32;43m200\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mverbose\u001B[49m\u001B[43m=\u001B[49m\u001B[32;43m0\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcallbacks\u001B[49m\u001B[43m=\u001B[49m\u001B[43m[\u001B[49m\u001B[43mtestmodelcb\u001B[49m\u001B[43m]\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 24\u001B[39m rediction_probs = model.predict_proba(grid_2d, batch_size=\u001B[32m32\u001B[39m, verbose=\u001B[32m0\u001B[39m)\n\u001B[32m 25\u001B[39m make_plot(X_test, y_test, \u001B[33m\"\u001B[39m\u001B[33mKeras Model\u001B[39m\u001B[33m\"\u001B[39m, file_name=\u001B[38;5;28;01mNone\u001B[39;00m, XX=XX, YY=YY, preds=prediction_probs)\n", + "\u001B[36mFile \u001B[39m\u001B[32m~/projects/nn-from-scratch/.venv/lib/python3.13/site-packages/keras/src/utils/traceback_utils.py:122\u001B[39m, in \u001B[36mfilter_traceback..error_handler\u001B[39m\u001B[34m(*args, **kwargs)\u001B[39m\n\u001B[32m 119\u001B[39m filtered_tb = _process_traceback_frames(e.__traceback__)\n\u001B[32m 120\u001B[39m \u001B[38;5;66;03m# To get the full stack trace, call:\u001B[39;00m\n\u001B[32m 121\u001B[39m \u001B[38;5;66;03m# `keras.config.disable_traceback_filtering()`\u001B[39;00m\n\u001B[32m--> \u001B[39m\u001B[32m122\u001B[39m \u001B[38;5;28;01mraise\u001B[39;00m e.with_traceback(filtered_tb) \u001B[38;5;28;01mfrom\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m\n\u001B[32m 123\u001B[39m \u001B[38;5;28;01mfinally\u001B[39;00m:\n\u001B[32m 124\u001B[39m \u001B[38;5;28;01mdel\u001B[39;00m filtered_tb\n", + "\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[110]\u001B[39m\u001B[32m, line 5\u001B[39m, in \u001B[36mcallback_keras_plot\u001B[39m\u001B[34m(epoch, logs)\u001B[39m\n\u001B[32m 3\u001B[39m file_name = \u001B[33m\"\u001B[39m\u001B[33mkeras_model_\u001B[39m\u001B[38;5;132;01m{:05}\u001B[39;00m\u001B[33m.png\u001B[39m\u001B[33m\"\u001B[39m.format(epoch)\n\u001B[32m 4\u001B[39m file_path = os.path.join(OUTPUT_DIR, file_name)\n\u001B[32m----> \u001B[39m\u001B[32m5\u001B[39m prediction_probs = \u001B[43mmodel\u001B[49m\u001B[43m.\u001B[49m\u001B[43mpredict_proba\u001B[49m(grid_2d, batch_size=\u001B[32m32\u001B[39m, verbose=\u001B[32m0\u001B[39m)\n\u001B[32m 6\u001B[39m make_plot(X_test, y_test, plot_title, file_name=file_path, XX=XX, YY=YY, preds=prediction_probs)\n", + "\u001B[31mAttributeError\u001B[39m: 'Sequential' object has no attribute 'predict_proba'" ] } ], - "execution_count": 83 + "execution_count": 110 } ], "metadata": {