Model Battle
The analysis below is also available in notebook form, and was tested in a Julia package environment specified by these Project.toml and Manifest.toml files.
Note. The @threads
cell in this notebook takes about 4 hours to complete on a 2018 MacBook Pro:
We compare the predictive performance of several tumor growth models on data collected in Laleh et al. (2022) "Classical mathematical models for prediction of response to chemotherapy and immunotherapy", PLOS Computational Biology". In particular, we determine whether differences observed are statistically significant.
In addition to classical models, we include a 2D generalization of the General Bertalanffy model, bertalanffy2
, and some 1D and 2D neural ODE's. The 2D models still model a single lesion feature, namely it's volume, but add a second latent variable coupled to the volume, effectively making the model second order. For further details, refer to the TumorGrowth.jl package documentation.
Conclusions
We needed to eliminate about 2% of patient records because of failure of the neural network models to converge before parameters went out of bounds. A bootstrap comparison of the differences in mean absolute errors suggest that the General Bertalanffy model performs significantly better than all other models, with of the exception the 1D neural ODE. However, in pair-wise comparisons the neural ODE model was not significantly better than any model. Results are summarised in the table below. Arrows point to bootstrap winners in the top row or first column.
model | gompertz | logistic | classical_bertalanffy | bertalanffy | bertalanffy2 | n1 | n2 |
---|---|---|---|---|---|---|---|
exponential | ↑ | draw | ↑ | ↑ | draw | draw | ← |
gompertz | n/a | draw | draw | ↑ | draw | draw | ← |
logistic | draw | n/a | draw | ↑ | draw | draw | ← |
classical_bertalanffy | draw | draw | n/a | ↑ | draw | draw | ← |
bertalanffy | ← | ← | ← | n/a | ← | draw | ← |
bertalanffy2 | draw | draw | draw | ↑ | n/a | draw | ← |
n1 | draw | draw | draw | draw | draw | n/a | ← |
using Random
using Statistics
using TumorGrowth
using Lux
using Plots
import PrettyPrint.pprint
using PrettyTables
using Bootstrap
using Serialization
using ProgressMeter
using .Threads
Activating project at `~/GoogleDrive/Julia/TumorGrowth/docs/src/examples/04_model_battle`
Data ingestion
Collect together all records with at least 6 measurements, from the data
records = filter(patient_data()) do record
record.readings >= 6
end;
Here's what a single record looks like:
pprint(records[13])
@NamedTuple{Pt_hashID::String, Study_Arm::InlineStrings.String15, Study_id::Int64, Arm_id::Int64, T_weeks::Vector{Float64}, T_days::Vector{Int64}, Lesion_diam::Vector{Float64}, Lesion_vol::Vector{Float64}, Lesion_normvol::Vector{Float64}, response::InlineStrings.String7, readings::Int64}(
Pt_hashID="d9b90f39d6a0b35cbc230adadbd50753-S1",
Study_Arm=InlineStrings.String15("Study_1_Arm_1"),
Study_id=1,
Arm_id=1,
T_weeks=[0.1, 6.0, 12.0, 18.0,
24.0, 36.0, 40.0, 42.0,
48.0],
T_days=[-16, 39, 82, 124, 165,
249, 277, 292, 334],
Lesion_diam=[17.0, 18.0, 16.0,
9.0, 8.0, 9.0, 7.0,
7.0, 7.0],
Lesion_vol=[2554.76, 3032.64, 2129.92,
379.08, 266.24, 379.08,
178.36, 178.36, 178.36],
Lesion_normvol=[0.000414516882387563,
0.00049205423531127,
0.000345585416295432,
6.15067794139087e-5,
4.3198177036929e-5,
6.15067794139087e-5,
2.89394037571615e-5,
2.89394037571615e-5,
2.89394037571615e-5],
response=InlineStrings.String7("flux"),
readings=9,
)
Neural ODEs
We define some one and two-dimensional neural ODE models we want to include in our comparison. The choice of architecture here is somewhat ad hoc and further experimentation might give better results.
network = Chain(
Dense(1, 3, Lux.tanh, init_weight=Lux.zeros64),
Dense(3, 1),
)
network2 = Chain(
Dense(2, 2, Lux.tanh, init_weight=Lux.zeros64),
Dense(2, 2),
)
n1 = neural(Xoshiro(123), network) # `Xoshiro` is a random number generator
n2 = neural2(Xoshiro(123), network2)
Neural2 model, (times, p) -> volumes, where length(p) = 14
transform: log
Models to be compared
model_exs =
[:exponential, :gompertz, :logistic, :classical_bertalanffy, :bertalanffy,
:bertalanffy2, :n1, :n2]
models = eval.(model_exs)
8-element Vector{Any}:
exponential (generic function with 1 method)
gompertz (generic function with 1 method)
logistic (generic function with 1 method)
classical_bertalanffy (generic function with 1 method)
bertalanffy (generic function with 1 method)
bertalanffy2 (generic function with 1 method)
neural (12 params)
neural2 (14 params)
Computing prediction errors on a holdout set
holdouts = 2
errs = fill(Inf, length(records), length(models))
p = Progress(length(records))
@threads for i in eachindex(records)
record = records[i]
times, volumes = record.T_weeks, record.Lesion_normvol
comparison = compare(times, volumes, models; holdouts)
errs[i,:] = TumorGrowth.errors(comparison)
next!(p)
end
finish!(p)
Progress: 100%|█████████████████████████████████████████████████████| Time: 3:28:43
serialize(joinpath(dir, "errors.jls"), errs);
Bootstrap comparisons (neural ODE's excluded)
Because the neural ODE errors contain more NaN
values, we start with a comparison that excludes them, discarding only those observations where NaN
occurs in a non-neural model.
bad_error_rows = filter(axes(errs, 1)) do i
es = errs[i,1:end-2]
any(isnan, es) || any(isinf, es) || max(es...) > 0.1
end
proportion_bad = length(bad_error_rows)/size(errs, 1)
@show proportion_bad
0.0078003120124804995
That's less than 1%. Let's remove them:
good_error_rows = setdiff(axes(errs, 1), bad_error_rows);
errs = errs[good_error_rows,:];
Errors are evidently not normally distributed (and we were not able to transform them to approximately normal):
plt = histogram(errs[:, 1], normalize=:pdf, alpha=0.4)
histogram!(errs[:, end-2], normalize=:pdf, alpha=0.4)
savefig(joinpath(dir, "errors_distribution.png"))
"/Users/anthony/GoogleDrive/Julia/TumorGrowth/docs/src/examples/04_model_battle/errors_distribution.png"
We deem a student t-test inappopriate and instead compute bootstrap confidence intervals for pairwise differences in model errors:
confidence_intervals = Array{Any}(undef, length(models) - 2, length(models) - 2)
for i in 1:(length(models) - 2)
for j in 1:(length(models) - 2)
b = bootstrap(
mean,
errs[:,i] - errs[:,j],
BasicSampling(10000),
)
confidence_intervals[i,j] = only(confint(b, BasicConfInt(0.95)))[2:3]
end
end
confidence_intervals
6×6 Matrix{Any}:
(0.0, 0.0) (2.56833e-5, 0.000859913) (-4.31786e-5, 0.000838573) (4.00769e-5, 0.000853559) (0.000152248, 0.000921184) (-0.000465458, 0.000405489)
(-0.000857682, -3.6013e-5) (0.0, 0.0) (-0.000120715, 2.75543e-5) (-2.59767e-5, 3.76029e-5) (1.97563e-5, 0.000177436) (-0.000967472, 4.09321e-6)
(-0.000839244, 4.63642e-5) (-3.01687e-5, 0.000123503) (0.0, 0.0) (-5.48039e-5, 0.000153696) (1.75595e-5, 0.000256715) (-0.000965075, 6.73056e-5)
(-0.00084966, -4.01049e-5) (-3.78307e-5, 2.56403e-5) (-0.000152721, 5.20309e-5) (0.0, 0.0) (1.64366e-5, 0.000169683) (-0.000988484, -7.33092e-6)
(-0.000936245, -0.000162257) (-0.000179215, -1.78991e-5) (-0.000254896, -1.99733e-5) (-0.000170176, -1.25176e-5) (0.0, 0.0) (-0.00104126, -0.000136899)
(-0.000391998, 0.000464662) (-6.38049e-6, 0.000960798) (-7.79031e-5, 0.000948532) (6.22836e-6, 0.000963191) (0.000107021, 0.00104599) (0.0, 0.0)
We can interpret the confidence intervals as follows:
if both endpoints -ve, row index wins
if both endpoints +ve, column index wins
otherwise a draw
winner_pointer(ci) = ci == (0, 0) ? "n/a" :
isnan(first(ci)) && isnan(last(ci)) ? "inconclusive" :
first(ci) < 0 && last(ci) < 0 ? "←" :
first(ci) > 0 && last(ci) > 0 ? "↑" :
"draw"
tabular(A, model_exs) = NamedTuple{(:model, model_exs[2:end]...)}((
model_exs[1:end-1],
(A[1:end-1, j] for j in 2:length(model_exs))...,
))
pretty_table(
tabular(winner_pointer.(confidence_intervals), model_exs[1:6]),
show_subheader=false,
)
┌───────────────────────┬──────────┬──────────┬───────────────────────┬─────────────┬──────────────┐
│ model │ gompertz │ logistic │ classical_bertalanffy │ bertalanffy │ bertalanffy2 │
├───────────────────────┼──────────┼──────────┼───────────────────────┼─────────────┼──────────────┤
│ exponential │ ↑ │ draw │ ↑ │ ↑ │ draw │
│ gompertz │ n/a │ draw │ draw │ ↑ │ draw │
│ logistic │ draw │ n/a │ draw │ ↑ │ draw │
│ classical_bertalanffy │ draw │ draw │ n/a │ ↑ │ ← │
│ bertalanffy │ ← │ ← │ ← │ n/a │ ← │
└───────────────────────┴──────────┴──────────┴───────────────────────┴─────────────┴──────────────┘
Bootstrap comparison of errors (neural ODE's included)
bad_error_rows = filter(axes(errs, 1)) do i
es = errs[i,:]
any(isnan, es) || any(isinf, es) || max(es...) > 0.1
end
proportion_bad = length(bad_error_rows)/size(errs, 1)
@show proportion_bad
0.020440251572327043
We remove the additional 2%:
good_error_rows = setdiff(axes(errs, 1), bad_error_rows);
errs = errs[good_error_rows,:];
And proceed as before, but with all columns of errs
(all models):
confidence_intervals = Array{Any}(undef, length(models), length(models))
for i in 1:length(models)
for j in 1:length(models)
b = bootstrap(
mean,
errs[:,i] - errs[:,j],
BasicSampling(10000),
)
confidence_intervals[i, j] = only(confint(b, BasicConfInt(0.95)))[2:3]
end
end
pretty_table(
tabular(winner_pointer.(confidence_intervals), model_exs),
show_subheader=false,
tf=PrettyTables.tf_markdown, vlines=:all,
)
| model | gompertz | logistic | classical_bertalanffy | bertalanffy | bertalanffy2 | n1 | n2 |
|-----------------------|----------|----------|-----------------------|-------------|--------------|------|----|
| exponential | ↑ | draw | ↑ | ↑ | draw | draw | ← |
| gompertz | n/a | draw | draw | ↑ | draw | draw | ← |
| logistic | draw | n/a | draw | ↑ | draw | draw | ← |
| classical_bertalanffy | draw | draw | n/a | ↑ | draw | draw | ← |
| bertalanffy | ← | ← | ← | n/a | ← | draw | ← |
| bertalanffy2 | draw | draw | draw | ↑ | n/a | draw | ← |
| n1 | draw | draw | draw | draw | draw | n/a | ← |
The lack of statistical significance notwithstanding, here are the models, listed in order of decreasing performance:
zipped = collect(zip(models, vec(mean(errs, dims=1))))
sort!(zipped, by=last)
model, error = collect.(zip(zipped...))
rankings = (; model, error)
pretty_table(
rankings,
show_subheader=false,
tf=PrettyTables.tf_markdown, vlines=:all,
)
| model | error |
|-----------------------|------------|
| bertalanffy | 0.00272664 |
| classical_bertalanffy | 0.00279946 |
| gompertz | 0.0028149 |
| logistic | 0.00288491 |
| neural (12 params) | 0.0031024 |
| bertalanffy2 | 0.00318344 |
| exponential | 0.00331202 |
| neural2 (14 params) | 0.0045919 |
This page was generated using Literate.jl.