None
Métodos potenciales de prospección, FCAG, 2024.
Para esta breve experiencia, empleamos JAX
. Para el software, escribimos en la terminal de Miniconda:
conda> pip install --upgrade pip
conda> pip install jax flax optax
Dado un conjunto $\{x_i,y_i\}$ de puntos de estación $x_i$ y observaciones $y_i$, más un modelo directo $f(x;a,b)$, intentaremos invertir el valor de los parámetros $a$ y $b$ que mejor ajustan las observaciones. La estructura del dato es similar a lo que veremos en las prácticas para un modelo directo de anomalía de gravedad.
Observamos que el modelo directo es no lineal.
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import treescope # visualizar arreglos y modelos
treescope.basic_interactive_setup(autovisualize_arrays=True)
# Estilo para las notebooks preliminares:
#plt.xkcd();
plt.rcParams.update({'font.size': 16}) # tamaño de fuente
# El modelo y las observaciones
# Las observaciones
A0 = 3.0
B0 = 2.0
N = 100 # número de observaciones
# posiciones de las observaciones:
#x = np.random.normal(size=[N])
seed = 2024
key = jax.random.PRNGKey(seed)
x = jax.random.normal(key,shape=[N])
# Ruido:
seed = 2025
key = jax.random.PRNGKey(seed)
noise = 0.01*jax.random.normal(key,shape=[N])
# Calculamos y
y = jnp.exp(-x**2 * A0) + B0
# Visualizamos:
plt.figure(figsize=(7,5))
plt.scatter(x,y,label="dato $y_i$",color="royalblue",alpha=0.8)
plt.xlabel("x")
plt.ylabel("y")
plt.ylim([0.5,3.5])
plt.legend()
plt.show()
# El modelo directo
def model(x, W = [0.,0.]):
a,b = W
return jnp.exp(-x**2 * a) + b
# Visualizamos:
plt.figure(figsize=(7,5))
plt.scatter(x, y, label=r"$y_i$", color="royalblue",alpha=0.8)
plt.scatter(x, model(x), label=r"$\hat y_i$ modelo inicial", color="tomato")
plt.xlabel(r"$x$")
plt.ylabel(r"$y$")
plt.ylim([0.5,3.5])
plt.legend()
plt.show()
Como observamos, el modelo inicial no resulta en un buen ajuste al dato. De manera cuantitativa, el error cuadrático medio (RMS) entre los datos y el resultado del modelo es:
# Una medida del error del modelo respecto al dato observado
def RMS(y,y0):
return jnp.mean( (y - y0) ** 2 ) # RMS
print(f"RMS = {RMS(model(x),y):.2f}")
Definimos una función de costo (el cuadrado de la norma L2) y nos decidimos por un algoritmo optimizador.
batch_size
muestras (batch size).epochs
pasadas (epochs) sobre el dato original completo.Otros optimizadores a probar son descenso de gradiente estocástico (SGD) y Adam (adaptive moment estimation). El algoritmo Adam necesitará un número menor de iteraciones respecto a SGD para llegar a un resultado aceptable.
# Parámetros de entrenamiento:
epochs = 2000
batch_size = len(x)//10
learning_rate = 0.1
print(f"Épocas: {epochs}",)
print(f"Tamaño de batch: {batch_size}")
print(f"Parámetro de aprendizaje: {learning_rate}")
# Función de costo:
def loss(W,ytrue): # La función de costo será el RMS
ypred = model(x,W)
return jnp.mean((ypred-ytrue)**2 )
# Entrenamiento
W = [0.,0.] # Parámetros iniciales del modelo
history = []
for epoch in range(epochs): # No utilizamos batches aquí para simplificar la demostración
ytrue = y
# derivadas parciales de la función de costo respecto de los
# parámetros del modelo:
loss_value, derivatives = jax.value_and_grad(loss)(W,ytrue)
# Actualización de los parámetros del modelo:
A,B = W
dLdA,dLdB = derivatives
A = A - learning_rate * dLdA
B = B - learning_rate * dLdB
W = [A,B]
# Guardamos el costo:
history.append(loss_value)
# Resultados cada un cierto número de épocas:
if epoch % 100 == 0:
print(f"Época {epoch}, loss = {loss_value:.4f}")
print("Listo.")
# Los parámetros invertidos son:
a,b = W
print(f"a* = {a:.3f} " )
print(f"b* = {b:.3f} " )
# Visualizamos
plt.figure(figsize=(7,5))
plt.scatter(x,y,label=r"$y_i$",color="royalblue",alpha=1)
plt.scatter(x,model(x,W),label=r"$\hat y_i$ modelo entrenado",color="orange",alpha=0.65)
plt.ylim([0.5,3.5])
plt.xlabel(r"$x$")
plt.ylabel(r"$y$")
plt.legend()
plt.show()
Apreciamos un mejor ajuste a los datos observados que el obtenido para el modelo inicial, con un RMS dado por:
# Ajuste a los datos observados que el modelo entrenado:
print("RMS = ", RMS(model(x,W),y))
Por último, observamos la función de costo (normalizada en la primer época) en función del número de épocas.
# Visualizamos el costo (normalizado) a lo largo de las iteraciones:
history = jnp.array(history)
plt.figure()
#plt.plot(jnp.log(history/history[0]),color="gray",lw=4)
plt.plot(history/history[0],color="gray",lw=4)
plt.xlabel("# Épocas")
plt.ylabel("Costo (normalizado) []")
plt.show()
Con estos conceptos y herramientas, en otra actividad práctica probaremos invertir datos de anomalía de gravedad para estimar parámetros físicos y geométricos de un modelo directo no lineal.
Para otros problemas emplearemos la librería nnx
de FLAX
con optimizadores provistos por optax
. Veamos como hacerlo en este ejemplo.
from flax import nnx # Neural Networks for JAX
from optax import sgd
epochs = 2000
learning_rate = 0.1
# Definimos la clase a la cual pertenece el modelo:
class Model(nnx.Module):
def __init__(self):
self.A = nnx.Param(0.,name="A")
self.B = nnx.Param(0.,name="B")
def __call__(self,x):
return jnp.exp( - x ** 2 * self.A ) + self.B
# Crear una instancia del modelo y asociar un optimizador:
model = Model()
optimizer = nnx.Optimizer(model, sgd(learning_rate)) # optimizador
#optimizer = nnx.Optimizer(model, adam(learning_rate)) # optimizador
# Rutina de entrenamiento
@nnx.jit #
def train_step(model, optimizer, x, y):
def loss_fn(model):
y_pred = model(x)
loss_per_sample = (y_pred - y) ** 2
return loss_per_sample.mean()
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(grads)
return loss
print(f"Parámetros iniciales: A = {model.A.value} y B = {model.B.value}")
# Entrenamiento:
history = []
for epoch in range(epochs): # No utilizamos batches aquí para simplificar la demostración
loss_value = train_step(model, optimizer, x, y)
# Guardamos el costo:
history.append(loss_value)
# Resultados cada un cierto número de épocas:
if epoch % 100 == 0:
print(f"Época {epoch}, loss = {loss_value:.4f}")
print("Listo.")
print(f"Parámetros obtenidos: A* = {model.A.value:.3f} y B* = {model.B.value:.3f}")
#nnx.display(model)
# utilizando treescope para una visualización diferente:
nnx.display(model)
Eso es todo por hoy.
Utilizamos un optimizador dado por una librería básica de JAX. Podemos emplear SGD, Adam, etc.
# Utilizando un optimizador de JAX (example library)
# El modelo directo
def model(x, W = [0.,0.]):
a,b = W
return jnp.exp(-x**2 * a) + b
epochs = 2000
batch_size = len(x)//10
learning_rate = 0.1
from jax.example_libraries import optimizers
params = [0.,0.]
opt_init, opt_update, get_params = optimizers.sgd(learning_rate) # un optimizador devuelve un tríptico (init,update,get)
#opt_init, opt_update, get_params = optimizers.adam(learning_rate) # un optimizador devuelve un tríptico (init,update,get)
opt_state = opt_init(params)
def step(step, opt_state, ypred, ytrue):
value, grads = jax.value_and_grad(loss)(get_params(opt_state), ytrue)
opt_state = opt_update(step, grads, opt_state)
return value, opt_state
history = []
for epoch in range(epochs):
ypred, ytrue = model(x,get_params(opt_state)), y
loss_value, opt_state = step(epoch, opt_state, ypred, ytrue) # actualiza los parámetros y da el valor de la función de costo
# Guardamos el costo:
history.append(loss_value)
# Resultados cada un cierto número de épocas:
if epoch % 100 == 0:
print(f"Época {epoch}, loss = {loss_value:.4f}")
history.append(loss_value)
# Parámetros obtenidos
W = get_params(opt_state)
# Los parámetros invertidos son:
a,b = W
print(f"")
print(f"a* = {a:.3f} " )
print(f"b* = {b:.3f} " )
print(f"")
Eso es todo por hoy.