2. Learning to Dampen the Duffing Oscillator¶
In this notebook we will explore training a neural network to dampen a Simple Harmonic Oscillator
[1]:
%matplotlib inline
from matplotlib import pyplot as plt
import copy
import desolver as de
import torch
Using `autoray` backend
2.1. Specifying the Dynamical System¶
Now let’s specify the right hand side of our dynamical system. It should be
But desolver only works with first order differential equations, thus we must cast this into a first order system before we can solve it. Thus we obtain the following system
[2]:
@de.rhs_prettifier(
equ_repr="[vx, -k*x/m]",
md_repr=r"""
$$
\frac{\mathrm{d}y}{\mathrm{dt}} = \begin{bmatrix}
0 & 1 \\
-\frac{k}{m} & 0
\end{bmatrix} \cdot \vec y
$$
"""
)
def rhs(t, state, k, m, **kwargs):
return torch.tensor([[0.0, 1.0], [-k/m, 0.0]], dtype=state.dtype, device=state.device)@state
[3]:
print(rhs)
display(rhs)
$$
\frac{\mathrm{d}y}{\mathrm{dt}} = \begin{bmatrix}
0 & 1 \\
-\frac{k}{m} & 0
\end{bmatrix} \cdot \vec y
$$
Let’s specify the initial conditions as well
[4]:
y_init = torch.tensor([1., 0.], dtype=torch.float64)
And now we’re ready to integrate!
2.2. The Numerical Integration¶
We will use the same constants from Wikipedia as our constants where the forcing amplitude increases and all the other parameters stay constants.
[5]:
#Let's define the fixed constants
constants = dict(
k = 1.0,
m = 1.0
)
# The period of the system
T = 2*torch.pi*(constants['m']/constants['k'])**0.5
# Initial and Final integration times
t0 = 0.0
tf = 40 * T
[6]:
a = de.OdeSystem(rhs, y0=y_init, dense_output=True, t=(t0, tf), dt=0.0001, rtol=1e-12, atol=1e-12, constants={**constants})
a.method = "RK87"
a.integrate()
2.3. Plotting the State and Phase Portrait¶
[7]:
# Times to evaluate the system at
eval_times = torch.linspace(0.0, 40.0, 1000, device=a.y[-1].device, dtype=a.y[-1].dtype)*T
[8]:
from matplotlib import gridspec
fig = plt.figure(figsize=(20, 4))
gs = gridspec.GridSpec(1, 2, width_ratios=[3, 1])
ax0 = fig.add_subplot(gs[0])
ax1 = fig.add_subplot(gs[1])
ax1.set_aspect(1)
ax0.plot(eval_times/T, a.sol(eval_times)[:, 0])
ax0.set_xlim(0.0, 40.0)
ax0.set_ylim(-1.0, 1.0)
ax0.set_xlabel(r"$t/T$")
ax0.set_ylabel(r"$x$")
ax0.set_title(r"$k={},m={}$".format(a.constants['k'], a.constants['m']))
ax1.plot(a.y[:, 0], a.y[:, 1])
ax1.set_xlim(-1.6, 1.6)
ax1.set_ylim(-1.6, 1.6)
ax1.set_xlabel(r"$x$")
ax1.set_ylabel(r"$\dot x$")
ax1.grid(which='major')
plt.tight_layout()
2.4. Defining a Simple Neural Network¶
Now we can define a simple neural network, in this case a feed-forward network (or a dense network), to dampen the oscillations of the system. Specifically, we will treat the network as providing some continuous force F which will be applied at every timestep assuming no lag in the controller nor any discretisation issues
[9]:
@de.rhs_prettifier(
equ_repr="[vx, -k*x/m+NN(x)/m]",
md_repr=r"""
$$
\frac{\mathrm{d}y}{\mathrm{dt}} = \begin{bmatrix}
0 & 1 \\
-\frac{k}{m} & 0
\end{bmatrix} \cdot \vec y + \begin{bmatrix}
0 \\
\mathcal{NN}(x)/m
\end{bmatrix}
$$
"""
)
def nn_rhs(t, state, k, m, nn_controller, **kwargs):
base_dynamics_rhs = rhs(t, state, k, m)
neural_network_impulse = nn_controller(state)[...,0]
neural_network_impulse = torch.stack([
torch.zeros_like(neural_network_impulse),
neural_network_impulse/m
])
return base_dynamics_rhs + neural_network_impulse
[10]:
print(nn_rhs)
display(nn_rhs)
$$
\frac{\mathrm{d}y}{\mathrm{dt}} = \begin{bmatrix}
0 & 1 \\
-\frac{k}{m} & 0
\end{bmatrix} \cdot \vec y + \begin{bmatrix}
0 \\
\mathcal{NN}(x)/m
\end{bmatrix}
$$
[11]:
state_dim = y_init.shape[0]
hidden_dim = 32
output_dim = 1
simple_nn = torch.nn.Sequential(
torch.nn.Linear(state_dim, hidden_dim),
torch.nn.GELU(),
torch.nn.Linear(hidden_dim, hidden_dim),
torch.nn.GELU(),
torch.nn.Linear(hidden_dim, output_dim),
).to('cuda' if torch.cuda.is_available() else 'cpu')
optimizer = torch.optim.AdamW(simple_nn.parameters(), lr=4e-3, weight_decay=1e-2)
number_of_steps = 512
y_init = y_init.to('cuda' if torch.cuda.is_available() else 'cpu', torch.float32)
def closure():
optimizer.zero_grad()
integrated_system = de.solve_ivp(nn_rhs, t_span=(t0, tf), y0=y_init, method='RK87', args=(constants['k'], constants['m'], simple_nn))
# The loss is the integrated error over the timespan.
# This penalises the network for taking more time to dampen the system
loss = torch.sum((integrated_system.t[1:] * integrated_system.y[0,1:].square()) * torch.diff(integrated_system.t))
loss = 0.1*loss + 0.9*integrated_system.y[-1,0].square()
if loss.requires_grad:
loss.backward()
return loss
best_loss = torch.inf
best_params = copy.deepcopy(simple_nn.state_dict())
for step_idx in range(number_of_steps):
loss = optimizer.step(closure).item()
if loss < best_loss:
best_loss = loss
best_params = copy.deepcopy(simple_nn.state_dict())
print(f"[{step_idx+1}/{number_of_steps}] - loss: {loss:.4e}, best_loss: {best_loss:.4e}")
[1/512] - loss: 1.0276e+02, best_loss: 1.0276e+02
[2/512] - loss: 3.6690e+01, best_loss: 3.6690e+01
[3/512] - loss: 2.2941e+01, best_loss: 2.2941e+01
[4/512] - loss: 1.4030e+01, best_loss: 1.4030e+01
[5/512] - loss: 7.8588e+00, best_loss: 7.8588e+00
[6/512] - loss: 4.0314e+00, best_loss: 4.0314e+00
[7/512] - loss: 2.3058e+00, best_loss: 2.3058e+00
[8/512] - loss: 2.2212e+00, best_loss: 2.2212e+00
[9/512] - loss: 3.1359e+00, best_loss: 2.2212e+00
[10/512] - loss: 4.4046e+00, best_loss: 2.2212e+00
[11/512] - loss: 5.4471e+00, best_loss: 2.2212e+00
[12/512] - loss: 5.9278e+00, best_loss: 2.2212e+00
[13/512] - loss: 5.7600e+00, best_loss: 2.2212e+00
[14/512] - loss: 5.0671e+00, best_loss: 2.2212e+00
[15/512] - loss: 4.0472e+00, best_loss: 2.2212e+00
[16/512] - loss: 2.9402e+00, best_loss: 2.2212e+00
[17/512] - loss: 1.9815e+00, best_loss: 1.9815e+00
[18/512] - loss: 1.3024e+00, best_loss: 1.3024e+00
[19/512] - loss: 9.6044e-01, best_loss: 9.6044e-01
[20/512] - loss: 9.6575e-01, best_loss: 9.6044e-01
[21/512] - loss: 1.1901e+00, best_loss: 9.6044e-01
[22/512] - loss: 1.5168e+00, best_loss: 9.6044e-01
[23/512] - loss: 1.8395e+00, best_loss: 9.6044e-01
[24/512] - loss: 2.0401e+00, best_loss: 9.6044e-01
[25/512] - loss: 2.1047e+00, best_loss: 9.6044e-01
[26/512] - loss: 1.9871e+00, best_loss: 9.6044e-01
[27/512] - loss: 1.7565e+00, best_loss: 9.6044e-01
[28/512] - loss: 1.4550e+00, best_loss: 9.6044e-01
[29/512] - loss: 1.1563e+00, best_loss: 9.6044e-01
[30/512] - loss: 9.2318e-01, best_loss: 9.2318e-01
[31/512] - loss: 7.8288e-01, best_loss: 7.8288e-01
[32/512] - loss: 7.5376e-01, best_loss: 7.5376e-01
[33/512] - loss: 8.0010e-01, best_loss: 7.5376e-01
[34/512] - loss: 8.9981e-01, best_loss: 7.5376e-01
[35/512] - loss: 1.0018e+00, best_loss: 7.5376e-01
[36/512] - loss: 1.0762e+00, best_loss: 7.5376e-01
[37/512] - loss: 1.0992e+00, best_loss: 7.5376e-01
[38/512] - loss: 1.0612e+00, best_loss: 7.5376e-01
[39/512] - loss: 9.8249e-01, best_loss: 7.5376e-01
[40/512] - loss: 8.7389e-01, best_loss: 7.5376e-01
[41/512] - loss: 7.8036e-01, best_loss: 7.5376e-01
[42/512] - loss: 7.2208e-01, best_loss: 7.2208e-01
[43/512] - loss: 6.9142e-01, best_loss: 6.9142e-01
[44/512] - loss: 6.9089e-01, best_loss: 6.9089e-01
[45/512] - loss: 7.1510e-01, best_loss: 6.9089e-01
[46/512] - loss: 7.5018e-01, best_loss: 6.9089e-01
[47/512] - loss: 7.8624e-01, best_loss: 6.9089e-01
[48/512] - loss: 7.8939e-01, best_loss: 6.9089e-01
[49/512] - loss: 7.8281e-01, best_loss: 6.9089e-01
[50/512] - loss: 7.5906e-01, best_loss: 6.9089e-01
[51/512] - loss: 7.2187e-01, best_loss: 6.9089e-01
[52/512] - loss: 6.9182e-01, best_loss: 6.9089e-01
[53/512] - loss: 6.6322e-01, best_loss: 6.6322e-01
[54/512] - loss: 6.5370e-01, best_loss: 6.5370e-01
[55/512] - loss: 6.4848e-01, best_loss: 6.4848e-01
[56/512] - loss: 6.5786e-01, best_loss: 6.4848e-01
[57/512] - loss: 6.6951e-01, best_loss: 6.4848e-01
[58/512] - loss: 6.7516e-01, best_loss: 6.4848e-01
[59/512] - loss: 6.7464e-01, best_loss: 6.4848e-01
[60/512] - loss: 6.6787e-01, best_loss: 6.4848e-01
[61/512] - loss: 6.5964e-01, best_loss: 6.4848e-01
[62/512] - loss: 6.4635e-01, best_loss: 6.4635e-01
[63/512] - loss: 6.3659e-01, best_loss: 6.3659e-01
[64/512] - loss: 6.2101e-01, best_loss: 6.2101e-01
[65/512] - loss: 6.1710e-01, best_loss: 6.1710e-01
[66/512] - loss: 6.1960e-01, best_loss: 6.1710e-01
[67/512] - loss: 6.2047e-01, best_loss: 6.1710e-01
[68/512] - loss: 6.2469e-01, best_loss: 6.1710e-01
[69/512] - loss: 6.2413e-01, best_loss: 6.1710e-01
[70/512] - loss: 6.2035e-01, best_loss: 6.1710e-01
[71/512] - loss: 6.0573e-01, best_loss: 6.0573e-01
[72/512] - loss: 6.0319e-01, best_loss: 6.0319e-01
[73/512] - loss: 6.0036e-01, best_loss: 6.0036e-01
[74/512] - loss: 5.9306e-01, best_loss: 5.9306e-01
[75/512] - loss: 5.9308e-01, best_loss: 5.9306e-01
[76/512] - loss: 5.9185e-01, best_loss: 5.9185e-01
[77/512] - loss: 5.9069e-01, best_loss: 5.9069e-01
[78/512] - loss: 5.8648e-01, best_loss: 5.8648e-01
[79/512] - loss: 5.8765e-01, best_loss: 5.8648e-01
[80/512] - loss: 5.7839e-01, best_loss: 5.7839e-01
[81/512] - loss: 5.8004e-01, best_loss: 5.7839e-01
[82/512] - loss: 5.7908e-01, best_loss: 5.7839e-01
[83/512] - loss: 5.7618e-01, best_loss: 5.7618e-01
[84/512] - loss: 5.6985e-01, best_loss: 5.6985e-01
[85/512] - loss: 5.7002e-01, best_loss: 5.6985e-01
[86/512] - loss: 5.6801e-01, best_loss: 5.6801e-01
[87/512] - loss: 5.6148e-01, best_loss: 5.6148e-01
[88/512] - loss: 5.6230e-01, best_loss: 5.6148e-01
[89/512] - loss: 5.5741e-01, best_loss: 5.5741e-01
[90/512] - loss: 5.5789e-01, best_loss: 5.5741e-01
[91/512] - loss: 5.5658e-01, best_loss: 5.5658e-01
[92/512] - loss: 5.5043e-01, best_loss: 5.5043e-01
[93/512] - loss: 5.4833e-01, best_loss: 5.4833e-01
[94/512] - loss: 5.4774e-01, best_loss: 5.4774e-01
[95/512] - loss: 5.4236e-01, best_loss: 5.4236e-01
[96/512] - loss: 5.4092e-01, best_loss: 5.4092e-01
[97/512] - loss: 5.4340e-01, best_loss: 5.4092e-01
[98/512] - loss: 5.3762e-01, best_loss: 5.3762e-01
[99/512] - loss: 5.3484e-01, best_loss: 5.3484e-01
[100/512] - loss: 5.3273e-01, best_loss: 5.3273e-01
[101/512] - loss: 5.2527e-01, best_loss: 5.2527e-01
[102/512] - loss: 5.3121e-01, best_loss: 5.2527e-01
[103/512] - loss: 5.3130e-01, best_loss: 5.2527e-01
[104/512] - loss: 5.2264e-01, best_loss: 5.2264e-01
[105/512] - loss: 5.2503e-01, best_loss: 5.2264e-01
[106/512] - loss: 5.2197e-01, best_loss: 5.2197e-01
[107/512] - loss: 5.1335e-01, best_loss: 5.1335e-01
[108/512] - loss: 5.1270e-01, best_loss: 5.1270e-01
[109/512] - loss: 5.1180e-01, best_loss: 5.1180e-01
[110/512] - loss: 5.0993e-01, best_loss: 5.0993e-01
[111/512] - loss: 5.0728e-01, best_loss: 5.0728e-01
[112/512] - loss: 5.0703e-01, best_loss: 5.0703e-01
[113/512] - loss: 5.0240e-01, best_loss: 5.0240e-01
[114/512] - loss: 5.0066e-01, best_loss: 5.0066e-01
[115/512] - loss: 5.0296e-01, best_loss: 5.0066e-01
[116/512] - loss: 4.9707e-01, best_loss: 4.9707e-01
[117/512] - loss: 4.9520e-01, best_loss: 4.9520e-01
[118/512] - loss: 4.9305e-01, best_loss: 4.9305e-01
[119/512] - loss: 4.9221e-01, best_loss: 4.9221e-01
[120/512] - loss: 4.9083e-01, best_loss: 4.9083e-01
[121/512] - loss: 4.8820e-01, best_loss: 4.8820e-01
[122/512] - loss: 4.8224e-01, best_loss: 4.8224e-01
[123/512] - loss: 4.8209e-01, best_loss: 4.8209e-01
[124/512] - loss: 4.8293e-01, best_loss: 4.8209e-01
[125/512] - loss: 4.8214e-01, best_loss: 4.8209e-01
[126/512] - loss: 4.7807e-01, best_loss: 4.7807e-01
[127/512] - loss: 4.7588e-01, best_loss: 4.7588e-01
[128/512] - loss: 4.7191e-01, best_loss: 4.7191e-01
[129/512] - loss: 4.7102e-01, best_loss: 4.7102e-01
[130/512] - loss: 4.6909e-01, best_loss: 4.6909e-01
[131/512] - loss: 4.6318e-01, best_loss: 4.6318e-01
[132/512] - loss: 4.6540e-01, best_loss: 4.6318e-01
[133/512] - loss: 4.6481e-01, best_loss: 4.6318e-01
[134/512] - loss: 4.5604e-01, best_loss: 4.5604e-01
[135/512] - loss: 4.6465e-01, best_loss: 4.5604e-01
[136/512] - loss: 4.5570e-01, best_loss: 4.5570e-01
[137/512] - loss: 4.5633e-01, best_loss: 4.5570e-01
[138/512] - loss: 4.5069e-01, best_loss: 4.5069e-01
[139/512] - loss: 4.5232e-01, best_loss: 4.5069e-01
[140/512] - loss: 4.4908e-01, best_loss: 4.4908e-01
[141/512] - loss: 4.4980e-01, best_loss: 4.4908e-01
[142/512] - loss: 4.4772e-01, best_loss: 4.4772e-01
[143/512] - loss: 4.4738e-01, best_loss: 4.4738e-01
[144/512] - loss: 4.4430e-01, best_loss: 4.4430e-01
[145/512] - loss: 4.3330e-01, best_loss: 4.3330e-01
[146/512] - loss: 4.2843e-01, best_loss: 4.2843e-01
[147/512] - loss: 4.3974e-01, best_loss: 4.2843e-01
[148/512] - loss: 4.3489e-01, best_loss: 4.2843e-01
[149/512] - loss: 4.3324e-01, best_loss: 4.2843e-01
[150/512] - loss: 4.3396e-01, best_loss: 4.2843e-01
[151/512] - loss: 4.2346e-01, best_loss: 4.2346e-01
[152/512] - loss: 4.2978e-01, best_loss: 4.2346e-01
[153/512] - loss: 4.2889e-01, best_loss: 4.2346e-01
[154/512] - loss: 4.2591e-01, best_loss: 4.2346e-01
[155/512] - loss: 4.2496e-01, best_loss: 4.2346e-01
[156/512] - loss: 4.2438e-01, best_loss: 4.2346e-01
[157/512] - loss: 4.2159e-01, best_loss: 4.2159e-01
[158/512] - loss: 4.1877e-01, best_loss: 4.1877e-01
[159/512] - loss: 4.1793e-01, best_loss: 4.1793e-01
[160/512] - loss: 4.1395e-01, best_loss: 4.1395e-01
[161/512] - loss: 4.0784e-01, best_loss: 4.0784e-01
[162/512] - loss: 4.1151e-01, best_loss: 4.0784e-01
[163/512] - loss: 4.1223e-01, best_loss: 4.0784e-01
[164/512] - loss: 4.1153e-01, best_loss: 4.0784e-01
[165/512] - loss: 4.0816e-01, best_loss: 4.0784e-01
[166/512] - loss: 4.0737e-01, best_loss: 4.0737e-01
[167/512] - loss: 4.0445e-01, best_loss: 4.0445e-01
[168/512] - loss: 4.0315e-01, best_loss: 4.0315e-01
[169/512] - loss: 3.9165e-01, best_loss: 3.9165e-01
[170/512] - loss: 3.9821e-01, best_loss: 3.9165e-01
[171/512] - loss: 3.9672e-01, best_loss: 3.9165e-01
[172/512] - loss: 3.9793e-01, best_loss: 3.9165e-01
[173/512] - loss: 3.9238e-01, best_loss: 3.9165e-01
[174/512] - loss: 3.9377e-01, best_loss: 3.9165e-01
[175/512] - loss: 3.9561e-01, best_loss: 3.9165e-01
[176/512] - loss: 3.9119e-01, best_loss: 3.9119e-01
[177/512] - loss: 3.8916e-01, best_loss: 3.8916e-01
[178/512] - loss: 3.8662e-01, best_loss: 3.8662e-01
[179/512] - loss: 3.8618e-01, best_loss: 3.8618e-01
[180/512] - loss: 3.8011e-01, best_loss: 3.8011e-01
[181/512] - loss: 3.8507e-01, best_loss: 3.8011e-01
[182/512] - loss: 3.8090e-01, best_loss: 3.8011e-01
[183/512] - loss: 3.8020e-01, best_loss: 3.8011e-01
[184/512] - loss: 3.7786e-01, best_loss: 3.7786e-01
[185/512] - loss: 3.7444e-01, best_loss: 3.7444e-01
[186/512] - loss: 3.7505e-01, best_loss: 3.7444e-01
[187/512] - loss: 3.7645e-01, best_loss: 3.7444e-01
[188/512] - loss: 3.7141e-01, best_loss: 3.7141e-01
[189/512] - loss: 3.7205e-01, best_loss: 3.7141e-01
[190/512] - loss: 3.7037e-01, best_loss: 3.7037e-01
[191/512] - loss: 3.6993e-01, best_loss: 3.6993e-01
[192/512] - loss: 3.7065e-01, best_loss: 3.6993e-01
[193/512] - loss: 3.6444e-01, best_loss: 3.6444e-01
[194/512] - loss: 3.6079e-01, best_loss: 3.6079e-01
[195/512] - loss: 3.6232e-01, best_loss: 3.6079e-01
[196/512] - loss: 3.5824e-01, best_loss: 3.5824e-01
[197/512] - loss: 3.5996e-01, best_loss: 3.5824e-01
[198/512] - loss: 3.5631e-01, best_loss: 3.5631e-01
[199/512] - loss: 3.5955e-01, best_loss: 3.5631e-01
[200/512] - loss: 3.5384e-01, best_loss: 3.5384e-01
[201/512] - loss: 3.5368e-01, best_loss: 3.5368e-01
[202/512] - loss: 3.5523e-01, best_loss: 3.5368e-01
[203/512] - loss: 3.4994e-01, best_loss: 3.4994e-01
[204/512] - loss: 3.5170e-01, best_loss: 3.4994e-01
[205/512] - loss: 3.4656e-01, best_loss: 3.4656e-01
[206/512] - loss: 3.4121e-01, best_loss: 3.4121e-01
[207/512] - loss: 3.4773e-01, best_loss: 3.4121e-01
[208/512] - loss: 3.3901e-01, best_loss: 3.3901e-01
[209/512] - loss: 3.4312e-01, best_loss: 3.3901e-01
[210/512] - loss: 3.3763e-01, best_loss: 3.3763e-01
[211/512] - loss: 3.3979e-01, best_loss: 3.3763e-01
[212/512] - loss: 3.4031e-01, best_loss: 3.3763e-01
[213/512] - loss: 3.3892e-01, best_loss: 3.3763e-01
[214/512] - loss: 3.3521e-01, best_loss: 3.3521e-01
[215/512] - loss: 3.3071e-01, best_loss: 3.3071e-01
[216/512] - loss: 3.3290e-01, best_loss: 3.3071e-01
[217/512] - loss: 3.3430e-01, best_loss: 3.3071e-01
[218/512] - loss: 3.2641e-01, best_loss: 3.2641e-01
[219/512] - loss: 3.3060e-01, best_loss: 3.2641e-01
[220/512] - loss: 3.2998e-01, best_loss: 3.2641e-01
[221/512] - loss: 3.2905e-01, best_loss: 3.2641e-01
[222/512] - loss: 3.2944e-01, best_loss: 3.2641e-01
[223/512] - loss: 3.2294e-01, best_loss: 3.2294e-01
[224/512] - loss: 3.2078e-01, best_loss: 3.2078e-01
[225/512] - loss: 3.2067e-01, best_loss: 3.2067e-01
[226/512] - loss: 3.2019e-01, best_loss: 3.2019e-01
[227/512] - loss: 3.2161e-01, best_loss: 3.2019e-01
[228/512] - loss: 3.2033e-01, best_loss: 3.2019e-01
[229/512] - loss: 3.1950e-01, best_loss: 3.1950e-01
[230/512] - loss: 3.1634e-01, best_loss: 3.1634e-01
[231/512] - loss: 3.1733e-01, best_loss: 3.1634e-01
[232/512] - loss: 3.1351e-01, best_loss: 3.1351e-01
[233/512] - loss: 3.0986e-01, best_loss: 3.0986e-01
[234/512] - loss: 3.1226e-01, best_loss: 3.0986e-01
[235/512] - loss: 3.1272e-01, best_loss: 3.0986e-01
[236/512] - loss: 3.0967e-01, best_loss: 3.0967e-01
[237/512] - loss: 3.1047e-01, best_loss: 3.0967e-01
[238/512] - loss: 3.0376e-01, best_loss: 3.0376e-01
[239/512] - loss: 3.0978e-01, best_loss: 3.0376e-01
[240/512] - loss: 3.0671e-01, best_loss: 3.0376e-01
[241/512] - loss: 3.0423e-01, best_loss: 3.0376e-01
[242/512] - loss: 3.0425e-01, best_loss: 3.0376e-01
[243/512] - loss: 3.0122e-01, best_loss: 3.0122e-01
[244/512] - loss: 3.0334e-01, best_loss: 3.0122e-01
[245/512] - loss: 2.9494e-01, best_loss: 2.9494e-01
[246/512] - loss: 2.9323e-01, best_loss: 2.9323e-01
[247/512] - loss: 3.0101e-01, best_loss: 2.9323e-01
[248/512] - loss: 2.9204e-01, best_loss: 2.9204e-01
[249/512] - loss: 2.9709e-01, best_loss: 2.9204e-01
[250/512] - loss: 2.9580e-01, best_loss: 2.9204e-01
[251/512] - loss: 2.9090e-01, best_loss: 2.9090e-01
[252/512] - loss: 2.9430e-01, best_loss: 2.9090e-01
[253/512] - loss: 2.9344e-01, best_loss: 2.9090e-01
[254/512] - loss: 2.9250e-01, best_loss: 2.9090e-01
[255/512] - loss: 2.8767e-01, best_loss: 2.8767e-01
[256/512] - loss: 2.8920e-01, best_loss: 2.8767e-01
[257/512] - loss: 2.9105e-01, best_loss: 2.8767e-01
[258/512] - loss: 2.8523e-01, best_loss: 2.8523e-01
[259/512] - loss: 2.8819e-01, best_loss: 2.8523e-01
[260/512] - loss: 2.8620e-01, best_loss: 2.8523e-01
[261/512] - loss: 2.7887e-01, best_loss: 2.7887e-01
[262/512] - loss: 2.8332e-01, best_loss: 2.7887e-01
[263/512] - loss: 2.8625e-01, best_loss: 2.7887e-01
[264/512] - loss: 2.7988e-01, best_loss: 2.7887e-01
[265/512] - loss: 2.7296e-01, best_loss: 2.7296e-01
[266/512] - loss: 2.8007e-01, best_loss: 2.7296e-01
[267/512] - loss: 2.6948e-01, best_loss: 2.6948e-01
[268/512] - loss: 2.8071e-01, best_loss: 2.6948e-01
[269/512] - loss: 2.7235e-01, best_loss: 2.6948e-01
[270/512] - loss: 2.7728e-01, best_loss: 2.6948e-01
[271/512] - loss: 2.7737e-01, best_loss: 2.6948e-01
[272/512] - loss: 2.7393e-01, best_loss: 2.6948e-01
[273/512] - loss: 2.7243e-01, best_loss: 2.6948e-01
[274/512] - loss: 2.7241e-01, best_loss: 2.6948e-01
[275/512] - loss: 2.7019e-01, best_loss: 2.6948e-01
[276/512] - loss: 2.6924e-01, best_loss: 2.6924e-01
[277/512] - loss: 2.6417e-01, best_loss: 2.6417e-01
[278/512] - loss: 2.5768e-01, best_loss: 2.5768e-01
[279/512] - loss: 2.6919e-01, best_loss: 2.5768e-01
[280/512] - loss: 2.6600e-01, best_loss: 2.5768e-01
[281/512] - loss: 2.6645e-01, best_loss: 2.5768e-01
[282/512] - loss: 2.6103e-01, best_loss: 2.5768e-01
[283/512] - loss: 2.6427e-01, best_loss: 2.5768e-01
[284/512] - loss: 2.6991e-01, best_loss: 2.5768e-01
[285/512] - loss: 2.5735e-01, best_loss: 2.5735e-01
[286/512] - loss: 2.6068e-01, best_loss: 2.5735e-01
[287/512] - loss: 2.5895e-01, best_loss: 2.5735e-01
[288/512] - loss: 2.5420e-01, best_loss: 2.5420e-01
[289/512] - loss: 2.5974e-01, best_loss: 2.5420e-01
[290/512] - loss: 2.5917e-01, best_loss: 2.5420e-01
[291/512] - loss: 2.5894e-01, best_loss: 2.5420e-01
[292/512] - loss: 2.5976e-01, best_loss: 2.5420e-01
[293/512] - loss: 2.4660e-01, best_loss: 2.4660e-01
[294/512] - loss: 2.5394e-01, best_loss: 2.4660e-01
[295/512] - loss: 2.4948e-01, best_loss: 2.4660e-01
[296/512] - loss: 2.5528e-01, best_loss: 2.4660e-01
[297/512] - loss: 2.5361e-01, best_loss: 2.4660e-01
[298/512] - loss: 2.5096e-01, best_loss: 2.4660e-01
[299/512] - loss: 2.4970e-01, best_loss: 2.4660e-01
[300/512] - loss: 2.4689e-01, best_loss: 2.4660e-01
[301/512] - loss: 2.5025e-01, best_loss: 2.4660e-01
[302/512] - loss: 2.4242e-01, best_loss: 2.4242e-01
[303/512] - loss: 2.5053e-01, best_loss: 2.4242e-01
[304/512] - loss: 2.4773e-01, best_loss: 2.4242e-01
[305/512] - loss: 2.4292e-01, best_loss: 2.4242e-01
[306/512] - loss: 2.4457e-01, best_loss: 2.4242e-01
[307/512] - loss: 2.4337e-01, best_loss: 2.4242e-01
[308/512] - loss: 2.3846e-01, best_loss: 2.3846e-01
[309/512] - loss: 2.4112e-01, best_loss: 2.3846e-01
[310/512] - loss: 2.3861e-01, best_loss: 2.3846e-01
[311/512] - loss: 2.4133e-01, best_loss: 2.3846e-01
[312/512] - loss: 2.3843e-01, best_loss: 2.3843e-01
[313/512] - loss: 2.4066e-01, best_loss: 2.3843e-01
[314/512] - loss: 2.4338e-01, best_loss: 2.3843e-01
[315/512] - loss: 2.3967e-01, best_loss: 2.3843e-01
[316/512] - loss: 2.3457e-01, best_loss: 2.3457e-01
[317/512] - loss: 2.3450e-01, best_loss: 2.3450e-01
[318/512] - loss: 2.3681e-01, best_loss: 2.3450e-01
[319/512] - loss: 2.3609e-01, best_loss: 2.3450e-01
[320/512] - loss: 2.3898e-01, best_loss: 2.3450e-01
[321/512] - loss: 2.3839e-01, best_loss: 2.3450e-01
[322/512] - loss: 2.2739e-01, best_loss: 2.2739e-01
[323/512] - loss: 2.3240e-01, best_loss: 2.2739e-01
[324/512] - loss: 2.3945e-01, best_loss: 2.2739e-01
[325/512] - loss: 2.3549e-01, best_loss: 2.2739e-01
[326/512] - loss: 2.2753e-01, best_loss: 2.2739e-01
[327/512] - loss: 2.2719e-01, best_loss: 2.2719e-01
[328/512] - loss: 2.2724e-01, best_loss: 2.2719e-01
[329/512] - loss: 2.2611e-01, best_loss: 2.2611e-01
[330/512] - loss: 2.3121e-01, best_loss: 2.2611e-01
[331/512] - loss: 2.2547e-01, best_loss: 2.2547e-01
[332/512] - loss: 2.2701e-01, best_loss: 2.2547e-01
[333/512] - loss: 2.2746e-01, best_loss: 2.2547e-01
[334/512] - loss: 2.2346e-01, best_loss: 2.2346e-01
[335/512] - loss: 2.2508e-01, best_loss: 2.2346e-01
[336/512] - loss: 2.2288e-01, best_loss: 2.2288e-01
[337/512] - loss: 2.1826e-01, best_loss: 2.1826e-01
[338/512] - loss: 2.1991e-01, best_loss: 2.1826e-01
[339/512] - loss: 2.1765e-01, best_loss: 2.1765e-01
[340/512] - loss: 2.2461e-01, best_loss: 2.1765e-01
[341/512] - loss: 2.1661e-01, best_loss: 2.1661e-01
[342/512] - loss: 2.1299e-01, best_loss: 2.1299e-01
[343/512] - loss: 2.1902e-01, best_loss: 2.1299e-01
[344/512] - loss: 2.1491e-01, best_loss: 2.1299e-01
[345/512] - loss: 2.1459e-01, best_loss: 2.1299e-01
[346/512] - loss: 2.2002e-01, best_loss: 2.1299e-01
[347/512] - loss: 2.1687e-01, best_loss: 2.1299e-01
[348/512] - loss: 2.1438e-01, best_loss: 2.1299e-01
[349/512] - loss: 2.0962e-01, best_loss: 2.0962e-01
[350/512] - loss: 2.1270e-01, best_loss: 2.0962e-01
[351/512] - loss: 2.1363e-01, best_loss: 2.0962e-01
[352/512] - loss: 2.0828e-01, best_loss: 2.0828e-01
[353/512] - loss: 2.1354e-01, best_loss: 2.0828e-01
[354/512] - loss: 2.1195e-01, best_loss: 2.0828e-01
[355/512] - loss: 2.1093e-01, best_loss: 2.0828e-01
[356/512] - loss: 2.1460e-01, best_loss: 2.0828e-01
[357/512] - loss: 2.0757e-01, best_loss: 2.0757e-01
[358/512] - loss: 2.0821e-01, best_loss: 2.0757e-01
[359/512] - loss: 2.0818e-01, best_loss: 2.0757e-01
[360/512] - loss: 2.0443e-01, best_loss: 2.0443e-01
[361/512] - loss: 2.0524e-01, best_loss: 2.0443e-01
[362/512] - loss: 2.0048e-01, best_loss: 2.0048e-01
[363/512] - loss: 2.0337e-01, best_loss: 2.0048e-01
[364/512] - loss: 2.0052e-01, best_loss: 2.0048e-01
[365/512] - loss: 2.0197e-01, best_loss: 2.0048e-01
[366/512] - loss: 2.0305e-01, best_loss: 2.0048e-01
[367/512] - loss: 2.0539e-01, best_loss: 2.0048e-01
[368/512] - loss: 2.0785e-01, best_loss: 2.0048e-01
[369/512] - loss: 1.9518e-01, best_loss: 1.9518e-01
[370/512] - loss: 1.9255e-01, best_loss: 1.9255e-01
[371/512] - loss: 2.0096e-01, best_loss: 1.9255e-01
[372/512] - loss: 2.0411e-01, best_loss: 1.9255e-01
[373/512] - loss: 1.9850e-01, best_loss: 1.9255e-01
[374/512] - loss: 1.9799e-01, best_loss: 1.9255e-01
[375/512] - loss: 1.9813e-01, best_loss: 1.9255e-01
[376/512] - loss: 1.9944e-01, best_loss: 1.9255e-01
[377/512] - loss: 2.0354e-01, best_loss: 1.9255e-01
[378/512] - loss: 1.9580e-01, best_loss: 1.9255e-01
[379/512] - loss: 1.9921e-01, best_loss: 1.9255e-01
[380/512] - loss: 1.9457e-01, best_loss: 1.9255e-01
[381/512] - loss: 1.9245e-01, best_loss: 1.9245e-01
[382/512] - loss: 1.8619e-01, best_loss: 1.8619e-01
[383/512] - loss: 1.8886e-01, best_loss: 1.8619e-01
[384/512] - loss: 1.9619e-01, best_loss: 1.8619e-01
[385/512] - loss: 1.9340e-01, best_loss: 1.8619e-01
[386/512] - loss: 1.8860e-01, best_loss: 1.8619e-01
[387/512] - loss: 1.9318e-01, best_loss: 1.8619e-01
[388/512] - loss: 1.9229e-01, best_loss: 1.8619e-01
[389/512] - loss: 1.9564e-01, best_loss: 1.8619e-01
[390/512] - loss: 1.8999e-01, best_loss: 1.8619e-01
[391/512] - loss: 1.8733e-01, best_loss: 1.8619e-01
[392/512] - loss: 1.9266e-01, best_loss: 1.8619e-01
[393/512] - loss: 1.8957e-01, best_loss: 1.8619e-01
[394/512] - loss: 1.9191e-01, best_loss: 1.8619e-01
[395/512] - loss: 1.8868e-01, best_loss: 1.8619e-01
[396/512] - loss: 1.8443e-01, best_loss: 1.8443e-01
[397/512] - loss: 1.8274e-01, best_loss: 1.8274e-01
[398/512] - loss: 1.8846e-01, best_loss: 1.8274e-01
[399/512] - loss: 1.7969e-01, best_loss: 1.7969e-01
[400/512] - loss: 1.8529e-01, best_loss: 1.7969e-01
[401/512] - loss: 1.8966e-01, best_loss: 1.7969e-01
[402/512] - loss: 1.8594e-01, best_loss: 1.7969e-01
[403/512] - loss: 1.8756e-01, best_loss: 1.7969e-01
[404/512] - loss: 1.7983e-01, best_loss: 1.7969e-01
[405/512] - loss: 1.8668e-01, best_loss: 1.7969e-01
[406/512] - loss: 1.8258e-01, best_loss: 1.7969e-01
[407/512] - loss: 1.8292e-01, best_loss: 1.7969e-01
[408/512] - loss: 1.7759e-01, best_loss: 1.7759e-01
[409/512] - loss: 1.7643e-01, best_loss: 1.7643e-01
[410/512] - loss: 1.8323e-01, best_loss: 1.7643e-01
[411/512] - loss: 1.7841e-01, best_loss: 1.7643e-01
[412/512] - loss: 1.7894e-01, best_loss: 1.7643e-01
[413/512] - loss: 1.7939e-01, best_loss: 1.7643e-01
[414/512] - loss: 1.8198e-01, best_loss: 1.7643e-01
[415/512] - loss: 1.7243e-01, best_loss: 1.7243e-01
[416/512] - loss: 1.7040e-01, best_loss: 1.7040e-01
[417/512] - loss: 1.7374e-01, best_loss: 1.7040e-01
[418/512] - loss: 1.7231e-01, best_loss: 1.7040e-01
[419/512] - loss: 1.8193e-01, best_loss: 1.7040e-01
[420/512] - loss: 1.7402e-01, best_loss: 1.7040e-01
[421/512] - loss: 1.7440e-01, best_loss: 1.7040e-01
[422/512] - loss: 1.7335e-01, best_loss: 1.7040e-01
[423/512] - loss: 1.7713e-01, best_loss: 1.7040e-01
[424/512] - loss: 1.6680e-01, best_loss: 1.6680e-01
[425/512] - loss: 1.6837e-01, best_loss: 1.6680e-01
[426/512] - loss: 1.6747e-01, best_loss: 1.6680e-01
[427/512] - loss: 1.7101e-01, best_loss: 1.6680e-01
[428/512] - loss: 1.7578e-01, best_loss: 1.6680e-01
[429/512] - loss: 1.6948e-01, best_loss: 1.6680e-01
[430/512] - loss: 1.7175e-01, best_loss: 1.6680e-01
[431/512] - loss: 1.7204e-01, best_loss: 1.6680e-01
[432/512] - loss: 1.6704e-01, best_loss: 1.6680e-01
[433/512] - loss: 1.7065e-01, best_loss: 1.6680e-01
[434/512] - loss: 1.6260e-01, best_loss: 1.6260e-01
[435/512] - loss: 1.6655e-01, best_loss: 1.6260e-01
[436/512] - loss: 1.6541e-01, best_loss: 1.6260e-01
[437/512] - loss: 1.6734e-01, best_loss: 1.6260e-01
[438/512] - loss: 1.7210e-01, best_loss: 1.6260e-01
[439/512] - loss: 1.6116e-01, best_loss: 1.6116e-01
[440/512] - loss: 1.5981e-01, best_loss: 1.5981e-01
[441/512] - loss: 1.6484e-01, best_loss: 1.5981e-01
[442/512] - loss: 1.6900e-01, best_loss: 1.5981e-01
[443/512] - loss: 1.6763e-01, best_loss: 1.5981e-01
[444/512] - loss: 1.6734e-01, best_loss: 1.5981e-01
[445/512] - loss: 1.6303e-01, best_loss: 1.5981e-01
[446/512] - loss: 1.6154e-01, best_loss: 1.5981e-01
[447/512] - loss: 1.6337e-01, best_loss: 1.5981e-01
[448/512] - loss: 1.5837e-01, best_loss: 1.5837e-01
[449/512] - loss: 1.6005e-01, best_loss: 1.5837e-01
[450/512] - loss: 1.6221e-01, best_loss: 1.5837e-01
[451/512] - loss: 1.5967e-01, best_loss: 1.5837e-01
[452/512] - loss: 1.6368e-01, best_loss: 1.5837e-01
[453/512] - loss: 1.6062e-01, best_loss: 1.5837e-01
[454/512] - loss: 1.5766e-01, best_loss: 1.5766e-01
[455/512] - loss: 1.6342e-01, best_loss: 1.5766e-01
[456/512] - loss: 1.5879e-01, best_loss: 1.5766e-01
[457/512] - loss: 1.5943e-01, best_loss: 1.5766e-01
[458/512] - loss: 1.5472e-01, best_loss: 1.5472e-01
[459/512] - loss: 1.5520e-01, best_loss: 1.5472e-01
[460/512] - loss: 1.5270e-01, best_loss: 1.5270e-01
[461/512] - loss: 1.5607e-01, best_loss: 1.5270e-01
[462/512] - loss: 1.5206e-01, best_loss: 1.5206e-01
[463/512] - loss: 1.5661e-01, best_loss: 1.5206e-01
[464/512] - loss: 1.5847e-01, best_loss: 1.5206e-01
[465/512] - loss: 1.5879e-01, best_loss: 1.5206e-01
[466/512] - loss: 1.4938e-01, best_loss: 1.4938e-01
[467/512] - loss: 1.5401e-01, best_loss: 1.4938e-01
[468/512] - loss: 1.5025e-01, best_loss: 1.4938e-01
[469/512] - loss: 1.5279e-01, best_loss: 1.4938e-01
[470/512] - loss: 1.5346e-01, best_loss: 1.4938e-01
[471/512] - loss: 1.5068e-01, best_loss: 1.4938e-01
[472/512] - loss: 1.5194e-01, best_loss: 1.4938e-01
[473/512] - loss: 1.5174e-01, best_loss: 1.4938e-01
[474/512] - loss: 1.5204e-01, best_loss: 1.4938e-01
[475/512] - loss: 1.4829e-01, best_loss: 1.4829e-01
[476/512] - loss: 1.5120e-01, best_loss: 1.4829e-01
[477/512] - loss: 1.4858e-01, best_loss: 1.4829e-01
[478/512] - loss: 1.5125e-01, best_loss: 1.4829e-01
[479/512] - loss: 1.4861e-01, best_loss: 1.4829e-01
[480/512] - loss: 1.4962e-01, best_loss: 1.4829e-01
[481/512] - loss: 1.4606e-01, best_loss: 1.4606e-01
[482/512] - loss: 1.4842e-01, best_loss: 1.4606e-01
[483/512] - loss: 1.4460e-01, best_loss: 1.4460e-01
[484/512] - loss: 1.5112e-01, best_loss: 1.4460e-01
[485/512] - loss: 1.4698e-01, best_loss: 1.4460e-01
[486/512] - loss: 1.4864e-01, best_loss: 1.4460e-01
[487/512] - loss: 1.4818e-01, best_loss: 1.4460e-01
[488/512] - loss: 1.4451e-01, best_loss: 1.4451e-01
[489/512] - loss: 1.4262e-01, best_loss: 1.4262e-01
[490/512] - loss: 1.4153e-01, best_loss: 1.4153e-01
[491/512] - loss: 1.4380e-01, best_loss: 1.4153e-01
[492/512] - loss: 1.4295e-01, best_loss: 1.4153e-01
[493/512] - loss: 1.4405e-01, best_loss: 1.4153e-01
[494/512] - loss: 1.4205e-01, best_loss: 1.4153e-01
[495/512] - loss: 1.4320e-01, best_loss: 1.4153e-01
[496/512] - loss: 1.4463e-01, best_loss: 1.4153e-01
[497/512] - loss: 1.4294e-01, best_loss: 1.4153e-01
[498/512] - loss: 1.4085e-01, best_loss: 1.4085e-01
[499/512] - loss: 1.4199e-01, best_loss: 1.4085e-01
[500/512] - loss: 1.4216e-01, best_loss: 1.4085e-01
[501/512] - loss: 1.3861e-01, best_loss: 1.3861e-01
[502/512] - loss: 1.4020e-01, best_loss: 1.3861e-01
[503/512] - loss: 1.4087e-01, best_loss: 1.3861e-01
[504/512] - loss: 1.4205e-01, best_loss: 1.3861e-01
[505/512] - loss: 1.3615e-01, best_loss: 1.3615e-01
[506/512] - loss: 1.4051e-01, best_loss: 1.3615e-01
[507/512] - loss: 1.3961e-01, best_loss: 1.3615e-01
[508/512] - loss: 1.3654e-01, best_loss: 1.3615e-01
[509/512] - loss: 1.3921e-01, best_loss: 1.3615e-01
[510/512] - loss: 1.3764e-01, best_loss: 1.3615e-01
[511/512] - loss: 1.3717e-01, best_loss: 1.3615e-01
[512/512] - loss: 1.3703e-01, best_loss: 1.3615e-01
[12]:
simple_nn.load_state_dict(best_params)
with torch.no_grad():
a_nn = de.OdeSystem(nn_rhs, y0=y_init, dense_output=True, t=(t0, tf), rtol=1e-7, atol=1e-7, constants={**constants, "nn_controller": simple_nn})
a_nn.method = "RK87"
a_nn.integrate()
[13]:
fig = plt.figure(figsize=(20, 4))
gs = gridspec.GridSpec(1, 2, width_ratios=[3, 1])
ax0 = fig.add_subplot(gs[0])
ax1 = fig.add_subplot(gs[1])
ax1.set_aspect(1)
ax0.plot(eval_times/T, a.sol(eval_times)[:, 0], label="Without NN")
ax0.plot(eval_times/T, a_nn.sol(eval_times.to(a_nn.y)).cpu()[:, 0], label="With NN")
ax0.set_xlim(0.0, 40.0)
ax0.set_ylim(-1.0, 1.0)
ax0.set_xlabel(r"$t/T$")
ax0.set_ylabel(r"$x$")
ax0.set_title(r"$k={},m={}$".format(a.constants['k'], a.constants['m']))
ax1.plot(a.y[:, 0], a.y[:, 1], label="Without NN")
ax1.plot(a_nn.y[:, 0].cpu(), a_nn.y[:, 1].cpu(), label="With NN")
ax1.set_xlim(-1.6, 1.6)
ax1.set_ylim(-1.6, 1.6)
ax1.set_xlabel(r"$x$")
ax1.set_ylabel(r"$\dot x$")
ax1.grid(which='major')
plt.tight_layout()