Training a Full Model DigitalEcho
Full models are models which are completely parameterized and have no controls. We can only vary parameters for different simulations. These type of models generally encapsulate the whole system.
Loading Data
Before training a DigitalEcho, we need to generate data. For full models, we need to simulate it for various sets of parameters. For detailed information on how to generate data, refer DataGeneration Strategies.
Let us load a pre-generated dataset of lotka-volterra model as an example to walk through the training process for the DigitalEcho. We can do this by downloading the dataset which is publically hosted on JuliaHub and load it as ExperimentData
.
using JuliaHub, JLSO, DataGeneration
train_dataset_name = "lotka_volterra"
path = JuliaHub.download_dataset(("juliasimtutorials", train_dataset_name), "path to save")
ed = ExperimentData(JLSO.load(path)[:result])
Number of Trajectories in ExperimentData: 10
Basic Statistics for Given Dynamical System's Specifications
Number of u0s in the ExperimentData: 2
Number of ps in the ExperimentData: 4
╭─────────┬────────────────────────────────────────────────────────────────────╮
│ Field │ │
├─────────┼────────────────────────────────────────────────────────────────────┤
│ │ ╭────────────┬──────────────┬──────────────┬────────┬──────────╮ │
│ │ │ Labels │ LowerBound │ UpperBound │ Mean │ StdDev │ │
│ │ ├────────────┼──────────────┼──────────────┼────────┼──────────┤ │
│ │ │ states_1 │ 1.0 │ 1.0 │ 1.0 │ 0.0 │ │
│ u0s │ ├────────────┼──────────────┼──────────────┼────────┼──────────┤ │
│ │ │ ⋮ │ ⋮ │ ⋮ │ ⋮ │ ⋮ │ │
│ │ │ ⋮ │ ⋮ │ ⋮ │ ⋮ │ ⋮ │ │
│ │ ├────────────┼──────────────┼──────────────┼────────┼──────────┤ │
│ │ │ states_2 │ 1.0 │ 1.0 │ 1.0 │ 0.0 │ │
│ │ ╰────────────┴──────────────┴──────────────┴────────┴──────────╯ │
├─────────┼────────────────────────────────────────────────────────────────────┤
│ │ ╭──────────┬──────────────┬──────────────┬─────────┬──────────╮ │
│ │ │ Labels │ LowerBound │ UpperBound │ Mean │ StdDev │ │
│ │ ├──────────┼──────────────┼──────────────┼─────────┼──────────┤ │
│ │ │ p_1 │ 1.562 │ 2.438 │ 1.969 │ 0.302 │ │
│ ps │ ├──────────┼──────────────┼──────────────┼─────────┼──────────┤ │
│ │ │ ⋮ │ ⋮ │ ⋮ │ ⋮ │ ⋮ │ │
│ │ │ ⋮ │ ⋮ │ ⋮ │ ⋮ │ ⋮ │ │
│ │ ├──────────┼──────────────┼──────────────┼─────────┼──────────┤ │
│ │ │ p_4 │ 1.766 │ 1.984 │ 1.87 │ 0.074 │ │
│ │ ╰──────────┴──────────────┴──────────────┴─────────┴──────────╯ │
╰─────────┴────────────────────────────────────────────────────────────────────╯
Basic Statistics for Given Dynamical System's Continuous Fields
Number of states in the ExperimentData: 2
╭──────────┬─────────────────────────────────────────────────────────────────...
──╮...
│ Field │...
│...
├──────────┼─────────────────────────────────────────────────────────────────...
──┤...
│ │ ╭────────────┬──────────────┬──────────────┬─────────┬─────────...
│ │ │ Labels │ LowerBound │ UpperBound │ Mean │ StdDev...
│ │ ├────────────┼──────────────┼──────────────┼─────────┼─────────...
│ │ │ states_1 │ 0.61 │ 1.851 │ 1.131 │ 0.294...
│ states │ ├────────────┼──────────────┼──────────────┼─────────┼─────────...
│ │ │ ⋮ │ ⋮ │ ⋮ │ ⋮ │...
│ │ │ ⋮ │ ⋮ │ ⋮ │ ⋮ │...
│ │ ├────────────┼──────────────┼──────────────┼─────────┼─────────...
│ │ │ states_2 │ 0.585 │ 1.93 │ 1.068 │ 0.272...
│ │ ╰────────────┴──────────────┴──────────────┴─────────┴─────────...
╰──────────┴─────────────────────────────────────────────────────────────────...
──╯...
Fitting a DigitalEcho
Fitting a DigitalEcho is as easy as calling a function.
using Surrogatize, Optimisers, OrdinaryDiffEq
digitalecho = DigitalEcho(ed;
ground_truth_port = :states,
solver = Tsit5(),
RSIZE = 256,
max_eig = 0.9,
tau = 1e-1,
n_layers = 2,
n_epochs = 24500,
opt = Optimisers.Adam(),
lambda = 1e-6,
batchsize = 2048,
solver_kwargs = (abstol = 1e-9, reltol = 1e-9),
train_on_gpu = true,
verbose = true,
callevery = 200)
A Continous Time Surrogate wrapper with:
prob:
A `DigitalEchoProblem` with:
model:
A DigitalEcho with :
RSIZE : 256
USIZE : 2
XSIZE : 0
PSIZE : 4
ICSIZE : 0
solver: Tsit5(; stage_limiter! = trivial_limiter!, step_limiter! = trivial_limiter!, thread = static(false),)
Let us look at some of the hyperparameters for the DigitalEcho. This is important for gaining intuition into how to use them for different kinds of models.
Firstly, we have to pass in the dataset loaded as an ExperimentData
as an argument. Rest of them are keyword arguments. Let us go through each one of them -
ground_truth_port
refers to the field of ExperimentData which is the ground truth for the simulations. It can either be:states
or:observables
depending on what data we want to train.solver
is the solver with which the reservoir system of the DigitalEcho is solved. Any solver fromOrdinaryDiffEq.jl
can be passed.RSIZE
is the dimension of the reservoir of the DigitalEcho.max_eig
is the maximum eigenvalue of the encoder.tau
is the time constant, i.e., how fast the dynamical system responds to the driver signal, for the DigitalEcho.n_layers
is the number of layers in the decoder embedded inside the DigitalEcho.n_epochs
is the number of epochs for the training process.opt
refers to the optimiser to use in the training process. Any optimiser defined usingOptimisers.jl
can be used for training.lambda
is the L2 regularisation constant for the model weights used in the loss function.batchsize
is the number of data points we want to train for each batch. This is because we use mini batch gradient descent for the training process.solver_kwargs
is the keyword arguments given for solving the reservoir system. These include the keyword arguments given to thesolve
function inOrdinaryDiffEq.jl
.train_on_gpu
is a boolean flag indicating whether the training process should happen on gpu or not.verbose
is a flag which when set totrue
gives meaningful info statements about the progress of the training process.callevery
is the frequency (number of epochs) all callbacks are applied. Ifverbose
istrue
, it will also print loss every those number of epochs.
Not all the hyperparameters are important to tune while fitting a DigitalEcho. The most important one is:
tau
- Time constant is one of the most important hyperparameter and we should tune it according to knowledge of the time scale of the original model/system. As a rule of thumb, for most systems values going from 0.01 to 1.0 works well.
For the rest of the hyperparameters, there are strong defaults and should only be tuned if necessary.