Inference of a DigitalEcho

Once we train a DigitalEcho, it is important to understand as to how to use it for predictions, i.e, inference or forward pass.

The way we do inference is slightly different for full model and component DigitalEcho. Let us go through both of them.

Full Model

First let us load a DigitalEcho trained on lotka-volterra data and also the data which we will use it for inference.

using JuliaHub, JLSO, Surrogatize, DataGeneration
digitalecho_dataset_name = "lotka_volterra_digitalecho"
path = JuliaHub.download_dataset(("juliasimtutorials", digitalecho_dataset_name), "path to save")
digitalecho = JLSO.load(path)[:result]
A Continous Time Surrogate wrapper with:
prob:
  A `DigitalEchoProblem` with:
  model:
    A DigitalEcho with : 
      RSIZE : 256
      USIZE : 2
      XSIZE : 0
      PSIZE : 4
      ICSIZE : 0
solver: Tsit5(; stage_limiter! = trivial_limiter!, step_limiter! = trivial_limiter!, thread = static(false),)
train_dataset_name = "lotka_volterra"
path = JuliaHub.download_dataset(("juliasimtutorials", train_dataset_name), "path to save")
ed = ExperimentData(JLSO.load(path)[:result])
 Number of Trajectories in ExperimentData: 10 
  Basic Statistics for Given Dynamical System's Specifications 
  Number of u0s in the ExperimentData: 2 
  Number of ps in the ExperimentData: 4 
 ╭─────────┬────────────────────────────────────────────────────────────────────╮
  Field                                                                      
├─────────┼────────────────────────────────────────────────────────────────────┤
           ╭────────────┬──────────────┬──────────────┬────────┬──────────╮  
              Labels     LowerBound    UpperBound    Mean    StdDev    
           ├────────────┼──────────────┼──────────────┼────────┼──────────┤  
             states_1       1.0           1.0        1.0      0.0      
   u0s     ├────────────┼──────────────┼──────────────┼────────┼──────────┤  
             
             
           ├────────────┼──────────────┼──────────────┼────────┼──────────┤  
             states_2       1.0           1.0        1.0      0.0      
           ╰────────────┴──────────────┴──────────────┴────────┴──────────╯  
├─────────┼────────────────────────────────────────────────────────────────────┤
           ╭──────────┬──────────────┬──────────────┬─────────┬──────────╮   
             Labels    LowerBound    UpperBound    Mean     StdDev     
           ├──────────┼──────────────┼──────────────┼─────────┼──────────┤   
              p_1        1.562         2.438       1.969    0.302      
   ps      ├──────────┼──────────────┼──────────────┼─────────┼──────────┤   
              
              
           ├──────────┼──────────────┼──────────────┼─────────┼──────────┤   
              p_4        1.766         1.984       1.87     0.074      
           ╰──────────┴──────────────┴──────────────┴─────────┴──────────╯   
╰─────────┴────────────────────────────────────────────────────────────────────╯
 Basic Statistics for Given Dynamical System's Continuous Fields 
  Number of states in the ExperimentData: 2 
 ╭──────────┬─────────────────────────────────────────────────────────────────...
──╮...
  Field   ...
       ...
├──────────┼─────────────────────────────────────────────────────────────────...
──┤...
            ╭────────────┬──────────────┬──────────────┬─────────┬─────────...
               Labels     LowerBound    UpperBound    Mean     StdDev...
            ├────────────┼──────────────┼──────────────┼─────────┼─────────...
              states_1       0.61         1.851       1.131    0.294...
  states    ├────────────┼──────────────┼──────────────┼─────────┼─────────...
            ...
            ...
            ├────────────┼──────────────┼──────────────┼─────────┼─────────...
              states_2      0.585          1.93       1.068    0.272...
            ╰────────────┴──────────────┴──────────────┴─────────┴─────────...
╰──────────┴─────────────────────────────────────────────────────────────────...
──╯...

To do inference, we need to index into the dataset appropriately and pass into the forward pass of the DigitalEcho. Let us try with the first sample of the dataset.

Firstly, let us get the initial condition.

u0 = ed.specs.u0s.vals[1]
2-element Vector{Float64}:
 1.0
 1.0

