Validation and CallBacks with DigitalEcho

In this demonstration we will show how to hook in callbacks into the training process for DigitalEcho. Callbacks are functions that are executed at predefined events, or periodically during the training process. They allow an user to customize the training process and provide a degree of control over it. A callback can be used to print current state, save models, log data, early stopping of the training process and much more. We will start by setting up our Julia Enviornment and loading necessary packages.

using JuliaHub, JLSO, DataGeneration, PreProcessing, Training, Surrogatize

We will load an existing dataset into an ExperimentData object. For more understanding on how to load a dataset and train a DigitalEcho see : Training a Full Model DigitalEcho.

train_dataset_name = "lotka_volterra"
path = JuliaHub.download_dataset(("juliasimtutorials", train_dataset_name), "path to save")
ed = ExperimentData(JLSO.load(path)[:result])
 Number of Trajectories in ExperimentData: 10 
  Basic Statistics for Given Dynamical System's Specifications 
  Number of u0s in the ExperimentData: 2 
  Number of ps in the ExperimentData: 4 
 ╭─────────┬────────────────────────────────────────────────────────────────────╮
  Field                                                                      
├─────────┼────────────────────────────────────────────────────────────────────┤
           ╭────────────┬──────────────┬──────────────┬────────┬──────────╮  
              Labels     LowerBound    UpperBound    Mean    StdDev    
           ├────────────┼──────────────┼──────────────┼────────┼──────────┤  
             states_1       1.0           1.0        1.0      0.0      
   u0s     ├────────────┼──────────────┼──────────────┼────────┼──────────┤  
             
             
           ├────────────┼──────────────┼──────────────┼────────┼──────────┤  
             states_2       1.0           1.0        1.0      0.0      
           ╰────────────┴──────────────┴──────────────┴────────┴──────────╯  
├─────────┼────────────────────────────────────────────────────────────────────┤
           ╭──────────┬──────────────┬──────────────┬─────────┬──────────╮   
             Labels    LowerBound    UpperBound    Mean     StdDev     
           ├──────────┼──────────────┼──────────────┼─────────┼──────────┤   
              p_1        1.562         2.438       1.969    0.302      
   ps      ├──────────┼──────────────┼──────────────┼─────────┼──────────┤   
              
              
           ├──────────┼──────────────┼──────────────┼─────────┼──────────┤   
              p_4        1.766         1.984       1.87     0.074      
           ╰──────────┴──────────────┴──────────────┴─────────┴──────────╯   
╰─────────┴────────────────────────────────────────────────────────────────────╯
 Basic Statistics for Given Dynamical System's Continuous Fields 
  Number of states in the ExperimentData: 2 
 ╭──────────┬─────────────────────────────────────────────────────────────────...
──╮...
  Field   ...
       ...
├──────────┼─────────────────────────────────────────────────────────────────...
──┤...
            ╭────────────┬──────────────┬──────────────┬─────────┬─────────...
               Labels     LowerBound    UpperBound    Mean     StdDev...
            ├────────────┼──────────────┼──────────────┼─────────┼─────────...
              states_1       0.61         1.851       1.131    0.294...
  states    ├────────────┼──────────────┼──────────────┼─────────┼─────────...
            ...
            ...
            ├────────────┼──────────────┼──────────────┼─────────┼─────────...
              states_2      0.585          1.93       1.068    0.272...
            ╰────────────┴──────────────┴──────────────┴─────────┴─────────...
╰──────────┴─────────────────────────────────────────────────────────────────...
──╯...

We will split this ExperimentData object into train and validation using train_valid_split with train_split = 0.8

