From 81c65d751372cdd90357b1ca07ec6209d8f0c7b4 Mon Sep 17 00:00:00 2001 From: oskar Date: Wed, 5 Nov 2025 23:31:08 +0100 Subject: [PATCH] uh --- add-sc.ipynb | 796 ++++++++++++++++++++++++++++++++++----------------- 1 file changed, 528 insertions(+), 268 deletions(-) diff --git a/add-sc.ipynb b/add-sc.ipynb index 4439a4f..7f1d2eb 100644 --- a/add-sc.ipynb +++ b/add-sc.ipynb @@ -5,8 +5,8 @@ "id": "initial_id", "metadata": { "ExecuteTime": { - "end_time": "2025-11-05T14:14:33.168058Z", - "start_time": "2025-11-05T14:14:33.163995Z" + "end_time": "2025-11-05T16:57:58.696888Z", + "start_time": "2025-11-05T16:57:58.695703Z" } }, "source": [ @@ -15,13 +15,13 @@ "import numpy as np\n" ], "outputs": [], - "execution_count": 23 + "execution_count": 1 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-05T13:55:56.162948Z", - "start_time": "2025-11-05T13:55:56.157819Z" + "end_time": "2025-11-05T16:57:58.710359Z", + "start_time": "2025-11-05T16:57:58.708927Z" } }, "cell_type": "code", @@ -43,18 +43,18 @@ ], "id": "48cafaf4b64967bb", "outputs": [], - "execution_count": 5 + "execution_count": 2 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-05T13:55:58.151744Z", - "start_time": "2025-11-05T13:55:58.145779Z" + "end_time": "2025-11-05T16:57:58.767674Z", + "start_time": "2025-11-05T16:57:58.760718Z" } }, "cell_type": "code", "source": [ - "def init_layers(nn_architecture, seed = 99):\n", + "def init_layers(nn_architecture, seed=99):\n", " np.random.seed(seed)\n", " number_of_layers = len(nn_architecture)\n", " params_values = {}\n", @@ -73,13 +73,13 @@ ], "id": "d13137630b41b756", "outputs": [], - "execution_count": 6 + "execution_count": 3 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-05T13:56:00.540357Z", - "start_time": "2025-11-05T13:56:00.536032Z" + "end_time": "2025-11-05T16:57:58.824122Z", + "start_time": "2025-11-05T16:57:58.819526Z" } }, "cell_type": "code", @@ -89,41 +89,44 @@ ], "id": "31f205147667dea6", "outputs": [], - "execution_count": 7 + "execution_count": 4 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-05T13:56:04.702145Z", - "start_time": "2025-11-05T13:56:04.696109Z" + "end_time": "2025-11-05T16:57:58.876505Z", + "start_time": "2025-11-05T16:57:58.871388Z" } }, "cell_type": "code", "source": [ "def sigmoid(Z):\n", - " return 1/(1+np.exp(-Z))\n", + " return 1 / (1 + np.exp(-Z))\n", + "\n", "\n", "def relu(Z):\n", - " return np.maximum(0,Z)\n", + " return np.maximum(0, Z)\n", + "\n", "\n", "def sigmoid_backward(dA, Z):\n", " sig = sigmoid(Z)\n", " return dA * sig * (1 - sig)\n", "\n", + "\n", "def relu_backward(dA, Z):\n", - " dZ = np.array(dA, copy = True)\n", + " dZ = np.array(dA, copy=True)\n", " dZ[Z <= 0] = 0;\n", " return dZ;" ], "id": "c1b960e7dcf09d91", "outputs": [], - "execution_count": 8 + "execution_count": 5 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-05T13:56:09.635495Z", - "start_time": "2025-11-05T13:56:09.630354Z" + "end_time": "2025-11-05T16:57:58.924888Z", + "start_time": "2025-11-05T16:57:58.921980Z" } }, "cell_type": "code", @@ -142,13 +145,13 @@ ], "id": "efae2e184daf2fce", "outputs": [], - "execution_count": 9 + "execution_count": 6 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-05T13:56:11.655081Z", - "start_time": "2025-11-05T13:56:11.649147Z" + "end_time": "2025-11-05T16:57:58.981719Z", + "start_time": "2025-11-05T16:57:58.976016Z" } }, "cell_type": "code", @@ -173,13 +176,13 @@ ], "id": "c3cd9e8f51dbe967", "outputs": [], - "execution_count": 10 + "execution_count": 7 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-05T13:56:14.173107Z", - "start_time": "2025-11-05T13:56:14.167007Z" + "end_time": "2025-11-05T21:38:03.035821Z", + "start_time": "2025-11-05T21:38:03.030450Z" } }, "cell_type": "code", @@ -189,6 +192,7 @@ " cost = -1 / m * (np.dot(Y, np.log(Y_hat).T) + np.dot(1 - Y, np.log(1 - Y_hat).T))\n", " return np.squeeze(cost)\n", "\n", + "\n", "# an auxiliary function that converts probability into class\n", "def convert_prob_into_class(probs):\n", " probs_ = np.copy(probs)\n", @@ -196,19 +200,28 @@ " probs_[probs_ <= 0.5] = 0\n", " return probs_\n", "\n", + "\n", "def get_accuracy_value(Y_hat, Y):\n", " Y_hat_ = convert_prob_into_class(Y_hat)\n", - " return (Y_hat_ == Y).all(axis=0).mean()" + " return (Y_hat_ == Y).all(axis=0).mean()\n", + "\n", + "def get_accuracy_vector(Y_hat, Y):\n", + " diff = np.subtract(Y, Y_hat)\n", + " dot = np.dot(diff.T, diff)\n", + " return dot.trace() / dot.shape[0]\n", + "\n", + "# u = np.random.rand(2, 9)\n", + "# v = np.random.rand(2, 9)" ], "id": "121416e7bbab57bb", "outputs": [], - "execution_count": 11 + "execution_count": 258 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-05T13:56:16.868763Z", - "start_time": "2025-11-05T13:56:16.862696Z" + "end_time": "2025-11-05T16:57:59.095180Z", + "start_time": "2025-11-05T16:57:59.092006Z" } }, "cell_type": "code", @@ -232,21 +245,21 @@ ], "id": "92e4b87664f18a63", "outputs": [], - "execution_count": 12 + "execution_count": 9 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-05T13:56:20.146436Z", - "start_time": "2025-11-05T13:56:20.139340Z" + "end_time": "2025-11-05T17:33:20.050712Z", + "start_time": "2025-11-05T17:33:20.045249Z" } }, "cell_type": "code", "source": [ "def full_backward_propagation(Y_hat, Y, memory, params_values, nn_architecture):\n", " grads_values = {}\n", - " m = Y.shape[1]\n", - " Y = Y.reshape(Y_hat.shape)\n", + " m = Y.shape\n", + " # Y = Y.reshape(Y_hat.shape)\n", "\n", " dA_prev = - (np.divide(Y, Y_hat) - np.divide(1 - Y, 1 - Y_hat));\n", "\n", @@ -271,13 +284,13 @@ ], "id": "2c8e4eed1846f003", "outputs": [], - "execution_count": 13 + "execution_count": 64 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-05T13:56:23.516827Z", - "start_time": "2025-11-05T13:56:23.511647Z" + "end_time": "2025-11-05T16:57:59.200900Z", + "start_time": "2025-11-05T16:57:59.195743Z" } }, "cell_type": "code", @@ -291,13 +304,13 @@ ], "id": "16320b953a183511", "outputs": [], - "execution_count": 14 + "execution_count": 11 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-05T13:56:29.549070Z", - "start_time": "2025-11-05T13:56:29.542074Z" + "end_time": "2025-11-05T22:00:52.345331Z", + "start_time": "2025-11-05T22:00:52.339775Z" } }, "cell_type": "code", @@ -318,31 +331,40 @@ " # calculating metrics and saving them in history\n", " cost = get_cost_value(Y_hat, Y)\n", " cost_history.append(cost)\n", - " accuracy = get_accuracy_value(Y_hat, Y)\n", + "\n", + " accuracy = get_accuracy_vector(Y_hat, Y)\n", + "\n", " accuracy_history.append(accuracy)\n", "\n", + " # print(\"Y_hat.shape: {}, Y.shape: {}\".format(Y_hat.shape, Y.shape))\n", + "\n", " # step backward - calculating gradient\n", " grads_values = full_backward_propagation(Y_hat, Y, cashe, params_values, nn_architecture)\n", - " # updating model state\n", + "\n", + " if (i % 50000 == 0):\n", + " print(\"Learning rate: {}\".format(learning_rate))\n", + " learning_rate = learning_rate / 10.0\n", + "\n", " params_values = update(params_values, grads_values, nn_architecture, learning_rate)\n", "\n", - " if(i % 1000 == 0):\n", - " if(verbose):\n", + " if (i % 1000 == 0):\n", + " print(\"dW1 norm 2 grad: {}, accu: {}\".format(np.linalg.norm(grads_values[\"dW1\"]), accuracy))\n", + " if (verbose):\n", " print(\"Iteration: {:05} - cost: {:.5f} - accuracy: {:.5f}\".format(i, cost, accuracy))\n", - " if(callback is not None):\n", + " if (callback is not None):\n", " callback(i, params_values)\n", "\n", " return params_values" ], "id": "fce33f70bba3898", "outputs": [], - "execution_count": 15 + "execution_count": 306 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-05T13:56:34.145803Z", - "start_time": "2025-11-05T13:56:33.471955Z" + "end_time": "2025-11-05T16:58:00.043457Z", + "start_time": "2025-11-05T16:57:59.317515Z" } }, "cell_type": "code", @@ -359,6 +381,7 @@ "import matplotlib.pyplot as plt\n", "from matplotlib import cm\n", "from mpl_toolkits.mplot3d import Axes3D\n", + "\n", "sns.set_style(\"whitegrid\")\n", "\n", "import keras\n", @@ -375,20 +398,20 @@ "name": "stderr", "output_type": "stream", "text": [ - "2025-11-05 14:56:33.582332: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", - "2025-11-05 14:56:33.605140: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "2025-11-05 17:57:59.430264: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", + "2025-11-05 17:57:59.456029: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", "To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", - "2025-11-05 14:56:34.038968: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n" + "2025-11-05 17:57:59.930030: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n" ] } ], - "execution_count": 16 + "execution_count": 13 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-05T13:56:48.371160Z", - "start_time": "2025-11-05T13:56:48.367011Z" + "end_time": "2025-11-05T16:58:02.921068Z", + "start_time": "2025-11-05T16:58:02.919524Z" } }, "cell_type": "code", @@ -400,13 +423,13 @@ ], "id": "4f66ffa878f01c02", "outputs": [], - "execution_count": 17 + "execution_count": 18 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-05T15:23:10.125780Z", - "start_time": "2025-11-05T15:23:10.110063Z" + "end_time": "2025-11-05T21:46:39.352413Z", + "start_time": "2025-11-05T21:46:39.344853Z" } }, "cell_type": "code", @@ -414,28 +437,38 @@ "def encode_add(i: int) -> float:\n", " return float(i) / 10.0\n", "\n", + "\n", "def decode_add(i) -> int:\n", " return int(i[0] * 10 + i[1])\n", "\n", - "def add(a:float,b:float):\n", + "\n", + "def add(a: float, b: float):\n", " r = a * 10.0 + b * 10.0\n", "\n", " r0 = floor(r % 10)\n", " r1 = floor(r / 10)\n", - " return r1,r0\n", + " return r1, r0\n", "\n", - "def encode_to_vector(x:float,y:float):\n", - " i,j = add(x,y)\n", + "\n", + "def encode_to_vector(x: float, y: float):\n", + " i, j = add(x, y)\n", " vector = np.zeros(20)\n", " vector[i] = 1\n", - " vector[j+10] = 1\n", + " vector[j + 10] = 1\n", " return vector\n", "\n", + "\n", + "def decode_from_vector(vector):\n", + " i = np.argmax(vector[0:10]) + 1\n", + " j = np.argmax(vector[10:20]) - 10\n", + " return decode_add((i, j))\n", + "\n", + "\n", "# add(encode_add(2),encode_add(3))\n", "# encode_to_vector(encode_add(2),encode_add(3))\n", "\n", "def make_sums(\n", - " n_samples=100, *, shuffle=True, noise=None, random_state=None, factor=0.8\n", + " n_samples=100, *, shuffle=False, noise=None, random_state=None, factor=0.8\n", "):\n", " X = []\n", " y = []\n", @@ -444,11 +477,11 @@ " for j in np.linspace(0, 9, 10):\n", " i_int = int(i)\n", " j_int = int(j)\n", - " X.append([i_int, j_int])\n", - " y.append(encode_to_vector( encode_add(i_int), encode_add(j_int) ))\n", + " X.append([encode_add(i_int), encode_add(j_int)])\n", + " y.append(encode_to_vector(encode_add(i_int), encode_add(j_int)))\n", "\n", - " X = np.array(X).T # Shape: (2, 100)\n", - " y = np.array(y).T # Shape: (20, 100)\n", + " # X = np.array(X).T # Shape: (2, 100)\n", + " # y = np.array(y).T # Shape: (20, 100)\n", "\n", " if shuffle and random_state is not None:\n", " np.random.seed(random_state)\n", @@ -458,67 +491,91 @@ "\n", " return X, y\n", "\n", - "make_sums()\n", - "\n" + "# decode_from_vector(encode_to_vector(encode_add(9), encode_add(9)))\n" ], "id": "7ce930351bba500c", - "outputs": [ - { - "data": { - "text/plain": [ - "(array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2,\n", - " 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4,\n", - " 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6,\n", - " 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8,\n", - " 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],\n", - " [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,\n", - " 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3,\n", - " 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5,\n", - " 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7,\n", - " 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),\n", - " array([[1., 1., 1., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 1., 1., 1.],\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - " [0., 0., 0., ..., 0., 1., 0.],\n", - " [0., 0., 0., ..., 0., 0., 1.],\n", - " [0., 0., 0., ..., 0., 0., 0.]], shape=(20, 100)))" - ] - }, - "execution_count": 50, - "metadata": {}, - "output_type": "execute_result" - } - ], - "execution_count": 50 + "outputs": [], + "execution_count": 272 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-05T13:38:02.731937Z", - "start_time": "2025-11-05T13:38:02.725571Z" + "end_time": "2025-11-05T22:27:36.512542Z", + "start_time": "2025-11-05T22:27:36.510154Z" } }, "cell_type": "code", "source": [ - "X, y = ds.make_circles(n_samples = N_SAMPLES, noise=0.2, random_state=100)\n", + "X, y = make_sums()\n", + "# X, y = ds.make_circles(n_samples = N_SAMPLES, noise=0.2, random_state=100)\n", "# X, y = make_moons(n_samples = N_SAMPLES, noise=0.2, random_state=100)\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=TEST_SIZE, random_state=42)" ], "id": "bebe0ed00a2d514", "outputs": [], - "execution_count": 136 + "execution_count": 370 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-05T13:38:25.210667Z", - "start_time": "2025-11-05T13:38:02.745281Z" + "end_time": "2025-11-05T21:46:45.705860Z", + "start_time": "2025-11-05T21:46:45.702814Z" } }, "cell_type": "code", "source": [ - "params_values = train(np.transpose(X_train), np.transpose(y_train.reshape((y_train.shape[0], 1))), nn_architecture, 30000, 0.01, verbose=True)\n", + "X = np.transpose(X_train)\n", + "Y = np.transpose(y_train)\n", + "epochs = 30000\n", + "learning_rate = 0.01\n", + "verbose = True\n", + "callback = None\n", + "\n", + "params_values = init_layers(nn_architecture, 2)\n", + "# initiation of lists storing the history\n", + "# of metrics calculated during the learning process\n", + "cost_history = []\n", + "accuracy_history = []\n", + "\n", + "# performing calculations for subsequent iterations\n", + "# for i in range(epochs):\n", + "# step forward\n", + "Y_hat, cashe = full_forward_propagation(X, params_values, nn_architecture)\n", + "\n", + "# calculating metrics and saving them in history\n", + "cost = get_cost_value(Y_hat, Y)\n", + "cost_history.append(cost)\n", + "accuracy = get_accuracy_vector(Y_hat, Y)\n", + "accuracy_history.append(accuracy)\n", + "\n", + "print(\"Y_hat.shape: {}, Y.shape: {}\".format(Y_hat.shape, Y.shape))\n", + "\n", + "grads_values = full_backward_propagation(Y_hat, Y, cashe, params_values, nn_architecture)\n", + "\n", + "params_values = update(params_values, grads_values, nn_architecture, learning_rate)\n" + ], + "id": "f362f4b889723f85", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Y_hat.shape: (20, 90), Y.shape: (20, 90)\n" + ] + } + ], + "execution_count": 274 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-11-05T22:02:19.390360Z", + "start_time": "2025-11-05T22:01:41.944726Z" + } + }, + "cell_type": "code", + "source": [ + "params_values = train(np.transpose(X_train), np.transpose(y_train), nn_architecture, 210000, 0.0092, verbose=False)\n", "# params_values\n" ], "id": "ce04892d496c5147", @@ -527,54 +584,247 @@ "name": "stdout", "output_type": "stream", "text": [ - "Iteration: 00000 - cost: 0.69318 - accuracy: 0.50444\n", - "Iteration: 01000 - cost: 0.69315 - accuracy: 0.50444\n", - "Iteration: 02000 - cost: 0.69312 - accuracy: 0.50444\n", - "Iteration: 03000 - cost: 0.69310 - accuracy: 0.50444\n", - "Iteration: 04000 - cost: 0.69308 - accuracy: 0.50444\n", - "Iteration: 05000 - cost: 0.69306 - accuracy: 0.50444\n", - "Iteration: 06000 - cost: 0.69304 - accuracy: 0.50444\n", - "Iteration: 07000 - cost: 0.69301 - accuracy: 0.50444\n", - "Iteration: 08000 - cost: 0.69298 - accuracy: 0.50444\n", - "Iteration: 09000 - cost: 0.69296 - accuracy: 0.50444\n", - "Iteration: 10000 - cost: 0.69292 - accuracy: 0.50444\n", - "Iteration: 11000 - cost: 0.69288 - accuracy: 0.50444\n", - "Iteration: 12000 - cost: 0.69284 - accuracy: 0.50444\n", - "Iteration: 13000 - cost: 0.69278 - accuracy: 0.50444\n", - "Iteration: 14000 - cost: 0.69272 - accuracy: 0.50444\n", - "Iteration: 15000 - cost: 0.69265 - accuracy: 0.50444\n", - "Iteration: 16000 - cost: 0.69256 - accuracy: 0.50444\n", - "Iteration: 17000 - cost: 0.69244 - accuracy: 0.50444\n", - "Iteration: 18000 - cost: 0.69229 - accuracy: 0.50778\n", - "Iteration: 19000 - cost: 0.69210 - accuracy: 0.52778\n", - "Iteration: 20000 - cost: 0.69184 - accuracy: 0.54111\n", - "Iteration: 21000 - cost: 0.69148 - accuracy: 0.55778\n", - "Iteration: 22000 - cost: 0.69097 - accuracy: 0.58333\n", - "Iteration: 23000 - cost: 0.69021 - accuracy: 0.60222\n", - "Iteration: 24000 - cost: 0.68900 - accuracy: 0.62222\n", - "Iteration: 25000 - cost: 0.68693 - accuracy: 0.65111\n", - "Iteration: 26000 - cost: 0.68299 - accuracy: 0.66889\n", - "Iteration: 27000 - cost: 0.67457 - accuracy: 0.68222\n", - "Iteration: 28000 - cost: 0.65530 - accuracy: 0.67333\n", - "Iteration: 29000 - cost: 0.61861 - accuracy: 0.67111\n" + "Learning rate: 0.0092\n", + "dW1 norm 2 grad: 0.030587792178820513, accu: 4.941336157937026\n", + "dW1 norm 2 grad: 0.1646929550834701, accu: 1.9357215120084397\n", + "dW1 norm 2 grad: 0.019128990025118184, accu: 1.403621959589698\n", + "dW1 norm 2 grad: 0.010786917974214204, accu: 1.3994765844539645\n", + "dW1 norm 2 grad: 0.006973910701561598, accu: 1.398801328858226\n", + "dW1 norm 2 grad: 0.004804662293654282, accu: 1.398564108821892\n", + "dW1 norm 2 grad: 0.0037441349283433623, accu: 1.3984079114655157\n", + "dW1 norm 2 grad: 0.0033954272366276093, accu: 1.3982712361598055\n", + "dW1 norm 2 grad: 0.003206636031118767, accu: 1.3981389888275535\n", + "dW1 norm 2 grad: 0.003465722901160753, accu: 1.3980025782117487\n", + "dW1 norm 2 grad: 0.003626722252604109, accu: 1.397860855459184\n", + "dW1 norm 2 grad: 0.003897822935664194, accu: 1.3977138200524064\n", + "dW1 norm 2 grad: 0.004241887307465233, accu: 1.3975571070828419\n", + "dW1 norm 2 grad: 0.004657598488948141, accu: 1.3973921901095299\n", + "dW1 norm 2 grad: 0.005220580588807355, accu: 1.397215829967114\n", + "dW1 norm 2 grad: 0.005846951555849162, accu: 1.3970250436043212\n", + "dW1 norm 2 grad: 0.006524684719971642, accu: 1.3968167894331327\n", + "dW1 norm 2 grad: 0.007310990726029376, accu: 1.3965867688366898\n", + "dW1 norm 2 grad: 0.008200467149319903, accu: 1.3963309075297812\n", + "dW1 norm 2 grad: 0.009301152067821774, accu: 1.396042748753016\n", + "dW1 norm 2 grad: 0.010166278854072197, accu: 1.3957157327814165\n", + "dW1 norm 2 grad: 0.011425139837519494, accu: 1.3953405108268921\n", + "dW1 norm 2 grad: 0.012950014512970601, accu: 1.3949055628030786\n", + "dW1 norm 2 grad: 0.014355878579536059, accu: 1.394394841643409\n", + "dW1 norm 2 grad: 0.016137061178322212, accu: 1.393787043523792\n", + "dW1 norm 2 grad: 0.01826892984944323, accu: 1.393051951259501\n", + "dW1 norm 2 grad: 0.021143954900072977, accu: 1.3921339908391197\n", + "dW1 norm 2 grad: 0.02423583026714934, accu: 1.3909084263048683\n", + "dW1 norm 2 grad: 0.02787064236557585, accu: 1.3893470499486298\n", + "dW1 norm 2 grad: 0.03272810256003809, accu: 1.3873267538809206\n", + "dW1 norm 2 grad: 0.03832710708326575, accu: 1.3845979543009885\n", + "dW1 norm 2 grad: 0.045545571050636745, accu: 1.3807623420159982\n", + "dW1 norm 2 grad: 0.05624419801088249, accu: 1.375112751337416\n", + "dW1 norm 2 grad: 0.07016642479210054, accu: 1.3663006786836638\n", + "dW1 norm 2 grad: 0.0905448224502262, accu: 1.3515789855937521\n", + "dW1 norm 2 grad: 0.1232520179315588, accu: 1.3240964990083572\n", + "dW1 norm 2 grad: 0.16541711360085276, accu: 1.2706817647472526\n", + "dW1 norm 2 grad: 0.19353060647709652, accu: 1.1786673745786147\n", + "dW1 norm 2 grad: 0.18934424610944875, accu: 1.0755383964359753\n", + "dW1 norm 2 grad: 0.1408928760505227, accu: 1.0107582411953095\n", + "dW1 norm 2 grad: 0.10766369209801253, accu: 0.9758933289489259\n", + "dW1 norm 2 grad: 0.0866161481253012, accu: 0.9546471539965528\n", + "dW1 norm 2 grad: 0.07664582736675311, accu: 0.9382959188648609\n", + "dW1 norm 2 grad: 0.06378446361836485, accu: 0.92428916845143\n", + "dW1 norm 2 grad: 0.0608145095307446, accu: 0.9114837085330408\n", + "dW1 norm 2 grad: 0.06972571774543508, accu: 0.8973253645899818\n", + "dW1 norm 2 grad: 0.09983976688230996, accu: 0.8800802854305243\n", + "dW1 norm 2 grad: 0.11626330494728294, accu: 0.8619511316626135\n", + "dW1 norm 2 grad: 0.049351391928600205, accu: 0.8416075309665876\n", + "dW1 norm 2 grad: 0.03913529011211546, accu: 0.8213559372122169\n", + "Learning rate: 0.00092\n", + "dW1 norm 2 grad: 0.0698354824909347, accu: 0.8021223071821478\n", + "dW1 norm 2 grad: 0.06931343312498642, accu: 0.8002995625700543\n", + "dW1 norm 2 grad: 0.04781421567390308, accu: 0.7985017709141918\n", + "dW1 norm 2 grad: 0.04748953758278176, accu: 0.796723505677864\n", + "dW1 norm 2 grad: 0.05246500533030227, accu: 0.7949704071822465\n", + "dW1 norm 2 grad: 0.027031577043732952, accu: 0.7932488950712784\n", + "dW1 norm 2 grad: 0.05116794968101623, accu: 0.7915531315627051\n", + "dW1 norm 2 grad: 0.11213636049360046, accu: 0.7898833836960827\n", + "dW1 norm 2 grad: 0.06823964221603801, accu: 0.7882536651796254\n", + "dW1 norm 2 grad: 0.10714340854837753, accu: 0.7866551302317062\n", + "dW1 norm 2 grad: 0.04994649850404826, accu: 0.7850793043793554\n", + "dW1 norm 2 grad: 0.1099960451952437, accu: 0.7835262077531631\n", + "dW1 norm 2 grad: 0.0439103861538206, accu: 0.7819942812222288\n", + "dW1 norm 2 grad: 0.10874239209642468, accu: 0.7804808471860741\n", + "dW1 norm 2 grad: 0.11393738889197819, accu: 0.7789666328853002\n", + "dW1 norm 2 grad: 0.04495814658240637, accu: 0.7774557442802752\n", + "dW1 norm 2 grad: 0.04395459066772111, accu: 0.7759593873186026\n", + "dW1 norm 2 grad: 0.043967202328786065, accu: 0.7744766168993356\n", + "dW1 norm 2 grad: 0.04442609020118015, accu: 0.7730071428706276\n", + "dW1 norm 2 grad: 0.04458841504135935, accu: 0.7715629241882856\n", + "dW1 norm 2 grad: 0.10906288744417346, accu: 0.7701327907640183\n", + "dW1 norm 2 grad: 0.04608212489484624, accu: 0.7686767623882792\n", + "dW1 norm 2 grad: 0.04205250484603993, accu: 0.7671757348155648\n", + "dW1 norm 2 grad: 0.10858734045274064, accu: 0.7656854562107144\n", + "dW1 norm 2 grad: 0.04079451613320205, accu: 0.7641965674293068\n", + "dW1 norm 2 grad: 0.04119869852007931, accu: 0.7627151990621612\n", + "dW1 norm 2 grad: 0.022823050772103445, accu: 0.7612407621215996\n", + "dW1 norm 2 grad: 0.05988234296477944, accu: 0.7597502571796344\n", + "dW1 norm 2 grad: 0.03994801225051277, accu: 0.758263715093205\n", + "dW1 norm 2 grad: 0.02001563328832784, accu: 0.7567800459820742\n", + "dW1 norm 2 grad: 0.09389221519341912, accu: 0.7553161878877876\n", + "dW1 norm 2 grad: 0.036694546941468544, accu: 0.7538650579919823\n", + "dW1 norm 2 grad: 0.09363821631353306, accu: 0.7524245293792678\n", + "dW1 norm 2 grad: 0.08890011864334593, accu: 0.751074225488152\n", + "dW1 norm 2 grad: 0.08711660000940519, accu: 0.7497468566289912\n", + "dW1 norm 2 grad: 0.08433113821881433, accu: 0.7484255278238835\n", + "dW1 norm 2 grad: 0.08274136620344481, accu: 0.7471096846730314\n", + "dW1 norm 2 grad: 0.08263159774910961, accu: 0.7457989989854615\n", + "dW1 norm 2 grad: 0.07996019815679523, accu: 0.7444924666609721\n", + "dW1 norm 2 grad: 0.07590974689644236, accu: 0.74319290888101\n", + "dW1 norm 2 grad: 0.07671148899598569, accu: 0.7418974636517687\n", + "dW1 norm 2 grad: 0.07651282066883193, accu: 0.7406057881746861\n", + "dW1 norm 2 grad: 0.07194127532607032, accu: 0.7393175244984712\n", + "dW1 norm 2 grad: 0.046034760180805066, accu: 0.7380394106977187\n", + "dW1 norm 2 grad: 0.04574718764808875, accu: 0.7368222319175388\n", + "dW1 norm 2 grad: 0.025349022989991577, accu: 0.7356395543948386\n", + "dW1 norm 2 grad: 0.03560280602726019, accu: 0.7344641888766926\n", + "dW1 norm 2 grad: 0.09259080063105282, accu: 0.7332947744286727\n", + "dW1 norm 2 grad: 0.0911961020909095, accu: 0.7321292733139444\n", + "dW1 norm 2 grad: 0.08942583365863134, accu: 0.7309675839019333\n", + "Learning rate: 9.2e-05\n", + "dW1 norm 2 grad: 0.05081509253625319, accu: 0.7298165738413436\n", + "dW1 norm 2 grad: 0.08069013256772081, accu: 0.7297031031059652\n", + "dW1 norm 2 grad: 0.033465976546192586, accu: 0.7295897688790668\n", + "dW1 norm 2 grad: 0.0838071759382461, accu: 0.7294765059452227\n", + "dW1 norm 2 grad: 0.08369039264117137, accu: 0.7293633551904505\n", + "dW1 norm 2 grad: 0.025063043610271528, accu: 0.7292503109024019\n", + "dW1 norm 2 grad: 0.0250903067228421, accu: 0.7291372655826458\n", + "dW1 norm 2 grad: 0.024807189346951402, accu: 0.7290242245133101\n", + "dW1 norm 2 grad: 0.07990580710201337, accu: 0.7289111175614715\n", + "dW1 norm 2 grad: 0.024621601100137287, accu: 0.72879804031431\n", + "dW1 norm 2 grad: 0.023355088185910863, accu: 0.7286849775453704\n", + "dW1 norm 2 grad: 0.02274113942950061, accu: 0.7285719795169106\n", + "dW1 norm 2 grad: 0.049949659194981566, accu: 0.7284593451650372\n", + "dW1 norm 2 grad: 0.023248965569912117, accu: 0.7283468190487229\n", + "dW1 norm 2 grad: 0.023152204671378268, accu: 0.7282343171189314\n", + "dW1 norm 2 grad: 0.04969774869533673, accu: 0.7281217905460505\n", + "dW1 norm 2 grad: 0.08333052873343195, accu: 0.7280093859426575\n", + "dW1 norm 2 grad: 0.07823212307227495, accu: 0.727897047104545\n", + "dW1 norm 2 grad: 0.08257422509504603, accu: 0.7277847567346591\n", + "dW1 norm 2 grad: 0.02224991771397424, accu: 0.7276725529494683\n", + "dW1 norm 2 grad: 0.02711910610218876, accu: 0.7275603644520168\n", + "dW1 norm 2 grad: 0.07886200292692834, accu: 0.7274482275685278\n", + "dW1 norm 2 grad: 0.02198031237106057, accu: 0.7273361722482247\n", + "dW1 norm 2 grad: 0.02115994730449518, accu: 0.7272241086079306\n", + "dW1 norm 2 grad: 0.021285257622553382, accu: 0.7271122986359888\n", + "dW1 norm 2 grad: 0.02357392567635191, accu: 0.7270004888545573\n", + "dW1 norm 2 grad: 0.022688686003681072, accu: 0.7268886540360114\n", + "dW1 norm 2 grad: 0.0774804194383182, accu: 0.7267768612068347\n", + "dW1 norm 2 grad: 0.023296298417766592, accu: 0.726665878152204\n", + "dW1 norm 2 grad: 0.0268274942998492, accu: 0.7265561850412595\n", + "dW1 norm 2 grad: 0.06052005017135857, accu: 0.7264469435954387\n", + "dW1 norm 2 grad: 0.04945838108083401, accu: 0.7263377323697943\n", + "dW1 norm 2 grad: 0.027067935503537366, accu: 0.726229201502864\n", + "dW1 norm 2 grad: 0.026893122048369005, accu: 0.7261207218241621\n", + "dW1 norm 2 grad: 0.02174669349434738, accu: 0.7260123279798082\n", + "dW1 norm 2 grad: 0.027435227540288768, accu: 0.7259041066903453\n", + "dW1 norm 2 grad: 0.08375911784691781, accu: 0.7257962378376251\n", + "dW1 norm 2 grad: 0.022652301601058415, accu: 0.7256884416312775\n", + "dW1 norm 2 grad: 0.08289691349070971, accu: 0.7255806565534714\n", + "dW1 norm 2 grad: 0.08364752167205262, accu: 0.7254731199083936\n", + "dW1 norm 2 grad: 0.026692046014157863, accu: 0.7253656043834251\n", + "dW1 norm 2 grad: 0.04913025170849942, accu: 0.7252580886894395\n", + "dW1 norm 2 grad: 0.023689498661408678, accu: 0.7251506668124373\n", + "dW1 norm 2 grad: 0.02657062856609085, accu: 0.725042860431935\n", + "dW1 norm 2 grad: 0.023702493689244892, accu: 0.7249343664607922\n", + "dW1 norm 2 grad: 0.02342068477793123, accu: 0.724825581515084\n", + "dW1 norm 2 grad: 0.023512325200450875, accu: 0.7247168048663374\n", + "dW1 norm 2 grad: 0.026039992426218454, accu: 0.7246080322347013\n", + "dW1 norm 2 grad: 0.06301257689207454, accu: 0.7244992898894886\n", + "dW1 norm 2 grad: 0.08123844996098815, accu: 0.724390594267674\n", + "Learning rate: 9.2e-06\n", + "dW1 norm 2 grad: 0.05660752226816397, accu: 0.724281912921313\n", + "dW1 norm 2 grad: 0.026860772665232544, accu: 0.7242710584653618\n", + "dW1 norm 2 grad: 0.024496875972681183, accu: 0.7242602093657362\n", + "dW1 norm 2 grad: 0.025294084526166837, accu: 0.7242493410645534\n", + "dW1 norm 2 grad: 0.022663938572628706, accu: 0.724238471439105\n", + "dW1 norm 2 grad: 0.08092590829169215, accu: 0.7242275992665872\n", + "dW1 norm 2 grad: 0.021902604575931107, accu: 0.7242167354020932\n", + "dW1 norm 2 grad: 0.025736603559630894, accu: 0.7242058633500534\n", + "dW1 norm 2 grad: 0.0581973314338413, accu: 0.7241949951587566\n", + "dW1 norm 2 grad: 0.02344384040486643, accu: 0.7241841285869589\n", + "dW1 norm 2 grad: 0.021904167014846036, accu: 0.7241732618842978\n", + "dW1 norm 2 grad: 0.02494119041657796, accu: 0.7241623921154894\n", + "dW1 norm 2 grad: 0.07836203187703025, accu: 0.7241515204595299\n", + "dW1 norm 2 grad: 0.025193938991455386, accu: 0.7241406544392891\n", + "dW1 norm 2 grad: 0.07560872303681854, accu: 0.7241297849270177\n", + "dW1 norm 2 grad: 0.02556473409099411, accu: 0.7241189184717867\n", + "dW1 norm 2 grad: 0.07650664220580596, accu: 0.7241080494594142\n", + "dW1 norm 2 grad: 0.08074806296038219, accu: 0.7240971814540356\n", + "dW1 norm 2 grad: 0.024689714218072648, accu: 0.7240863166411479\n", + "dW1 norm 2 grad: 0.08164790148339035, accu: 0.7240754467496378\n", + "dW1 norm 2 grad: 0.023112710622302055, accu: 0.7240645839138531\n", + "dW1 norm 2 grad: 0.025532646872846385, accu: 0.7240537140180505\n", + "dW1 norm 2 grad: 0.08065985712188835, accu: 0.7240428454529251\n", + "dW1 norm 2 grad: 0.0471336793636139, accu: 0.7240319784959498\n", + "dW1 norm 2 grad: 0.08153184230557896, accu: 0.7240211117910492\n", + "dW1 norm 2 grad: 0.0815184663751406, accu: 0.7240102446800086\n", + "dW1 norm 2 grad: 0.025456334538617304, accu: 0.7239993797762518\n", + "dW1 norm 2 grad: 0.023264980480717054, accu: 0.7239885152111308\n", + "dW1 norm 2 grad: 0.07961769303874187, accu: 0.7239776454654755\n", + "dW1 norm 2 grad: 0.07625319759743494, accu: 0.7239667793159715\n", + "dW1 norm 2 grad: 0.02539726871451324, accu: 0.7239559145024411\n", + "dW1 norm 2 grad: 0.0237378349559094, accu: 0.7239450513078237\n", + "dW1 norm 2 grad: 0.02513608435427682, accu: 0.7239341840865249\n", + "dW1 norm 2 grad: 0.08137682892886208, accu: 0.7239233148786304\n", + "dW1 norm 2 grad: 0.04811297421618333, accu: 0.7239124656845819\n", + "dW1 norm 2 grad: 0.023193417673846712, accu: 0.7239016316445489\n", + "dW1 norm 2 grad: 0.0256038658999626, accu: 0.723890791326119\n", + "dW1 norm 2 grad: 0.021947534582446696, accu: 0.723879955190287\n", + "dW1 norm 2 grad: 0.08167978225793682, accu: 0.7238691117330809\n", + "dW1 norm 2 grad: 0.0231233474515702, accu: 0.7238582754116796\n", + "dW1 norm 2 grad: 0.08161827390446047, accu: 0.7238474330611652\n", + "dW1 norm 2 grad: 0.07974789689423355, accu: 0.7238365943467001\n", + "dW1 norm 2 grad: 0.022592499032915825, accu: 0.7238257594435297\n", + "dW1 norm 2 grad: 0.022504525579480045, accu: 0.7238149204937029\n", + "dW1 norm 2 grad: 0.08156846999929206, accu: 0.7238040775315533\n", + "dW1 norm 2 grad: 0.023618593965625455, accu: 0.7237932416255065\n", + "dW1 norm 2 grad: 0.08150850470928286, accu: 0.7237823997171771\n", + "dW1 norm 2 grad: 0.07696639270478159, accu: 0.7237715613280187\n", + "dW1 norm 2 grad: 0.025295856277240028, accu: 0.7237607248761024\n", + "dW1 norm 2 grad: 0.0796405362393356, accu: 0.72374988335138\n", + "Learning rate: 9.2e-07\n", + "dW1 norm 2 grad: 0.02533836373189624, accu: 0.7237390482440271\n", + "dW1 norm 2 grad: 0.07694830713060108, accu: 0.723737963788499\n", + "dW1 norm 2 grad: 0.07942747965329897, accu: 0.7237368813798639\n", + "dW1 norm 2 grad: 0.024989911441789227, accu: 0.7237357976831307\n", + "dW1 norm 2 grad: 0.025115561130679122, accu: 0.723734713740944\n", + "dW1 norm 2 grad: 0.024750581477314067, accu: 0.7237336301018029\n", + "dW1 norm 2 grad: 0.025057096033220413, accu: 0.7237325461940796\n", + "dW1 norm 2 grad: 0.08142729217753651, accu: 0.7237314619742643\n", + "dW1 norm 2 grad: 0.023236646366497984, accu: 0.7237303784489091\n", + "dW1 norm 2 grad: 0.02622057086597789, accu: 0.7237292943668882\n" ] } ], - "execution_count": 137 + "execution_count": 308 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-05T13:38:25.280709Z", - "start_time": "2025-11-05T13:38:25.278689Z" + "end_time": "2025-11-05T22:02:26.485044Z", + "start_time": "2025-11-05T22:02:26.478191Z" } }, "cell_type": "code", "source": [ "Y_test_hat, _ = full_forward_propagation(np.transpose(X_test), params_values, nn_architecture)\n", + "# X_test[0]\n", + "# Y_test_hat[0]\n", + "# Y_test_hat.T[0][10:20]\n", + "# np.argmax(Y_test_hat[0][0:10])\n", + "# Y_test_hat[0][0:10][8]\n", "\n", - "acc_test = get_accuracy_value(Y_test_hat, np.transpose(y_test.reshape((y_test.shape[0], 1))))\n", - "print(\"Test set accuracy: {:.2f} - David\".format(acc_test))\n" + "# decode_from_vector(Y_test_hat.T[0])\n", + "\n", + "for i in range(len(X_test)):\n", + " print(\"for: {}, pred: {}\".format(int(np.sum(X_test[i])*10), decode_from_vector(Y_test_hat.T[i])))\n", + "\n" ], "id": "26e7a2a8848714d9", "outputs": [ @@ -582,16 +832,59 @@ "name": "stdout", "output_type": "stream", "text": [ - "Test set accuracy: 0.69 - David\n" + "for: 11, pred: 10\n", + "for: 8, pred: 9\n", + "for: 7, pred: 7\n", + "for: 9, pred: 9\n", + "for: 8, pred: 9\n", + "for: 12, pred: 11\n", + "for: 4, pred: 6\n", + "for: 8, pred: 9\n", + "for: 1, pred: 6\n", + "for: 0, pred: 6\n" ] } ], - "execution_count": 138 + "execution_count": 309 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-05T13:43:01.544185Z", + "end_time": "2025-11-05T22:02:31.875471Z", + "start_time": "2025-11-05T22:02:31.869284Z" + } + }, + "cell_type": "code", + "source": [ + "Y_train_hat, _ = full_forward_propagation(np.transpose(X_train[1:10]), params_values, nn_architecture)\n", + "\n", + "for i in range(len(X_train[1:10])):\n", + " print(\"for: {}, pred: {}\".format(int(np.sum(X_train[1:10][i])*10), decode_from_vector(Y_train_hat.T[i])))\n" + ], + "id": "50a9492ad1e6a37c", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "for: 3, pred: 6\n", + "for: 10, pred: 10\n", + "for: 6, pred: 6\n", + "for: 9, pred: 9\n", + "for: 4, pred: 6\n", + "for: 12, pred: 13\n", + "for: 14, pred: 13\n", + "for: 3, pred: 6\n", + "for: 4, pred: 6\n" + ] + } + ], + "execution_count": 310 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-11-05T16:58:00.360548085Z", "start_time": "2025-11-05T13:43:01.537741Z" } }, @@ -617,7 +910,7 @@ { "metadata": { "ExecuteTime": { - "end_time": "2025-11-05T13:38:25.355642Z", + "end_time": "2025-11-05T16:58:00.361037286Z", "start_time": "2025-11-05T13:38:25.347259Z" } }, @@ -628,17 +921,17 @@ " plt.style.use('dark_background')\n", " else:\n", " sns.set_style(\"whitegrid\")\n", - " plt.figure(figsize=(16,12))\n", + " plt.figure(figsize=(16, 12))\n", " axes = plt.gca()\n", " axes.set(xlabel=\"$X_1$\", ylabel=\"$X_2$\")\n", " plt.title(plot_name, fontsize=30)\n", " plt.subplots_adjust(left=0.20)\n", " plt.subplots_adjust(right=0.80)\n", - " if(XX is not None and YY is not None and preds is not None):\n", - " plt.contourf(XX, YY, preds.reshape(XX.shape), 25, alpha = 1, cmap=cm.Spectral)\n", + " if (XX is not None and YY is not None and preds is not None):\n", + " plt.contourf(XX, YY, preds.reshape(XX.shape), 25, alpha=1, cmap=cm.Spectral)\n", " plt.contour(XX, YY, preds.reshape(XX.shape), levels=[.5], cmap=\"Greys\", vmin=0, vmax=.6)\n", " plt.scatter(X[:, 0], X[:, 1], c=y.ravel(), s=40, cmap=plt.cm.Spectral, edgecolors='black')\n", - " if(file_name):\n", + " if (file_name):\n", " plt.savefig(file_name)\n", " plt.close()" ], @@ -649,103 +942,7 @@ { "metadata": { "ExecuteTime": { - "end_time": "2025-11-05T13:47:50.552445Z", - "start_time": "2025-11-05T13:47:50.545245Z" - } - }, - "cell_type": "code", - "source": "X.T", - "id": "b87560199ee27331", - "outputs": [ - { - "data": { - "text/plain": [ - "array([[-1.00304644, 0.27761395, 0.7121587 , ..., -0.13965464,\n", - " -0.68842612, 0.59102256],\n", - " [ 0.08532924, -1.33189304, 0.49761513, ..., -0.91986245,\n", - " -0.38146264, 0.40804492]], shape=(2, 1000))" - ] - }, - "execution_count": 158, - "metadata": {}, - "output_type": "execute_result" - } - ], - "execution_count": 158 - }, - { - "metadata": { - "ExecuteTime": { - "end_time": "2025-11-05T13:46:41.287971Z", - "start_time": "2025-11-05T13:46:41.280330Z" - } - }, - "cell_type": "code", - "source": "y", - "id": "e0c22b36e47fd9e4", - "outputs": [ - { - "data": { - "text/plain": [ - "array([0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0,\n", - " 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0,\n", - " 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0,\n", - " 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1,\n", - " 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0,\n", - " 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0,\n", - " 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1,\n", - " 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0,\n", - " 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0,\n", - " 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,\n", - " 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0,\n", - " 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,\n", - " 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0,\n", - " 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,\n", - " 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1,\n", - " 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1,\n", - " 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1,\n", - " 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0,\n", - " 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,\n", - " 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1,\n", - " 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1,\n", - " 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1,\n", - " 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0,\n", - " 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1,\n", - " 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1,\n", - " 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0,\n", - " 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1,\n", - " 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1,\n", - " 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1,\n", - " 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1,\n", - " 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1,\n", - " 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1,\n", - " 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1,\n", - " 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0,\n", - " 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1,\n", - " 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0,\n", - " 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0,\n", - " 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,\n", - " 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1,\n", - " 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0,\n", - " 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0,\n", - " 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0,\n", - " 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0,\n", - " 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0,\n", - " 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1,\n", - " 1, 0, 1, 1, 0, 1, 1, 1, 1, 1])" - ] - }, - "execution_count": 157, - "metadata": {}, - "output_type": "execute_result" - } - ], - "execution_count": 157 - }, - { - "metadata": { - "ExecuteTime": { - "end_time": "2025-11-05T13:38:25.471082Z", + "end_time": "2025-11-05T16:58:00.363234514Z", "start_time": "2025-11-05T13:38:25.370173Z" } }, @@ -772,13 +969,15 @@ { "metadata": { "ExecuteTime": { - "end_time": "2025-11-05T13:39:53.206765Z", + "end_time": "2025-11-05T16:58:00.363988157Z", "start_time": "2025-11-05T13:39:27.866713Z" } }, "cell_type": "code", "source": [ "from time import sleep\n", + "\n", + "\n", "def callback_numpy_plot(index, params):\n", " plot_title = \"NumPy Model - It: {:05}\".format(index)\n", " file_name = \"numpy_model_{:05}.png\".format(index // 50)\n", @@ -787,6 +986,7 @@ " prediction_probs = prediction_probs.reshape(prediction_probs.shape[1], 1)\n", " make_plot(X_test, y_test, plot_title, file_name=file_path, XX=XX, YY=YY, preds=prediction_probs, dark=True)\n", "\n", + "\n", "# Training\n", "params_values = train(np.transpose(X_train), np.transpose(y_train.reshape((y_train.shape[0], 1))), nn_architecture,\n", " 30000, 0.01, False, callback_numpy_plot)" @@ -798,7 +998,7 @@ { "metadata": { "ExecuteTime": { - "end_time": "2025-11-05T13:43:06.906300Z", + "end_time": "2025-11-05T16:58:00.364712192Z", "start_time": "2025-11-05T13:43:06.780036Z" } }, @@ -840,23 +1040,26 @@ { "metadata": { "ExecuteTime": { - "end_time": "2025-11-05T13:45:25.105461Z", - "start_time": "2025-11-05T13:45:17.001211Z" + "end_time": "2025-11-05T22:28:55.082124Z", + "start_time": "2025-11-05T22:28:29.890543Z" } }, "cell_type": "code", "source": [ + "X_train = np.array(X_train)\n", + "y_train = np.array(y_train)\n", + "\n", "model = Sequential()\n", - "model.add(Dense(25, input_dim=2,activation='relu'))\n", + "model.add(Dense(25, input_dim=2, activation='relu'))\n", "model.add(Dense(50, activation='relu'))\n", "model.add(Dense(50, activation='relu'))\n", "model.add(Dense(25, activation='relu'))\n", - "model.add(Dense(1, activation='sigmoid'))\n", + "model.add(Dense(20, activation='sigmoid'))\n", "\n", "model.compile(loss='binary_crossentropy', optimizer=\"sgd\", metrics=['accuracy'])\n", "\n", "# Training\n", - "history = model.fit(X_train, y_train, epochs=200, verbose=0)" + "history = model.fit(X_train, y_train, epochs=1000, verbose=0)" ], "id": "f05ff40ed26e45c2", "outputs": [ @@ -869,21 +1072,32 @@ ] } ], - "execution_count": 154 + "execution_count": 373 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-05T13:45:27.083553Z", - "start_time": "2025-11-05T13:45:26.623760Z" + "end_time": "2025-11-05T22:28:56.234343Z", + "start_time": "2025-11-05T22:28:55.933926Z" } }, "cell_type": "code", "source": [ + "X_test = np.array(X_test)\n", + "y_test = np.array(y_test)\n", + "\n", "Y_test_prob = model.predict(X_test)\n", - "Y_test_hat = (Y_test_prob > 0.5).astype(int).ravel()\n", - "acc_test = accuracy_score(y_test, Y_test_hat)\n", - "print(\"Test set accuracy: {:.2f} - Goliath\".format(acc_test))" + "\n", + "# Convert probabilities to class predictions\n", + "Y_test_pred = np.argmax(Y_test_prob, axis=1)\n", + "\n", + "# Convert y_test to class labels (if it's one-hot encoded)\n", + "# If y_test is already in label format (e.g., [0, 1, 0, 1, ...]), skip this step\n", + "y_test_labels = np.argmax(y_test, axis=1) if len(y_test.shape) > 1 and y_test.shape[1] > 1 else y_test\n", + "\n", + "# Calculate accuracy\n", + "acc_test = accuracy_score(y_test_labels, Y_test_pred)\n", + "print(\"Test set accuracy: {:.2f}\".format(acc_test))" ], "id": "ef52bee9c93081d3", "outputs": [ @@ -891,19 +1105,51 @@ "name": "stdout", "output_type": "stream", "text": [ - "\u001B[1m1/4\u001B[0m \u001B[32m━━━━━\u001B[0m\u001B[37m━━━━━━━━━━━━━━━\u001B[0m \u001B[1m0s\u001B[0m 202ms/stepWARNING:tensorflow:5 out of the last 62926 calls to .one_step_on_data_distributed at 0x7a1f345302c0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", - "\u001B[1m4/4\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m0s\u001B[0m 78ms/step\n", - "Test set accuracy: 0.69 - Goliath\n" + "\u001B[1m1/1\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m0s\u001B[0m 271ms/step\n", + "Test set accuracy: 0.30\n" ] } ], - "execution_count": 155 + "execution_count": 374 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-05T13:41:27.539268Z", - "start_time": "2025-11-05T13:40:32.908771Z" + "end_time": "2025-11-05T22:29:35.732951Z", + "start_time": "2025-11-05T22:29:35.730200Z" + } + }, + "cell_type": "code", + "source": [ + "for i in range(len(X_test)):\n", + " print(\"for: {}, pred: {}\".format(int(np.sum(X_test[i])*10), decode_from_vector(Y_test_prob[i])))" + ], + "id": "7c788fc6411656ea", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "for: 11, pred: 16\n", + "for: 8, pred: 16\n", + "for: 7, pred: 16\n", + "for: 9, pred: 16\n", + "for: 8, pred: 16\n", + "for: 12, pred: 16\n", + "for: 4, pred: 16\n", + "for: 8, pred: 16\n", + "for: 1, pred: 16\n", + "for: 0, pred: 6\n" + ] + } + ], + "execution_count": 377 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-11-05T22:08:59.757756Z", + "start_time": "2025-11-05T22:08:59.657259Z" } }, "cell_type": "code", @@ -942,14 +1188,27 @@ "/home/oskar/projects/nn-from-scratch/.venv/lib/python3.13/site-packages/keras/src/layers/core/dense.py:95: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.\n", " super().__init__(activity_regularizer=activity_regularizer, **kwargs)\n" ] + }, + { + "ename": "ValueError", + "evalue": "Arguments `target` and `output` must have the same shape. Received: target.shape=(None, 20), output.shape=(None, 1)", + "output_type": "error", + "traceback": [ + "\u001B[31m---------------------------------------------------------------------------\u001B[39m", + "\u001B[31mValueError\u001B[39m Traceback (most recent call last)", + "\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[321]\u001B[39m\u001B[32m, line 24\u001B[39m\n\u001B[32m 21\u001B[39m model.compile(loss=\u001B[33m'\u001B[39m\u001B[33mbinary_crossentropy\u001B[39m\u001B[33m'\u001B[39m, optimizer=\u001B[33m\"\u001B[39m\u001B[33msgd\u001B[39m\u001B[33m\"\u001B[39m, metrics=[\u001B[33m'\u001B[39m\u001B[33maccuracy\u001B[39m\u001B[33m'\u001B[39m])\n\u001B[32m 23\u001B[39m \u001B[38;5;66;03m# Training\u001B[39;00m\n\u001B[32m---> \u001B[39m\u001B[32m24\u001B[39m history = \u001B[43mmodel\u001B[49m\u001B[43m.\u001B[49m\u001B[43mfit\u001B[49m\u001B[43m(\u001B[49m\u001B[43mX_train\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43my_train\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mepochs\u001B[49m\u001B[43m=\u001B[49m\u001B[32;43m200\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mverbose\u001B[49m\u001B[43m=\u001B[49m\u001B[32;43m0\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcallbacks\u001B[49m\u001B[43m=\u001B[49m\u001B[43m[\u001B[49m\u001B[43mtestmodelcb\u001B[49m\u001B[43m]\u001B[49m\u001B[43m)\u001B[49m\n", + "\u001B[36mFile \u001B[39m\u001B[32m~/projects/nn-from-scratch/.venv/lib/python3.13/site-packages/keras/src/utils/traceback_utils.py:122\u001B[39m, in \u001B[36mfilter_traceback..error_handler\u001B[39m\u001B[34m(*args, **kwargs)\u001B[39m\n\u001B[32m 119\u001B[39m filtered_tb = _process_traceback_frames(e.__traceback__)\n\u001B[32m 120\u001B[39m \u001B[38;5;66;03m# To get the full stack trace, call:\u001B[39;00m\n\u001B[32m 121\u001B[39m \u001B[38;5;66;03m# `keras.config.disable_traceback_filtering()`\u001B[39;00m\n\u001B[32m--> \u001B[39m\u001B[32m122\u001B[39m \u001B[38;5;28;01mraise\u001B[39;00m e.with_traceback(filtered_tb) \u001B[38;5;28;01mfrom\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m\n\u001B[32m 123\u001B[39m \u001B[38;5;28;01mfinally\u001B[39;00m:\n\u001B[32m 124\u001B[39m \u001B[38;5;28;01mdel\u001B[39;00m filtered_tb\n", + "\u001B[36mFile \u001B[39m\u001B[32m~/projects/nn-from-scratch/.venv/lib/python3.13/site-packages/keras/src/backend/tensorflow/nn.py:783\u001B[39m, in \u001B[36mbinary_crossentropy\u001B[39m\u001B[34m(target, output, from_logits)\u001B[39m\n\u001B[32m 781\u001B[39m \u001B[38;5;28;01mfor\u001B[39;00m e1, e2 \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mzip\u001B[39m(target.shape, output.shape):\n\u001B[32m 782\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m e1 \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m \u001B[38;5;129;01mand\u001B[39;00m e2 \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m \u001B[38;5;129;01mand\u001B[39;00m e1 != e2:\n\u001B[32m--> \u001B[39m\u001B[32m783\u001B[39m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mValueError\u001B[39;00m(\n\u001B[32m 784\u001B[39m \u001B[33m\"\u001B[39m\u001B[33mArguments `target` and `output` must have the same shape. \u001B[39m\u001B[33m\"\u001B[39m\n\u001B[32m 785\u001B[39m \u001B[33m\"\u001B[39m\u001B[33mReceived: \u001B[39m\u001B[33m\"\u001B[39m\n\u001B[32m 786\u001B[39m \u001B[33mf\u001B[39m\u001B[33m\"\u001B[39m\u001B[33mtarget.shape=\u001B[39m\u001B[38;5;132;01m{\u001B[39;00mtarget.shape\u001B[38;5;132;01m}\u001B[39;00m\u001B[33m, output.shape=\u001B[39m\u001B[38;5;132;01m{\u001B[39;00moutput.shape\u001B[38;5;132;01m}\u001B[39;00m\u001B[33m\"\u001B[39m\n\u001B[32m 787\u001B[39m )\n\u001B[32m 789\u001B[39m output, from_logits = _get_logits(\n\u001B[32m 790\u001B[39m output, from_logits, \u001B[33m\"\u001B[39m\u001B[33mSigmoid\u001B[39m\u001B[33m\"\u001B[39m, \u001B[33m\"\u001B[39m\u001B[33mbinary_crossentropy\u001B[39m\u001B[33m\"\u001B[39m\n\u001B[32m 791\u001B[39m )\n\u001B[32m 793\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m from_logits:\n", + "\u001B[31mValueError\u001B[39m: Arguments `target` and `output` must have the same shape. Received: target.shape=(None, 20), output.shape=(None, 1)" + ] } ], - "execution_count": 146 + "execution_count": 321 }, { "metadata": { "ExecuteTime": { - "end_time": "2025-11-05T13:43:21.691060Z", + "end_time": "2025-11-05T16:58:00.367286962Z", "start_time": "2025-11-05T13:43:21.443563Z" } }, @@ -991,6 +1250,7 @@ "cell_type": "code", "source": [ "import torch\n", + "\n", "print(torch.cuda.is_available())" ], "id": "9958964dbe0d2732",