Next, we need to get the contol function. As it is a full model, we don't have any controls. The function signature for passing control always takes in both state and time as arguments. Hence we create a closure and return nothing.

x = (u, t) -> nothing
#1 (generic function with 1 method)

Then, let us grab the parameters.

p = ed.specs.ps.vals[1]
4-element view(::Matrix{Float64}, :, 1) with eltype Float64:
 1.6875
 1.828125
 2.4375
 1.859375

Lastly, let us grab the timepoints we want to save the values and the time span.

ts = ed.results.tss.vals[1]
tspan = (ts[1], ts[end])
(0.0, 12.5)

We can then call the forward pass.

pred = digitalecho(u0, x, p, tspan; saveat = ts)
2×94 Matrix{Float64}:
 0.999987  0.994941  0.99092   …  1.09537  1.03969  1.00227  0.997731
 0.999986  0.97535   0.938104     1.16707  1.09687  1.0088   0.989946

Component

First let us load a DigitalEcho trained on CSTR data and also the data which we will use it for inference.

using JuliaHub, JLSO, Surrogatize, DataGeneration, PreProcessing
digitalecho_dataset_name = "cstr_digitalecho"
path = JuliaHub.download_dataset(digitalecho_dataset, "path to save")
digitalecho = JLSO.load(path)[:result]
A Continous Time Surrogate wrapper with:
prob:
  A `DigitalEchoProblem` with:
  model:
    A DigitalEcho with : 
      RSIZE : 256
      USIZE : 4
      XSIZE : 2
      PSIZE : 0
      ICSIZE : 0
solver: Tsit5(; stage_limiter! = trivial_limiter!, step_limiter! = trivial_limiter!, thread = static(false),)
train_dataset_name = "cstr"
path = JuliaHub.download_dataset(("juliasimtutorials", train_dataset_name), "path to save")
ed = ExperimentData(JLSO.load(path)[:result])
 Number of Trajectories in ExperimentData: 10 
  Basic Statistics for Given Dynamical System's Specifications 
  Number of u0s in the ExperimentData: 4 
 ╭─────────┬──────────────────────────────────────────────────────────────────...
─╮...
  Field  ...
      ...
├─────────┼──────────────────────────────────────────────────────────────────...
─┤...
           ╭────────────┬──────────────┬──────────────┬─────────┬──────────...
              Labels     LowerBound    UpperBound    Mean     StdDev...
           ├────────────┼──────────────┼──────────────┼─────────┼──────────...
             states_1       0.8           0.8         0.8      0.0...
   u0s     ├────────────┼──────────────┼──────────────┼─────────┼──────────...
           ...
           ...
           ├────────────┼──────────────┼──────────────┼─────────┼──────────...
             states_4      130.0         130.0       130.0     0.0...
           ╰────────────┴──────────────┴──────────────┴─────────┴──────────...
╰─────────┴──────────────────────────────────────────────────────────────────...
─╯...
 Basic Statistics for Given Dynamical System's Continuous Fields 
  Number of states in the ExperimentData: 4 
  Number of controls in the ExperimentData: 2 
 ╭────────────┬───────────────────────────────────────────────────────────────...
────────╮...
   Field    ...
             ...
├────────────┼───────────────────────────────────────────────────────────────...
────────┤...
               ╭────────────┬──────────────┬──────────────┬───────────┬────...
                  Labels     LowerBound    UpperBound     Mean...
               ├────────────┼──────────────┼──────────────┼───────────┼────...
                 states_1      0.782         2.694        2.02...
   states      ├────────────┼──────────────┼──────────────┼───────────┼────...
               ...
               ...
               ├────────────┼──────────────┼──────────────┼───────────┼────...
                 states_4     125.364       139.993      131.565...
               ╰────────────┴──────────────┴──────────────┴───────────┴────...
├────────────┼───────────────────────────────────────────────────────────────...
────────┤...
              ╭──────────┬──────────────┬──────────────┬─────────────┬─────...
                Labels    LowerBound    UpperBound      Mean     ...
              ├──────────┼──────────────┼──────────────┼─────────────┼─────...
                  F         8.333         99.468       55.896    ...
  controls    ├──────────┼──────────────┼──────────────┼─────────────┼─────...
              ...
              ...
              ├──────────┼──────────────┼──────────────┼─────────────┼─────...
                -8497.874      -227.416     -4158.499  ...
              ╰──────────┴──────────────┴──────────────┴─────────────┴─────...
