Calibration workflows
This demonstration is also available in notebook form, and has been tested in the Julia package environment specified by these Project.toml and Manifest.toml files.
using TumorGrowth
using Statistics
using Plots
using IterationControl
Plots.scalefontsizes() # reset font sizes
Plots.scalefontsizes(0.85)
Activating project at `~/GoogleDrive/Julia/TumorGrowth/docs/src/examples/03_calibration`
Data ingestion
Get the records which have a least 6 measurements:
records = patient_data();
records6 = filter(records) do record
record.readings >= 6
end;
Helpers
Wrapper to only apply control every 100 steps:
sometimes(control) = IterationControl.skip(control, predicate=100)
sometimes (generic function with 1 method)
Wrapper to only apply control after first 30 steps:
warmup(control) = IterationControl.Warmup(control, 30)
warmup (generic function with 1 method)
Patient A - a volume that is mostly decreasiing
record = records6[2]
(Pt_hashID = "19ce84cc1b10000b63820280995107c2-S1", Study_Arm = InlineStrings.String15("Study_1_Arm_1"), Study_id = 1, Arm_id = 1, T_weeks = [0.1, 6.0, 16.0, 24.0, 32.0, 39.0], T_days = [-12, 38, 106, 166, 222, 270], Lesion_diam = [14.0, 10.0, 9.0, 8.0, 8.0, 8.0], Lesion_vol = [1426.88, 520.0, 379.08, 266.24, 266.24, 266.24], Lesion_normvol = [0.000231515230057292, 8.4371439525252e-5, 6.15067794139087e-5, 4.3198177036929e-5, 4.3198177036929e-5, 4.3198177036929e-5], response = InlineStrings.String7("down"), readings = 6)
times = record.T_weeks
volumes = record.Lesion_normvol;
We'll try calibrating the General Bertalanffy model, bertalanffy
, with fixed parameter λ=1/5
:
problem = CalibrationProblem(
times,
volumes,
bertalanffy;
frozen=(; λ=1/5),
learning_rate=0.001,
half_life=21, # place greater weight on recent measurements
)
CalibrationProblem:
model: bertalanffy
optimiser: Adam(0.001, (0.9, 0.999), 1.0e-8)
current solution: v0=0.000232 v∞=4.32e-5 ω=0.0432 λ=0.2
The controls in the solve!
call below have the following interpretations:
Step(1)
: compute 1 iteration at a timeInvalidValue()
: to catch parameters going out of boundsNumberLimit(6000)
: stop after 6000 stepsGL() |> warmup
: stop using Prechelt's GL criterion after the warm-up periodNumberSinceBest(10) |> warmup
: stop when it's 10 steps since the best so farCallback(prob-> (plot(prob); gui())) |> sometimes
: periodically plot the problem
Some other possible controls are:
TimeLimit(1/60)
: stop after 1 minuteWithLossDo()
: log toInfo
the current loss
See IterationControl.jl for a complete list.
solve!(
problem,
Step(1),
InvalidValue(),
NumberLimit(6000),
GL() |> warmup,
NumberSinceBest(10) |> warmup,
Callback(prob-> (plot(prob); gui())) |> sometimes,
)
((IterationControl.Step(1), (new_iterations = 6000,)), (EarlyStopping.InvalidValue(), (done = false, log = "")), (EarlyStopping.NumberLimit(6000), (done = true, log = "Stop triggered by EarlyStopping.NumberLimit(6000) stopping criterion. ")), (EarlyStopping.Warmup{EarlyStopping.GL}(EarlyStopping.GL(2.0), 30), (done = false, log = "")), (EarlyStopping.Warmup{EarlyStopping.NumberSinceBest}(EarlyStopping.NumberSinceBest(10), 30), (done = false, log = "")))
p = solution(problem)
extended_times = vcat(times, [40.0, 47.0])
bertalanffy(extended_times, p)
8-element Vector{Float64}:
0.0002188341893460766
0.00010826304912371365
5.720601298511401e-5
4.5669494984156554e-5
4.1096282718025634e-5
3.931318815519471e-5
3.9148263686165166e-5
3.836347468826367e-5
plot(problem, title="Patient A, λ=1/5 fixed", color=:black)
savefig(joinpath(dir, "patientA.png"))
"/Users/anthony/GoogleDrive/Julia/TumorGrowth/docs/src/examples/03_calibration/patientA.png"
Patient B - relapse following initial improvement
record = records6[10]
times = record.T_weeks
volumes = record.Lesion_normvol;
We'll first try the earlier simple model, but we won't freeze λ
. Also, we won't specify a half_life
, giving all the data equal weight.
problem = CalibrationProblem(
times,
volumes,
bertalanffy;
learning_rate=0.001,
)
solve!(
problem,
Step(1),
InvalidValue(),
NumberLimit(6000),
)
plot(problem, label="bertalanffy")
Let's try the 2D generalization of the General Bertalanffy model:
problem = CalibrationProblem(
times,
volumes,
bertalanffy2;
learning_rate=0.001,
)
solve!(
problem,
Step(1),
InvalidValue(),
NumberLimit(6000),
)
plot!(problem, label="bertalanffy2")
Here's how we can inspect the final parameters:
solution(problem)
(v0 = 0.012683548940358924, v∞ = 0.0010008510378552815, ω = 0.12190919867691268, λ = 0.9134381946630703, γ = 0.498384973747)
Or we can do:
solution(problem) |> pretty
"v0=0.0127 v∞=0.001 ω=0.122 λ=0.913 γ=0.498"
And finally, we'll try a 2D neural ODE model, with fixed volume scale v∞
.
using Lux, Random
Note well the zero-initialization of weights in first layer:
network2 = Chain(
Dense(2, 2, Lux.tanh, init_weight=Lux.zeros64),
Dense(2, 2),
)
Chain(
layer_1 = Dense(2 => 2, tanh_fast), # 6 parameters
layer_2 = Dense(2 => 2), # 6 parameters
) # Total: 12 parameters,
# plus 0 states.
Notice this network has a total of 12 parameters. To that we'll be adding the initial value u0
of the latent variable. So this is a model with relatively high complexity for this problem.
n2 = neural2(Xoshiro(123), network2) # `Xoshiro` is a random number generator
Neural2 model, (times, p) -> volumes, where length(p) = 14
transform: log
Note the reduced learning rate.
v∞ = mean(volumes)
problem = CalibrationProblem(
times,
volumes,
n2;
frozen = (; v∞),
learning_rate=0.001,
)
solve!(
problem,
Step(1),
InvalidValue(),
NumberLimit(6000),
)
plot!(
problem,
title = "Model comparison for Patient B",
label = "neural2",
legend=:inside,
)
savefig(joinpath(dir, "patientB.png"))
"/Users/anthony/GoogleDrive/Julia/TumorGrowth/docs/src/examples/03_calibration/patientB.png"
For a more principled comparison, we compare the models on a holdout set. We'll additionally throw in 1D neural ODE model.
network1 = Chain(
Dense(1, 3, Lux.tanh, init_weight=Lux.zeros64),
Dense(3, 1),
)
n1 = neural(Xoshiro(123), network1)
models = [bertalanffy, bertalanffy2, n1, n2]
calibration_options = [
(frozen = (; λ=1/5), learning_rate=0.001, half_life=21), # bertalanffy
(frozen = (; λ=1/5), learning_rate=0.001, half_life=21), # bertalanffy2
(frozen = (; v∞), learning_rate=0.001, half_life=21), # neural
(frozen = (; v∞), learning_rate=0.001, half_life=21), # neural2
]
iterations = [6000, 6000, 6000, 6000]
comparison = compare(times, volumes, models; calibration_options, iterations)
ModelComparison with 3 holdouts:
metric: mae
bertalanffy: 0.004521
bertalanffy2: 0.004324
neural (12 params): 0.004685
neural2 (14 params): 0.004149
plot(comparison)
savefig(joinpath(dir, "patientB_validation.png"))
"/Users/anthony/GoogleDrive/Julia/TumorGrowth/docs/src/examples/03_calibration/patientB_validation.png"
This page was generated using Literate.jl.