train_ed, val_ed = train_valid_split(ed, train_ratio = 0.8)
2-element Vector{ExperimentData{DSResults{JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{Matrix{Float64}}, Vector{String}}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}, JSSBase.StatsAndVals{Nothing, Vector{Any}, Nothing}}, DSSpecification{JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}, Vector{String}}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{Vector{Float64}}, Vector{String}}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}, JSSBase.StatsAndVals{Nothing, Vector{Tuple{Float64, Float64}}, Nothing}}}}:
 ExperimentData{DSResults{JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{Matrix{Float64}}, Vector{String}}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}, JSSBase.StatsAndVals{Nothing, Vector{Any}, Nothing}}, DSSpecification{JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}, Vector{String}}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{Vector{Float64}}, Vector{String}}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}, JSSBase.StatsAndVals{Nothing, Vector{Tuple{Float64, Float64}}, Nothing}}}(DSSpecification{JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}, Vector{String}}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{Vector{Float64}}, Vector{String}}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}, JSSBase.StatsAndVals{Nothing, Vector{Tuple{Float64, Float64}}, Nothing}}(JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{Vector{Float64}}, Vector{String}}(["states_1", "states_2"], (lb = [1.0; 1.0;;], ub = [1.0; 1.0;;], mean = [1.0; 1.0;;], std = [0.0; 0.0;;]), [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]), JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}(nothing, (lb = Matrix{Float64}(undef, 0, 1), ub = Matrix{Float64}(undef, 0, 1), mean = Matrix{Float64}(undef, 0, 1), std = Matrix{Float64}(undef, 0, 1)), nothing), JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}, Vector{String}}(["p_1", "p_2", "p_3", "p_4"], (lb = [1.5625; 1.765625; 1.5625; 1.765625;;], ub = [2.4375; 1.9921875; 2.46875; 1.984375;;], mean = [1.96875; 1.8859375; 2.04375; 1.8703125;;], std = [0.30190368221228; 0.07718180647017793; 0.3087272258807117; 0.07431691242390404;;]), SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}[[1.6875, 1.828125, 2.4375, 1.859375], [2.1875, 1.953125, 1.9375, 1.984375], [2.4375, 1.765625, 2.1875, 1.796875], [1.9375, 1.890625, 1.6875, 1.921875], [1.8125, 1.796875, 1.8125, 1.890625], [2.3125, 1.921875, 2.3125, 1.765625], [2.0625, 1.859375, 1.5625, 1.953125], [1.5625, 1.984375, 2.0625, 1.828125]]), JSSBase.StatsAndVals{Nothing, Vector{Tuple{Float64, Float64}}, Nothing}(nothing, nothing, [(0.0, 12.5), (0.0, 12.5), (0.0, 12.5), (0.0, 12.5), (0.0, 12.5), (0.0, 12.5), (0.0, 12.5), (0.0, 12.5)])), DSResults{JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{Matrix{Float64}}, Vector{String}}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}, JSSBase.StatsAndVals{Nothing, Vector{Any}, Nothing}}(JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{Matrix{Float64}}, Vector{String}}(["states_1", "states_2"], (lb = [0.6098798922540988; 0.5851842121921965;;], ub = [1.851268466389882; 1.9298439018722724;;], mean = [1.1313672856099248; 1.0678037215329483;;], std = [0.2944880215199974; 0.2719585629989631;;]), [[1.0 0.9949601273549168 … 1.002287324127581 0.997732046637907; 1.0 0.9753541170505029 … 1.0088870770775706 0.9900313032771977], [1.0 1.011228897681795 … 1.0526042615122204 1.0601358882710805; 1.0 1.0027995826642306 … 1.024280085025019 1.0307159038971654], [1.0 1.0276683300332472 … 1.3587919635893717 1.2653312066176265; 1.0 0.9855214311198398 … 1.9029580292586692 1.9274087988676367], [1.0 1.0017422147423583 … 0.7674519071175443 0.7753735200930125; 1.0 1.0114586319405123 … 0.9945948414917796 0.9691594706940082], [1.0 1.0006608565061674 … 0.9212668418343337 0.9223442907513513; 1.0 1.0045012287775348 … 0.9908344933601647 0.9884774988626833], [1.0 1.0168381897059495 … 1.6425836203344064 1.542425581991914; 1.0 0.9786280516080033 … 1.472005619031362 1.5519712625432036], [1.0 1.0082552353339473 … 0.6110253413889926 0.6100453651685215; 1.0 1.0177099064665793 … 1.1386890223947794 1.0973160292952278], [1.0 0.982254244013019 … 1.241894208987914 1.2875773051218276; 1.0 0.9891423258915556 … 0.5996976725966995 0.6150058914892187]]), JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}(nothing, (lb = Matrix{Float64}(undef, 0, 1), ub = Matrix{Float64}(undef, 0, 1), mean = Matrix{Float64}(undef, 0, 1), std = Matrix{Float64}(undef, 0, 1)), nothing), JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}(nothing, (lb = Matrix{Float64}(undef, 0, 1), ub = Matrix{Float64}(undef, 0, 1), mean = Matrix{Float64}(undef, 0, 1), std = Matrix{Float64}(undef, 0, 1)), nothing), JSSBase.StatsAndVals{Nothing, Vector{Any}, Nothing}(nothing, nothing, Any[[0.0, 0.042795649394557135, 0.10858595070039834, 0.18363466822379002, 0.26978132613576583, 0.36542344507070257, 0.47140319922884233, 0.588268776810634, 0.7192936972464299, 0.8592056439984613  …  11.403041695806124, 11.53132451706903, 11.66784258654401, 11.80418225635944, 11.926681370095109, 12.051702149625019, 12.177264807544118, 12.313296162026765, 12.467374990288835, 12.5], [0.0, 0.0481686372478409, 0.13034334221417043, 0.22892299806316269, 0.3478470828984734, 0.48076343619091283, 0.6073016031370317, 0.7408084358763157, 0.8752797936076603, 1.0156729507576514  …  11.030160551752234, 11.209826806328104, 11.39121462589741, 11.5622670485865, 11.738547113737948, 11.918223963284914, 12.102978746617454, 12.281288404836516, 12.46055965877014, 12.5], [0.0, 0.03984435783107581, 0.10396071240154142, 0.1760290125093218, 0.2559443595588732, 0.34158287280592026, 0.43211993743908944, 0.5264978843542848, 0.6245495974239373, 0.7273156279198932  …  11.502734142603272, 11.623173767909664, 11.743325325383598, 11.874165056610863, 11.988449360432117, 12.096847262920026, 12.201033999369933, 12.305823408915312, 12.424790421939832, 12.5], [0.0, 0.04823326672881516, 0.12813550342348087, 0.2204840881604705, 0.3249860206496322, 0.4411898248603003, 0.5706340704725463, 0.7181839624403255, 0.8658968762295869, 1.015414241849308  …  10.85442937676545, 11.03786415124431, 11.237353944774766, 11.419649751302824, 11.599408592312795, 11.778872975892625, 11.966521266232263, 12.184738095913888, 12.374059547606707, 12.5], [0.0, 0.05697789761856221, 0.15787106715806157, 0.27745903049438214, 0.4128221420224708, 0.5653071961080587, 0.7342276956582562, 0.9137631049256564, 1.0981796009386973, 1.2892051101344313  …  10.570120833654839, 10.80691948505728, 11.041577954691528, 11.28370665596851, 11.516135742077402, 11.745106964896532, 11.976135734572493, 12.217466448680684, 12.46584620916035, 12.5], [0.0, 0.04058554812840774, 0.10804462487220493, 0.18963338302425956, 0.28186636305670737, 0.3780308602404361, 0.47889920910560285, 0.582402085006066, 0.6883933355351735, 0.7965596207088911  …  11.405459287411498, 11.535628335072115, 11.661842033938838, 11.786420372029836, 11.912039641213733, 12.055834105012497, 12.170698946188958, 12.285817494269134, 12.394651384007107, 12.5], [0.0, 0.04400691898309895, 0.11771884780537337, 0.20131182758317567, 0.2937696436413155, 0.39370870774539557, 0.5007523046964786, 0.6154188501477611, 0.7411140941043177, 0.8903635770784906  …  11.112369202368122, 11.26732797628588, 11.439729230811771, 11.592214266958043, 11.744406775276898, 11.895372564501871, 12.053021224858558, 12.229636961427607, 12.400190779647161, 12.5], [0.0, 0.04353137306515072, 0.11823443670346498, 0.21160759657929035, 0.33835638665255163, 0.4475097854474529, 0.5749629224905353, 0.7028730052269062, 0.8408854766557542, 0.9847864034491689  …  11.02420045468437, 11.177489200000656, 11.345798301165125, 11.508160451808525, 11.679913743075657, 11.860168496048173, 12.050252512980572, 12.224853382306812, 12.399148571498351, 12.5]])))
 ExperimentData{DSResults{JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{Matrix{Float64}}, Vector{String}}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}, JSSBase.StatsAndVals{Nothing, Vector{Any}, Nothing}}, DSSpecification{JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}, Vector{String}}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{Vector{Float64}}, Vector{String}}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}, JSSBase.StatsAndVals{Nothing, Vector{Tuple{Float64, Float64}}, Nothing}}}(DSSpecification{JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}, Vector{String}}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{Vector{Float64}}, Vector{String}}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}, JSSBase.StatsAndVals{Nothing, Vector{Tuple{Float64, Float64}}, Nothing}}(JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{Vector{Float64}}, Vector{String}}(["states_1", "states_2"], (lb = [1.0; 1.0;;], ub = [1.0; 1.0;;], mean = [1.0; 1.0;;], std = [0.0; 0.0;;]), [[1.0, 1.0], [1.0, 1.0]]), JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}(nothing, (lb = Matrix{Float64}(undef, 0, 1), ub = Matrix{Float64}(undef, 0, 1), mean = Matrix{Float64}(undef, 0, 1), std = Matrix{Float64}(undef, 0, 1)), nothing), JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}, Vector{String}}(["p_1", "p_2", "p_3", "p_4"], (lb = [1.5625; 1.765625; 1.5625; 1.765625;;], ub = [2.4375; 1.9921875; 2.46875; 1.984375;;], mean = [1.96875; 1.8859375; 2.04375; 1.8703125;;], std = [0.30190368221228; 0.07718180647017793; 0.3087272258807117; 0.07431691242390404;;]), SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}[[1.59375, 1.8671875, 1.96875, 1.9140625], [2.09375, 1.9921875, 2.46875, 1.7890625]]), JSSBase.StatsAndVals{Nothing, Vector{Tuple{Float64, Float64}}, Nothing}(nothing, nothing, [(0.0, 12.5), (0.0, 12.5)])), DSResults{JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{Matrix{Float64}}, Vector{String}}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}, JSSBase.StatsAndVals{Nothing, Vector{Any}, Nothing}}(JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{Matrix{Float64}}, Vector{String}}(["states_1", "states_2"], (lb = [0.6098798922540988; 0.5851842121921965;;], ub = [1.851268466389882; 1.9298439018722724;;], mean = [1.1313672856099248; 1.0678037215329483;;], std = [0.2944880215199974; 0.2719585629989631;;]), [[1.0 0.9873247423577616 … 1.0493477497730714 1.0557474224950205; 1.0 0.996852702206734 … 0.7210536973701577 0.7218697794640881], [1.0 1.0052653414378392 … 1.7994942732111399 1.8433254744409278; 1.0 0.9728416890591636 … 0.9040131293427596 0.9905440168655524]]), JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}(nothing, (lb = Matrix{Float64}(undef, 0, 1), ub = Matrix{Float64}(undef, 0, 1), mean = Matrix{Float64}(undef, 0, 1), std = Matrix{Float64}(undef, 0, 1)), nothing), JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}(nothing, (lb = Matrix{Float64}(undef, 0, 1), ub = Matrix{Float64}(undef, 0, 1), mean = Matrix{Float64}(undef, 0, 1), std = Matrix{Float64}(undef, 0, 1)), nothing), JSSBase.StatsAndVals{Nothing, Vector{Any}, Nothing}(nothing, nothing, Any[[0.0, 0.047127585160017554, 0.12509802840531214, 0.21704735958398497, 0.324049780970661, 0.44888984360034456, 0.6084666817579972, 0.7519412601927129, 0.9099971548010859, 1.0694453469567609  …  10.924667534723069, 11.104601448837544, 11.307327069317365, 11.49634670966968, 11.684450855109624, 11.873835110406612, 12.069966840060463, 12.2765312767051, 12.47535187118055, 12.5], [0.0, 0.04077217759971856, 0.1018117378873869, 0.17167816881801023, 0.25312053924218414, 0.3453119477062352, 0.4514396511489154, 0.5654100923191183, 0.6780964289410054, 0.7930750857549353  …  11.33321864574448, 11.459504252184034, 11.593674238928834, 11.739055482112489, 11.871170867492287, 12.002831401939117, 12.129990459447217, 12.256547465954746, 12.385032798958612, 12.5]])))

