None basic_ML_MPP

Ejemplo mínimo de aprendizaje automático

Métodos potenciales de prospección, FCAG, 2024.

Para esta breve experiencia, empleamos JAX. Para el software, escribimos en la terminal de Miniconda:

conda> pip install --upgrade pip
conda> pip install jax flax optax

Problema

Dado un conjunto $\{x_i,y_i\}$ de puntos de estación $x_i$ y observaciones $y_i$, más un modelo directo $f(x;a,b)$, intentaremos invertir el valor de los parámetros $a$ y $b$ que mejor ajustan las observaciones. La estructura del dato es similar a lo que veremos en las prácticas para un modelo directo de anomalía de gravedad.

$$\{x_i,{\color{green}{y_i}}\}, \quad f(x;{\color{blue}a},{\color{blue}b}) = e^{-{\color{blue}a}x^2} + {\color{blue} b}, \quad {{\color{green}{y_i}}}\approx f(x_i;{\color{blue}a},{\color{blue}b}) = {\hat{\color{red}{y_i}}}.$$

Observamos que el modelo directo es no lineal.

In [1]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

import treescope   # visualizar arreglos y modelos 
treescope.basic_interactive_setup(autovisualize_arrays=True)

# Estilo para  las notebooks preliminares:
#plt.xkcd();
plt.rcParams.update({'font.size': 16}) # tamaño de fuente
In [2]:
# El modelo y las observaciones

# Las observaciones

A0 = 3.0
B0 = 2.0

N = 100 # número de observaciones

# posiciones de las observaciones:
#x     = np.random.normal(size=[N])
seed  = 2024
key   = jax.random.PRNGKey(seed)
x     = jax.random.normal(key,shape=[N])
# Ruido:
seed  = 2025
key   = jax.random.PRNGKey(seed)
noise = 0.01*jax.random.normal(key,shape=[N])

# Calculamos y
y = jnp.exp(-x**2 * A0) + B0 
In [3]:
# Visualizamos:
plt.figure(figsize=(7,5))
plt.scatter(x,y,label="dato $y_i$",color="royalblue",alpha=0.8)
plt.xlabel("x")
plt.ylabel("y")
plt.ylim([0.5,3.5])
plt.legend()
plt.show()

Modelo

In [4]:
# El modelo directo
def model(x, W = [0.,0.]):
    a,b = W
    return jnp.exp(-x**2 * a) + b 
In [5]:
# Visualizamos:
plt.figure(figsize=(7,5))
plt.scatter(x, y, label=r"$y_i$", color="royalblue",alpha=0.8)
plt.scatter(x, model(x), label=r"$\hat y_i$ modelo inicial", color="tomato")
plt.xlabel(r"$x$")
plt.ylabel(r"$y$")
plt.ylim([0.5,3.5])
plt.legend()
plt.show()

Como observamos, el modelo inicial no resulta en un buen ajuste al dato. De manera cuantitativa, el error cuadrático medio (RMS) entre los datos y el resultado del modelo es:

In [7]:
# Una medida del error del modelo respecto al dato observado
def RMS(y,y0):
    return jnp.mean( (y - y0) ** 2 ) # RMS
                     
print(f"RMS = {RMS(model(x),y):.2f}")                
RMS = 1.99

Entrenamiento

Definimos una función de costo (el cuadrado de la norma L2) y nos decidimos por un algoritmo optimizador.

  • El optimizador es RMSprop.
  • Utilizaremos un tamaño de paso o parámetro de aprendizaje (learning rate) fijo.
  • Estimamos el valor del gradiente tomando el dato de entrada en grupos de batch_size muestras (batch size).
  • Haremos un número de epochs pasadas (epochs) sobre el dato original completo.

Otros optimizadores a probar son descenso de gradiente estocástico (SGD) y Adam (adaptive moment estimation). El algoritmo Adam necesitará un número menor de iteraciones respecto a SGD para llegar a un resultado aceptable.

In [8]:
# Parámetros de entrenamiento:

epochs        = 2000 
batch_size    = len(x)//10
learning_rate = 0.1

print(f"Épocas: {epochs}",)
print(f"Tamaño de batch: {batch_size}")
print(f"Parámetro de aprendizaje: {learning_rate}")
Épocas: 2000
Tamaño de batch: 10
Parámetro de aprendizaje: 0.1
In [21]:
# Función de costo:
def loss(W,ytrue): # La función de costo será el RMS    
    ypred = model(x,W)
    return jnp.mean((ypred-ytrue)**2 )
