nn-from-scratch/work-sc.ipynb

767 lines
157 KiB
Plaintext
Raw Normal View History

2025-11-04 18:05:00 +01:00
{
"cells": [
{
"cell_type": "code",
"id": "initial_id",
"metadata": {
"ExecuteTime": {
2025-11-04 23:26:22 +01:00
"end_time": "2025-11-04T22:20:01.935976Z",
"start_time": "2025-11-04T22:20:01.932910Z"
2025-11-04 18:05:00 +01:00
}
},
"source": "import numpy as np",
"outputs": [],
2025-11-04 23:26:22 +01:00
"execution_count": 49
2025-11-04 18:05:00 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-04 23:26:22 +01:00
"end_time": "2025-11-04T22:20:01.948882Z",
"start_time": "2025-11-04T22:20:01.944370Z"
2025-11-04 18:05:00 +01:00
}
},
"cell_type": "code",
"source": [
"nn_architecture = [\n",
" {\"input_dim\": 2, \"output_dim\": 4, \"activation\": \"relu\"},\n",
" {\"input_dim\": 4, \"output_dim\": 6, \"activation\": \"relu\"},\n",
" {\"input_dim\": 6, \"output_dim\": 6, \"activation\": \"relu\"},\n",
" {\"input_dim\": 6, \"output_dim\": 4, \"activation\": \"relu\"},\n",
" {\"input_dim\": 4, \"output_dim\": 1, \"activation\": \"sigmoid\"},\n",
"]"
],
"id": "48cafaf4b64967bb",
"outputs": [],
2025-11-04 23:26:22 +01:00
"execution_count": 50
2025-11-04 18:05:00 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-04 23:26:22 +01:00
"end_time": "2025-11-04T22:20:02.002162Z",
"start_time": "2025-11-04T22:20:01.996198Z"
2025-11-04 18:05:00 +01:00
}
},
"cell_type": "code",
"source": [
"def init_layers(nn_architecture, seed = 99):\n",
" np.random.seed(seed)\n",
" number_of_layers = len(nn_architecture)\n",
" params_values = {}\n",
"\n",
" for idx, layer in enumerate(nn_architecture):\n",
" layer_idx = idx + 1\n",
" layer_input_size = layer[\"input_dim\"]\n",
" layer_output_size = layer[\"output_dim\"]\n",
"\n",
" params_values['W' + str(layer_idx)] = np.random.randn(\n",
" layer_output_size, layer_input_size) * 0.1\n",
" params_values['b' + str(layer_idx)] = np.random.randn(\n",
" layer_output_size, 1) * 0.1\n",
"\n",
" return params_values\n"
],
"id": "d13137630b41b756",
"outputs": [],
2025-11-04 23:26:22 +01:00
"execution_count": 51
2025-11-04 18:05:00 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-04 23:26:22 +01:00
"end_time": "2025-11-04T22:20:02.054626Z",
"start_time": "2025-11-04T22:20:02.050524Z"
2025-11-04 18:05:00 +01:00
}
},
"cell_type": "code",
2025-11-04 22:56:05 +01:00
"source": [
"params = init_layers(nn_architecture)\n",
"# params"
],
2025-11-04 18:05:00 +01:00
"id": "31f205147667dea6",
2025-11-04 22:56:05 +01:00
"outputs": [],
2025-11-04 23:26:22 +01:00
"execution_count": 52
2025-11-04 22:56:05 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-04 23:26:22 +01:00
"end_time": "2025-11-04T22:20:02.109755Z",
"start_time": "2025-11-04T22:20:02.102747Z"
2025-11-04 22:56:05 +01:00
}
},
"cell_type": "code",
"source": [
"def sigmoid(Z):\n",
" return 1/(1+np.exp(-Z))\n",
"\n",
"def relu(Z):\n",
" return np.maximum(0,Z)\n",
"\n",
"def sigmoid_backward(dA, Z):\n",
" sig = sigmoid(Z)\n",
" return dA * sig * (1 - sig)\n",
"\n",
"def relu_backward(dA, Z):\n",
" dZ = np.array(dA, copy = True)\n",
" dZ[Z <= 0] = 0;\n",
" return dZ;"
],
"id": "c1b960e7dcf09d91",
"outputs": [],
2025-11-04 23:26:22 +01:00
"execution_count": 53
2025-11-04 22:56:05 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-04 23:26:22 +01:00
"end_time": "2025-11-04T22:20:02.164149Z",
"start_time": "2025-11-04T22:20:02.158664Z"
2025-11-04 22:56:05 +01:00
}
},
"cell_type": "code",
"source": [
"def single_layer_forward_propagation(A_prev, W_curr, b_curr, activation=\"relu\"):\n",
" Z_curr = np.dot(W_curr, A_prev) + b_curr\n",
"\n",
" if activation == \"relu\":\n",
" activation_func = relu\n",
" elif activation == \"sigmoid\":\n",
" activation_func = sigmoid\n",
" else:\n",
" raise Exception('Non-supported activation function')\n",
"\n",
" return activation_func(Z_curr), Z_curr"
],
"id": "efae2e184daf2fce",
"outputs": [],
2025-11-04 23:26:22 +01:00
"execution_count": 54
2025-11-04 22:56:05 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-04 23:26:22 +01:00
"end_time": "2025-11-04T22:20:02.218489Z",
"start_time": "2025-11-04T22:20:02.212053Z"
2025-11-04 22:56:05 +01:00
}
},
"cell_type": "code",
"source": [
"def full_forward_propagation(X, params_values, nn_architecture):\n",
" memory = {}\n",
" A_curr = X\n",
"\n",
" for idx, layer in enumerate(nn_architecture):\n",
" layer_idx = idx + 1\n",
" A_prev = A_curr\n",
"\n",
" activ_function_curr = layer[\"activation\"]\n",
" W_curr = params_values[\"W\" + str(layer_idx)]\n",
" b_curr = params_values[\"b\" + str(layer_idx)]\n",
" A_curr, Z_curr = single_layer_forward_propagation(A_prev, W_curr, b_curr, activ_function_curr)\n",
"\n",
" memory[\"A\" + str(idx)] = A_prev\n",
" memory[\"Z\" + str(layer_idx)] = Z_curr\n",
"\n",
" return A_curr, memory"
],
"id": "c3cd9e8f51dbe967",
"outputs": [],
2025-11-04 23:26:22 +01:00
"execution_count": 55
2025-11-04 22:56:05 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-04 23:26:22 +01:00
"end_time": "2025-11-04T22:20:02.265846Z",
"start_time": "2025-11-04T22:20:02.262056Z"
2025-11-04 22:56:05 +01:00
}
},
"cell_type": "code",
"source": [
"def get_cost_value(Y_hat, Y):\n",
" m = Y_hat.shape[1]\n",
" cost = -1 / m * (np.dot(Y, np.log(Y_hat).T) + np.dot(1 - Y, np.log(1 - Y_hat).T))\n",
" return np.squeeze(cost)\n",
"\n",
"# an auxiliary function that converts probability into class\n",
"def convert_prob_into_class(probs):\n",
" probs_ = np.copy(probs)\n",
" probs_[probs_ > 0.5] = 1\n",
" probs_[probs_ <= 0.5] = 0\n",
" return probs_\n",
"\n",
"def get_accuracy_value(Y_hat, Y):\n",
" Y_hat_ = convert_prob_into_class(Y_hat)\n",
" return (Y_hat_ == Y).all(axis=0).mean()"
],
"id": "121416e7bbab57bb",
"outputs": [],
2025-11-04 23:26:22 +01:00
"execution_count": 56
2025-11-04 22:56:05 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-04 23:26:22 +01:00
"end_time": "2025-11-04T22:20:02.319756Z",
"start_time": "2025-11-04T22:20:02.314518Z"
2025-11-04 22:56:05 +01:00
}
},
"cell_type": "code",
"source": [
"def single_layer_backward_propagation(dA_curr, W_curr, b_curr, Z_curr, A_prev, activation=\"relu\"):\n",
" m = A_prev.shape[1]\n",
"\n",
2025-11-04 23:06:08 +01:00
" if activation == \"relu\":\n",
2025-11-04 22:56:05 +01:00
" backward_activation_func = relu_backward\n",
2025-11-04 23:06:08 +01:00
" elif activation == \"sigmoid\":\n",
2025-11-04 22:56:05 +01:00
" backward_activation_func = sigmoid_backward\n",
" else:\n",
" raise Exception('Non-supported activation function')\n",
"\n",
" dZ_curr = backward_activation_func(dA_curr, Z_curr)\n",
" dW_curr = np.dot(dZ_curr, A_prev.T) / m\n",
" db_curr = np.sum(dZ_curr, axis=1, keepdims=True) / m\n",
" dA_prev = np.dot(W_curr.T, dZ_curr)\n",
"\n",
" return dA_prev, dW_curr, db_curr"
],
"id": "92e4b87664f18a63",
"outputs": [],
2025-11-04 23:26:22 +01:00
"execution_count": 57
2025-11-04 22:56:05 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-04 23:26:22 +01:00
"end_time": "2025-11-04T22:20:02.375728Z",
"start_time": "2025-11-04T22:20:02.368707Z"
2025-11-04 22:56:05 +01:00
}
},
"cell_type": "code",
"source": [
"def full_backward_propagation(Y_hat, Y, memory, params_values, nn_architecture):\n",
" grads_values = {}\n",
" m = Y.shape[1]\n",
" Y = Y.reshape(Y_hat.shape)\n",
"\n",
" dA_prev = - (np.divide(Y, Y_hat) - np.divide(1 - Y, 1 - Y_hat));\n",
"\n",
" for layer_idx_prev, layer in reversed(list(enumerate(nn_architecture))):\n",
" layer_idx_curr = layer_idx_prev + 1\n",
" activ_function_curr = layer[\"activation\"]\n",
"\n",
" dA_curr = dA_prev\n",
"\n",
" A_prev = memory[\"A\" + str(layer_idx_prev)]\n",
" Z_curr = memory[\"Z\" + str(layer_idx_curr)]\n",
" W_curr = params_values[\"W\" + str(layer_idx_curr)]\n",
" b_curr = params_values[\"b\" + str(layer_idx_curr)]\n",
"\n",
" dA_prev, dW_curr, db_curr = single_layer_backward_propagation(\n",
" dA_curr, W_curr, b_curr, Z_curr, A_prev, activ_function_curr)\n",
"\n",
" grads_values[\"dW\" + str(layer_idx_curr)] = dW_curr\n",
" grads_values[\"db\" + str(layer_idx_curr)] = db_curr\n",
"\n",
" return grads_values"
],
"id": "2c8e4eed1846f003",
"outputs": [],
2025-11-04 23:26:22 +01:00
"execution_count": 58
2025-11-04 22:56:05 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-04 23:26:22 +01:00
"end_time": "2025-11-04T22:20:02.429757Z",
"start_time": "2025-11-04T22:20:02.424922Z"
2025-11-04 22:56:05 +01:00
}
},
"cell_type": "code",
"source": [
"def update(params_values, grads_values, nn_architecture, learning_rate):\n",
" for layer_idx, layer in enumerate(nn_architecture):\n",
" layer_idx=layer_idx+1\n",
" params_values[\"W\" + str(layer_idx)] -= learning_rate * grads_values[\"dW\" + str(layer_idx)]\n",
" params_values[\"b\" + str(layer_idx)] -= learning_rate * grads_values[\"db\" + str(layer_idx)]\n",
"\n",
" return params_values;"
],
"id": "16320b953a183511",
"outputs": [],
2025-11-04 23:26:22 +01:00
"execution_count": 59
2025-11-04 22:56:05 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-04 23:26:22 +01:00
"end_time": "2025-11-04T22:20:02.487991Z",
"start_time": "2025-11-04T22:20:02.480828Z"
2025-11-04 22:56:05 +01:00
}
},
"cell_type": "code",
"source": [
"def train(X, Y, nn_architecture, epochs, learning_rate, verbose=False, callback=None):\n",
" # initiation of neural net parameters\n",
" params_values = init_layers(nn_architecture, 2)\n",
" # initiation of lists storing the history\n",
" # of metrics calculated during the learning process\n",
" cost_history = []\n",
" accuracy_history = []\n",
"\n",
" # performing calculations for subsequent iterations\n",
" for i in range(epochs):\n",
" # step forward\n",
" Y_hat, cashe = full_forward_propagation(X, params_values, nn_architecture)\n",
"\n",
" # calculating metrics and saving them in history\n",
" cost = get_cost_value(Y_hat, Y)\n",
" cost_history.append(cost)\n",
" accuracy = get_accuracy_value(Y_hat, Y)\n",
" accuracy_history.append(accuracy)\n",
"\n",
" # step backward - calculating gradient\n",
" grads_values = full_backward_propagation(Y_hat, Y, cashe, params_values, nn_architecture)\n",
" # updating model state\n",
" params_values = update(params_values, grads_values, nn_architecture, learning_rate)\n",
"\n",
2025-11-04 23:26:22 +01:00
" if(i % 1000 == 0):\n",
2025-11-04 22:56:05 +01:00
" if(verbose):\n",
" print(\"Iteration: {:05} - cost: {:.5f} - accuracy: {:.5f}\".format(i, cost, accuracy))\n",
" if(callback is not None):\n",
" callback(i, params_values)\n",
"\n",
" return params_values"
],
"id": "fce33f70bba3898",
"outputs": [],
2025-11-04 23:26:22 +01:00
"execution_count": 60
2025-11-04 22:56:05 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-04 23:26:22 +01:00
"end_time": "2025-11-04T22:20:02.541893Z",
"start_time": "2025-11-04T22:20:02.536342Z"
2025-11-04 22:56:05 +01:00
}
},
"cell_type": "code",
"source": [
"import os\n",
"import tensorflow as tf\n",
"\n",
"from sklearn.datasets import make_moons\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib import cm\n",
"from mpl_toolkits.mplot3d import Axes3D\n",
"sns.set_style(\"whitegrid\")\n",
"\n",
"import keras\n",
"from keras.models import Sequential\n",
"from keras.layers import Dense\n",
"# from keras.utils import np_utils\n",
"from keras import regularizers\n",
"\n",
"from sklearn.metrics import accuracy_score"
],
"id": "cccd73b5018799d4",
"outputs": [],
2025-11-04 23:26:22 +01:00
"execution_count": 61
2025-11-04 22:56:05 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-04 23:26:22 +01:00
"end_time": "2025-11-04T22:20:02.595656Z",
"start_time": "2025-11-04T22:20:02.591558Z"
2025-11-04 22:56:05 +01:00
}
},
"cell_type": "code",
"source": [
"# number of samples in the data set\n",
"N_SAMPLES = 1000\n",
"# ratio between training and test sets\n",
"TEST_SIZE = 0.1"
],
"id": "4f66ffa878f01c02",
"outputs": [],
2025-11-04 23:26:22 +01:00
"execution_count": 62
2025-11-04 22:56:05 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-04 23:26:22 +01:00
"end_time": "2025-11-04T22:21:10.863339Z",
"start_time": "2025-11-04T22:21:10.858278Z"
2025-11-04 22:56:05 +01:00
}
},
"cell_type": "code",
"source": [
"X, y = make_moons(n_samples = N_SAMPLES, noise=0.2, random_state=100)\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=TEST_SIZE, random_state=42)"
],
"id": "bebe0ed00a2d514",
"outputs": [],
2025-11-04 23:26:22 +01:00
"execution_count": 65
2025-11-04 22:56:05 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-04 23:26:22 +01:00
"end_time": "2025-11-04T22:21:50.510912Z",
"start_time": "2025-11-04T22:21:38.081823Z"
2025-11-04 22:56:05 +01:00
}
},
"cell_type": "code",
"source": [
2025-11-04 23:26:22 +01:00
"params_values = train(np.transpose(X_train), np.transpose(y_train.reshape((y_train.shape[0], 1))), nn_architecture, 100000, 0.001, verbose=True)\n",
2025-11-04 22:56:05 +01:00
"# params_values\n"
],
"id": "ce04892d496c5147",
2025-11-04 23:26:22 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Iteration: 00000 - cost: 0.69407 - accuracy: 0.49556\n",
"Iteration: 01000 - cost: 0.69369 - accuracy: 0.49556\n",
"Iteration: 02000 - cost: 0.69346 - accuracy: 0.49556\n",
"Iteration: 03000 - cost: 0.69332 - accuracy: 0.49556\n",
"Iteration: 04000 - cost: 0.69324 - accuracy: 0.49556\n",
"Iteration: 05000 - cost: 0.69319 - accuracy: 0.49556\n",
"Iteration: 06000 - cost: 0.69316 - accuracy: 0.49556\n",
"Iteration: 07000 - cost: 0.69314 - accuracy: 0.50444\n",
"Iteration: 08000 - cost: 0.69313 - accuracy: 0.50444\n",
"Iteration: 09000 - cost: 0.69312 - accuracy: 0.50444\n",
"Iteration: 10000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 11000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 12000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 13000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 14000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 15000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 16000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 17000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 18000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 19000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 20000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 21000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 22000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 23000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 24000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 25000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 26000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 27000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 28000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 29000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 30000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 31000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 32000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 33000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 34000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 35000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 36000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 37000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 38000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 39000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 40000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 41000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 42000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 43000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 44000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 45000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 46000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 47000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 48000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 49000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 50000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 51000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 52000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 53000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 54000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 55000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 56000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 57000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 58000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 59000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 60000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 61000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 62000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 63000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 64000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 65000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 66000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 67000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 68000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 69000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 70000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 71000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 72000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 73000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 74000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 75000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 76000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 77000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 78000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 79000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 80000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 81000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 82000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 83000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 84000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 85000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 86000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 87000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 88000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 89000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 90000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 91000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 92000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 93000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 94000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 95000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 96000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 97000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 98000 - cost: 0.69311 - accuracy: 0.50444\n",
"Iteration: 99000 - cost: 0.69311 - accuracy: 0.50444\n"
]
}
],
"execution_count": 66
2025-11-04 22:56:05 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-04 23:26:22 +01:00
"end_time": "2025-11-04T22:20:02.749879571Z",
"start_time": "2025-11-04T22:18:29.064616Z"
2025-11-04 22:56:05 +01:00
}
},
"cell_type": "code",
"source": [
"Y_test_hat, _ = full_forward_propagation(np.transpose(X_test), params_values, nn_architecture)\n",
"\n",
"acc_test = get_accuracy_value(Y_test_hat, np.transpose(y_test.reshape((y_test.shape[0], 1))))\n",
"print(\"Test set accuracy: {:.2f} - David\".format(acc_test))\n"
],
"id": "26e7a2a8848714d9",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test set accuracy: 0.46 - David\n"
]
}
],
2025-11-04 23:26:22 +01:00
"execution_count": 47
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-11-04T22:25:47.659577Z",
"start_time": "2025-11-04T22:25:47.652275Z"
}
},
"cell_type": "code",
"source": [
"def make_plot(X, y, plot_name, file_name=None, XX=None, YY=None, preds=None, dark=False):\n",
" if (dark):\n",
" plt.style.use('dark_background')\n",
" else:\n",
" sns.set_style(\"whitegrid\")\n",
" plt.figure(figsize=(16,12))\n",
" axes = plt.gca()\n",
" axes.set(xlabel=\"$X_1$\", ylabel=\"$X_2$\")\n",
" plt.title(plot_name, fontsize=30)\n",
" plt.subplots_adjust(left=0.20)\n",
" plt.subplots_adjust(right=0.80)\n",
" if(XX is not None and YY is not None and preds is not None):\n",
" plt.contourf(XX, YY, preds.reshape(XX.shape), 25, alpha = 1, cmap=cm.Spectral)\n",
" plt.contour(XX, YY, preds.reshape(XX.shape), levels=[.5], cmap=\"Greys\", vmin=0, vmax=.6)\n",
" plt.scatter(X[:, 0], X[:, 1], c=y.ravel(), s=40, cmap=plt.cm.Spectral, edgecolors='black')\n",
" if(file_name):\n",
" plt.savefig(file_name)\n",
" plt.close()"
],
"id": "553e08ddc23ab78c",
"outputs": [],
"execution_count": 70
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": [
"def callback_numpy_plot(index, params):\n",
" plot_title = \"NumPy Model - It: {:05}\".format(index)\n",
" file_name = \"numpy_model_{:05}.png\".format(index // 50)\n",
" file_path = os.path.join(OUTPUT_DIR, file_name)\n",
" prediction_probs, _ = full_forward_propagation(np.transpose(grid_2d), params, nn_architecture)\n",
" prediction_probs = prediction_probs.reshape(prediction_probs.shape[1], 1)\n",
" make_plot(X_test, y_test, plot_title, file_name=file_path, XX=XX, YY=YY, preds=prediction_probs, dark=True)\n",
"\n",
"\n",
"# Training\n",
"params_values = train(np.transpose(X_train), np.transpose(y_train.reshape((y_train.shape[0], 1))), nn_architecture,\n",
" 10000, 0.01, False, callback_numpy_plot)\n",
"prediction_probs_numpy, _ = full_forward_propagation(np.transpose(grid_2d), params_values, nn_architecture)\n",
"prediction_probs_numpy = prediction_probs_numpy.reshape(prediction_probs_numpy.shape[1], 1)\n",
"make_plot(X_test, y_test, \"NumPy Model\", file_name=None, XX=XX, YY=YY, preds=prediction_probs_numpy)"
],
"id": "b6a4d6a1a1fb289"
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": "",
"id": "dc0b9266ae37298a"
2025-11-04 22:56:05 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-04 23:26:22 +01:00
"end_time": "2025-11-04T22:20:02.760128452Z",
"start_time": "2025-11-04T22:17:38.579795Z"
2025-11-04 22:56:05 +01:00
}
},
"cell_type": "code",
"source": [
"model = Sequential()\n",
"model.add(Dense(25, input_dim=2,activation='relu'))\n",
"model.add(Dense(50, activation='relu'))\n",
"model.add(Dense(50, activation='relu'))\n",
"model.add(Dense(25, activation='relu'))\n",
"model.add(Dense(1, activation='sigmoid'))\n",
"\n",
"model.compile(loss='binary_crossentropy', optimizer=\"sgd\", metrics=['accuracy'])\n",
"\n",
"# Training\n",
"history = model.fit(X_train, y_train, epochs=200, verbose=0)"
],
"id": "f05ff40ed26e45c2",
2025-11-04 23:26:22 +01:00
"outputs": [],
"execution_count": 42
2025-11-04 22:56:05 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-04 23:26:22 +01:00
"end_time": "2025-11-04T22:20:02.761214479Z",
"start_time": "2025-11-04T22:17:44.410525Z"
2025-11-04 22:56:05 +01:00
}
},
"cell_type": "code",
"source": [
2025-11-04 23:06:08 +01:00
"Y_test_prob = model.predict(X_test)\n",
"Y_test_hat = (Y_test_prob > 0.5).astype(int).ravel()\n",
2025-11-04 22:56:05 +01:00
"acc_test = accuracy_score(y_test, Y_test_hat)\n",
"print(\"Test set accuracy: {:.2f} - Goliath\".format(acc_test))"
],
"id": "ef52bee9c93081d3",
"outputs": [
2025-11-04 23:06:08 +01:00
{
"name": "stdout",
"output_type": "stream",
"text": [
2025-11-04 23:26:22 +01:00
"\u001B[1m4/4\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m0s\u001B[0m 7ms/step \n",
"Test set accuracy: 1.00 - Goliath\n"
2025-11-04 23:06:08 +01:00
]
}
],
2025-11-04 23:26:22 +01:00
"execution_count": 43
2025-11-04 23:06:08 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-04 23:26:22 +01:00
"end_time": "2025-11-04T22:22:16.887963Z",
"start_time": "2025-11-04T22:22:16.881924Z"
2025-11-04 23:06:08 +01:00
}
},
"cell_type": "code",
"source": [
"# boundary of the graph\n",
"GRID_X_START = -1.5\n",
"GRID_X_END = 2.5\n",
"GRID_Y_START = -1.0\n",
"GRID_Y_END = 2\n",
"# output directory (the folder must be created on the drive)\n",
"OUTPUT_DIR = \"./binary_classification_vizualizations/\"\n",
2025-11-04 23:26:22 +01:00
"os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
2025-11-04 23:06:08 +01:00
"### Definition of grid boundaries\n",
2025-11-04 23:26:22 +01:00
"grid = np.mgrid[GRID_X_START:GRID_X_END:100j, GRID_Y_START:GRID_Y_END:100j]\n",
2025-11-04 23:06:08 +01:00
"grid_2d = grid.reshape(2, -1).T\n",
"XX, YY = grid"
],
"id": "b070f03d55981894",
"outputs": [],
2025-11-04 23:26:22 +01:00
"execution_count": 68
2025-11-04 23:06:08 +01:00
},
{
"metadata": {
"ExecuteTime": {
2025-11-04 23:26:22 +01:00
"end_time": "2025-11-04T22:23:06.066935Z",
"start_time": "2025-11-04T22:22:22.111811Z"
2025-11-04 23:06:08 +01:00
}
},
"cell_type": "code",
"source": [
"def callback_keras_plot(epoch, logs):\n",
" plot_title = \"Keras Model - It: {:05}\".format(epoch)\n",
" file_name = \"keras_model_{:05}.png\".format(epoch)\n",
" file_path = os.path.join(OUTPUT_DIR, file_name)\n",
2025-11-04 23:26:22 +01:00
" prediction_probs = model.predict(grid_2d, batch_size=32, verbose=0)\n",
" prediction_probs = prediction_probs.reshape(-1)\n",
2025-11-04 23:06:08 +01:00
" make_plot(X_test, y_test, plot_title, file_name=file_path, XX=XX, YY=YY, preds=prediction_probs)\n",
"\n",
"\n",
"# Adding callback functions that they will run in every epoch\n",
"testmodelcb = keras.callbacks.LambdaCallback(on_epoch_end=callback_keras_plot)\n",
"\n",
"# Building a model\n",
"model = Sequential()\n",
"model.add(Dense(25, input_dim=2, activation='relu'))\n",
"model.add(Dense(50, activation='relu'))\n",
"model.add(Dense(50, activation='relu'))\n",
"model.add(Dense(25, activation='relu'))\n",
"model.add(Dense(1, activation='sigmoid'))\n",
"\n",
"model.compile(loss='binary_crossentropy', optimizer=\"sgd\", metrics=['accuracy'])\n",
"\n",
"# Training\n",
"history = model.fit(X_train, y_train, epochs=200, verbose=0, callbacks=[testmodelcb])\n",
2025-11-04 23:26:22 +01:00
"prediction_probs = model.predict(grid_2d, batch_size=32, verbose=0)\n",
"prediction_probs = prediction_probs.reshape(-1)\n",
"make_plot(X_test, y_test, \"Keras Model\", file_name=None, XX=XX, YY=YY, preds=prediction_probs)\n"
2025-11-04 23:06:08 +01:00
],
"id": "6feab7da06e7a828",
"outputs": [
{
2025-11-04 23:26:22 +01:00
"data": {
"text/plain": [
"<Figure size 1600x1200 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAABCAAAAQHCAYAAAAtRRjrAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjcsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvTLEjVAAAAAlwSFlzAAAPYQAAD2EBqD+naQABAABJREFUeJzs3XeY3GW9/vH7O2V77yVlU0hCIKQAIShIESsWigWPAiogNlSsYPkJ6AEUO8eCXYpyVFBALIhHFIGQUEIgpPdsdmd3syXJtmnf3x+bnZ3Zndlp36n7fl0Xl0x/NqybnXue53MbpmmaAgAAAAAASCFbphcAAAAAAADyHwEEAAAAAABIOQIIAAAAAACQcgQQAAAAAAAg5QggAAAAAABAyhFAAAAAAACAlCOAAAAAAAAAKUcAAQAAAAAAUo4AAgAAAAAApBwBBAAAwAzw9NNPa/HixYF/br/99kwvKSvXBABIHQIIAAAAAACQco5MLwAAgESde+65am9vD1y+8847ddppp8X8eJ/Pp89//vP64x//GHL9RRddpK9+9auy2+1WLRUpcP/99+v666+fcn1jY6Mee+wx2Wzxf85imqbOO+88HThwYMptt9xyiy666KKE1goAANgBAQCYoTwejz75yU9OCR/e9a536eabbyZ8yGEul0tPPfVUQo9dv3592PABAAAkjwACADDjuN1uXXPNNfrrX/8acv373/9+3XDDDTIMI0Mrg1UmB0ux+sMf/mDtQgAAQAABBABgRhkZGdGHPvQh/fOf/wy5/sMf/rA+97nPZWhVsELwkYtHH31Ug4ODcT1+eHhYf/vb38I+HwAASB5/swIAZozBwUFdddVV+s9//hNy/ac+9Sl9/OMfz9CqYJXg+R9DQ0N65JFH4nr8I488EhJarFmzxrK1AQAAAggAwAxx+PBhvf/979e6desC1xmGoS984Qv6wAc+kMGVwSrHHXecli5dGrgc73GK4GMbJ5xwghYuXGjV0gAAgAggAAAzQF9fny6//HJt2LAhcJ3NZtNXvvIVXXbZZZlbGCx3wQUXBP593bp16ujoiOlxnZ2dWrt2bdjnAQAA1qCGEwCQ13p6evS+971P27ZtC1zncDh066236s1vfnPCz3vo0CFt2LBBPT096u/vV0lJiWpra7Vs2TLNnj3biqWHePHFF7Vv3z51d3drdHRULS0tUde/f/9+7dixQwcPHtTRo0dlt9tVWVmp1tZWLV++XKWlpUmva/fu3dqyZYu6u7s1ODgou92ukpISNTY2avbs2Vq4cKEcjvT9uvHmN79Zt912mzwej0zT1AMPPKAPfvCDUR/3wAMPyO/3S5KcTqfe9KY36Yc//KEla9q8ebN27NihQ4cOye12q6amRs3NzTr55JNVVFSU9PN7PB6tX79e+/fvV19fn4qLi9XW1qaTTz5ZZWVlFnwFodL9vQ8AyB8EEACAvNXZ2anLL79ce/bsCVzndDr1rW99S6997Wvjfj6/368HH3xQd911lzZt2iTTNMPeb8GCBbryyit1wQUXxDTI8P7779f1118fuHzLLbfooosu0sjIiH72s5/p/vvvn1INWV5ePiWAGB0d1WOPPaZHHnlETz/9tLq7uyO+pt1u1+mnn64PfOADIbMTYuF2u/XLX/5Sv/vd77Rv375p71tUVKQVK1bo9a9/vd71rnfF9TqJqKmp0RlnnBEYMvrHP/4xpgAi+LjGmWeeqZqamqTWcfToUf3kJz/R/fffr66urrD3KSws1JlnnqmPf/zjWrRoUdyvMTIyou9///v67W9/q/7+/im3FxQU6IILLtC1116b9NeTqu99AMDMQgABAMhL+/fv13vf+96QN+6FhYW6/fbbddZZZ8X9fHv27NHHP/5xbdmyJep9d+7cqeuvv17/+7//qx/+8IcJvflrb2/XBz7wAe3YsSPmx7zrXe/Spk2bYrqvz+fTf/7zH/3nP//Ru9/9bn3+85+PaafCwYMHdcUVV2jXrl0xvc7IyIjWrl2rtWvX6u1vf3tadkNceOGFgQBi9+7deuGFF7R8+fKI93/hhRe0e/fukMcnY926dfrEJz6hQ4cOTXu/0dFRPfroo/rnP/+pq666Stdee23Mr7F//35deeWVIeHaZG63W7/97W/12GOP6ac//WnMzz1Zur/3AQD5i2gaAJB3du/erfe85z0h4UNJSYnuuOOOhMKHF154QZdccsmUN2B2u11z587VSSedpIULF6qwsDDk9g0bNuid73ynent743q9o0eP6v3vf39I+FBbW6ulS5dq4cKFKikpCfs4t9s95bqGhgYtWrRIK1as0KJFi1ReXj7lPvfcc4++9KUvRV3XyMiI3ve+900JH2w2m1pbW3XCCSfopJNO0vz58yOuMR3OOeccVVZWBi4HD5cMJ3j3Q1VVlc4+++yEX/uxxx7TlVdeOSV8KCws1Pz583XCCSdMeVPu8/n0ox/9SJ///Odjeg2XyzVlZ4808f144oknqqGhIXB9V1dX2DXFIt3f+wCA/MYOCABAXtm+fbve9773hRw/KC8v149//GOtWrUq7ufr7u7Whz70IfX19QWuW7x4sa6++mqdffbZIXMURkdH9Y9//EPf/va3A0cT9u3bp+uuu0533HGHDMOI6TXvuOMO9fT0SJLe+MY36uqrr9aSJUsCt3s8Hj355JNhH9vS0qLXv/71etWrXqVly5ZNmQFgmqa2bt2qe++9V7/97W/l8/kkjR0DOffcc/Wa17wm4rruvvvukDe9NTU1uvbaa/W6170u5A3/+Ovs379fTz75pP7+97/riSeeiOlrt0JBQYHe+MY36je/+Y0k6c9//rOuv/56FRQUTLmv2+3Wn//858DlN77xjWHvF4uOjg595jOf0ejoaOC6qqoqffrTn9Yb3/jGkO+V559/XrfddpueffbZwHX33Xefli1bFvWoyhe+8AW1t7cHLjudTn3wgx/Uu971LtXW1gau3759u773ve/pkUceUVdXl77xjW/E9fVk4nsfAJDf2AEBAMgbmzdv1qWXXhoSPlRVVemXv/xlQuGDJF1//fUhnxy/853v1H333afzzz9/yhDHwsJCvfGNb9R9992nlStXBq7/17/+pUcffTTm1xwPHz7/+c/r29/+dkj4II294Qy3k+OGG27Qo48+qs997nM6/fTTww4gNAxDS5Ys0Q033KCf/OQnIW+2f/zjH0+7rr/+9a+Bfy8oKNDdd9+td7zjHVPCh/HXmTNnji655BL97Gc/08MPPyy73T79F26h4GMU/f39+te//hX2fv/3f/+ngYGBsI+L14033qjDhw8HLjc3N+v+++/X29/+9infKytXrtTdd9+tt771rSHXf+1rX5PL5Yr4Gn/+85/1+OOPBy4XFBToJz/5iT760Y+GhA/SWC3p7bffHpiBERxaxCIT3/sAgPxGAAEAyBu33npryKe1dXV1uvPOO3XiiScm9HwbNmwIebP3qle9SjfeeKOcTue0j6uoqNDtt98e8ibt5z//eVyvff755+vyyy+P6zGnnHJKXG/yX/nKV+qKK64IXN64ceO0MyeCdz+cdtppWrBgQcyvtWDBgrR+Cr58+XLNmzcvcDn4mEWw4OMZ8+fP10knnZTQ6+3atUuPPfZY4LLNZtP3vvc9tba2RnyMzWbTzTffHDKAcnh4OLBzI5xf/epXIZevvfZanX766dOu7dprr9UrX/nKKF9BqEx+7wMA8hcBBAAgb0yezP/5z39eixcvTvj5Jr/Zu/7662N+E11fX6+3v/3tgcvPPfdcYGdDLD7+8Y/HfN9kvOUtbwm5/Pzzz0e878jISODf01mtmagLLrgg8O///ve/Q8IpaaxOMvhNdvD94/X73/8+5Pvv/PPPjynMcDgc+uxnPxty3e9+97uwLRM7d+7Uhg0bApcbGxt16aWXxrS+ya8RTSa/9wEA+YsAAgCQt2666aaYJveH4/f7Q96cjg9XjMfkT52feeaZmB63bNkyzZ07N67XStSsWbNCLr/88ssR7xs82PCZZ57RwYMHU7YuK7z1rW8NVEF6PB49/PDDIbc/9NBD8nq
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
2025-11-04 22:56:05 +01:00
}
],
2025-11-04 23:26:22 +01:00
"execution_count": 69
2025-11-04 18:05:00 +01:00
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}