Training a Component Model DigitalEcho

Component models are models which are a part of a larger system. These models receive inputs which vary over time. The nature of the inputs heavily depend on the larger system and its interactions.

Generating Data

Before training a DigitalEcho, we need to generate data. For component models, we use random neural controllers for generating data. For detailed information on how to generate data, refer DataGeneration Strategies.

Let us load a dataset of Continuously Stirred Tank Reactor (CSTR) model as an example to walk through the training process for the DigitalEcho. We can do this by downloading the dataset which is publically hosted on JuliaHub and load it as ExperimentData.

using JuliaHub, JLSO, DataGeneration
train_dataset_name = "cstr"
path = JuliaHub.download_dataset(("juliasimtutorials", train_dataset_name), "path to save")
ed = ExperimentData(JLSO.load(path)[:result])
 Number of Trajectories in ExperimentData: 10 
  Basic Statistics for Given Dynamical System's Specifications 
  Number of u0s in the ExperimentData: 4 
 ╭─────────┬──────────────────────────────────────────────────────────────────...
─╮...
  Field  ...
      ...
├─────────┼──────────────────────────────────────────────────────────────────...
─┤...
           ╭────────────┬──────────────┬──────────────┬─────────┬──────────...
              Labels     LowerBound    UpperBound    Mean     StdDev...
           ├────────────┼──────────────┼──────────────┼─────────┼──────────...
             states_1       0.8           0.8         0.8      0.0...
   u0s     ├────────────┼──────────────┼──────────────┼─────────┼──────────...
           ...
           ...
           ├────────────┼──────────────┼──────────────┼─────────┼──────────...
             states_4      130.0         130.0       130.0     0.0...
           ╰────────────┴──────────────┴──────────────┴─────────┴──────────...
╰─────────┴──────────────────────────────────────────────────────────────────...
─╯...
 Basic Statistics for Given Dynamical System's Continuous Fields 
  Number of states in the ExperimentData: 4 
  Number of controls in the ExperimentData: 2 
 ╭────────────┬───────────────────────────────────────────────────────────────...
────────╮...
   Field    ...
             ...
├────────────┼───────────────────────────────────────────────────────────────...
────────┤...
               ╭────────────┬──────────────┬──────────────┬───────────┬────...
                  Labels     LowerBound    UpperBound     Mean...
               ├────────────┼──────────────┼──────────────┼───────────┼────...
                 states_1      0.782         2.694        2.02...
   states      ├────────────┼──────────────┼──────────────┼───────────┼────...
               ...
               ...
               ├────────────┼──────────────┼──────────────┼───────────┼────...
                 states_4     125.364       139.993      131.565...
               ╰────────────┴──────────────┴──────────────┴───────────┴────...
├────────────┼───────────────────────────────────────────────────────────────...
────────┤...
              ╭──────────┬──────────────┬──────────────┬─────────────┬─────...
                Labels    LowerBound    UpperBound      Mean     ...
              ├──────────┼──────────────┼──────────────┼─────────────┼─────...
                  F         8.333         99.468       55.896    ...
  controls    ├──────────┼──────────────┼──────────────┼─────────────┼─────...
              ...
              ...
              ├──────────┼──────────────┼──────────────┼─────────────┼─────...
                -8497.874      -227.416     -4158.499  ...
              ╰──────────┴──────────────┴──────────────┴─────────────┴─────...
╰────────────┴───────────────────────────────────────────────────────────────...
────────╯...

Fitting a DigitalEcho

Fitting a DigitalEcho is as easy as calling a function.

using Surrogatize, Optimisers, OrdinaryDiffEq

digitalecho = DigitalEcho(ed;
    ground_truth_port = :states,
    solver = Tsit5(),
    RSIZE = 256,
    max_eig = 0.9,
    tau = 1e-1,
    n_layers = 2,
    n_epochs = 24500,
    opt = Optimisers.Adam(),
    lambda = 1e-6,
    batchsize = 2048,
    solver_kwargs = (abstol = 1e-9, reltol = 1e-9),
    train_on_gpu = true,
    verbose = true,
    callevery = 200)
A Continous Time Surrogate wrapper with:
prob:
  A `DigitalEchoProblem` with:
  model:
    A DigitalEcho with : 
      RSIZE : 256
      USIZE : 4
      XSIZE : 2
      PSIZE : 0
      ICSIZE : 0
solver: Tsit5(; stage_limiter! = trivial_limiter!, step_limiter! = trivial_limiter!, thread = static(false),)

Let us look at some of the hyperparameters for the DigitalEcho. This is important for gaining intuition into how to use them for different kinds of models.

Firstly, we have to pass in the dataset loaded as an ExperimentData as an argument. Rest of them are keyword arguments. Let us go through each one of them -

  1. ground_truth_port refers to the field of ExperimentData which is the ground truth for the simulations. It can either be :states or :observables depending on what data we want to train.

  2. solver is the solver with which the reservoir system of the DigitalEcho is solved. Any solver from OrdinaryDiffEq.jl can be passed.

  3. RSIZE is the dimension of the reservoir of the DigitalEcho.

  4. max_eig is the maximum eigenvalue of the encoder.

  5. tau is the time constant, i.e., how fast the dynamical system responds to the driver signal, for the DigitalEcho.

  6. n_layers is the number of layers in the decoder embedded inside the DigitalEcho.

  7. n_epochs is the number of epochs for the training process.

  8. opt refers to the optimiser to use in the training process. Any optimiser defined using Optimisers.jl can be used for training.

  9. lambda is the L2 regularisation constant for the model weights used in the loss function.

  10. batchsize is the number of data points we want to train for each batch. This is because we use mini batch gradient descent for the training process.

  11. solver_kwargs is the keyword arguments given for solving the reservoir system. These include the keyword arguments given to the solve function in OrdinaryDiffEq.jl.

  12. train_on_gpu is a boolean flag indicating whether the training process should happen on gpu or not.

  13. verbose is a flag which when set to true gives meaningful info statements about the progress of the training process.

  14. callevery is the frequency (number of epochs) all callbacks are applied. If verbose is true, it will also print loss every those number of epochs.

Info

Not all the hyperparameters are important to tune while fitting a DigitalEcho. The most important one is:

  1. tau - Time constant is one of the most important hyperparameter and we should tune it according to knowledge of the time scale of the original model/system. As a rule of thumb, for most systems values going from 0.01 to 1.0 works well.

For the rest of the hyperparameters, there are strong defaults and should only be tuned if necessary.