JAX é uma nova biblioteca de Python da Google com foco em pesquisa de alta performance em Aprendizado de Máquina e seguindo o paradigma de programação funcional.

Mais especificamente JAX nos dá acesso a: e , as principais sendo grad, jit, vmap e pmap(que vai ter seu próprio post no futuro).

  • uma API compatível com numpy e scipy
  • uma própria API de números aleatório manejados manualmente
  • transformações de função composicionais (você pode aplicar elas em conjunto sem muito problema)
  • transformações para trivialmente computar derivadas de funções
  • transformações para usar aceleradores (CPU, GPU e TPU)
  • transformações que permitem paralelizar seu código facilmente

Para usar o JAX é recomendado checar as instruções de instalação, e a seguir importá-lo:

# jax.numpy toda hora, então fazemos um alias dele para jnp
import jax
import jax.numpy as jnp

Agora vamos ver cada uma dessas features da biblioteca

O Wrapper de Numpy: jax.numpy:

import numpy as np

a = np.array([1., 2., 3.])
b = np.array([1., 1., -1.])
print(np.dot(a, b), jnp.dot(a, b))
0.0 0.0
/home/joaogui/miniconda3/envs/jax/lib/python3.7/site-packages/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')

Esse warning que aparece é ele dizendo que sou pobre e não tenho nem GPU nem TPU :(

(np.square(a), jnp.square(a))
(array([1., 4., 9.]), DeviceArray([1., 4., 9.], dtype=float32))

Note que JAX tem seu próprio tipo de array, o DeviceArray, em geral as funções vão transformar arrays de numpy em DeviceArrays, então se você quiser boa performance é melhor fazer essa transformação manualmente antes de passar os dados para várias funções.

Números aleatórios jax.random

Uma das partes mais peculiares de JAX, para faciliar implementações usando paralelismo não existe uma semente global para geradores de números aleatórios, em vez disso em JAX você passa explicitamente a seed para cada função que envolve aleatoriedade, e cabe a você atualizá-la

key = jax.random.PRNGKey(42) #cria um semente aleatória
a = jax.random.normal(key, ())
b = jax.random.normal(key, ())
print(a, b)
print(a == b) #como usamos a mesma semente para a mesma função temos valores iguais
k1, k2 = jax.random.split(key, 2) #vamos criar duas novas seeds a partir da primeira
a = jax.random.normal(k1, ())
b = jax.random.normal(k2, ())
print(a, b) #Agora são diferentes
-0.18471184 -0.18471184
True
0.13790314 1.3694694

Transformações

O principal diferencial de JAX são suas tranformações de funções, que nos permitem modificar facilmente funções definidas a partir de outras funções do JAX e algumas primitivas de Python.

Algo muito útil e legal delas é que podem ser utilizadas em conjunto (são "composicionais"), nos permitindo por exemplo compilar a derivada de uma função vetorizada apenas aplicando 3 transformações uma seguida da outra a função original.

Porém existem alguns cuidados que dever ser tomados ao se usar esses transformações, para entender esse cuidados melhor cheque esse link e abra o notebook

Diferenciação Automática: jax.grad

Em aprendizado de máquina, principalmente quando estamos tratando de redes neurais, lidamos com muitas derivadas, gradientes e afins: Para treinar uma regressão linear ou logística, precisamos computar um hessiano, para treinar uma rede neural usamos descida de gradiente, que requer o cálculo de um gradiente, dentre outros exemplos.

Computar essas derivadas na mão é muitas vezes extremamente trabalhoso, ou até mesmo impossível dado o tempo disponível, assim temos algoritmos como o backpropagation para redes neurais, porém se sempre tivessemos que implementar nós mesmos esse algoritmo, e implementar a derivada de cada uma das funções que vamos usar, terminaríamos com uma quatidade imensa de código duplicado, além de uma imensa chance de errarmos algo na implementação e terminarmos sem conseguir bons resultados ou com resultados que não correspodem a realidade.

Para lidar com isso temos diferenciação automática, transformações que recebem uma função e retornam algum tipo de derivada dela. Simplesmente ter diferenciação automática para as funções de Numpy já é o bastante para uma biblioteca mostrar seu valor, tanto que existe uma biblioteca que é exatamente isso, chamada de Autograd, em muitos sentidos JAX é um sucessor dessa biblioteca.

from jax import grad
from math import pi, sqrt
dup = grad(jnp.square)
print(dup(3.0)) #A derivada de x² é 2x
print(grad(dup)(3.0)) #Podemos aplicar várias vezes a grad

@grad #Podemos usar as transformações como decoradores
def composite_func(x):
    y = x**2
    return jnp.cos(y)
# Pela regra da cadeia, dcos(x²)/dx = -2xsen(x²)
print(composite_func(jnp.sqrt(pi/2)), -2*sqrt(pi/2))
6.0
2.0
-2.5066283 -2.5066282746310002

Para funções com várias variáveis de entrada a grad por padrão nos dá a derivada em função do primeiro parâmetro, mas podemos mudar isso com o argumento argnums. Também vale ressaltar que os argumentos não precisam ser apenas números e podem ser vetores

def f(x, y):
    return x*(y**2)
dfdy = grad(f, argnums=(1))
print(dfdy(3.0, 4.0))
gradient = grad(f, argnums=(0, 1))
print(gradient(3.0, 4.0))

def g(v):
    return jnp.linalg.norm(v)
print(grad(g)(a))
24.0
(DeviceArray(16., dtype=float32), DeviceArray(24., dtype=float32))
1.0

Compilação com XLA: jit

Mas as vantagens de jax não param em diferenciação automática, se não seria apenas um clone do autograd, jax também tem a habilidade de compilar funções usando o XLA (accelerated linear algebra) da Google, tornando-as bem mais rápidas, além de permitir o uso de aceleradores como GPUs e TPUs.

from jax import jit
a = 1 + jax.random.normal(k1, (2024, 2024))
b = 1 + jax.random.normal(k2, (2024, 2024))

A vantagem se torna maior (> 20x mais rápido) quando usamos aceleradores.

@jit #equivalente a definir jcos e escrever jcos = jit(jcos)
def jcos(a, b):
    return jnp.dot(a, b)/jnp.sqrt(jnp.dot(a, a)*jnp.dot(b, b))
def npcos(a, b):
    return np.dot(a, b)/np.sqrt(np.dot(a, a)*np.dot(b, b))
%%timeit
npcos(a, b)
241 ms ± 60.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
jcos(a, b) #Rodamos uma vez fora para compilar a função
DeviceArray([[1.0336248 , 1.028297  , 1.0310295 , ..., 0.96507126,
              1.0339284 , 1.0405782 ],
             [0.94755113, 1.0175968 , 0.9993737 , ..., 1.0073111 ,
              1.0127649 , 1.0145004 ],
             [0.99184996, 1.0617024 , 1.0004048 , ..., 1.0069526 ,
              1.0510893 , 1.0679886 ],
             ...,
             [0.965744  , 1.0246744 , 1.0708025 , ..., 1.0135127 ,
              1.0477784 , 0.98690724],
             [0.9184314 , 0.99969995, 0.9819234 , ..., 0.9972953 ,
              0.9442775 , 0.9897808 ],
             [0.99618256, 1.0751995 , 1.0236498 , ..., 1.0285234 ,
              1.0351353 , 1.0573723 ]], dtype=float32)
%%timeit
jcos(a, b)
220 ms ± 4.32 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)

Vetorização Automática: vmap

Vmap é uma transformação muito interessante, usando ela é possível vetorizar automaticamente nossas funções, ou seja, em vez de ter que fazer uma função que lida com um batch de dados, podemos fazer uma função que recebe um único dado e depois usar a trasformação para ganhar a versão que lida com o batch.

a = np.array([1., 2., 3.])
b = np.array([1., 1., -1.])
c = np.array([[1., 2., 3.], [4., 5., 6.]])

@jax.vmap #Podemos usar as transformações como decoradores
def f(x, y):
    return x/y + 1.
print(f(a, b))

def prod(x, y):
    return x@y
print(prod(a, b))
[ 2.  3. -2.]
0.0
try:
    prod(a, c) #a e c não têm dimensões compatíveis
except Exception as e:
    print(e)
matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 2 is different from 3)
batch_prod = jax.vmap(prod, in_axes=(None, 0)) #vamos multiplica a por cada linha de c
batch_prod(a, c)
DeviceArray([14., 32.], dtype=float32)

Nessa primeira parte vimos qual o propósito da biblioteca e suas principais funções, nos próximos posts vamos explorar como criar redes neurais com jax, suas bibliotecas experimentais, o ecossistema de bibliotecas escritas usando jax.