1.5 Introdução a Redes Neurais: minha primeira Rede Neural em JAX :D
import matplotlib.pyplot as plt
# Matemática + manipulação de vetores
import math
import numpy as np
# JAX
import jax
from jax import nn
import jax.numpy as jnp
# # "Fixar" números aleatórios a serem gerados
np.random.seed(0)
# Trabalhar com os dados
import pandas as pd
from sklearn.datasets import load_iris, fetch_openml
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split
# Utilidades
import utils
# Recarregar automaticamente dependências caso elas mudem
%load_ext autoreload
%autoreload 2
SYNT_TRAIN_SIZE = 200
# controla o quão espalhados são os dados
STD_DEV = 0.7
def random_error(size, mu=0, std_dev=0.5):
return np.random.normal(mu, std_dev, size)
def add_batch_dim(tensor):
if len(tensor.shape) == 1:
return jnp.expand_dims(tensor, axis=1)
else:
return tensor
def remove_batch_dim(tensor):
return jnp.squeeze(tensor, axis=1)
def generate_x(size, use_batch_dim=True):
x = np.random.rand(size)
if use_batch_dim:
x = add_batch_dim(x)
return x
def plot_line(x, y, style='-b'):
x, y = remove_batch_dim(x), remove_batch_dim(y)
return plt.plot([min(x), max(x)], [min(y), max(y)], style)
def generate_f(x, a=7, b=15, error_std_dev=0.5, use_batch_dim=True):
y = a * x + b + random_error(x.shape, std_dev=error_std_dev)
if use_batch_dim:
y = add_batch_dim(y)
return y
def identity(x):
return x
def _accuracy(pred_y, real_y):
p = np.argmax(jax.nn.softmax(pred_y), axis=1)
if len(real_y.shape): #Se tiver usando one_hot encoding
real_y = np.argmax(real_y, axis=1)
return np.sum(p == real_y) / len(pred_y)
# gera valores aleatórios para x
synt_x = generate_x(SYNT_TRAIN_SIZE)
# gera a funcão: Y = 7 * X + 15
synt_y = generate_f(synt_x, error_std_dev=STD_DEV)
plt.plot(synt_x, synt_y, 'ro', alpha=0.4)
plot_line(synt_x, synt_x * 7 + 15)
plt.show()
def define_params(sizes=[1, 1]):
'''He-et-all initialization'''
weights = []
for i, (in_dim, out_dim) in enumerate(zip(sizes[:-1], sizes[1:])):
weights.append({"w": np.random.randn(in_dim, out_dim) * np.sqrt(2/in_dim),
"b": np.random.randn(out_dim) * np.sqrt(2/in_dim)})
return weights
def apply_fn(weights, batch_x, activations):
output = batch_x
for layer, act_fn in zip(weights, activations):
output = jnp.dot(output, layer["w"]) + layer["b"]
output = act_fn(output)
return output
def l2_loss(weights, batch_x, real_y, activations):
pred_y = apply_fn(weights, batch_x, activations)
return 0.5 * np.mean((pred_y - real_y)**2)
def cross_entropy(weights, batch_x, real_y, activations):
pred_y = apply_fn(weights, batch_x, activations)
real_y = jnp.asarray(real_y)
return -jnp.mean(jnp.sum(pred_y * real_y, axis=1))
def train_step(weights, batch_x, batch_y, activations, loss_fn=l2_loss, lr=0.1):
loss, grads = jax.value_and_grad(loss_fn)(weights, batch_x, batch_y, activations)
weights = jax.tree_util.tree_multimap(lambda v, g: v - lr*g, weights, grads)
return weights, loss
def evaluate(weights, activations, batch_x, batch_y, metrics=[], loss_fn=l2_loss):
# run feed forward network
pred_y = apply_fn(weights, batch_x, activations)
# loss
loss = loss_fn(weights, batch_x, batch_y, activations)
# metrics
res_metrics = []
for m in metrics:
res_metrics.append(m(pred_y, batch_y))
return loss, res_metrics
def plot_losses(train_losses, eval_losses, step):
if len(eval_losses) > 0:
plt.title('Train Loss: %.4f | Test Loss: %.4f for step %d' % (train_losses[-1], eval_losses[-1], step))
plt.plot([i for i in range(0, step, 10)], eval_losses)
else:
plt.title('Train Loss: %.4f for step %d' % (train_losses[-1], step))
plt.plot([i for i in range(step)], train_losses)
Gradients
L2 loss with 1 layer, no activation
Loss
$$L = 1/2 * 1/n * \sum{(y_i - ŷ_i)^{2}}$$ $$L = 1/2 * 1/n * \sum{(y_i - w_i * x_i + b_i)^{2}}$$
Gradients
$$\frac{\partial L}{\partial w_i} = 1/2 * 1/n * 2 * \sum{(y_i - ŷ_i)} * \frac{\partial {ŷ_i}}{\partial w_i} $$ $$\frac{\partial L}{\partial w_i} = 1/n * \sum{(y_i - ŷ_i)} * x_i$$
$$\frac{\partial L}{\partial b_i} = 1/2 * 1/n * 2 * \sum{(y_i - ŷ_i)} * \frac{\partial {ŷ_i}}{\partial b_i} $$ $$\frac{\partial L}{\partial b_i} = 1/n * \sum{(y_i - ŷ_i)} * 1$$
L2 loss with 2 layers, relu activation in the hidden layer
Loss
$$L = 1/2 * 1/n * \sum{(y_i - ŷ_i)^{2}}$$ $$L = 1/2 * 1/n * \sum{(y_i - (w_j * x_j + b_j))^{2}}$$ $$x_j = relu(w_i * x_i + b_i)$$
Gradients
$$\frac{\partial L}{\partial w_i} = 1/n * \sum{(y_i - ŷ_i)} * x_j $$ $$\frac{\partial L}{\partial b_i} = 1/n * \sum{(y_i - ŷ_i)} * 1$$
$$\frac{\partial L}{\partial w_j} = 1/n * \sum{(y_i - ŷ_i)} * x_j * x_i, se relu() > 0$$ $$\frac{\partial L}{\partial b_j} = 1/n * \sum{(y_i - ŷ_i)} * x_j, se relu() > 0$$
neural_net = define_params()
train_losses = []
eval_losses = []
for i in range(1000):
neural_net, loss = train_step(neural_net, synt_x, synt_y, [identity])
train_losses.append(loss)
# if i % 10 == 0:
# loss, metrics = evaluate(neural_net, [identity], synt_x, synt_y, metrics=[_accuracy])
# eval_losses.append(loss)
plot_losses(train_losses, eval_losses, 1000)
print('Parâmetros aprendidos:')
print('pesos:', neural_net[0]["w"])
print('bias:', neural_net[0]["b"])
print('Função que modela os dados: 7 * X + 15')
plot_line(synt_x, apply_fn(neural_net, synt_x, [identity]), '--r')
plot_line(synt_x, synt_y)
plt.show()
def get_random_error(size, mu=0, std_dev=0.8):
return np.random.normal(mu, std_dev, size)
synt_x = np.random.rand(SYNT_TRAIN_SIZE)
synt_y = jnp.reshape(7 * np.log(synt_x) + 1 + get_random_error(SYNT_TRAIN_SIZE), (SYNT_TRAIN_SIZE, 1))
synt_x = jnp.reshape(synt_x, (SYNT_TRAIN_SIZE, 1))
plt.plot(synt_x, synt_y, 'ro', alpha=0.5)
nn = define_params(sizes=[1, 10, 1])
activations=[jax.nn.sigmoid, identity]
train_losses = []
eval_losses = []
for i in range(1000):
nn, loss = train_step(nn, synt_x, synt_y, activations)
train_losses.append(loss)
# if i % 10 == 0:
# loss, metrics = evaluate(neural_net, [identity], synt_x, synt_y, metrics=[_accuracy])
# eval_losses.append(loss)
plot_losses(train_losses, eval_losses, 1000)
print('Parâmetros aprendidos:')
print('pesos:', [weight["w"] for weight in nn])
print('bias:', [weight["b"] for weight in nn])
print('Função que modela os dados: 7 * X + 15')
plt.plot(synt_x, apply_fn(nn, synt_x, activations), 'or', alpha=0.3)
plt.plot(synt_x, synt_y, 'og', alpha=0.3)
plt.show()
xor_x = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]])
xor_y = jnp.array([[0], [1], [1], [0]])
activations=[jax.nn.relu, identity]
nn = define_params(sizes=[2, 10, 2])
for i in range(1000):
nn, loss = train_step(nn, xor_x, xor_y, activations)
# plot_losses()
plt.plot(xor_x, apply_fn(nn, xor_x, activations), 'bo', xor_x, xor_y, 'ro', alpha=0.3)
activations=[identity, identity]
nn = define_params(sizes=[2, 10, 2])
for i in range(1000):
nn, loss = train_step(nn, xor_x, xor_y, activations)
# plot_losses()
plt.plot(xor_x, apply_fn(nn, xor_x, activations), 'bo', xor_x, xor_y, 'ro', alpha=0.3)
Exemplo: base dados Iris
Digamos que para um exemplo da base de dados queremos determinar qual a espécie dessa planta.
Entradas
A base de dados iris tem 4 atributos de uma planta que iremos usar como entrada.
Saídas
Neste caso a saída que nos interessa é a espécie da planta. Então digamos que a saída é um número que indica qual a espécie:
0 = Iris Setosa , 1 = Iris Versicolour, 2 = Iris Virginica
iris = load_iris()
# np.c_ concatena as features e targets do dataset
iris_data = pd.DataFrame(data=np.c_[iris['data'], iris['target']],
columns=['x0', 'x1', 'x2', 'x3', 'target'])
iris_data.head()
iris_data.describe()
iris_data.drop(['target'], axis=1).diff().hist(color='k', alpha=0.5, bins=10, figsize=(4, 5))
plt.show()
def _one_hot(x, k, dtype=np.float32):
"""Create a one-hot encoding of x of size k."""
return np.array(x[:, None] == np.arange(k), dtype)
x = iris.data
y = iris.target
y = _one_hot(y, 3)
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.25, random_state=42)
def batches(x, y, batch_size=16):
idx = np.random.permutation(len(x))
x = x[idx]
y = y[idx]
for i in range(0, len(x)-batch_size-1, batch_size):
batch_x = x[i:i+batch_size]
batch_y = y[i:i+batch_size]
yield batch_x, batch_y
activations=[jax.nn.relu, jax.nn.log_softmax]
nn = define_params(sizes=[4, 10, 3])
for i in range(100):
for batch_x, batch_y in batches(x_train, y_train):
nn, loss = train_step(nn, batch_x, batch_y, activations, loss_fn=cross_entropy, lr=0.01)
if i % 10 == 0:
loss, metrics = evaluate(nn, activations, x_test, y_test, metrics=[_accuracy], loss_fn=cross_entropy)
print('Test loss = %.5f, accuracy %.5f' % (loss, metrics[0]))
plot_losses(train_losses, eval_losses, 1000)
mnist = fetch_openml('mnist_784', version=1, return_X_y=True, as_frame=False)
x = mnist[0] / np.max(mnist[0])
y = np.array([int(label) for label in mnist[1]])
y = _one_hot(y, 10)
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.25, random_state=42)
activations=[jax.nn.relu, jax.nn.relu, jax.nn.log_softmax]
nn = define_params(sizes=[784, 512, 256, 10])
for i in range(20):
for batch_x, batch_y in batches(x_train, y_train, 64):
nn, loss = train_step(nn, batch_x, batch_y, activations, loss_fn=cross_entropy, lr=0.001)
loss, metrics = evaluate(nn, activations, x_test, y_test, metrics=[_accuracy], loss_fn=cross_entropy)
print('Test loss = %.5f, accuracy %.5f' % (loss, metrics[0]))
# t.plot_losses()