╰────────────┴───────────────────────────────────────────────────────────────...
────────╯...

Before doing inference, we need to convert the dataset into splines for continuous controls.

spline = SplineED()
ed_splines = spline(ed)
 Number of Trajectories in ExperimentData: 10 
  Basic Statistics for Given Dynamical System's Specifications 
  Number of u0s in the ExperimentData: 4 
 ╭─────────┬──────────────────────────────────────────────────────────────────...
─╮...
  Field  ...
      ...
├─────────┼──────────────────────────────────────────────────────────────────...
─┤...
           ╭────────────┬──────────────┬──────────────┬─────────┬──────────...
              Labels     LowerBound    UpperBound    Mean     StdDev...
           ├────────────┼──────────────┼──────────────┼─────────┼──────────...
             states_1       0.8           0.8         0.8      0.0...
   u0s     ├────────────┼──────────────┼──────────────┼─────────┼──────────...
           ...
           ...
           ├────────────┼──────────────┼──────────────┼─────────┼──────────...
             states_4      130.0         130.0       130.0     0.0...
           ╰────────────┴──────────────┴──────────────┴─────────┴──────────...
╰─────────┴──────────────────────────────────────────────────────────────────...
─╯...
 Basic Statistics for Given Dynamical System's Continuous Fields 
  Number of states in the ExperimentData: 4 
  Number of controls in the ExperimentData: 2 
 ╭────────────┬───────────────────────────────────────────────────────────────...
────────╮...
   Field    ...
             ...
├────────────┼───────────────────────────────────────────────────────────────...
────────┤...
               ╭────────────┬──────────────┬──────────────┬───────────┬────...
                  Labels     LowerBound    UpperBound     Mean...
               ├────────────┼──────────────┼──────────────┼───────────┼────...
                 states_1      0.782         2.694        2.02...
   states      ├────────────┼──────────────┼──────────────┼───────────┼────...
               ...
               ...
               ├────────────┼──────────────┼──────────────┼───────────┼────...
                 states_4     125.364       139.993      131.565...
               ╰────────────┴──────────────┴──────────────┴───────────┴────...
├────────────┼───────────────────────────────────────────────────────────────...
────────┤...
              ╭──────────┬──────────────┬──────────────┬─────────────┬─────...
                Labels    LowerBound    UpperBound      Mean     ...
              ├──────────┼──────────────┼──────────────┼─────────────┼─────...
                  F         8.333         99.468       55.896    ...
  controls    ├──────────┼──────────────┼──────────────┼─────────────┼─────...
              ...
              ...
              ├──────────┼──────────────┼──────────────┼─────────────┼─────...
                -8497.874      -227.416     -4158.499  ...
              ╰──────────┴──────────────┴──────────────┴─────────────┴─────...
╰────────────┴───────────────────────────────────────────────────────────────...
────────╯...

Once we have converted into splines, next step is to index into various fields similar to how we do inference in the full model.

Firstly, let us get the initial condition.

u0 = ed_splines.specs.u0s.vals[1]
4-element Vector{Float64}:
   0.8
   0.5
 134.14
 130.0

Next, we get the controls. The dataset was generated using open loop controls. The function signature for passing control always takes in both state and time as arguments. Hence we create a closure with the spline we index from the dataset.

x = (u, t) -> ed_splines.results.controls.vals[1](t)
#3 (generic function with 1 method)

We set the parameters as nothing as the DigitalEcho was only trained on controls and cannot vary parameters.

p = nothing

Lastly, let us grab the timepoints we want to save the values and the time span.

ts = ed_splines.results.tss.vals[1]
tspan = (ts[1], ts[end])
(0.0, 0.25)

We can then call the forward pass.

pred = digitalecho(u0, x, p, tspan; saveat = ts)
4×48 Matrix{Float64}:
   0.800167    1.17452     1.34796   …    2.16621    2.16404    2.16218
   0.500015    0.497711    0.511449       1.11122    1.11157    1.1118
 134.14      133.659     133.465        135.057    135.083    135.107
 130.0       129.809     129.682        125.516    125.628    125.719