nn-from-scratch/add-sc.ipynb

1325 lines
749 KiB
Plaintext
Raw Normal View History

2025-11-05 17:53:45 +01:00
{
"cells": [
{
"cell_type": "code",
"id": "initial_id",
"metadata": {
"ExecuteTime": {
2025-11-05 23:31:08 +01:00
"end_time": "2025-11-05T16:57:58.696888Z",
"start_time": "2025-11-05T16:57:58.695703Z"
2025-11-05 17:53:45 +01:00
}
},
"source": [
"from math import floor\n",
"\n",
"import numpy as np\n"
],
"outputs": [],
2025-11-05 23:31:08 +01:00
"execution_count": 1
2025-11-05 17:53:45 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-05 23:31:08 +01:00
"end_time": "2025-11-05T16:57:58.710359Z",
"start_time": "2025-11-05T16:57:58.708927Z"
2025-11-05 17:53:45 +01:00
}
},
"cell_type": "code",
"source": [
"# nn_architecture = [\n",
"# {\"input_dim\": 2, \"output_dim\": 4, \"activation\": \"relu\"},\n",
"# {\"input_dim\": 4, \"output_dim\": 6, \"activation\": \"relu\"},\n",
"# {\"input_dim\": 6, \"output_dim\": 6, \"activation\": \"relu\"},\n",
"# {\"input_dim\": 6, \"output_dim\": 4, \"activation\": \"relu\"},\n",
"# {\"input_dim\": 4, \"output_dim\": 1, \"activation\": \"sigmoid\"},\n",
"# ]\n",
"nn_architecture = [\n",
" {\"input_dim\": 2, \"output_dim\": 25, \"activation\": \"relu\"},\n",
" {\"input_dim\": 25, \"output_dim\": 50, \"activation\": \"relu\"},\n",
" {\"input_dim\": 50, \"output_dim\": 50, \"activation\": \"relu\"},\n",
" {\"input_dim\": 50, \"output_dim\": 25, \"activation\": \"relu\"},\n",
" {\"input_dim\": 25, \"output_dim\": 20, \"activation\": \"sigmoid\"},\n",
"]"
],
"id": "48cafaf4b64967bb",
"outputs": [],
2025-11-05 23:31:08 +01:00
"execution_count": 2
2025-11-05 17:53:45 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-05 23:31:08 +01:00
"end_time": "2025-11-05T16:57:58.767674Z",
"start_time": "2025-11-05T16:57:58.760718Z"
2025-11-05 17:53:45 +01:00
}
},
"cell_type": "code",
"source": [
2025-11-05 23:31:08 +01:00
"def init_layers(nn_architecture, seed=99):\n",
2025-11-05 17:53:45 +01:00
" np.random.seed(seed)\n",
" number_of_layers = len(nn_architecture)\n",
" params_values = {}\n",
"\n",
" for idx, layer in enumerate(nn_architecture):\n",
" layer_idx = idx + 1\n",
" layer_input_size = layer[\"input_dim\"]\n",
" layer_output_size = layer[\"output_dim\"]\n",
"\n",
" params_values['W' + str(layer_idx)] = np.random.randn(\n",
" layer_output_size, layer_input_size) * 0.1\n",
" params_values['b' + str(layer_idx)] = np.random.randn(\n",
" layer_output_size, 1) * 0.1\n",
"\n",
" return params_values\n"
],
"id": "d13137630b41b756",
"outputs": [],
2025-11-05 23:31:08 +01:00
"execution_count": 3
2025-11-05 17:53:45 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-05 23:31:08 +01:00
"end_time": "2025-11-05T16:57:58.824122Z",
"start_time": "2025-11-05T16:57:58.819526Z"
2025-11-05 17:53:45 +01:00
}
},
"cell_type": "code",
"source": [
"params = init_layers(nn_architecture)\n",
"# params"
],
"id": "31f205147667dea6",
"outputs": [],
2025-11-05 23:31:08 +01:00
"execution_count": 4
2025-11-05 17:53:45 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-05 23:31:08 +01:00
"end_time": "2025-11-05T16:57:58.876505Z",
"start_time": "2025-11-05T16:57:58.871388Z"
2025-11-05 17:53:45 +01:00
}
},
"cell_type": "code",
"source": [
"def sigmoid(Z):\n",
2025-11-05 23:31:08 +01:00
" return 1 / (1 + np.exp(-Z))\n",
"\n",
2025-11-05 17:53:45 +01:00
"\n",
"def relu(Z):\n",
2025-11-05 23:31:08 +01:00
" return np.maximum(0, Z)\n",
"\n",
2025-11-05 17:53:45 +01:00
"\n",
"def sigmoid_backward(dA, Z):\n",
" sig = sigmoid(Z)\n",
" return dA * sig * (1 - sig)\n",
"\n",
2025-11-05 23:31:08 +01:00
"\n",
2025-11-05 17:53:45 +01:00
"def relu_backward(dA, Z):\n",
2025-11-05 23:31:08 +01:00
" dZ = np.array(dA, copy=True)\n",
2025-11-05 17:53:45 +01:00
" dZ[Z <= 0] = 0;\n",
" return dZ;"
],
"id": "c1b960e7dcf09d91",
"outputs": [],
2025-11-05 23:31:08 +01:00
"execution_count": 5
2025-11-05 17:53:45 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-05 23:31:08 +01:00
"end_time": "2025-11-05T16:57:58.924888Z",
"start_time": "2025-11-05T16:57:58.921980Z"
2025-11-05 17:53:45 +01:00
}
},
"cell_type": "code",
"source": [
"def single_layer_forward_propagation(A_prev, W_curr, b_curr, activation=\"relu\"):\n",
" Z_curr = np.dot(W_curr, A_prev) + b_curr\n",
"\n",
" if activation == \"relu\":\n",
" activation_func = relu\n",
" elif activation == \"sigmoid\":\n",
" activation_func = sigmoid\n",
" else:\n",
" raise Exception('Non-supported activation function')\n",
"\n",
" return activation_func(Z_curr), Z_curr"
],
"id": "efae2e184daf2fce",
"outputs": [],
2025-11-05 23:31:08 +01:00
"execution_count": 6
2025-11-05 17:53:45 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-05 23:31:08 +01:00
"end_time": "2025-11-05T16:57:58.981719Z",
"start_time": "2025-11-05T16:57:58.976016Z"
2025-11-05 17:53:45 +01:00
}
},
"cell_type": "code",
"source": [
"def full_forward_propagation(X, params_values, nn_architecture):\n",
" memory = {}\n",
" A_curr = X\n",
"\n",
" for idx, layer in enumerate(nn_architecture):\n",
" layer_idx = idx + 1\n",
" A_prev = A_curr\n",
"\n",
" activ_function_curr = layer[\"activation\"]\n",
" W_curr = params_values[\"W\" + str(layer_idx)]\n",
" b_curr = params_values[\"b\" + str(layer_idx)]\n",
" A_curr, Z_curr = single_layer_forward_propagation(A_prev, W_curr, b_curr, activ_function_curr)\n",
"\n",
" memory[\"A\" + str(idx)] = A_prev\n",
" memory[\"Z\" + str(layer_idx)] = Z_curr\n",
"\n",
" return A_curr, memory"
],
"id": "c3cd9e8f51dbe967",
"outputs": [],
2025-11-05 23:31:08 +01:00
"execution_count": 7
2025-11-05 17:53:45 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-05 23:31:08 +01:00
"end_time": "2025-11-05T21:38:03.035821Z",
"start_time": "2025-11-05T21:38:03.030450Z"
2025-11-05 17:53:45 +01:00
}
},
"cell_type": "code",
"source": [
"def get_cost_value(Y_hat, Y):\n",
" m = Y_hat.shape[1]\n",
" cost = -1 / m * (np.dot(Y, np.log(Y_hat).T) + np.dot(1 - Y, np.log(1 - Y_hat).T))\n",
" return np.squeeze(cost)\n",
"\n",
2025-11-05 23:31:08 +01:00
"\n",
2025-11-05 17:53:45 +01:00
"# an auxiliary function that converts probability into class\n",
"def convert_prob_into_class(probs):\n",
" probs_ = np.copy(probs)\n",
" probs_[probs_ > 0.5] = 1\n",
" probs_[probs_ <= 0.5] = 0\n",
" return probs_\n",
"\n",
2025-11-05 23:31:08 +01:00
"\n",
2025-11-05 17:53:45 +01:00
"def get_accuracy_value(Y_hat, Y):\n",
" Y_hat_ = convert_prob_into_class(Y_hat)\n",
2025-11-05 23:31:08 +01:00
" return (Y_hat_ == Y).all(axis=0).mean()\n",
"\n",
"def get_accuracy_vector(Y_hat, Y):\n",
" diff = np.subtract(Y, Y_hat)\n",
" dot = np.dot(diff.T, diff)\n",
" return dot.trace() / dot.shape[0]\n",
"\n",
"# u = np.random.rand(2, 9)\n",
"# v = np.random.rand(2, 9)"
2025-11-05 17:53:45 +01:00
],
"id": "121416e7bbab57bb",
"outputs": [],
2025-11-05 23:31:08 +01:00
"execution_count": 258
2025-11-05 17:53:45 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-05 23:31:08 +01:00
"end_time": "2025-11-05T16:57:59.095180Z",
"start_time": "2025-11-05T16:57:59.092006Z"
2025-11-05 17:53:45 +01:00
}
},
"cell_type": "code",
"source": [
"def single_layer_backward_propagation(dA_curr, W_curr, b_curr, Z_curr, A_prev, activation=\"relu\"):\n",
" m = A_prev.shape[1]\n",
"\n",
" if activation == \"relu\":\n",
" backward_activation_func = relu_backward\n",
" elif activation == \"sigmoid\":\n",
" backward_activation_func = sigmoid_backward\n",
" else:\n",
" raise Exception('Non-supported activation function')\n",
"\n",
" dZ_curr = backward_activation_func(dA_curr, Z_curr)\n",
" dW_curr = np.dot(dZ_curr, A_prev.T) / m\n",
" db_curr = np.sum(dZ_curr, axis=1, keepdims=True) / m\n",
" dA_prev = np.dot(W_curr.T, dZ_curr)\n",
"\n",
" return dA_prev, dW_curr, db_curr"
],
"id": "92e4b87664f18a63",
"outputs": [],
2025-11-05 23:31:08 +01:00
"execution_count": 9
2025-11-05 17:53:45 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-05 23:31:08 +01:00
"end_time": "2025-11-05T17:33:20.050712Z",
"start_time": "2025-11-05T17:33:20.045249Z"
2025-11-05 17:53:45 +01:00
}
},
"cell_type": "code",
"source": [
"def full_backward_propagation(Y_hat, Y, memory, params_values, nn_architecture):\n",
" grads_values = {}\n",
2025-11-05 23:31:08 +01:00
" m = Y.shape\n",
" # Y = Y.reshape(Y_hat.shape)\n",
2025-11-05 17:53:45 +01:00
"\n",
" dA_prev = - (np.divide(Y, Y_hat) - np.divide(1 - Y, 1 - Y_hat));\n",
"\n",
" for layer_idx_prev, layer in reversed(list(enumerate(nn_architecture))):\n",
" layer_idx_curr = layer_idx_prev + 1\n",
" activ_function_curr = layer[\"activation\"]\n",
"\n",
" dA_curr = dA_prev\n",
"\n",
" A_prev = memory[\"A\" + str(layer_idx_prev)]\n",
" Z_curr = memory[\"Z\" + str(layer_idx_curr)]\n",
" W_curr = params_values[\"W\" + str(layer_idx_curr)]\n",
" b_curr = params_values[\"b\" + str(layer_idx_curr)]\n",
"\n",
" dA_prev, dW_curr, db_curr = single_layer_backward_propagation(\n",
" dA_curr, W_curr, b_curr, Z_curr, A_prev, activ_function_curr)\n",
"\n",
" grads_values[\"dW\" + str(layer_idx_curr)] = dW_curr\n",
" grads_values[\"db\" + str(layer_idx_curr)] = db_curr\n",
"\n",
" return grads_values"
],
"id": "2c8e4eed1846f003",
"outputs": [],
2025-11-05 23:31:08 +01:00
"execution_count": 64
2025-11-05 17:53:45 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-05 23:31:08 +01:00
"end_time": "2025-11-05T16:57:59.200900Z",
"start_time": "2025-11-05T16:57:59.195743Z"
2025-11-05 17:53:45 +01:00
}
},
"cell_type": "code",
"source": [
"def update(params_values, grads_values, nn_architecture, learning_rate):\n",
" for layer_idx, layer in enumerate(nn_architecture, 1):\n",
" params_values[\"W\" + str(layer_idx)] -= learning_rate * grads_values[\"dW\" + str(layer_idx)]\n",
" params_values[\"b\" + str(layer_idx)] -= learning_rate * grads_values[\"db\" + str(layer_idx)]\n",
"\n",
" return params_values;"
],
"id": "16320b953a183511",
"outputs": [],
2025-11-05 23:31:08 +01:00
"execution_count": 11
2025-11-05 17:53:45 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-05 23:31:08 +01:00
"end_time": "2025-11-05T22:00:52.345331Z",
"start_time": "2025-11-05T22:00:52.339775Z"
2025-11-05 17:53:45 +01:00
}
},
"cell_type": "code",
"source": [
"def train(X, Y, nn_architecture, epochs, learning_rate, verbose=False, callback=None):\n",
" # initiation of neural net parameters\n",
" params_values = init_layers(nn_architecture, 2)\n",
" # initiation of lists storing the history\n",
" # of metrics calculated during the learning process\n",
" cost_history = []\n",
" accuracy_history = []\n",
"\n",
" # performing calculations for subsequent iterations\n",
" for i in range(epochs):\n",
" # step forward\n",
" Y_hat, cashe = full_forward_propagation(X, params_values, nn_architecture)\n",
"\n",
" # calculating metrics and saving them in history\n",
" cost = get_cost_value(Y_hat, Y)\n",
" cost_history.append(cost)\n",
2025-11-05 23:31:08 +01:00
"\n",
" accuracy = get_accuracy_vector(Y_hat, Y)\n",
"\n",
2025-11-05 17:53:45 +01:00
" accuracy_history.append(accuracy)\n",
"\n",
2025-11-05 23:31:08 +01:00
" # print(\"Y_hat.shape: {}, Y.shape: {}\".format(Y_hat.shape, Y.shape))\n",
"\n",
2025-11-05 17:53:45 +01:00
" # step backward - calculating gradient\n",
" grads_values = full_backward_propagation(Y_hat, Y, cashe, params_values, nn_architecture)\n",
2025-11-05 23:31:08 +01:00
"\n",
" if (i % 50000 == 0):\n",
" print(\"Learning rate: {}\".format(learning_rate))\n",
" learning_rate = learning_rate / 10.0\n",
"\n",
2025-11-05 17:53:45 +01:00
" params_values = update(params_values, grads_values, nn_architecture, learning_rate)\n",
"\n",
2025-11-05 23:31:08 +01:00
" if (i % 1000 == 0):\n",
" print(\"dW1 norm 2 grad: {}, accu: {}\".format(np.linalg.norm(grads_values[\"dW1\"]), accuracy))\n",
" if (verbose):\n",
2025-11-05 17:53:45 +01:00
" print(\"Iteration: {:05} - cost: {:.5f} - accuracy: {:.5f}\".format(i, cost, accuracy))\n",
2025-11-05 23:31:08 +01:00
" if (callback is not None):\n",
2025-11-05 17:53:45 +01:00
" callback(i, params_values)\n",
"\n",
" return params_values"
],
"id": "fce33f70bba3898",
"outputs": [],
2025-11-05 23:31:08 +01:00
"execution_count": 306
2025-11-05 17:53:45 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-05 23:31:08 +01:00
"end_time": "2025-11-05T16:58:00.043457Z",
"start_time": "2025-11-05T16:57:59.317515Z"
2025-11-05 17:53:45 +01:00
}
},
"cell_type": "code",
"source": [
"import os\n",
"import tensorflow as tf\n",
"\n",
"import sklearn.datasets as ds\n",
"import sklearn.utils as su\n",
"from sklearn.datasets import make_moons\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib import cm\n",
"from mpl_toolkits.mplot3d import Axes3D\n",
2025-11-05 23:31:08 +01:00
"\n",
2025-11-05 17:53:45 +01:00
"sns.set_style(\"whitegrid\")\n",
"\n",
"import keras\n",
"from keras.models import Sequential\n",
"from keras.layers import Dense\n",
"# from keras.utils import np_utils\n",
"from keras import regularizers\n",
"\n",
"from sklearn.metrics import accuracy_score"
],
"id": "cccd73b5018799d4",
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2025-11-05 23:31:08 +01:00
"2025-11-05 17:57:59.430264: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
"2025-11-05 17:57:59.456029: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
2025-11-05 17:53:45 +01:00
"To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
2025-11-05 23:31:08 +01:00
"2025-11-05 17:57:59.930030: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n"
2025-11-05 17:53:45 +01:00
]
}
],
2025-11-05 23:31:08 +01:00
"execution_count": 13
2025-11-05 17:53:45 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-05 23:31:08 +01:00
"end_time": "2025-11-05T16:58:02.921068Z",
"start_time": "2025-11-05T16:58:02.919524Z"
2025-11-05 17:53:45 +01:00
}
},
"cell_type": "code",
"source": [
"# number of samples in the data set\n",
"N_SAMPLES = 1000\n",
"# ratio between training and test sets\n",
"TEST_SIZE = 0.1"
],
"id": "4f66ffa878f01c02",
"outputs": [],
2025-11-05 23:31:08 +01:00
"execution_count": 18
2025-11-05 17:53:45 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-05 23:31:08 +01:00
"end_time": "2025-11-05T21:46:39.352413Z",
"start_time": "2025-11-05T21:46:39.344853Z"
2025-11-05 17:53:45 +01:00
}
},
"cell_type": "code",
"source": [
"def encode_add(i: int) -> float:\n",
" return float(i) / 10.0\n",
"\n",
2025-11-05 23:31:08 +01:00
"\n",
2025-11-05 17:53:45 +01:00
"def decode_add(i) -> int:\n",
" return int(i[0] * 10 + i[1])\n",
"\n",
2025-11-05 23:31:08 +01:00
"\n",
"def add(a: float, b: float):\n",
2025-11-05 17:53:45 +01:00
" r = a * 10.0 + b * 10.0\n",
"\n",
" r0 = floor(r % 10)\n",
" r1 = floor(r / 10)\n",
2025-11-05 23:31:08 +01:00
" return r1, r0\n",
2025-11-05 17:53:45 +01:00
"\n",
2025-11-05 23:31:08 +01:00
"\n",
"def encode_to_vector(x: float, y: float):\n",
" i, j = add(x, y)\n",
2025-11-05 17:53:45 +01:00
" vector = np.zeros(20)\n",
" vector[i] = 1\n",
2025-11-05 23:31:08 +01:00
" vector[j + 10] = 1\n",
2025-11-05 17:53:45 +01:00
" return vector\n",
"\n",
2025-11-05 23:31:08 +01:00
"\n",
"def decode_from_vector(vector):\n",
" i = np.argmax(vector[0:10]) + 1\n",
" j = np.argmax(vector[10:20]) - 10\n",
" return decode_add((i, j))\n",
"\n",
"\n",
2025-11-05 17:53:45 +01:00
"# add(encode_add(2),encode_add(3))\n",
"# encode_to_vector(encode_add(2),encode_add(3))\n",
"\n",
"def make_sums(\n",
2025-11-05 23:31:08 +01:00
" n_samples=100, *, shuffle=False, noise=None, random_state=None, factor=0.8\n",
2025-11-05 17:53:45 +01:00
"):\n",
" X = []\n",
" y = []\n",
"\n",
" for i in np.linspace(0, 9, 10):\n",
" for j in np.linspace(0, 9, 10):\n",
" i_int = int(i)\n",
" j_int = int(j)\n",
2025-11-05 23:31:08 +01:00
" X.append([encode_add(i_int), encode_add(j_int)])\n",
" y.append(encode_to_vector(encode_add(i_int), encode_add(j_int)))\n",
2025-11-05 17:53:45 +01:00
"\n",
2025-11-05 23:31:08 +01:00
" # X = np.array(X).T # Shape: (2, 100)\n",
" # y = np.array(y).T # Shape: (20, 100)\n",
2025-11-05 17:53:45 +01:00
"\n",
" if shuffle and random_state is not None:\n",
" np.random.seed(random_state)\n",
" indices = np.random.permutation(X.shape[1])\n",
" X = X[:, indices]\n",
" y = y[:, indices]\n",
"\n",
" return X, y\n",
"\n",
2025-11-05 23:31:08 +01:00
"# decode_from_vector(encode_to_vector(encode_add(9), encode_add(9)))\n"
2025-11-05 17:53:45 +01:00
],
"id": "7ce930351bba500c",
2025-11-05 23:31:08 +01:00
"outputs": [],
"execution_count": 272
2025-11-05 17:53:45 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-05 23:31:08 +01:00
"end_time": "2025-11-05T22:27:36.512542Z",
"start_time": "2025-11-05T22:27:36.510154Z"
2025-11-05 17:53:45 +01:00
}
},
"cell_type": "code",
"source": [
2025-11-05 23:31:08 +01:00
"X, y = make_sums()\n",
"# X, y = ds.make_circles(n_samples = N_SAMPLES, noise=0.2, random_state=100)\n",
2025-11-05 17:53:45 +01:00
"# X, y = make_moons(n_samples = N_SAMPLES, noise=0.2, random_state=100)\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=TEST_SIZE, random_state=42)"
],
"id": "bebe0ed00a2d514",
"outputs": [],
2025-11-05 23:31:08 +01:00
"execution_count": 370
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-11-05T21:46:45.705860Z",
"start_time": "2025-11-05T21:46:45.702814Z"
}
},
"cell_type": "code",
"source": [
"X = np.transpose(X_train)\n",
"Y = np.transpose(y_train)\n",
"epochs = 30000\n",
"learning_rate = 0.01\n",
"verbose = True\n",
"callback = None\n",
"\n",
"params_values = init_layers(nn_architecture, 2)\n",
"# initiation of lists storing the history\n",
"# of metrics calculated during the learning process\n",
"cost_history = []\n",
"accuracy_history = []\n",
"\n",
"# performing calculations for subsequent iterations\n",
"# for i in range(epochs):\n",
"# step forward\n",
"Y_hat, cashe = full_forward_propagation(X, params_values, nn_architecture)\n",
"\n",
"# calculating metrics and saving them in history\n",
"cost = get_cost_value(Y_hat, Y)\n",
"cost_history.append(cost)\n",
"accuracy = get_accuracy_vector(Y_hat, Y)\n",
"accuracy_history.append(accuracy)\n",
"\n",
"print(\"Y_hat.shape: {}, Y.shape: {}\".format(Y_hat.shape, Y.shape))\n",
"\n",
"grads_values = full_backward_propagation(Y_hat, Y, cashe, params_values, nn_architecture)\n",
"\n",
"params_values = update(params_values, grads_values, nn_architecture, learning_rate)\n"
],
"id": "f362f4b889723f85",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Y_hat.shape: (20, 90), Y.shape: (20, 90)\n"
]
}
],
"execution_count": 274
2025-11-05 17:53:45 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-05 23:31:08 +01:00
"end_time": "2025-11-05T22:02:19.390360Z",
"start_time": "2025-11-05T22:01:41.944726Z"
2025-11-05 17:53:45 +01:00
}
},
"cell_type": "code",
"source": [
2025-11-05 23:31:08 +01:00
"params_values = train(np.transpose(X_train), np.transpose(y_train), nn_architecture, 210000, 0.0092, verbose=False)\n",
2025-11-05 17:53:45 +01:00
"# params_values\n"
],
"id": "ce04892d496c5147",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2025-11-05 23:31:08 +01:00
"Learning rate: 0.0092\n",
"dW1 norm 2 grad: 0.030587792178820513, accu: 4.941336157937026\n",
"dW1 norm 2 grad: 0.1646929550834701, accu: 1.9357215120084397\n",
"dW1 norm 2 grad: 0.019128990025118184, accu: 1.403621959589698\n",
"dW1 norm 2 grad: 0.010786917974214204, accu: 1.3994765844539645\n",
"dW1 norm 2 grad: 0.006973910701561598, accu: 1.398801328858226\n",
"dW1 norm 2 grad: 0.004804662293654282, accu: 1.398564108821892\n",
"dW1 norm 2 grad: 0.0037441349283433623, accu: 1.3984079114655157\n",
"dW1 norm 2 grad: 0.0033954272366276093, accu: 1.3982712361598055\n",
"dW1 norm 2 grad: 0.003206636031118767, accu: 1.3981389888275535\n",
"dW1 norm 2 grad: 0.003465722901160753, accu: 1.3980025782117487\n",
"dW1 norm 2 grad: 0.003626722252604109, accu: 1.397860855459184\n",
"dW1 norm 2 grad: 0.003897822935664194, accu: 1.3977138200524064\n",
"dW1 norm 2 grad: 0.004241887307465233, accu: 1.3975571070828419\n",
"dW1 norm 2 grad: 0.004657598488948141, accu: 1.3973921901095299\n",
"dW1 norm 2 grad: 0.005220580588807355, accu: 1.397215829967114\n",
"dW1 norm 2 grad: 0.005846951555849162, accu: 1.3970250436043212\n",
"dW1 norm 2 grad: 0.006524684719971642, accu: 1.3968167894331327\n",
"dW1 norm 2 grad: 0.007310990726029376, accu: 1.3965867688366898\n",
"dW1 norm 2 grad: 0.008200467149319903, accu: 1.3963309075297812\n",
"dW1 norm 2 grad: 0.009301152067821774, accu: 1.396042748753016\n",
"dW1 norm 2 grad: 0.010166278854072197, accu: 1.3957157327814165\n",
"dW1 norm 2 grad: 0.011425139837519494, accu: 1.3953405108268921\n",
"dW1 norm 2 grad: 0.012950014512970601, accu: 1.3949055628030786\n",
"dW1 norm 2 grad: 0.014355878579536059, accu: 1.394394841643409\n",
"dW1 norm 2 grad: 0.016137061178322212, accu: 1.393787043523792\n",
"dW1 norm 2 grad: 0.01826892984944323, accu: 1.393051951259501\n",
"dW1 norm 2 grad: 0.021143954900072977, accu: 1.3921339908391197\n",
"dW1 norm 2 grad: 0.02423583026714934, accu: 1.3909084263048683\n",
"dW1 norm 2 grad: 0.02787064236557585, accu: 1.3893470499486298\n",
"dW1 norm 2 grad: 0.03272810256003809, accu: 1.3873267538809206\n",
"dW1 norm 2 grad: 0.03832710708326575, accu: 1.3845979543009885\n",
"dW1 norm 2 grad: 0.045545571050636745, accu: 1.3807623420159982\n",
"dW1 norm 2 grad: 0.05624419801088249, accu: 1.375112751337416\n",
"dW1 norm 2 grad: 0.07016642479210054, accu: 1.3663006786836638\n",
"dW1 norm 2 grad: 0.0905448224502262, accu: 1.3515789855937521\n",
"dW1 norm 2 grad: 0.1232520179315588, accu: 1.3240964990083572\n",
"dW1 norm 2 grad: 0.16541711360085276, accu: 1.2706817647472526\n",
"dW1 norm 2 grad: 0.19353060647709652, accu: 1.1786673745786147\n",
"dW1 norm 2 grad: 0.18934424610944875, accu: 1.0755383964359753\n",
"dW1 norm 2 grad: 0.1408928760505227, accu: 1.0107582411953095\n",
"dW1 norm 2 grad: 0.10766369209801253, accu: 0.9758933289489259\n",
"dW1 norm 2 grad: 0.0866161481253012, accu: 0.9546471539965528\n",
"dW1 norm 2 grad: 0.07664582736675311, accu: 0.9382959188648609\n",
"dW1 norm 2 grad: 0.06378446361836485, accu: 0.92428916845143\n",
"dW1 norm 2 grad: 0.0608145095307446, accu: 0.9114837085330408\n",
"dW1 norm 2 grad: 0.06972571774543508, accu: 0.8973253645899818\n",
"dW1 norm 2 grad: 0.09983976688230996, accu: 0.8800802854305243\n",
"dW1 norm 2 grad: 0.11626330494728294, accu: 0.8619511316626135\n",
"dW1 norm 2 grad: 0.049351391928600205, accu: 0.8416075309665876\n",
"dW1 norm 2 grad: 0.03913529011211546, accu: 0.8213559372122169\n",
"Learning rate: 0.00092\n",
"dW1 norm 2 grad: 0.0698354824909347, accu: 0.8021223071821478\n",
"dW1 norm 2 grad: 0.06931343312498642, accu: 0.8002995625700543\n",
"dW1 norm 2 grad: 0.04781421567390308, accu: 0.7985017709141918\n",
"dW1 norm 2 grad: 0.04748953758278176, accu: 0.796723505677864\n",
"dW1 norm 2 grad: 0.05246500533030227, accu: 0.7949704071822465\n",
"dW1 norm 2 grad: 0.027031577043732952, accu: 0.7932488950712784\n",
"dW1 norm 2 grad: 0.05116794968101623, accu: 0.7915531315627051\n",
"dW1 norm 2 grad: 0.11213636049360046, accu: 0.7898833836960827\n",
"dW1 norm 2 grad: 0.06823964221603801, accu: 0.7882536651796254\n",
"dW1 norm 2 grad: 0.10714340854837753, accu: 0.7866551302317062\n",
"dW1 norm 2 grad: 0.04994649850404826, accu: 0.7850793043793554\n",
"dW1 norm 2 grad: 0.1099960451952437, accu: 0.7835262077531631\n",
"dW1 norm 2 grad: 0.0439103861538206, accu: 0.7819942812222288\n",
"dW1 norm 2 grad: 0.10874239209642468, accu: 0.7804808471860741\n",
"dW1 norm 2 grad: 0.11393738889197819, accu: 0.7789666328853002\n",
"dW1 norm 2 grad: 0.04495814658240637, accu: 0.7774557442802752\n",
"dW1 norm 2 grad: 0.04395459066772111, accu: 0.7759593873186026\n",
"dW1 norm 2 grad: 0.043967202328786065, accu: 0.7744766168993356\n",
"dW1 norm 2 grad: 0.04442609020118015, accu: 0.7730071428706276\n",
"dW1 norm 2 grad: 0.04458841504135935, accu: 0.7715629241882856\n",
"dW1 norm 2 grad: 0.10906288744417346, accu: 0.7701327907640183\n",
"dW1 norm 2 grad: 0.04608212489484624, accu: 0.7686767623882792\n",
"dW1 norm 2 grad: 0.04205250484603993, accu: 0.7671757348155648\n",
"dW1 norm 2 grad: 0.10858734045274064, accu: 0.7656854562107144\n",
"dW1 norm 2 grad: 0.04079451613320205, accu: 0.7641965674293068\n",
"dW1 norm 2 grad: 0.04119869852007931, accu: 0.7627151990621612\n",
"dW1 norm 2 grad: 0.022823050772103445, accu: 0.7612407621215996\n",
"dW1 norm 2 grad: 0.05988234296477944, accu: 0.7597502571796344\n",
"dW1 norm 2 grad: 0.03994801225051277, accu: 0.758263715093205\n",
"dW1 norm 2 grad: 0.02001563328832784, accu: 0.7567800459820742\n",
"dW1 norm 2 grad: 0.09389221519341912, accu: 0.7553161878877876\n",
"dW1 norm 2 grad: 0.036694546941468544, accu: 0.7538650579919823\n",
"dW1 norm 2 grad: 0.09363821631353306, accu: 0.7524245293792678\n",
"dW1 norm 2 grad: 0.08890011864334593, accu: 0.751074225488152\n",
"dW1 norm 2 grad: 0.08711660000940519, accu: 0.7497468566289912\n",
"dW1 norm 2 grad: 0.08433113821881433, accu: 0.7484255278238835\n",
"dW1 norm 2 grad: 0.08274136620344481, accu: 0.7471096846730314\n",
"dW1 norm 2 grad: 0.08263159774910961, accu: 0.7457989989854615\n",
"dW1 norm 2 grad: 0.07996019815679523, accu: 0.7444924666609721\n",
"dW1 norm 2 grad: 0.07590974689644236, accu: 0.74319290888101\n",
"dW1 norm 2 grad: 0.07671148899598569, accu: 0.7418974636517687\n",
"dW1 norm 2 grad: 0.07651282066883193, accu: 0.7406057881746861\n",
"dW1 norm 2 grad: 0.07194127532607032, accu: 0.7393175244984712\n",
"dW1 norm 2 grad: 0.046034760180805066, accu: 0.7380394106977187\n",
"dW1 norm 2 grad: 0.04574718764808875, accu: 0.7368222319175388\n",
"dW1 norm 2 grad: 0.025349022989991577, accu: 0.7356395543948386\n",
"dW1 norm 2 grad: 0.03560280602726019, accu: 0.7344641888766926\n",
"dW1 norm 2 grad: 0.09259080063105282, accu: 0.7332947744286727\n",
"dW1 norm 2 grad: 0.0911961020909095, accu: 0.7321292733139444\n",
"dW1 norm 2 grad: 0.08942583365863134, accu: 0.7309675839019333\n",
"Learning rate: 9.2e-05\n",
"dW1 norm 2 grad: 0.05081509253625319, accu: 0.7298165738413436\n",
"dW1 norm 2 grad: 0.08069013256772081, accu: 0.7297031031059652\n",
"dW1 norm 2 grad: 0.033465976546192586, accu: 0.7295897688790668\n",
"dW1 norm 2 grad: 0.0838071759382461, accu: 0.7294765059452227\n",
"dW1 norm 2 grad: 0.08369039264117137, accu: 0.7293633551904505\n",
"dW1 norm 2 grad: 0.025063043610271528, accu: 0.7292503109024019\n",
"dW1 norm 2 grad: 0.0250903067228421, accu: 0.7291372655826458\n",
"dW1 norm 2 grad: 0.024807189346951402, accu: 0.7290242245133101\n",
"dW1 norm 2 grad: 0.07990580710201337, accu: 0.7289111175614715\n",
"dW1 norm 2 grad: 0.024621601100137287, accu: 0.72879804031431\n",
"dW1 norm 2 grad: 0.023355088185910863, accu: 0.7286849775453704\n",
"dW1 norm 2 grad: 0.02274113942950061, accu: 0.7285719795169106\n",
"dW1 norm 2 grad: 0.049949659194981566, accu: 0.7284593451650372\n",
"dW1 norm 2 grad: 0.023248965569912117, accu: 0.7283468190487229\n",
"dW1 norm 2 grad: 0.023152204671378268, accu: 0.7282343171189314\n",
"dW1 norm 2 grad: 0.04969774869533673, accu: 0.7281217905460505\n",
"dW1 norm 2 grad: 0.08333052873343195, accu: 0.7280093859426575\n",
"dW1 norm 2 grad: 0.07823212307227495, accu: 0.727897047104545\n",
"dW1 norm 2 grad: 0.08257422509504603, accu: 0.7277847567346591\n",
"dW1 norm 2 grad: 0.02224991771397424, accu: 0.7276725529494683\n",
"dW1 norm 2 grad: 0.02711910610218876, accu: 0.7275603644520168\n",
"dW1 norm 2 grad: 0.07886200292692834, accu: 0.7274482275685278\n",
"dW1 norm 2 grad: 0.02198031237106057, accu: 0.7273361722482247\n",
"dW1 norm 2 grad: 0.02115994730449518, accu: 0.7272241086079306\n",
"dW1 norm 2 grad: 0.021285257622553382, accu: 0.7271122986359888\n",
"dW1 norm 2 grad: 0.02357392567635191, accu: 0.7270004888545573\n",
"dW1 norm 2 grad: 0.022688686003681072, accu: 0.7268886540360114\n",
"dW1 norm 2 grad: 0.0774804194383182, accu: 0.7267768612068347\n",
"dW1 norm 2 grad: 0.023296298417766592, accu: 0.726665878152204\n",
"dW1 norm 2 grad: 0.0268274942998492, accu: 0.7265561850412595\n",
"dW1 norm 2 grad: 0.06052005017135857, accu: 0.7264469435954387\n",
"dW1 norm 2 grad: 0.04945838108083401, accu: 0.7263377323697943\n",
"dW1 norm 2 grad: 0.027067935503537366, accu: 0.726229201502864\n",
"dW1 norm 2 grad: 0.026893122048369005, accu: 0.7261207218241621\n",
"dW1 norm 2 grad: 0.02174669349434738, accu: 0.7260123279798082\n",
"dW1 norm 2 grad: 0.027435227540288768, accu: 0.7259041066903453\n",
"dW1 norm 2 grad: 0.08375911784691781, accu: 0.7257962378376251\n",
"dW1 norm 2 grad: 0.022652301601058415, accu: 0.7256884416312775\n",
"dW1 norm 2 grad: 0.08289691349070971, accu: 0.7255806565534714\n",
"dW1 norm 2 grad: 0.08364752167205262, accu: 0.7254731199083936\n",
"dW1 norm 2 grad: 0.026692046014157863, accu: 0.7253656043834251\n",
"dW1 norm 2 grad: 0.04913025170849942, accu: 0.7252580886894395\n",
"dW1 norm 2 grad: 0.023689498661408678, accu: 0.7251506668124373\n",
"dW1 norm 2 grad: 0.02657062856609085, accu: 0.725042860431935\n",
"dW1 norm 2 grad: 0.023702493689244892, accu: 0.7249343664607922\n",
"dW1 norm 2 grad: 0.02342068477793123, accu: 0.724825581515084\n",
"dW1 norm 2 grad: 0.023512325200450875, accu: 0.7247168048663374\n",
"dW1 norm 2 grad: 0.026039992426218454, accu: 0.7246080322347013\n",
"dW1 norm 2 grad: 0.06301257689207454, accu: 0.7244992898894886\n",
"dW1 norm 2 grad: 0.08123844996098815, accu: 0.724390594267674\n",
"Learning rate: 9.2e-06\n",
"dW1 norm 2 grad: 0.05660752226816397, accu: 0.724281912921313\n",
"dW1 norm 2 grad: 0.026860772665232544, accu: 0.7242710584653618\n",
"dW1 norm 2 grad: 0.024496875972681183, accu: 0.7242602093657362\n",
"dW1 norm 2 grad: 0.025294084526166837, accu: 0.7242493410645534\n",
"dW1 norm 2 grad: 0.022663938572628706, accu: 0.724238471439105\n",
"dW1 norm 2 grad: 0.08092590829169215, accu: 0.7242275992665872\n",
"dW1 norm 2 grad: 0.021902604575931107, accu: 0.7242167354020932\n",
"dW1 norm 2 grad: 0.025736603559630894, accu: 0.7242058633500534\n",
"dW1 norm 2 grad: 0.0581973314338413, accu: 0.7241949951587566\n",
"dW1 norm 2 grad: 0.02344384040486643, accu: 0.7241841285869589\n",
"dW1 norm 2 grad: 0.021904167014846036, accu: 0.7241732618842978\n",
"dW1 norm 2 grad: 0.02494119041657796, accu: 0.7241623921154894\n",
"dW1 norm 2 grad: 0.07836203187703025, accu: 0.7241515204595299\n",
"dW1 norm 2 grad: 0.025193938991455386, accu: 0.7241406544392891\n",
"dW1 norm 2 grad: 0.07560872303681854, accu: 0.7241297849270177\n",
"dW1 norm 2 grad: 0.02556473409099411, accu: 0.7241189184717867\n",
"dW1 norm 2 grad: 0.07650664220580596, accu: 0.7241080494594142\n",
"dW1 norm 2 grad: 0.08074806296038219, accu: 0.7240971814540356\n",
"dW1 norm 2 grad: 0.024689714218072648, accu: 0.7240863166411479\n",
"dW1 norm 2 grad: 0.08164790148339035, accu: 0.7240754467496378\n",
"dW1 norm 2 grad: 0.023112710622302055, accu: 0.7240645839138531\n",
"dW1 norm 2 grad: 0.025532646872846385, accu: 0.7240537140180505\n",
"dW1 norm 2 grad: 0.08065985712188835, accu: 0.7240428454529251\n",
"dW1 norm 2 grad: 0.0471336793636139, accu: 0.7240319784959498\n",
"dW1 norm 2 grad: 0.08153184230557896, accu: 0.7240211117910492\n",
"dW1 norm 2 grad: 0.0815184663751406, accu: 0.7240102446800086\n",
"dW1 norm 2 grad: 0.025456334538617304, accu: 0.7239993797762518\n",
"dW1 norm 2 grad: 0.023264980480717054, accu: 0.7239885152111308\n",
"dW1 norm 2 grad: 0.07961769303874187, accu: 0.7239776454654755\n",
"dW1 norm 2 grad: 0.07625319759743494, accu: 0.7239667793159715\n",
"dW1 norm 2 grad: 0.02539726871451324, accu: 0.7239559145024411\n",
"dW1 norm 2 grad: 0.0237378349559094, accu: 0.7239450513078237\n",
"dW1 norm 2 grad: 0.02513608435427682, accu: 0.7239341840865249\n",
"dW1 norm 2 grad: 0.08137682892886208, accu: 0.7239233148786304\n",
"dW1 norm 2 grad: 0.04811297421618333, accu: 0.7239124656845819\n",
"dW1 norm 2 grad: 0.023193417673846712, accu: 0.7239016316445489\n",
"dW1 norm 2 grad: 0.0256038658999626, accu: 0.723890791326119\n",
"dW1 norm 2 grad: 0.021947534582446696, accu: 0.723879955190287\n",
"dW1 norm 2 grad: 0.08167978225793682, accu: 0.7238691117330809\n",
"dW1 norm 2 grad: 0.0231233474515702, accu: 0.7238582754116796\n",
"dW1 norm 2 grad: 0.08161827390446047, accu: 0.7238474330611652\n",
"dW1 norm 2 grad: 0.07974789689423355, accu: 0.7238365943467001\n",
"dW1 norm 2 grad: 0.022592499032915825, accu: 0.7238257594435297\n",
"dW1 norm 2 grad: 0.022504525579480045, accu: 0.7238149204937029\n",
"dW1 norm 2 grad: 0.08156846999929206, accu: 0.7238040775315533\n",
"dW1 norm 2 grad: 0.023618593965625455, accu: 0.7237932416255065\n",
"dW1 norm 2 grad: 0.08150850470928286, accu: 0.7237823997171771\n",
"dW1 norm 2 grad: 0.07696639270478159, accu: 0.7237715613280187\n",
"dW1 norm 2 grad: 0.025295856277240028, accu: 0.7237607248761024\n",
"dW1 norm 2 grad: 0.0796405362393356, accu: 0.72374988335138\n",
"Learning rate: 9.2e-07\n",
"dW1 norm 2 grad: 0.02533836373189624, accu: 0.7237390482440271\n",
"dW1 norm 2 grad: 0.07694830713060108, accu: 0.723737963788499\n",
"dW1 norm 2 grad: 0.07942747965329897, accu: 0.7237368813798639\n",
"dW1 norm 2 grad: 0.024989911441789227, accu: 0.7237357976831307\n",
"dW1 norm 2 grad: 0.025115561130679122, accu: 0.723734713740944\n",
"dW1 norm 2 grad: 0.024750581477314067, accu: 0.7237336301018029\n",
"dW1 norm 2 grad: 0.025057096033220413, accu: 0.7237325461940796\n",
"dW1 norm 2 grad: 0.08142729217753651, accu: 0.7237314619742643\n",
"dW1 norm 2 grad: 0.023236646366497984, accu: 0.7237303784489091\n",
"dW1 norm 2 grad: 0.02622057086597789, accu: 0.7237292943668882\n"
2025-11-05 17:53:45 +01:00
]
}
],
2025-11-05 23:31:08 +01:00
"execution_count": 308
2025-11-05 17:53:45 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-05 23:31:08 +01:00
"end_time": "2025-11-05T22:02:26.485044Z",
"start_time": "2025-11-05T22:02:26.478191Z"
2025-11-05 17:53:45 +01:00
}
},
"cell_type": "code",
"source": [
"Y_test_hat, _ = full_forward_propagation(np.transpose(X_test), params_values, nn_architecture)\n",
2025-11-05 23:31:08 +01:00
"# X_test[0]\n",
"# Y_test_hat[0]\n",
"# Y_test_hat.T[0][10:20]\n",
"# np.argmax(Y_test_hat[0][0:10])\n",
"# Y_test_hat[0][0:10][8]\n",
"\n",
"# decode_from_vector(Y_test_hat.T[0])\n",
2025-11-05 17:53:45 +01:00
"\n",
2025-11-05 23:31:08 +01:00
"for i in range(len(X_test)):\n",
" print(\"for: {}, pred: {}\".format(int(np.sum(X_test[i])*10), decode_from_vector(Y_test_hat.T[i])))\n",
"\n"
2025-11-05 17:53:45 +01:00
],
"id": "26e7a2a8848714d9",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2025-11-05 23:31:08 +01:00
"for: 11, pred: 10\n",
"for: 8, pred: 9\n",
"for: 7, pred: 7\n",
"for: 9, pred: 9\n",
"for: 8, pred: 9\n",
"for: 12, pred: 11\n",
"for: 4, pred: 6\n",
"for: 8, pred: 9\n",
"for: 1, pred: 6\n",
"for: 0, pred: 6\n"
]
}
],
"execution_count": 309
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-11-05T22:02:31.875471Z",
"start_time": "2025-11-05T22:02:31.869284Z"
}
},
"cell_type": "code",
"source": [
"Y_train_hat, _ = full_forward_propagation(np.transpose(X_train[1:10]), params_values, nn_architecture)\n",
"\n",
"for i in range(len(X_train[1:10])):\n",
" print(\"for: {}, pred: {}\".format(int(np.sum(X_train[1:10][i])*10), decode_from_vector(Y_train_hat.T[i])))\n"
],
"id": "50a9492ad1e6a37c",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"for: 3, pred: 6\n",
"for: 10, pred: 10\n",
"for: 6, pred: 6\n",
"for: 9, pred: 9\n",
"for: 4, pred: 6\n",
"for: 12, pred: 13\n",
"for: 14, pred: 13\n",
"for: 3, pred: 6\n",
"for: 4, pred: 6\n"
2025-11-05 17:53:45 +01:00
]
}
],
2025-11-05 23:31:08 +01:00
"execution_count": 310
2025-11-05 17:53:45 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-05 23:31:08 +01:00
"end_time": "2025-11-05T16:58:00.360548085Z",
2025-11-05 17:53:45 +01:00
"start_time": "2025-11-05T13:43:01.537741Z"
}
},
"cell_type": "code",
"source": [
"# boundary of the graph\n",
"GRID_X_START = -1.5\n",
"GRID_X_END = 1.5\n",
"GRID_Y_START = -1.5\n",
"GRID_Y_END = 1.5\n",
"# output directory (the folder must be created on the drive)\n",
"OUTPUT_DIR = \"./binary_classification_vizualizations/\"\n",
"os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
"### Definition of grid boundaries\n",
"grid = np.mgrid[GRID_X_START:GRID_X_END:100j, GRID_Y_START:GRID_Y_END:100j]\n",
"grid_2d = grid.reshape(2, -1).T\n",
"XX, YY = grid"
],
"id": "b070f03d55981894",
"outputs": [],
"execution_count": 148
},
{
"metadata": {
"ExecuteTime": {
2025-11-05 23:31:08 +01:00
"end_time": "2025-11-05T16:58:00.361037286Z",
2025-11-05 17:53:45 +01:00
"start_time": "2025-11-05T13:38:25.347259Z"
}
},
"cell_type": "code",
"source": [
"def make_plot(X, y, plot_name, file_name=None, XX=None, YY=None, preds=None, dark=False):\n",
" if (dark):\n",
" plt.style.use('dark_background')\n",
" else:\n",
" sns.set_style(\"whitegrid\")\n",
2025-11-05 23:31:08 +01:00
" plt.figure(figsize=(16, 12))\n",
2025-11-05 17:53:45 +01:00
" axes = plt.gca()\n",
" axes.set(xlabel=\"$X_1$\", ylabel=\"$X_2$\")\n",
" plt.title(plot_name, fontsize=30)\n",
" plt.subplots_adjust(left=0.20)\n",
" plt.subplots_adjust(right=0.80)\n",
2025-11-05 23:31:08 +01:00
" if (XX is not None and YY is not None and preds is not None):\n",
" plt.contourf(XX, YY, preds.reshape(XX.shape), 25, alpha=1, cmap=cm.Spectral)\n",
2025-11-05 17:53:45 +01:00
" plt.contour(XX, YY, preds.reshape(XX.shape), levels=[.5], cmap=\"Greys\", vmin=0, vmax=.6)\n",
" plt.scatter(X[:, 0], X[:, 1], c=y.ravel(), s=40, cmap=plt.cm.Spectral, edgecolors='black')\n",
2025-11-05 23:31:08 +01:00
" if (file_name):\n",
2025-11-05 17:53:45 +01:00
" plt.savefig(file_name)\n",
" plt.close()"
],
"id": "553e08ddc23ab78c",
"outputs": [],
"execution_count": 140
},
{
"metadata": {
"ExecuteTime": {
2025-11-05 23:31:08 +01:00
"end_time": "2025-11-05T16:58:00.363234514Z",
2025-11-05 17:53:45 +01:00
"start_time": "2025-11-05T13:38:25.370173Z"
}
},
"cell_type": "code",
"source": " make_plot(X, y, \"Dataset\")",
"id": "36c83562b7404392",
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 1600x1200 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAABBUAAAQHCAYAAACjn1GXAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjcsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvTLEjVAAAAAlwSFlzAAAPYQAAD2EBqD+naQABAABJREFUeJzs3XdAVeX/wPH3ufeyl2xBceHCCe49cs8cpQ1z5EzNLNuWZZqmqWWalXtvTcs90lTc4N7iRJAhsue99/z+8CffCDBE4IJ+Xv/UPc9znudzrge493OeoaiqqiKEEEIIIYQQQgjxlDSmDkAIIYQQQgghhBBFkyQVhBBCCCGEEEIIkSuSVBBCCCGEEEIIIUSuSFJBCCGEEEIIIYQQuSJJBSGEEEIIIYQQQuSKJBWEEEIIIYQQQgiRK5JUEEIIIYQQQgghRK5IUkEIIYQQQgghhBC5IkkFIYQQQgghhBBC5IokFYQQQgghhBBCCJErklQQQgghhBBCCCFEruhMHYAQQgjxInnppZe4d+/eE+uYm5tjbm5OsWLFcHV1pXTp0pQvX55atWpRo0YNzMzMCihaIYQQQognU1RVVU0dhBBCCPGiyElS4Uns7e1p06YNb731Fj4+PnkYWd6YNWsWs2fPTn+9dOlS6tevb8KInj/Hjh2jb9++6a9HjhzJu+++a8KIhBBCvMhk+oMQQghRhMTGxrJhwwa6devGqFGjCAsLM3VIQgghhHiByfQHIYQQwoQ++eQTKleunOFYWloasbGxxMbGEhISwunTpzl//jzJyckZ6u3cuZPjx48zc+ZMGQ0ghBBCCJOQpIIQQghhQlWrVs1RQiA5OZnNmzezZMkSgoKC0o8/fPiQIUOGMG/ePOrVq5efoQohhBBCZCLTH4QQQogiwNLSkt69e/PHH3/Qv3//DGXJycm89957hIeHmyY4IYQQQrywJKkghBBCFCE6nY7PPvuMzz77LMPxqKgopk6daqKohBBCCPGikukPQgghRBHUv39/Tp48ye7du9OPbdmyheHDh1OuXLkctXH//n2uXbtGcHAwcXFxABQrVgx3d3d8fX1xcHDIl9if1t27d7l+/TohISHEx8ej1WpxcHCgRIkS1KxZExsbm2fu4+bNm1y+fJmIiAgSEhLQarVYW1vj7u6Ol5cX5cuXR6fL/cemBw8ecPr0aSIjI4mOjsba2hpnZ2eqV6+Ol5fXM8cvhBBCmIpsKSmEEEIUoH9vKfksWy7evXuXtm3bYjQa04/1798/0yiGx/R6PYcPH2bnzp0cOXLkiVtbKoqCr68vgwYNolWrViiKkm3df29x+DSuXLmS6VhKSgr79+9n165dHDt2jIiIiGzP12q1NGzYkCFDhjz1+5iamsrixYtZt24dd+7ceWJdS0tLfH19ad++Pa+//nqO2jcajfzxxx8sW7aMCxcukN1HLm9vbwYNGkS3bt3QaLIfRFqpUqUc9ftvsq2nEEKI/CRJBSGEEKIA5WVSAWD48OHs3bs3/XXJkiUzvP6nUaNGsXPnzqfuo23btkyZMgVra+ssy/M6qdCjRw8uXLjw1G29+eabfP755zkaURASEsLAgQO5cePGU/dz4cKF/+zj1q1bvPfee1y+fDnH7fr6+vLLL7/g5OSUZbkkFYQQQhRGMv1BCCGEKMLatm2bIYkQHBzMvXv3KFGiRKa6KSkpmY45OTnh5OSEjY0NKSkphIWF8fDhwwx1du3aRVxcHAsXLnzik/S8kpqamumYm5sbxYoVw9ramsTEREJDQ9OnbDy2YsUKkpKSmDx58hPbT05OZsCAAdy6dSvDcY1Gg4eHB8WKFUOr1RIfH8/9+/dJTEx8qvjPnDnD0KFDM72PWq2WkiVL4uDgQGJiInfv3s3wb3L69Gl69+7NmjVrsk0sCCGEEIWNJBWEEEKIIqxGjRqZjl26dCnLpAKAo6MjHTp0oEWLFlSvXj3LL6+3b99mw4YNLF68OP1L75EjR1i6dGmmnScAKleuzKJFiwDYtGkTmzdvTi/75JNPqFy58lNfl6enJ+3bt6dZs2ZUr14dW1vbDOWqqnLlyhVWr17N2rVrMRgMAGzcuJGXXnqJNm3aZNv28uXLMyQUnJyceP/992nXrl2mdSRUVeXu3bscPnyY3bt34+/v/8S4IyIieOeddzIkFCpVqsTQoUNp0aJFhvUfUlJS2Lt3Lz/88EP69Is7d+7w6aef8ttvv2WacvL4Pb58+TJTpkxJP/7yyy/TrVu3bGPKzfsvhBBC5JRMfxBCCCEKUF5Pf1BVlVq1amV4mv7pp58yYMCATHVPnTpFlSpVsLCwyFHbly5don///kRHRwPg7u7OX3/99cSh/7NmzWL27Nnpr3NzfSdPnsTPzw+tVpuj+v7+/gwbNix9hEONGjVYt25dtvVfeeUVzp07B4C5uTmbNm3C29s7R30FBQVRrly5bNeYGDRoEAcPHkx/3bt3b7788kvMzMyybTM2NpYhQ4Zw6tSp9GOzZ8/ONjHy7+kmI0eO5N13381R/EIIIUReky0lhRBCiCJMURQcHR0zHAsPD8+yrp+fX44TCgA+Pj58+OGH6a/DwsL+80l9XqhTp06OEwoAjRs3ZuDAgemvz549y/Xr17Ot/89RCvXr189xQgEeLaqYXULh9OnTGRIKzZo1Y/z48U9MKADY29sza9asDKMYFi5cmOOYhBBCCFOSpIIQQghRxNnZ2WV4/bRrADxJp06dMnzBDwwMzLO281LXrl0zvP7nU/9/S05OTv//Z9km8t+WLFmS4fVnn332xF0z/snV1ZVXX301/XVgYCCRkZF5FpsQQgiRXySpIIQQQhRx/96VIS0tLU/b/ue6C5cuXcqztvNSyZIlM7y+ePFitnXd3NzS///kyZOEhIQ8c/9GozHDKIUaNWpQrly5p2qjcePGGV6fPHnymeMSQggh8pss1CiEEEIUcQkJCRlem5ub/+c5165dY+fOnVy4cIGgoCBiYmJISEj4z4TEv3c0yG9nz55lz549XLp0iRs3bhAXF0dCQgJ6vf6J5z0pzsaNG7N27VoA4uLi6Nu3L2PGjKF169b/OVUhO1evXs2wG0W1atWeug1PT88Mr4OCgnIVixBCCFGQJKkghBBCFHHx8fEZXv975MI/XblyhQkTJnDixIlc9RUbG5ur857WyZMn+eabb7hy5Uquzv/3dpP/NHDgQP7880+SkpIAuHv3LqNHj8be3p4mTZpQr149/Pz8qFixYo630Px3AmDlypWsXLkyV7E/FhMT80znCyGEEAVBkgpCCCFEEaaqaqan8v8c3v9P+/bt4913332m6RF5ObUiO6tXr+brr7/mWTaoerwTRFbKlCnDzJkz+eCDDzIkZGJjY9m2bRvbtm0DwMHBgfr169OuXTtat26NpaVltm0+3iEjLz0pMSKEEEIUFpJUEEIIIYqwGzduZFqYsVSpUpnq3bx5k1GjRmVICiiKQo0aNfDz88PLywsXFxcsLCwy7RDx0UcfFdiigUePHs2UUNDpdNSqVYuaNWvi6emJs7MzFhYWmaZ5ZLWNZnaaN2/O1q1bmTNnDn/++WeWi1vGxMSwa9cudu3ahZOTE8OHD6dPnz5ZLr6YHyM4ZNdvIYQQRYEkFYQQQogi7OzZs5mOValSJdOx6dOnZ3h6X6NGDb777rscbaeY0x0M8sKUKVMyfJlu0aIF48ePp3jx4k8870kjE7JTvHhxvvnmGz799FMOHz7M8ePHOXnyJJcvX8ZgMGSoGxUVxcSJEzlx4gQ//PBDpi0vraysMrzu3LkzPXv2fOqY/im7ESdCCCFEYSJJBSGEEKII27lzZ4bXpUuXzvQFPCEhgf3796e/dnFxYf78+Tg4OOSoj4Ka23/z5s0MuzZUrFiRWbNm5WjhyWeZfmBtbU3r1q1p3bo18GiNioCAAPbv38/WrVszXP/OnTtZuHAhgwcPztBGsWLFMry2s7OjUaNGuY5JCCGEKCpkS0khhBCiiLp79y5///13hmOPvxj/08WLFzNMe+jUqVOOEwq3b9/O1SiA3Dhz5kyG16+88kqOEgoA169fz7M4bG1tad6
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
"execution_count": 141
},
{
"metadata": {
"ExecuteTime": {
2025-11-05 23:31:08 +01:00
"end_time": "2025-11-05T16:58:00.363988157Z",
2025-11-05 17:53:45 +01:00
"start_time": "2025-11-05T13:39:27.866713Z"
}
},
"cell_type": "code",
"source": [
"from time import sleep\n",
2025-11-05 23:31:08 +01:00
"\n",
"\n",
2025-11-05 17:53:45 +01:00
"def callback_numpy_plot(index, params):\n",
" plot_title = \"NumPy Model - It: {:05}\".format(index)\n",
" file_name = \"numpy_model_{:05}.png\".format(index // 50)\n",
" file_path = os.path.join(OUTPUT_DIR, file_name)\n",
" prediction_probs, _ = full_forward_propagation(np.transpose(grid_2d), params, nn_architecture)\n",
" prediction_probs = prediction_probs.reshape(prediction_probs.shape[1], 1)\n",
" make_plot(X_test, y_test, plot_title, file_name=file_path, XX=XX, YY=YY, preds=prediction_probs, dark=True)\n",
"\n",
2025-11-05 23:31:08 +01:00
"\n",
2025-11-05 17:53:45 +01:00
"# Training\n",
"params_values = train(np.transpose(X_train), np.transpose(y_train.reshape((y_train.shape[0], 1))), nn_architecture,\n",
" 30000, 0.01, False, callback_numpy_plot)"
],
"id": "b6a4d6a1a1fb289",
"outputs": [],
"execution_count": 144
},
{
"metadata": {
"ExecuteTime": {
2025-11-05 23:31:08 +01:00
"end_time": "2025-11-05T16:58:00.364712192Z",
2025-11-05 17:53:45 +01:00
"start_time": "2025-11-05T13:43:06.780036Z"
}
},
"cell_type": "code",
"source": [
"\n",
"\n",
"prediction_probs_numpy, _ = full_forward_propagation(np.transpose(grid_2d), params_values, nn_architecture)\n",
"prediction_probs_numpy = prediction_probs_numpy.reshape(prediction_probs_numpy.shape[1], 1)\n",
"\n",
"make_plot(X_test, y_test, \"NumPy Model\", file_name=None, XX=XX, YY=YY, preds=prediction_probs_numpy)"
],
"id": "6b36e606efa7f99a",
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 1600x1200 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAABCAAAAQHCAYAAAAtRRjrAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjcsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvTLEjVAAAAAlwSFlzAAAPYQAAD2EBqD+naQABAABJREFUeJzs3Xd8FHX+x/H37Kb3kE4KJUCAUELvIkgRRUSwe3p6nu2sZ++e7UQ9e9fTs5w/zwIqgtJEBaSXAKG3EAghDVJI3935/RFTNtlNdnZnZmdm38/Hg4cm2yaQbLKvfL7fEURRFEFEREREREREpCCTtw+AiIiIiIiIiIyPAYKIiIiIiIiIFMcAQURERERERESKY4AgIiIiIiIiIsUxQBARERERERGR4hggiIiIiIiIiEhxDBBEREREREREpDgGCCIiIiIiIiJSHAMEERERERERESnOz9sHQERERETad/XVV2Pjxo3Nb+/bt8+LR9NIi8dERETOcQKCiIiIiIiIiBTHCQgiIvKKyZMnIz8/3+59qamp+Omnn+Dv7+/RfW3atAkRERGyHKfROfp3aMtkMiE8PBwRERFIT0/HwIEDce6556JXr14qHaUynH3sTz75JC6//HK37nPBggV46KGH2r0/OTkZK1eudOs+iYiIjIITEEREpBnHjh3D/PnzvX0Y1IbNZkN5eTmOHTuGX3/9FW+88QbOP/98XHvttcjNzfX24cnuu+++c/u23377rXwHQkREZDAMEEREpClvv/026urqvH0Y5IJ169Zh9uzZ+O2337x9KLLatm0b8vLyJN8uPz8fmzZtUuCIiIiIjIFLMIiISFMKCwvxxRdf4Nprr/X2ofikBx54AH379rV7n9VqRVlZGXbv3o3FixejsLCw+bKamhrcdddd+OKLL9rdTm9MJhNsNhuAximIO+64Q9Ltv/vuO4ii2O6+iIiIqBEnIIiISHPee+89VFVVefswfFJmZibGjh1r92fChAm44IIL8MADD2DFihW45ppr7G5TXV2N559/3ktHLJ9Ro0Y1///333/fHBNc9f333zf//+jRo2U7LiIiIqNggCAiIk0YMmRI8/+fOnUKn3zyiRePhpwJCAjAI488ggsuuMDu/WvXrsX+/fu9dFTymD17dvP/Hz9+HJs3b3b5tlu2bMHRo0eb377wwgvlPDQiIiJDYIAgIiJNuOuuuyAIQvPbH330EcrLy714RNSRe++9FyaT/Y8Rv//+u5eORh5nnXUWunTp0vy2lM0oW183JiYGZ511loxHRkREZAzcA4KIiDShf//+mDZtGpYuXQoAqKysxIcffoi7777by0emnNOnT2Pbtm0oLCxEeXk5oqOjMWjQIPTr16/T2x49ehTZ2dkoKiqCIAiIj4/HqFGjkJCQoMKRA4mJiejbty92797d/D69T0D4+flh5syZ+PTTTwEAS5cuxeOPP47AwMAOb1dXV4effvqp+e2ZM2fCz0+eH7EqKiqwdetWFBUV4fTp0wgJCUFMTAz69euHHj16yPIYx44dw44dO1BYWAiLxYLY2FgMGDAAffr0keX+W7NarcjJycHRo0dx6tQp1NfXIzo6GikpKRg2bBgCAgJkf0wiItIOBggiItKMO++8EytWrIDVagUAfPbZZ7jmmmsQGxsr22MsWLAADz30UPPbzz33HObMmePy7TMyMpr/f+TIkfjss8+cXvfqq6/Gxo0bm9/et28fAODQoUN4/fXX8fPPP6OhoaHd7fr374/HH3/cbllKk/Xr1+OVV15BdnZ2u8sEQcCkSZPw2GOPoWvXri5/TO5KTU21CxCnT59u/v/3338fL730UvPbf//733HzzTdLfozbbrsNy5cvt7vfiRMnunnEnZs9e3ZzgKisrMSKFStw/vnnd3ibFStWoLKy0u4+PLVp0ya8+eab2Lx5MywWi8PrdOvWDVdeeSWuvPJKt164b9u2Dc8//zy2bdvm8PLevXvjjjvuwLRp0yTfd1vHjx/H22+/jZ9//hllZWUOrxMcHIypU6fijjvuQGpqqsePSURE2sMlGEREpBnp6emYNWtW89vV1dV49913vXhE8luxYgXmzJmDJUuWOIwPALB7925cffXVdi+8AeDNN9/Etdde6zA+AIAoili5ciUuvfRSHD58WO5Db6ftb/lbv1CeO3cu/P39m9/+5ptvJG/qWFxcjF9++aX57a5du2LChAluHq1rMjMz7X7z/+2333Z6m9bX6dOnD/r37+/249fX1+P+++/Hn/70J6xfv95pfAAap2Cee+45XHDBBTh06JCkx3nzzTdx5ZVXOo0PAHDgwAHcfvvteOaZZyT/27X29ttv49xzz8X8+fOdxgeg8YwqCxcuxIwZM/D111+7/XhERKRdDBBERKQpt956q90L1y+//BIFBQVePCL5bNmyBXfddRdqa2sBAIGBgUhPT8eAAQPs9h4AgIaGBtx3333Izc0F0Pib/zfeeKP5hWB4eDgyMjLQr18/hISE2N22uLgYt99+u9PAIZeioiK7t6Oiopr/PyYmBlOmTGl++9ixY1i/fr2k+//222/bRY22+04oofUGkmvXrkVxcbHT6xYVFWHt2rXNb3sy/VBfX4+bbrrJ7mwaTeLi4jBgwAB0797d7usDAHJzc3HllVfaTaN05N1338Ubb7zR7jShUVFR6N+/P3r16oWgoKDm93/22Wd45513JH88VqsVDz74IF577bV2n4tRUVHIyMjAgAED2i0bamhowKOPPoqPP/5Y8mMSEZG2MUAQEZGmpKam4pJLLml+u76+Hm+99ZYXj0g+9913HxoaGhAfH4958+Zhw4YN+PHHHzF//nysXbsWH374od3SiZqaGrz88svYsmULXnnlFQDAwIED8fHHH2PDhg1YuHAhvvvuO2zYsAFPPPGE3Rj+wYMH8b///U+xj6WmpgY7d+60e19KSord25dddpnd21999ZXL9y+KIr755pvmt81mMy6++GI3jlS6WbNmwWw2A2h8Ef3DDz84ve7ChQublwyZzeZ2ZweR4uWXX7aLGQAwZcoUfP/991izZg3mz5+PpUuXYs2aNbjvvvvsIkFZWRnuvPPOTk9fu3XrVrz66qt278vMzMSnn36K9evX49tvv8XixYuxfv16PPvss81R6a233rI7y4cr3nrrLbvpEH9/f1xzzTVYvHhx8+fv/PnzsWrVKixfvhyXXXaZ3Ua0L774IrZu3SrpMYmISNsYIIiISHNuueUWuxdX3377bfMkgJ7l5+ejR48e+Oabb3DRRRchODi4+TJBEDB+/Hh89NFHdiFhxYoVePjhh2Gz2TB9+nR88cUXGDNmTPMLZKDx1JhXXnklnnzySbvHa/0CXm6fffZZ8yRHk9GjR9u9PWbMGHTv3r357eXLl+PUqVMu3f+GDRvsXvBOmDABiYmJ7h+wBPHx8Rg7dmzz2x2dDaP1ZePGjUN8fLxbj7ljx452v/G/9dZb8dZbb6Fv375274+KisJf//pXfP755wgLC2t+f15eXru40JrNZsPjjz9ut5xi4sSJ+PLLLzFq1Ci7F//BwcG4+OKLsWDBAiQkJMBisaCwsNDlj2fr1q12UxPR0dH44osv8Mgjj6BXr17trp+WloannnoKL7/8cvOUi8ViwT/+8Q+XH5OIiLSPAYKIiDQnPj4eV155ZfPbFosFb7zxhhePSB7+/v549dVXOzxTRY8ePew2xbRarcjNzUW3bt3w/PPPtxu/b+2iiy6yOzPC3r172y2TkMPKlSvx+uuv272v6SwGbbWegmhoaHC4vMCRtnsAtJ6KUUPrpRT79u3Dnj172l0nJycHBw4ccHgbqT755BO7MDBp0iTccccdHd5mwIABePrpp+3e980339htiNna77//bne8cXFxePnllzv8nEpOTsbLL7/syodg56233mpe4mEymfD2229j4MCBnd7uvPPOw3XXXdf89r59+9pNhRARkX4xQBARkSbdeOONCA0NbX77xx9/bD6LhF5Nnz693W+zHZk8eXK7991www12ExOOCILQ7raOXjhLZbVacerUKaxevRr33HMP/va3v7V
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
"execution_count": 149
},
{
"metadata": {},
"cell_type": "code",
"source": "",
"id": "dc0b9266ae37298a",
"outputs": [],
"execution_count": null
},
{
"metadata": {
"ExecuteTime": {
2025-11-05 23:31:08 +01:00
"end_time": "2025-11-05T22:28:55.082124Z",
"start_time": "2025-11-05T22:28:29.890543Z"
2025-11-05 17:53:45 +01:00
}
},
"cell_type": "code",
"source": [
2025-11-05 23:31:08 +01:00
"X_train = np.array(X_train)\n",
"y_train = np.array(y_train)\n",
"\n",
2025-11-05 17:53:45 +01:00
"model = Sequential()\n",
2025-11-05 23:31:08 +01:00
"model.add(Dense(25, input_dim=2, activation='relu'))\n",
2025-11-05 17:53:45 +01:00
"model.add(Dense(50, activation='relu'))\n",
"model.add(Dense(50, activation='relu'))\n",
"model.add(Dense(25, activation='relu'))\n",
2025-11-05 23:31:08 +01:00
"model.add(Dense(20, activation='sigmoid'))\n",
2025-11-05 17:53:45 +01:00
"\n",
"model.compile(loss='binary_crossentropy', optimizer=\"sgd\", metrics=['accuracy'])\n",
"\n",
"# Training\n",
2025-11-05 23:31:08 +01:00
"history = model.fit(X_train, y_train, epochs=1000, verbose=0)"
2025-11-05 17:53:45 +01:00
],
"id": "f05ff40ed26e45c2",
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/oskar/projects/nn-from-scratch/.venv/lib/python3.13/site-packages/keras/src/layers/core/dense.py:95: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.\n",
" super().__init__(activity_regularizer=activity_regularizer, **kwargs)\n"
]
}
],
2025-11-05 23:31:08 +01:00
"execution_count": 373
2025-11-05 17:53:45 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-05 23:31:08 +01:00
"end_time": "2025-11-05T22:28:56.234343Z",
"start_time": "2025-11-05T22:28:55.933926Z"
2025-11-05 17:53:45 +01:00
}
},
"cell_type": "code",
"source": [
2025-11-05 23:31:08 +01:00
"X_test = np.array(X_test)\n",
"y_test = np.array(y_test)\n",
"\n",
2025-11-05 17:53:45 +01:00
"Y_test_prob = model.predict(X_test)\n",
2025-11-05 23:31:08 +01:00
"\n",
"# Convert probabilities to class predictions\n",
"Y_test_pred = np.argmax(Y_test_prob, axis=1)\n",
"\n",
"# Convert y_test to class labels (if it's one-hot encoded)\n",
"# If y_test is already in label format (e.g., [0, 1, 0, 1, ...]), skip this step\n",
"y_test_labels = np.argmax(y_test, axis=1) if len(y_test.shape) > 1 and y_test.shape[1] > 1 else y_test\n",
"\n",
"# Calculate accuracy\n",
"acc_test = accuracy_score(y_test_labels, Y_test_pred)\n",
"print(\"Test set accuracy: {:.2f}\".format(acc_test))"
2025-11-05 17:53:45 +01:00
],
"id": "ef52bee9c93081d3",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2025-11-05 23:31:08 +01:00
"\u001B[1m1/1\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m0s\u001B[0m 271ms/step\n",
"Test set accuracy: 0.30\n"
]
}
],
"execution_count": 374
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-11-05T22:29:35.732951Z",
"start_time": "2025-11-05T22:29:35.730200Z"
}
},
"cell_type": "code",
"source": [
"for i in range(len(X_test)):\n",
" print(\"for: {}, pred: {}\".format(int(np.sum(X_test[i])*10), decode_from_vector(Y_test_prob[i])))"
],
"id": "7c788fc6411656ea",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"for: 11, pred: 16\n",
"for: 8, pred: 16\n",
"for: 7, pred: 16\n",
"for: 9, pred: 16\n",
"for: 8, pred: 16\n",
"for: 12, pred: 16\n",
"for: 4, pred: 16\n",
"for: 8, pred: 16\n",
"for: 1, pred: 16\n",
"for: 0, pred: 6\n"
2025-11-05 17:53:45 +01:00
]
}
],
2025-11-05 23:31:08 +01:00
"execution_count": 377
2025-11-05 17:53:45 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-05 23:31:08 +01:00
"end_time": "2025-11-05T22:08:59.757756Z",
"start_time": "2025-11-05T22:08:59.657259Z"
2025-11-05 17:53:45 +01:00
}
},
"cell_type": "code",
"source": [
"def callback_keras_plot(epoch, logs):\n",
" plot_title = \"Keras Model - It: {:05}\".format(epoch)\n",
" file_name = \"keras_model_{:05}.png\".format(epoch)\n",
" file_path = os.path.join(OUTPUT_DIR, file_name)\n",
" prediction_probs = model.predict(grid_2d, batch_size=32, verbose=0)\n",
" prediction_probs = prediction_probs.reshape(-1)\n",
" make_plot(X_test, y_test, plot_title, file_name=file_path, XX=XX, YY=YY, preds=prediction_probs)\n",
"\n",
"\n",
"# Adding callback functions that they will run in every epoch\n",
"testmodelcb = keras.callbacks.LambdaCallback(on_epoch_end=callback_keras_plot)\n",
"\n",
"# Building a model\n",
"model = Sequential()\n",
"model.add(Dense(25, input_dim=2, activation='relu'))\n",
"model.add(Dense(50, activation='relu'))\n",
"model.add(Dense(50, activation='relu'))\n",
"model.add(Dense(25, activation='relu'))\n",
"model.add(Dense(1, activation='sigmoid'))\n",
"\n",
"model.compile(loss='binary_crossentropy', optimizer=\"sgd\", metrics=['accuracy'])\n",
"\n",
"# Training\n",
"history = model.fit(X_train, y_train, epochs=200, verbose=0, callbacks=[testmodelcb])\n"
],
"id": "6feab7da06e7a828",
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/oskar/projects/nn-from-scratch/.venv/lib/python3.13/site-packages/keras/src/layers/core/dense.py:95: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.\n",
" super().__init__(activity_regularizer=activity_regularizer, **kwargs)\n"
]
2025-11-05 23:31:08 +01:00
},
{
"ename": "ValueError",
"evalue": "Arguments `target` and `output` must have the same shape. Received: target.shape=(None, 20), output.shape=(None, 1)",
"output_type": "error",
"traceback": [
"\u001B[31m---------------------------------------------------------------------------\u001B[39m",
"\u001B[31mValueError\u001B[39m Traceback (most recent call last)",
"\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[321]\u001B[39m\u001B[32m, line 24\u001B[39m\n\u001B[32m 21\u001B[39m model.compile(loss=\u001B[33m'\u001B[39m\u001B[33mbinary_crossentropy\u001B[39m\u001B[33m'\u001B[39m, optimizer=\u001B[33m\"\u001B[39m\u001B[33msgd\u001B[39m\u001B[33m\"\u001B[39m, metrics=[\u001B[33m'\u001B[39m\u001B[33maccuracy\u001B[39m\u001B[33m'\u001B[39m])\n\u001B[32m 23\u001B[39m \u001B[38;5;66;03m# Training\u001B[39;00m\n\u001B[32m---> \u001B[39m\u001B[32m24\u001B[39m history = \u001B[43mmodel\u001B[49m\u001B[43m.\u001B[49m\u001B[43mfit\u001B[49m\u001B[43m(\u001B[49m\u001B[43mX_train\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43my_train\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mepochs\u001B[49m\u001B[43m=\u001B[49m\u001B[32;43m200\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mverbose\u001B[49m\u001B[43m=\u001B[49m\u001B[32;43m0\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcallbacks\u001B[49m\u001B[43m=\u001B[49m\u001B[43m[\u001B[49m\u001B[43mtestmodelcb\u001B[49m\u001B[43m]\u001B[49m\u001B[43m)\u001B[49m\n",
"\u001B[36mFile \u001B[39m\u001B[32m~/projects/nn-from-scratch/.venv/lib/python3.13/site-packages/keras/src/utils/traceback_utils.py:122\u001B[39m, in \u001B[36mfilter_traceback.<locals>.error_handler\u001B[39m\u001B[34m(*args, **kwargs)\u001B[39m\n\u001B[32m 119\u001B[39m filtered_tb = _process_traceback_frames(e.__traceback__)\n\u001B[32m 120\u001B[39m \u001B[38;5;66;03m# To get the full stack trace, call:\u001B[39;00m\n\u001B[32m 121\u001B[39m \u001B[38;5;66;03m# `keras.config.disable_traceback_filtering()`\u001B[39;00m\n\u001B[32m--> \u001B[39m\u001B[32m122\u001B[39m \u001B[38;5;28;01mraise\u001B[39;00m e.with_traceback(filtered_tb) \u001B[38;5;28;01mfrom\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m\n\u001B[32m 123\u001B[39m \u001B[38;5;28;01mfinally\u001B[39;00m:\n\u001B[32m 124\u001B[39m \u001B[38;5;28;01mdel\u001B[39;00m filtered_tb\n",
"\u001B[36mFile \u001B[39m\u001B[32m~/projects/nn-from-scratch/.venv/lib/python3.13/site-packages/keras/src/backend/tensorflow/nn.py:783\u001B[39m, in \u001B[36mbinary_crossentropy\u001B[39m\u001B[34m(target, output, from_logits)\u001B[39m\n\u001B[32m 781\u001B[39m \u001B[38;5;28;01mfor\u001B[39;00m e1, e2 \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mzip\u001B[39m(target.shape, output.shape):\n\u001B[32m 782\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m e1 \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m \u001B[38;5;129;01mand\u001B[39;00m e2 \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m \u001B[38;5;129;01mand\u001B[39;00m e1 != e2:\n\u001B[32m--> \u001B[39m\u001B[32m783\u001B[39m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mValueError\u001B[39;00m(\n\u001B[32m 784\u001B[39m \u001B[33m\"\u001B[39m\u001B[33mArguments `target` and `output` must have the same shape. \u001B[39m\u001B[33m\"\u001B[39m\n\u001B[32m 785\u001B[39m \u001B[33m\"\u001B[39m\u001B[33mReceived: \u001B[39m\u001B[33m\"\u001B[39m\n\u001B[32m 786\u001B[39m \u001B[33mf\u001B[39m\u001B[33m\"\u001B[39m\u001B[33mtarget.shape=\u001B[39m\u001B[38;5;132;01m{\u001B[39;00mtarget.shape\u001B[38;5;132;01m}\u001B[39;00m\u001B[33m, output.shape=\u001B[39m\u001B[38;5;132;01m{\u001B[39;00moutput.shape\u001B[38;5;132;01m}\u001B[39;00m\u001B[33m\"\u001B[39m\n\u001B[32m 787\u001B[39m )\n\u001B[32m 789\u001B[39m output, from_logits = _get_logits(\n\u001B[32m 790\u001B[39m output, from_logits, \u001B[33m\"\u001B[39m\u001B[33mSigmoid\u001B[39m\u001B[33m\"\u001B[39m, \u001B[33m\"\u001B[39m\u001B[33mbinary_crossentropy\u001B[39m\u001B[33m\"\u001B[39m\n\u001B[32m 791\u001B[39m )\n\u001B[32m 793\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m from_logits:\n",
"\u001B[31mValueError\u001B[39m: Arguments `target` and `output` must have the same shape. Received: target.shape=(None, 20), output.shape=(None, 1)"
]
2025-11-05 17:53:45 +01:00
}
],
2025-11-05 23:31:08 +01:00
"execution_count": 321
2025-11-05 17:53:45 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-05 23:31:08 +01:00
"end_time": "2025-11-05T16:58:00.367286962Z",
2025-11-05 17:53:45 +01:00
"start_time": "2025-11-05T13:43:21.443563Z"
}
},
"cell_type": "code",
"source": [
"\n",
"prediction_probs = model.predict(grid_2d, batch_size=32, verbose=0)\n",
"prediction_probs = prediction_probs.reshape(-1)\n",
"make_plot(X_test, y_test, \"Keras Model\", file_name=None, XX=XX, YY=YY, preds=prediction_probs)\n"
],
"id": "38dee4608746a358",
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 1600x1200 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAABCAAAAQHCAYAAAAtRRjrAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjcsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvTLEjVAAAAAlwSFlzAAAPYQAAD2EBqD+naQABAABJREFUeJzs3XV4VGfexvH7zMSdCHEcgntxKwWKlJZCdSvUqG2pvd1uqVKj1HW3tnUXqFC0eHF3hwQiEKLEbea8f6QJGWaSjByf+3NdXG0mIyeQTDLf/J7nCKIoiiAiIiIiIiIikpFJ7QMgIiIiIiIiIuNjgCAiIiIiIiIi2TFAEBEREREREZHsGCCIiIiIiIiISHYMEEREREREREQkOwYIIiIiIiIiIpIdAwQRERERERERyY4BgoiIiIiIiIhkxwBBRERERERERLJjgCAiIiLyAps3b0ZKSkr9n3fffVftQ9LkMRERkXwYIIiIiIiIiIhIdj5qHwAREZG7Ro8ejczMzPq3v/zySwwcONDp21ssFjz++OP49ddfbS6fOnUqXnjhBZjNZqkOlWQwf/58zJo1y+7y2NhYrF69GiaT679nEUURY8aMQUZGht37XnrpJUydOtWtYyUiIiJOQBARkZeqrq7Gww8/bBcfrr/+esyZM4fxQceys7OxceNGt267detWh/GBiIiIPMcAQUREXqeqqgozZ87EkiVLbC6/7bbbMHv2bAiCoNKRkVQuDEvO+uWXX6Q9ECIiIqrHAEFERF6loqIC99xzD1atWmVz+b333ot///vfKh0VSaHhkovly5ejtLTUpduXl5dj6dKlDu+PiIiIPMfvrERE5DVKS0sxY8YMrFu3zuby//u//8MDDzyg0lGRVBru/1FWVoZly5a5dPtly5bZRItBgwZJdmxERETEAEFERF6iqKgIt912G7Zs2VJ/mSAIeOKJJ3DnnXeqeGQklY4dO6Jr1671b7u6nKLhso1u3bqhQ4cOUh0aERERgQGCiIi8QEFBAaZPn45du3bVX2YymfD888/j5ptvVu/ASHJTpkyp//8tW7bg9OnTTt3uzJkz2LRpk8P7ISIiImnwNJxERGRoubm5uPXWW3HkyJH6y3x8fDB37lxMnjzZ7fvNy8vDrl27kJubi8LCQgQFBSEqKgo9evRAcnKyFIduY+/evTh16hRycnJQWVmJhISEZo8/PT0dx44dQ1ZWFkpKSmA2mxEeHo7ExET06tULwcHBHh9XamoqDh06hJycHJSWlsJsNiMoKAixsbFITk5Ghw4d4OOj3I8bkydPxquvvorq6mqIoojffvsNd999d7O3++2332C1WgEAvr6+uOyyy/D+++9LckwHDx7EsWPHkJeXh6qqKkRGRiI+Ph79+vVDQECAx/dfXV2NrVu3Ij09HQUFBQgMDESbNm3Qr18/hISESPAR2FL6c5+IiIyDAYKIiAzrzJkzmD59OtLS0uov8/X1xRtvvIFx48a5fH9WqxW///47vvrqK+zfvx+iKDq8Xvv27XHHHXdgypQpTm1kOH/+fMyaNav+7ZdeeglTp05FRUUFPvnkE8yfP9/u1JChoaF2AaKyshKrV6/GsmXLsHnzZuTk5DT6mGazGYMHD8add95ps3eCM6qqqvD555/jp59+wqlTp5q8bkBAAHr37o3x48fj+uuvd+lx3BEZGYlhw4bVbzL666+/OhUgGi7XGD58OCIjIz06jpKSEnz88ceYP38+zp496/A6/v7+GD58OB544AF06tTJ5ceoqKjAf/7zH/z4448oLCy0e7+fnx+mTJmChx56yOOPR67PfSIi8i4MEEREZEjp6em45ZZbbF64+/v7491338XIkSNdvr+0tDQ88MADOHToULPXPX78OGbNmoUffvgB77//vlsv/jIzM3HnnXfi2LFjTt/m+uuvx/79+526rsViwbp167Bu3TrccMMNePzxx52aVMjKysLtt9+OEydOOPU4FRUV2LRpEzZt2oSrr75akWmIK6+8sj5ApKamYvfu3ejVq1ej19+9ezdSU1Ntbu+JLVu24MEHH0ReXl6T16usrMTy5cuxatUqzJgxAw899JDTj5Geno477rjDJq5dqKqqCj/++CNWr16N//3vf07f94WU/twnIiLjYpomIiLDSU1NxY033mgTH4KCgvDhhx+6FR92796N6667zu4FmNlsRuvWrdGzZ0906NAB/v7+Nu/ftWsXrr32WuTn57v0eCUlJbjtttts4kNUVBS6du2KDh06ICgoyOHtqqqq7C5r2bIlOnXqhN69e6NTp04IDQ21u84333yDp556qtnjqqiowK233moXH0wmExITE9GtWzf07NkT7dq1a/QYlXDxxRcjPDy8/u2Gm0s60nD6ISIiAqNGjXL7sVevXo077rjDLj74+/ujXbt26Natm92LcovFgg8++ACPP/64U4+RnZ1tN9kDnP987N69O1q2bFl/+dmzZx0ekzOU/twnIiJj4wQEEREZytGjR3HrrbfaLD8IDQ3FRx99hL59+7p8fzk5ObjnnntQUFBQf1lKSgruuusujBo1ymYfhcrKSqxYsQJvvvlm/dKEU6dO4bHHHsOHH34IQRCceswPP/wQubm5AICJEyfirrvuQufOnevfX11djQ0bNji8bUJCAsaPH48RI0agR48ednsAiKKIw4cP4/vvv8ePP/4Ii8UCoHYZyOjRozF27NhGj+vrr7+2edEbGRmJhx56CJdeeqnNC/66x0lPT8eGDRvw559/Yv369U597FLw8/PDxIkT8d133wEAFi1ahFmzZsHPz8/uulVVVVi0aFH92xMnTnR4PWecPn0a//rXv1BZWVl/WUREBB555BFMnDjR5nNl586dePXVV7F9+/b6y+bNm4cePXo0u1TliSeeQGZmZv3bvr6+uPvuu3H99dcjKiqq/vKjR4/inXfewbJly3D27Fm89tprLn08anzuExGRsXECgoiIDOPgwYO46aabbOJDREQEPv/8c7fiAwDMmjXL5jfH1157LebNm4dJkybZbeLo7++PiRMnYt68eejTp0/95WvWrMHy5cudfsy6+PD444/jzTfftIkPQO0LTkeTHLNnz8by5cvx73//G4MHD3a4AaEgCOjcuTNmz56Njz/+2ObF9kcffdTkcS1ZsqT+//38/PD111/jmmuusYsPdY/TqlUrXHfddfjkk0+wcOFCmM3mpj9wCTVcRlFYWIg1a9Y4vN7KlStx7tw5h7dz1bPPPouioqL6t+Pj4zF//nxcffXVdp8rffr0wddff40rrrjC5vKXX34Z2dnZjT7GokWL8Ndff9W/7efnh48//hj33XefTXwAak9L+u6779bvgdEwWjhDjc99IiIyNgYIIiIyjLlz59r8tjY6Ohpffvklunfv7tb97dq1y+bF3ogRI/Dss8/C19e3yduFhYXh3XfftXmR9umnn7r02JMmTcL06dNduk3//v1depE/dOhQ3H777fVv79mzp8k9JxpOPwwcOBDt27d3+rHat2+v6G/Be/XqhbZt29a/3XCZRUMNl2e0a9cOPXv2dOvxTpw4gdWrV9e/bTKZ8M477yAxMbHR25hMJsyZM8dmA8ry8vL6yQ1HvvjiC5u3H3roIQwePLjJY3vooYcwdOjQZj4CW2p+7hMRkXExQBARkWFcuDP/448/jpSUFLfv78IXe7NmzXL6RXRMTAyuvvrq+rd37NhRP9ngjAceeMDp63ri8ssvt3l7586djV63oqKi/v+VPLWmu6ZMmVL//2vXrrWJU0Dt6SQbvshueH1X/fzzzzaff5MmTXIqZvj4+ODRRx+1ueynn35yeJaJ48ePY9euXfVvx8bG4qabbnLq+C58jOao+blPRETGxQBBRESG9dxzzzm1c78jVqvV5sVp3eaKrrjwt87btm1z6nY9evRA69atXXosdyUlJdm8feDAgUav23Bjw23btiErK0u245LCFVdcUX8qyOrqaixcuNDm/QsWLEBNTQ2A2mmEC5dDuGLr1q02b0+bNs3p2w4dOhRxcXH1b+fm5tqclaPO5s2bbd6eNGlSsxMJdTp37ow
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
"execution_count": 150
},
{
"metadata": {},
"cell_type": "code",
"source": "!nvidia-smi\n",
"id": "7369490697e1ed40",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"import torch\n",
2025-11-05 23:31:08 +01:00
"\n",
2025-11-05 17:53:45 +01:00
"print(torch.cuda.is_available())"
],
"id": "9958964dbe0d2732",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"print(torch.cuda.get_device_name(0))\n",
"print(torch.cuda.current_device())"
],
"id": "f88bc28cff654bde",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": "print(tf.config.list_physical_devices('GPU'))",
"id": "d60dd759b2bf2bf",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"import torch, time\n",
"\n",
"x_cpu = torch.randn(10000, 10000)\n",
"start = time.time()\n",
"y_cpu = x_cpu @ x_cpu\n",
"print(\"CPU:\", time.time() - start, \"s\")\n",
"\n",
"if torch.cuda.is_available():\n",
" x_gpu = x_cpu.cuda()\n",
" torch.cuda.synchronize()\n",
" start = time.time()\n",
" y_gpu = x_gpu @ x_gpu\n",
" torch.cuda.synchronize()\n",
" print(\"GPU:\", time.time() - start, \"s\")\n"
],
"id": "b482f6b3594de45e",
"outputs": [],
"execution_count": null
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}