Validation and CallBacks with DigitalEcho
In this demonstration we will show how to hook in callbacks into the training process for DigitalEcho. Callbacks are functions that are executed at predefined events, or periodically during the training process. They allow an user to customize the training process and provide a degree of control over it. A callback can be used to print current state, save models, log data, early stopping of the training process and much more. We will start by setting up our Julia Environment and loading necessary packages.
using JuliaHub, JLSO, DataGeneration, PreProcessing, Training, Surrogatize
We will load an existing dataset into an ExperimentData
object. For more understanding on how to load a dataset and train a DigitalEcho
see : Training a Full Model DigitalEcho.
train_dataset_name = "lotka_volterra"
path = JuliaHub.download_dataset(("juliasimtutorials", train_dataset_name), "path to save")
ed = ExperimentData(JLSO.load(path)[:result])
Number of Trajectories in ExperimentData: 10
Basic Statistics for Given Dynamical System's Specifications
Number of u0s in the ExperimentData: 2
Number of ps in the ExperimentData: 4
╭─────────┬──────────────────────────────────────────────────────────────────...
────╮...
│ Field │...
│...
├─────────┼──────────────────────────────────────────────────────────────────...
────┤...
│ │ ╭────────────┬──────────────┬──────────────┬────────┬─────────...
│ │ │ Labels │ LowerBound │ UpperBound │ Mean │ StdDev...
│ │ ├────────────┼──────────────┼──────────────┼────────┼─────────...
│ │ │ states_1 │ 1 │ 1 │ 1 │...
│ u0s │ ├────────────┼──────────────┼──────────────┼────────┼─────────...
│ │ │ ⋮ │ ⋮ │ ⋮ │ ⋮ │...
│ │ │ ⋮ │ ⋮ │ ⋮ │ ⋮ │...
│ │ ├────────────┼──────────────┼──────────────┼────────┼─────────...
│ │ │ states_2 │ 1 │ 1 │ 1 │...
│ │ ╰────────────┴──────────────┴──────────────┴────────┴─────────...
├─────────┼──────────────────────────────────────────────────────────────────...
────┤...
│ │ ╭──────────┬──────────────┬──────────────┬───────────┬──────────...
│ │ │ Labels │ LowerBound │ UpperBound │ Mean │ StdDev...
│ │ ├──────────┼──────────────┼──────────────┼───────────┼──────────...
│ │ │ p_1 │ 1.5625 │ 2.4375 │ 1.96875 │ 0.301904...
│ ps │ ├──────────┼──────────────┼──────────────┼───────────┼──────────...
│ │ │ ⋮ │ ⋮ │ ⋮ │ ⋮ │...
│ │ │ ⋮ │ ⋮ │ ⋮ │ ⋮ │...
│ │ ├──────────┼──────────────┼──────────────┼───────────┼──────────...
│ │ │ p_4 │ 1.76562 │ 1.98438 │ 1.87031 │ 0.074316...
│ │ ╰──────────┴──────────────┴──────────────┴───────────┴──────────...
╰─────────┴──────────────────────────────────────────────────────────────────...
────╯...
Basic Statistics for Given Dynamical System's Continuous Fields
Number of states in the ExperimentData: 2
╭──────────┬─────────────────────────────────────────────────────────────────...
──────╮...
│ Field │...
│...
├──────────┼─────────────────────────────────────────────────────────────────...
──────┤...
│ │ ╭────────────┬──────────────┬──────────────┬───────────┬───────...
│ │ │ Labels │ LowerBound │ UpperBound │ Mean │...
│ │ ├────────────┼──────────────┼──────────────┼───────────┼───────...
│ │ │ states_1 │ 0.60988 │ 1.85127 │ 1.13137 │...
│ states │ ├────────────┼──────────────┼──────────────┼───────────┼───────...
│ │ │ ⋮ │ ⋮ │ ⋮ │ ⋮ │...
│ │ │ ⋮ │ ⋮ │ ⋮ │ ⋮ │...
│ │ ├────────────┼──────────────┼──────────────┼───────────┼───────...
│ │ │ states_2 │ 0.585184 │ 1.92984 │ 1.0678 │...
│ │ ╰────────────┴──────────────┴──────────────┴───────────┴───────...
╰──────────┴─────────────────────────────────────────────────────────────────...
──────╯...
We will split this ExperimentData
object into train and validation using train_valid_split
with train_split = 0.8
train_ed, val_ed = train_valid_split(ed, train_ratio = 0.8)
2-element Vector{ExperimentData{DSResults{JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{Matrix{Float64}}, Vector{String}}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}, JSSBase.StatsAndVals{Nothing, Vector{Any}, Nothing}}, DSSpecification{JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}, Vector{String}}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{Vector{Float64}}, Vector{String}}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}, JSSBase.StatsAndVals{Nothing, Vector{Tuple{Float64, Float64}}, Nothing}}}}:
ExperimentData{DSResults{JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{Matrix{Float64}}, Vector{String}}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}, JSSBase.StatsAndVals{Nothing, Vector{Any}, Nothing}}, DSSpecification{JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}, Vector{String}}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{Vector{Float64}}, Vector{String}}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}, JSSBase.StatsAndVals{Nothing, Vector{Tuple{Float64, Float64}}, Nothing}}}(DSSpecification{JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}, Vector{String}}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{Vector{Float64}}, Vector{String}}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}, JSSBase.StatsAndVals{Nothing, Vector{Tuple{Float64, Float64}}, Nothing}}(JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{Vector{Float64}}, Vector{String}}(["states_1", "states_2"], (lb = [1.0; 1.0;;], ub = [1.0; 1.0;;], mean = [1.0; 1.0;;], std = [0.0; 0.0;;]), [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]), JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}(nothing, (lb = Matrix{Float64}(undef, 0, 1), ub = Matrix{Float64}(undef, 0, 1), mean = Matrix{Float64}(undef, 0, 1), std = Matrix{Float64}(undef, 0, 1)), nothing), JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}, Vector{String}}(["p_1", "p_2", "p_3", "p_4"], (lb = [1.5625; 1.765625; 1.5625; 1.765625;;], ub = [2.4375; 1.9921875; 2.46875; 1.984375;;], mean = [1.96875; 1.8859375; 2.04375; 1.8703125;;], std = [0.30190368221228; 0.07718180647017793; 0.3087272258807117; 0.07431691242390404;;]), SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}[[1.6875, 1.828125, 2.4375, 1.859375], [2.1875, 1.953125, 1.9375, 1.984375], [2.4375, 1.765625, 2.1875, 1.796875], [1.9375, 1.890625, 1.6875, 1.921875], [1.8125, 1.796875, 1.8125, 1.890625], [2.3125, 1.921875, 2.3125, 1.765625], [2.0625, 1.859375, 1.5625, 1.953125], [1.5625, 1.984375, 2.0625, 1.828125]]), JSSBase.StatsAndVals{Nothing, Vector{Tuple{Float64, Float64}}, Nothing}(nothing, nothing, [(0.0, 12.5), (0.0, 12.5), (0.0, 12.5), (0.0, 12.5), (0.0, 12.5), (0.0, 12.5), (0.0, 12.5), (0.0, 12.5)])), DSResults{JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{Matrix{Float64}}, Vector{String}}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}, JSSBase.StatsAndVals{Nothing, Vector{Any}, Nothing}}(JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{Matrix{Float64}}, Vector{String}}(["states_1", "states_2"], (lb = [0.6098798922540988; 0.5851842121921965;;], ub = [1.851268466389882; 1.9298439018722724;;], mean = [1.1313672856099248; 1.0678037215329483;;], std = [0.2944880215199974; 0.2719585629989631;;]), [[1.0 0.9949601273549168 … 1.002287324127581 0.997732046637907; 1.0 0.9753541170505029 … 1.0088870770775706 0.9900313032771977], [1.0 1.011228897681795 … 1.0526042615122204 1.0601358882710805; 1.0 1.0027995826642306 … 1.024280085025019 1.0307159038971654], [1.0 1.0276683300332472 … 1.3587919635893717 1.2653312066176265; 1.0 0.9855214311198398 … 1.9029580292586692 1.9274087988676367], [1.0 1.0017422147423583 … 0.7674519071175443 0.7753735200930125; 1.0 1.0114586319405123 … 0.9945948414917796 0.9691594706940082], [1.0 1.0006608565061674 … 0.9212668418343337 0.9223442907513513; 1.0 1.0045012287775348 … 0.9908344933601647 0.9884774988626833], [1.0 1.0168381897059495 … 1.6425836203344064 1.542425581991914; 1.0 0.9786280516080033 … 1.472005619031362 1.5519712625432036], [1.0 1.0082552353339473 … 0.6110253413889926 0.6100453651685215; 1.0 1.0177099064665793 … 1.1386890223947794 1.0973160292952278], [1.0 0.982254244013019 … 1.241894208987914 1.2875773051218276; 1.0 0.9891423258915556 … 0.5996976725966995 0.6150058914892187]]), JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}(nothing, (lb = Matrix{Float64}(undef, 0, 1), ub = Matrix{Float64}(undef, 0, 1), mean = Matrix{Float64}(undef, 0, 1), std = Matrix{Float64}(undef, 0, 1)), nothing), JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}(nothing, (lb = Matrix{Float64}(undef, 0, 1), ub = Matrix{Float64}(undef, 0, 1), mean = Matrix{Float64}(undef, 0, 1), std = Matrix{Float64}(undef, 0, 1)), nothing), JSSBase.StatsAndVals{Nothing, Vector{Any}, Nothing}(nothing, nothing, Any[[0.0, 0.042795649394557135, 0.10858595070039834, 0.18363466822379002, 0.26978132613576583, 0.36542344507070257, 0.47140319922884233, 0.588268776810634, 0.7192936972464299, 0.8592056439984613 … 11.403041695806124, 11.53132451706903, 11.66784258654401, 11.80418225635944, 11.926681370095109, 12.051702149625019, 12.177264807544118, 12.313296162026765, 12.467374990288835, 12.5], [0.0, 0.0481686372478409, 0.13034334221417043, 0.22892299806316269, 0.3478470828984734, 0.48076343619091283, 0.6073016031370317, 0.7408084358763157, 0.8752797936076603, 1.0156729507576514 … 11.030160551752234, 11.209826806328104, 11.39121462589741, 11.5622670485865, 11.738547113737948, 11.918223963284914, 12.102978746617454, 12.281288404836516, 12.46055965877014, 12.5], [0.0, 0.03984435783107581, 0.10396071240154142, 0.1760290125093218, 0.2559443595588732, 0.34158287280592026, 0.43211993743908944, 0.5264978843542848, 0.6245495974239373, 0.7273156279198932 … 11.502734142603272, 11.623173767909664, 11.743325325383598, 11.874165056610863, 11.988449360432117, 12.096847262920026, 12.201033999369933, 12.305823408915312, 12.424790421939832, 12.5], [0.0, 0.04823326672881516, 0.12813550342348087, 0.2204840881604705, 0.3249860206496322, 0.4411898248603003, 0.5706340704725463, 0.7181839624403255, 0.8658968762295869, 1.015414241849308 … 10.85442937676545, 11.03786415124431, 11.237353944774766, 11.419649751302824, 11.599408592312795, 11.778872975892625, 11.966521266232263, 12.184738095913888, 12.374059547606707, 12.5], [0.0, 0.05697789761856221, 0.15787106715806157, 0.27745903049438214, 0.4128221420224708, 0.5653071961080587, 0.7342276956582562, 0.9137631049256564, 1.0981796009386973, 1.2892051101344313 … 10.570120833654839, 10.80691948505728, 11.041577954691528, 11.28370665596851, 11.516135742077402, 11.745106964896532, 11.976135734572493, 12.217466448680684, 12.46584620916035, 12.5], [0.0, 0.04058554812840774, 0.10804462487220493, 0.18963338302425956, 0.28186636305670737, 0.3780308602404361, 0.47889920910560285, 0.582402085006066, 0.6883933355351735, 0.7965596207088911 … 11.405459287411498, 11.535628335072115, 11.661842033938838, 11.786420372029836, 11.912039641213733, 12.055834105012497, 12.170698946188958, 12.285817494269134, 12.394651384007107, 12.5], [0.0, 0.04400691898309895, 0.11771884780537337, 0.20131182758317567, 0.2937696436413155, 0.39370870774539557, 0.5007523046964786, 0.6154188501477611, 0.7411140941043177, 0.8903635770784906 … 11.112369202368122, 11.26732797628588, 11.439729230811771, 11.592214266958043, 11.744406775276898, 11.895372564501871, 12.053021224858558, 12.229636961427607, 12.400190779647161, 12.5], [0.0, 0.04353137306515072, 0.11823443670346498, 0.21160759657929035, 0.33835638665255163, 0.4475097854474529, 0.5749629224905353, 0.7028730052269062, 0.8408854766557542, 0.9847864034491689 … 11.02420045468437, 11.177489200000656, 11.345798301165125, 11.508160451808525, 11.679913743075657, 11.860168496048173, 12.050252512980572, 12.224853382306812, 12.399148571498351, 12.5]])))
ExperimentData{DSResults{JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{Matrix{Float64}}, Vector{String}}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}, JSSBase.StatsAndVals{Nothing, Vector{Any}, Nothing}}, DSSpecification{JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}, Vector{String}}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{Vector{Float64}}, Vector{String}}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}, JSSBase.StatsAndVals{Nothing, Vector{Tuple{Float64, Float64}}, Nothing}}}(DSSpecification{JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}, Vector{String}}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{Vector{Float64}}, Vector{String}}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}, JSSBase.StatsAndVals{Nothing, Vector{Tuple{Float64, Float64}}, Nothing}}(JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{Vector{Float64}}, Vector{String}}(["states_1", "states_2"], (lb = [1.0; 1.0;;], ub = [1.0; 1.0;;], mean = [1.0; 1.0;;], std = [0.0; 0.0;;]), [[1.0, 1.0], [1.0, 1.0]]), JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}(nothing, (lb = Matrix{Float64}(undef, 0, 1), ub = Matrix{Float64}(undef, 0, 1), mean = Matrix{Float64}(undef, 0, 1), std = Matrix{Float64}(undef, 0, 1)), nothing), JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}, Vector{String}}(["p_1", "p_2", "p_3", "p_4"], (lb = [1.5625; 1.765625; 1.5625; 1.765625;;], ub = [2.4375; 1.9921875; 2.46875; 1.984375;;], mean = [1.96875; 1.8859375; 2.04375; 1.8703125;;], std = [0.30190368221228; 0.07718180647017793; 0.3087272258807117; 0.07431691242390404;;]), SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}[[1.59375, 1.8671875, 1.96875, 1.9140625], [2.09375, 1.9921875, 2.46875, 1.7890625]]), JSSBase.StatsAndVals{Nothing, Vector{Tuple{Float64, Float64}}, Nothing}(nothing, nothing, [(0.0, 12.5), (0.0, 12.5)])), DSResults{JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{Matrix{Float64}}, Vector{String}}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}, JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}, JSSBase.StatsAndVals{Nothing, Vector{Any}, Nothing}}(JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Vector{Matrix{Float64}}, Vector{String}}(["states_1", "states_2"], (lb = [0.6098798922540988; 0.5851842121921965;;], ub = [1.851268466389882; 1.9298439018722724;;], mean = [1.1313672856099248; 1.0678037215329483;;], std = [0.2944880215199974; 0.2719585629989631;;]), [[1.0 0.9873247423577616 … 1.0493477497730714 1.0557474224950205; 1.0 0.996852702206734 … 0.7210536973701577 0.7218697794640881], [1.0 1.0052653414378392 … 1.7994942732111399 1.8433254744409278; 1.0 0.9728416890591636 … 0.9040131293427596 0.9905440168655524]]), JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}(nothing, (lb = Matrix{Float64}(undef, 0, 1), ub = Matrix{Float64}(undef, 0, 1), mean = Matrix{Float64}(undef, 0, 1), std = Matrix{Float64}(undef, 0, 1)), nothing), JSSBase.StatsAndVals{@NamedTuple{lb::Matrix{Float64}, ub::Matrix{Float64}, mean::Matrix{Float64}, std::Matrix{Float64}}, Nothing, Nothing}(nothing, (lb = Matrix{Float64}(undef, 0, 1), ub = Matrix{Float64}(undef, 0, 1), mean = Matrix{Float64}(undef, 0, 1), std = Matrix{Float64}(undef, 0, 1)), nothing), JSSBase.StatsAndVals{Nothing, Vector{Any}, Nothing}(nothing, nothing, Any[[0.0, 0.047127585160017554, 0.12509802840531214, 0.21704735958398497, 0.324049780970661, 0.44888984360034456, 0.6084666817579972, 0.7519412601927129, 0.9099971548010859, 1.0694453469567609 … 10.924667534723069, 11.104601448837544, 11.307327069317365, 11.49634670966968, 11.684450855109624, 11.873835110406612, 12.069966840060463, 12.2765312767051, 12.47535187118055, 12.5], [0.0, 0.04077217759971856, 0.1018117378873869, 0.17167816881801023, 0.25312053924218414, 0.3453119477062352, 0.4514396511489154, 0.5654100923191183, 0.6780964289410054, 0.7930750857549353 … 11.33321864574448, 11.459504252184034, 11.593674238928834, 11.739055482112489, 11.871170867492287, 12.002831401939117, 12.129990459447217, 12.256547465954746, 12.385032798958612, 12.5]])))
Adding validation and test data
We need to analyze the validation
and test
loss during training to see how our DigitalEcho
model performs on unseen data. We can simply provide the val_ed
as a keyword argument to DigitalEcho
as:
model = DigitalEcho(
train_ed;
tau = 1e-2,
val_ed = val_ed
)
CallBacks
We will now walkthrough various pre-implemented callbacks and how to hook them into the training process.
Defining existing callbacks
1. PrintCallBack
The PrintCallBack allows a user to print out the state and metrics of the training process into the REPL. It prints the epoch number, the current learning rate, and the losses (train, validation and test).
We define the callback with an argument every
which implies to how often the callback is called. In the following example every
is set to 10, which implies this callback will be called after every 10 epochs.
print_cb = PrintCallBack(every = 10)
A `PrintCallBack` model
2. CheckPoint
The CheckPoint
callback allows the user to store the model in a .bson
file.
We define the callback with an argument model_path
which is the path to the directory where the models should be stored. And an every
argument which implies to how often the callback is called.
checkpoint_cb = CheckPoint(model_path = model_path, every = 25)
A `CheckPoint` model
Adding the callbacks to training process
We can provide callbacks to the training process by using the keyword argument cb
. The callevery
keyword argument controls how often all callbacks are collectively called.
We need to provide verbose = true
for PrintCallBack
to get called.
The callevery
argument controls the calling interval for all the callbacks collectively, and the every
argument to each specific callback only controls the interval for that specific callback. The precedence is taken by callevery
and then every
argument for the specific callback. For example, if callevery
is set to 10, and every
is set to 5, the callback will be called every 10 epochs.
We can provide a single callback as:
model = DigitalEcho(train_ed;
nepochs = 100,
verbose = true,
callevery = 10,
cb = [print_cb],
val_ed = val_ed)
We can provide multiple callbacks as:
model = DigitalEcho(train_ed;
nepochs = 100,
verbose = true,
callevery = 10,
cb = [print_cb, checkpoint_cb], val_ed = val_ed)