In [23]:
# Entrenamiento

W = [0.,0.]   # Parámetros iniciales del modelo
history = []

for epoch in range(epochs):   # No utilizamos batches aquí para simplificar la demostración   
    ytrue = y
   
    # derivadas parciales de la función de costo respecto de los
    # parámetros del modelo:
    loss_value, derivatives = jax.value_and_grad(loss)(W,ytrue)
   
    # Actualización de los parámetros del modelo:
    A,B       = W
    dLdA,dLdB = derivatives
    A   = A - learning_rate * dLdA
    B   = B - learning_rate * dLdB
    W         = [A,B]

    # Guardamos el costo:
    history.append(loss_value)
    # Resultados cada un cierto número de épocas:
    if epoch % 100 == 0:
        print(f"Época {epoch}, loss = {loss_value:.4f}")
print("Listo.")
Época 0, loss = 1.9931
Época 100, loss = 0.0179
Época 200, loss = 0.0129
Época 300, loss = 0.0098
Época 400, loss = 0.0076
Época 500, loss = 0.0061
Época 600, loss = 0.0050
Época 700, loss = 0.0041
Época 800, loss = 0.0034
Época 900, loss = 0.0029
Época 1000, loss = 0.0025
Época 1100, loss = 0.0021
Época 1200, loss = 0.0018
Época 1300, loss = 0.0016
Época 1400, loss = 0.0014
Época 1500, loss = 0.0012
Época 1600, loss = 0.0010
Época 1700, loss = 0.0009
Época 1800, loss = 0.0008
Época 1900, loss = 0.0007
Listo.
In [24]:
# Los parámetros invertidos son:
    
a,b = W

print(f"a* = {a:.3f} " )
print(f"b* = {b:.3f} " )
a* = 2.506 
b* = 1.968 
In [25]:
# Visualizamos
plt.figure(figsize=(7,5))
plt.scatter(x,y,label=r"$y_i$",color="royalblue",alpha=1)
plt.scatter(x,model(x,W),label=r"$\hat y_i$ modelo entrenado",color="orange",alpha=0.65)
plt.ylim([0.5,3.5])
plt.xlabel(r"$x$")
plt.ylabel(r"$y$")
plt.legend()
plt.show()

Apreciamos un mejor ajuste a los datos observados que el obtenido para el modelo inicial, con un RMS dado por:

In [ ]:
# Ajuste a los datos observados que el modelo entrenado:
print("RMS = ", RMS(model(x,W),y))    

Por último, observamos la función de costo (normalizada en la primer época) en función del número de épocas.

In [26]:
# Visualizamos el costo (normalizado) a lo largo de las iteraciones:

history = jnp.array(history)
plt.figure()
#plt.plot(jnp.log(history/history[0]),color="gray",lw=4)
plt.plot(history/history[0],color="gray",lw=4)
plt.xlabel("# Épocas")
plt.ylabel("Costo (normalizado) []")
plt.show()

Con estos conceptos y herramientas, en otra actividad práctica probaremos invertir datos de anomalía de gravedad para estimar parámetros físicos y geométricos de un modelo directo no lineal.

Empleando NNX y Optax

Para otros problemas emplearemos la librería nnx de FLAX con optimizadores provistos por optax. Veamos como hacerlo en este ejemplo.

In [27]:
from flax import nnx  # Neural Networks for JAX
from optax import sgd
In [28]:
epochs        = 2000 
learning_rate = 0.1
In [29]:
# Definimos la clase a la cual pertenece el modelo:

class Model(nnx.Module):
    def __init__(self):
        self.A = nnx.Param(0.,name="A")
        self.B = nnx.Param(0.,name="B")
    def __call__(self,x):
        return jnp.exp( - x ** 2 * self.A ) +  self.B 
    
# Crear una instancia del modelo y asociar un optimizador:
model     = Model()
optimizer = nnx.Optimizer(model, sgd(learning_rate))  # optimizador
#optimizer = nnx.Optimizer(model, adam(learning_rate))  # optimizador

# Rutina de entrenamiento

@nnx.jit # 
def train_step(model, optimizer, x, y):
  
  def loss_fn(model):
    y_pred = model(x)  
    loss_per_sample = (y_pred - y) ** 2
    return loss_per_sample.mean()

  loss, grads = nnx.value_and_grad(loss_fn)(model)
  optimizer.update(grads)
  return loss
