Inference of a DigitalEcho
Once we train a DigitalEcho, it is important to understand as to how to use it for predictions, i.e, inference or forward pass.
The way we do inference is slightly different for full model and component DigitalEcho. Let us go through both of them.
Full Model
First let us load a DigitalEcho trained on lotka-volterra data and also the data which we will use it for inference.
using JuliaHub, JLSO, Surrogatize, DataGeneration
digitalecho_dataset_name = "lotka_volterra_digitalecho"
path = JuliaHub.download_dataset(("juliasimtutorials", digitalecho_dataset_name), "path to save")
digitalecho = JLSO.load(path)[:result]
A Continous Time Surrogate wrapper with:
prob:
A `DigitalEchoProblem` with:
model:
A DigitalEcho with :
RSIZE : 256
USIZE : 2
XSIZE : 0
PSIZE : 4
ICSIZE : 0
solver: Tsit5(; stage_limiter! = trivial_limiter!, step_limiter! = trivial_limiter!, thread = static(false),)
train_dataset_name = "lotka_volterra"
path = JuliaHub.download_dataset(("juliasimtutorials", train_dataset_name), "path to save")
ed = ExperimentData(JLSO.load(path)[:result])
Number of Trajectories in ExperimentData: 10
Basic Statistics for Given Dynamical System's Specifications
Number of u0s in the ExperimentData: 2
Number of ps in the ExperimentData: 4
╭─────────┬────────────────────────────────────────────────────────────────────╮
│ Field │ │
├─────────┼────────────────────────────────────────────────────────────────────┤
│ │ ╭────────────┬──────────────┬──────────────┬────────┬──────────╮ │
│ │ │ Labels │ LowerBound │ UpperBound │ Mean │ StdDev │ │
│ │ ├────────────┼──────────────┼──────────────┼────────┼──────────┤ │
│ │ │ states_1 │ 1.0 │ 1.0 │ 1.0 │ 0.0 │ │
│ u0s │ ├────────────┼──────────────┼──────────────┼────────┼──────────┤ │
│ │ │ ⋮ │ ⋮ │ ⋮ │ ⋮ │ ⋮ │ │
│ │ │ ⋮ │ ⋮ │ ⋮ │ ⋮ │ ⋮ │ │
│ │ ├────────────┼──────────────┼──────────────┼────────┼──────────┤ │
│ │ │ states_2 │ 1.0 │ 1.0 │ 1.0 │ 0.0 │ │
│ │ ╰────────────┴──────────────┴──────────────┴────────┴──────────╯ │
├─────────┼────────────────────────────────────────────────────────────────────┤
│ │ ╭──────────┬──────────────┬──────────────┬─────────┬──────────╮ │
│ │ │ Labels │ LowerBound │ UpperBound │ Mean │ StdDev │ │
│ │ ├──────────┼──────────────┼──────────────┼─────────┼──────────┤ │
│ │ │ p_1 │ 1.562 │ 2.438 │ 1.969 │ 0.302 │ │
│ ps │ ├──────────┼──────────────┼──────────────┼─────────┼──────────┤ │
│ │ │ ⋮ │ ⋮ │ ⋮ │ ⋮ │ ⋮ │ │
│ │ │ ⋮ │ ⋮ │ ⋮ │ ⋮ │ ⋮ │ │
│ │ ├──────────┼──────────────┼──────────────┼─────────┼──────────┤ │
│ │ │ p_4 │ 1.766 │ 1.984 │ 1.87 │ 0.074 │ │
│ │ ╰──────────┴──────────────┴──────────────┴─────────┴──────────╯ │
╰─────────┴────────────────────────────────────────────────────────────────────╯
Basic Statistics for Given Dynamical System's Continuous Fields
Number of states in the ExperimentData: 2
╭──────────┬─────────────────────────────────────────────────────────────────...
──╮...
│ Field │...
│...
├──────────┼─────────────────────────────────────────────────────────────────...
──┤...
│ │ ╭────────────┬──────────────┬──────────────┬─────────┬─────────...
│ │ │ Labels │ LowerBound │ UpperBound │ Mean │ StdDev...
│ │ ├────────────┼──────────────┼──────────────┼─────────┼─────────...
│ │ │ states_1 │ 0.61 │ 1.851 │ 1.131 │ 0.294...
│ states │ ├────────────┼──────────────┼──────────────┼─────────┼─────────...
│ │ │ ⋮ │ ⋮ │ ⋮ │ ⋮ │...
│ │ │ ⋮ │ ⋮ │ ⋮ │ ⋮ │...
│ │ ├────────────┼──────────────┼──────────────┼─────────┼─────────...
│ │ │ states_2 │ 0.585 │ 1.93 │ 1.068 │ 0.272...
│ │ ╰────────────┴──────────────┴──────────────┴─────────┴─────────...
╰──────────┴─────────────────────────────────────────────────────────────────...
──╯...
To do inference, we need to index into the dataset appropriately and pass into the forward pass of the DigitalEcho. Let us try with the first sample of the dataset.
Firstly, let us get the initial condition.
u0 = ed.specs.u0s.vals[1]
2-element Vector{Float64}:
1.0
1.0
Next, we need to get the contol function. As it is a full model, we don't have any controls. The function signature for passing control always takes in both state and time as arguments. Hence we create a closure and return nothing
.
x = (u, t) -> nothing
#1 (generic function with 1 method)
Then, let us grab the parameters.
p = ed.specs.ps.vals[1]
4-element view(::Matrix{Float64}, :, 1) with eltype Float64:
1.6875
1.828125
2.4375
1.859375
Lastly, let us grab the timepoints we want to save the values and the time span.
ts = ed.results.tss.vals[1]
tspan = (ts[1], ts[end])
(0.0, 12.5)
We can then call the forward pass.
pred = digitalecho(u0, x, p, tspan; saveat = ts)
2×94 Matrix{Float64}:
0.999987 0.994941 0.99092 … 1.09537 1.03969 1.00227 0.997731
0.999986 0.97535 0.938104 1.16707 1.09687 1.0088 0.989946
Component
First let us load a DigitalEcho trained on CSTR data and also the data which we will use it for inference.
using JuliaHub, JLSO, Surrogatize, DataGeneration, PreProcessing
digitalecho_dataset_name = "cstr_digitalecho"
path = JuliaHub.download_dataset(digitalecho_dataset, "path to save")
digitalecho = JLSO.load(path)[:result]
A Continous Time Surrogate wrapper with:
prob:
A `DigitalEchoProblem` with:
model:
A DigitalEcho with :
RSIZE : 256
USIZE : 4
XSIZE : 2
PSIZE : 0
ICSIZE : 0
solver: Tsit5(; stage_limiter! = trivial_limiter!, step_limiter! = trivial_limiter!, thread = static(false),)
train_dataset_name = "cstr"
path = JuliaHub.download_dataset(("juliasimtutorials", train_dataset_name), "path to save")
ed = ExperimentData(JLSO.load(path)[:result])
Number of Trajectories in ExperimentData: 10
Basic Statistics for Given Dynamical System's Specifications
Number of u0s in the ExperimentData: 4
╭─────────┬──────────────────────────────────────────────────────────────────...
─╮...
│ Field │...
│...
├─────────┼──────────────────────────────────────────────────────────────────...
─┤...
│ │ ╭────────────┬──────────────┬──────────────┬─────────┬──────────...
│ │ │ Labels │ LowerBound │ UpperBound │ Mean │ StdDev...
│ │ ├────────────┼──────────────┼──────────────┼─────────┼──────────...
│ │ │ states_1 │ 0.8 │ 0.8 │ 0.8 │ 0.0...
│ u0s │ ├────────────┼──────────────┼──────────────┼─────────┼──────────...
│ │ │ ⋮ │ ⋮ │ ⋮ │ ⋮ │ ⋮...
│ │ │ ⋮ │ ⋮ │ ⋮ │ ⋮ │ ⋮...
│ │ ├────────────┼──────────────┼──────────────┼─────────┼──────────...
│ │ │ states_4 │ 130.0 │ 130.0 │ 130.0 │ 0.0...
│ │ ╰────────────┴──────────────┴──────────────┴─────────┴──────────...
╰─────────┴──────────────────────────────────────────────────────────────────...
─╯...
Basic Statistics for Given Dynamical System's Continuous Fields
Number of states in the ExperimentData: 4
Number of controls in the ExperimentData: 2
╭────────────┬───────────────────────────────────────────────────────────────...
────────╮...
│ Field │...
│...
├────────────┼───────────────────────────────────────────────────────────────...
────────┤...
│ │ ╭────────────┬──────────────┬──────────────┬───────────┬────...
│ │ │ Labels │ LowerBound │ UpperBound │ Mean...
│ │ ├────────────┼──────────────┼──────────────┼───────────┼────...
│ │ │ states_1 │ 0.782 │ 2.694 │ 2.02...
│ states │ ├────────────┼──────────────┼──────────────┼───────────┼────...
│ │ │ ⋮ │ ⋮ │ ⋮ │ ⋮...
│ │ │ ⋮ │ ⋮ │ ⋮ │ ⋮...
│ │ ├────────────┼──────────────┼──────────────┼───────────┼────...
│ │ │ states_4 │ 125.364 │ 139.993 │ 131.565...
│ │ ╰────────────┴──────────────┴──────────────┴───────────┴────...
├────────────┼───────────────────────────────────────────────────────────────...
────────┤...
│ │ ╭──────────┬──────────────┬──────────────┬─────────────┬─────...
│ │ │ Labels │ LowerBound │ UpperBound │ Mean │...
│ │ ├──────────┼──────────────┼──────────────┼─────────────┼─────...
│ │ │ F │ 8.333 │ 99.468 │ 55.896 │...
│ controls │ ├──────────┼──────────────┼──────────────┼─────────────┼─────...
│ │ │ ⋮ │ ⋮ │ ⋮ │ ⋮ │...
│ │ │ ⋮ │ ⋮ │ ⋮ │ ⋮ │...
│ │ ├──────────┼──────────────┼──────────────┼─────────────┼─────...
│ │ │ Q̇ │ -8497.874 │ -227.416 │ -4158.499 │...
│ │ ╰──────────┴──────────────┴──────────────┴─────────────┴─────...
╰────────────┴───────────────────────────────────────────────────────────────...
────────╯...
Before doing inference, we need to convert the dataset into splines for continuous controls.
spline = SplineED()
ed_splines = spline(ed)
Number of Trajectories in ExperimentData: 10
Basic Statistics for Given Dynamical System's Specifications
Number of u0s in the ExperimentData: 4
╭─────────┬──────────────────────────────────────────────────────────────────...
─╮...
│ Field │...
│...
├─────────┼──────────────────────────────────────────────────────────────────...
─┤...
│ │ ╭────────────┬──────────────┬──────────────┬─────────┬──────────...
│ │ │ Labels │ LowerBound │ UpperBound │ Mean │ StdDev...
│ │ ├────────────┼──────────────┼──────────────┼─────────┼──────────...
│ │ │ states_1 │ 0.8 │ 0.8 │ 0.8 │ 0.0...
│ u0s │ ├────────────┼──────────────┼──────────────┼─────────┼──────────...
│ │ │ ⋮ │ ⋮ │ ⋮ │ ⋮ │ ⋮...
│ │ │ ⋮ │ ⋮ │ ⋮ │ ⋮ │ ⋮...
│ │ ├────────────┼──────────────┼──────────────┼─────────┼──────────...
│ │ │ states_4 │ 130.0 │ 130.0 │ 130.0 │ 0.0...
│ │ ╰────────────┴──────────────┴──────────────┴─────────┴──────────...
╰─────────┴──────────────────────────────────────────────────────────────────...
─╯...
Basic Statistics for Given Dynamical System's Continuous Fields
Number of states in the ExperimentData: 4
Number of controls in the ExperimentData: 2
╭────────────┬───────────────────────────────────────────────────────────────...
────────╮...
│ Field │...
│...
├────────────┼───────────────────────────────────────────────────────────────...
────────┤...
│ │ ╭────────────┬──────────────┬──────────────┬───────────┬────...
│ │ │ Labels │ LowerBound │ UpperBound │ Mean...
│ │ ├────────────┼──────────────┼──────────────┼───────────┼────...
│ │ │ states_1 │ 0.782 │ 2.694 │ 2.02...
│ states │ ├────────────┼──────────────┼──────────────┼───────────┼────...
│ │ │ ⋮ │ ⋮ │ ⋮ │ ⋮...
│ │ │ ⋮ │ ⋮ │ ⋮ │ ⋮...
│ │ ├────────────┼──────────────┼──────────────┼───────────┼────...
│ │ │ states_4 │ 125.364 │ 139.993 │ 131.565...
│ │ ╰────────────┴──────────────┴──────────────┴───────────┴────...
├────────────┼───────────────────────────────────────────────────────────────...
────────┤...
│ │ ╭──────────┬──────────────┬──────────────┬─────────────┬─────...
│ │ │ Labels │ LowerBound │ UpperBound │ Mean │...
│ │ ├──────────┼──────────────┼──────────────┼─────────────┼─────...
│ │ │ F │ 8.333 │ 99.468 │ 55.896 │...
│ controls │ ├──────────┼──────────────┼──────────────┼─────────────┼─────...
│ │ │ ⋮ │ ⋮ │ ⋮ │ ⋮ │...
│ │ │ ⋮ │ ⋮ │ ⋮ │ ⋮ │...
│ │ ├──────────┼──────────────┼──────────────┼─────────────┼─────...
│ │ │ Q̇ │ -8497.874 │ -227.416 │ -4158.499 │...
│ │ ╰──────────┴──────────────┴──────────────┴─────────────┴─────...
╰────────────┴───────────────────────────────────────────────────────────────...
────────╯...
Once we have converted into splines, next step is to index into various fields similar to how we do inference in the full model.
Firstly, let us get the initial condition.
u0 = ed_splines.specs.u0s.vals[1]
4-element Vector{Float64}:
0.8
0.5
134.14
130.0
Next, we get the controls. The dataset was generated using open loop controls. The function signature for passing control always takes in both state and time as arguments. Hence we create a closure with the spline we index from the dataset.
x = (u, t) -> ed_splines.results.controls.vals[1](t)
#3 (generic function with 1 method)
We set the parameters as nothing
as the DigitalEcho was only trained on controls and cannot vary parameters.
p = nothing
Lastly, let us grab the timepoints we want to save the values and the time span.
ts = ed_splines.results.tss.vals[1]
tspan = (ts[1], ts[end])
(0.0, 0.25)
We can then call the forward pass.
pred = digitalecho(u0, x, p, tspan; saveat = ts)
4×48 Matrix{Float64}:
0.800167 1.17452 1.34796 … 2.16621 2.16404 2.16218
0.500015 0.497711 0.511449 1.11122 1.11157 1.1118
134.14 133.659 133.465 135.057 135.083 135.107
130.0 129.809 129.682 125.516 125.628 125.719