Adding validation and test data

We need to analyze the validation and test loss during training to see how our DigitalEcho model performs on unseen data. We can simply provide the val_ed as a keyword argument to DigitalEcho as:

model = DigitalEcho(
    train_ed;
    tau = 1e-2,
    val_ed = val_ed
)

CallBacks

We will now walkthrough various pre-implemented callbacks and how to hook them into the training process.

Defining existing callbacks

1. PrintCallBack

The PrintCallBack allows a user to print out the state and metrics of the training process into the REPL. It prints the epoch number, the current learning rate, and the losses (train, validation and test).

We define the callback with an argument every which implies to how often the callback is called. In the following example every is set to 10, which implies this callback will be called after every 10 epochs.

print_cb = PrintCallBack(every = 10)
A `PrintCallBack` model

2. CheckPoint

The CheckPoint callback allows the user to store the model in a .bson file.

We define the callback with an argument model_path which is the path to the directory where the models should be stored. And an every argument which implies to how often the callback is called.

checkpoint_cb = CheckPoint(model_path = model_path, every = 25)
A `CheckPoint` model

Adding the callbacks to training process

We can provide callbacks to the training process by using the keyword argument cb. The callevery keyword argument controls how often all callbacks are collectively called.

Note

We need to provide verbose = true for PrintCallBack to get called.

Note

The callevery argument controls the calling interval for all the callbacks collectively, and the every argument to each specific callback only controls the interval for that specific callback. The precedence is taken by callevery and then every argument for the specific callback. For example, if callevery is set to 10, and every is set to 5, the callback will be called every 10 epochs.

We can provide a single callback as:

model = DigitalEcho(train_ed;
    nepochs = 100,
    verbose = true,
    callevery = 10,
    cb = [print_cb],
    val_ed = val_ed)

We can provide multiple callbacks as:

model = DigitalEcho(train_ed;
    nepochs = 100,
    verbose = true,
    callevery = 10,
    cb = [print_cb, checkpoint_cb], val_ed = val_ed)