In [30]:
print(f"Parámetros iniciales: A = {model.A.value} y B = {model.B.value}")
Parámetros iniciales: A = 0.0 y B = 0.0
In [31]:
# Entrenamiento:
history = []
for epoch in range(epochs):   # No utilizamos batches aquí para simplificar la demostración
  
    loss_value = train_step(model, optimizer, x, y)
   
    # Guardamos el costo:
    history.append(loss_value)
    # Resultados cada un cierto número de épocas:
    if epoch % 100 == 0:
        print(f"Época {epoch}, loss = {loss_value:.4f}")
print("Listo.")
Época 0, loss = 1.9931
Época 100, loss = 0.0179
Época 200, loss = 0.0129
Época 300, loss = 0.0098
Época 400, loss = 0.0076
Época 500, loss = 0.0061
Época 600, loss = 0.0050
Época 700, loss = 0.0041
Época 800, loss = 0.0034
Época 900, loss = 0.0029
Época 1000, loss = 0.0025
Época 1100, loss = 0.0021
Época 1200, loss = 0.0018
Época 1300, loss = 0.0016
Época 1400, loss = 0.0014
Época 1500, loss = 0.0012
Época 1600, loss = 0.0010
Época 1700, loss = 0.0009
Época 1800, loss = 0.0008
Época 1900, loss = 0.0007
Listo.
In [32]:
print(f"Parámetros obtenidos: A* = {model.A.value:.3f} y B* = {model.B.value:.3f}")
#nnx.display(model)
Parámetros obtenidos: A* = 2.506 y B* = 1.968
In [33]:
# utilizando treescope para una visualización diferente:
nnx.display(model)

Eso es todo por hoy.

Misceláneas

Utilizamos un optimizador dado por una librería básica de JAX. Podemos emplear SGD, Adam, etc.

In [35]:
# Utilizando un optimizador de JAX (example library)

# El modelo directo
def model(x, W = [0.,0.]):
    a,b = W
    return jnp.exp(-x**2 * a) + b 

epochs        = 2000 
batch_size    = len(x)//10
learning_rate = 0.1


from jax.example_libraries import optimizers

params    = [0.,0.]

opt_init, opt_update, get_params = optimizers.sgd(learning_rate) # un optimizador devuelve un tríptico (init,update,get)
#opt_init, opt_update, get_params = optimizers.adam(learning_rate) # un optimizador devuelve un tríptico (init,update,get)

opt_state = opt_init(params)

def step(step, opt_state, ypred, ytrue):
    value, grads = jax.value_and_grad(loss)(get_params(opt_state), ytrue)
    opt_state = opt_update(step, grads, opt_state)
    return value, opt_state


history = []
for epoch in range(epochs):    
    
    ypred, ytrue = model(x,get_params(opt_state)), y
    loss_value, opt_state = step(epoch, opt_state, ypred, ytrue)  # actualiza los parámetros y da el valor de la función de costo
    
    # Guardamos el costo:
    history.append(loss_value)
    # Resultados cada un cierto número de épocas:
    if epoch % 100 == 0:
        print(f"Época {epoch}, loss = {loss_value:.4f}")
    history.append(loss_value)

# Parámetros obtenidos
W = get_params(opt_state)

# Los parámetros invertidos son:
    
a,b = W
print(f"")
print(f"a* = {a:.3f} " )
print(f"b* = {b:.3f} " )
print(f"")
Época 0, loss = 1.9931
Época 100, loss = 0.0179
Época 200, loss = 0.0129
Época 300, loss = 0.0098
Época 400, loss = 0.0076
Época 500, loss = 0.0061
Época 600, loss = 0.0050
Época 700, loss = 0.0041
Época 800, loss = 0.0034
Época 900, loss = 0.0029
Época 1000, loss = 0.0025
Época 1100, loss = 0.0021
Época 1200, loss = 0.0018
Época 1300, loss = 0.0016
Época 1400, loss = 0.0014
Época 1500, loss = 0.0012
Época 1600, loss = 0.0010
Época 1700, loss = 0.0009
Época 1800, loss = 0.0008
Época 1900, loss = 0.0007

a* = 2.506 
b* = 1.968 

Eso es todo por hoy.

Referencias