Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ authors = ["Lazaro Alonso", "Bernhard Ahrens", "Markus Reichstein"]
[deps]
AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Chain = "8be319e6-bccf-4806-a6f7-6fae938471bc"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Expand Down Expand Up @@ -44,6 +45,7 @@ EasyHybridMakie = "Makie"
[compat]
AxisKeys = "0.2"
CSV = "0.10.15"
CUDA = "5.11.0"
Chain = "0.6, 1"
ChainRulesCore = "1.25.1"
ComponentArrays = "0.15.28"
Expand Down
1 change: 1 addition & 0 deletions src/EasyHybrid.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ using Static: False, True
@reexport begin
import LuxCore
using Lux: Lux, Dense, Chain, Dropout, relu, sigmoid, swish, tanh, Recurrence, LSTMCell
using Lux: gpu_device, cpu_device
using Random
using Statistics
using DataFrames
Expand Down
6 changes: 6 additions & 0 deletions src/config/DataConfig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,10 @@ cross-validation, and sequence construction for time-series training.

"Whether to apply batch normalization to the model inputs. Default: `false`."
input_batchnorm::Bool = false

"Select a gpu_device or default to cpu if none available"
gdev = gpu_device()

"Set the `cpu_device`, useful for sending back to the cpu model parameters"
cdev = cpu_device()
end
6 changes: 6 additions & 0 deletions src/config/TrainingConfig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ loss computation, data handling, output, and visualization.
"Whether to return gradients during training. Default: True()."
return_gradients = True()

"Select a gpu_device or default to cpu if none available"
gdev = gpu_device()

"Set the `cpu_device`, useful for sending back to the cpu model parameters"
cdev = cpu_device()

"Loss type to use during training. Default: `:mse`."
training_loss::Symbol = :mse

Expand Down
5 changes: 3 additions & 2 deletions src/data/loaders.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
function build_loader(x_train, y_train, cfg::TrainConfig)
function build_loader(x_train, forcings_train, y_train, cfg::TrainConfig)
loader = DataLoader(
(x_train, y_train);
((x_train, forcings_train), y_train);
parallel = true,
batchsize = cfg.batchsize,
shuffle = true,
)
Expand Down
30 changes: 18 additions & 12 deletions src/data/prepare_data.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
export prepare_data

function prepare_data(hm, data::KeyedArray; kwargs...)
predictors_forcing, targets = get_prediction_target_names(hm)
function prepare_data(hm, data::KeyedArray; cfg=DataConfig(), kwargs...)
predictors, forcings, targets = get_prediction_target_names(hm)
# KeyedArray: use () syntax for views that are differentiable
return (data(predictors_forcing), data(targets))
dev = cfg.gdev
targets_nt = NamedTuple([target => dev(Array(data(target))) for target in targets])
forcings_nt = NamedTuple([forcing => dev(Array(data(forcing))) for forcing in forcings])
Copy link
Copy Markdown
Member

@lazarusA lazarusA Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should do dev/ Array at the batch loader level. Up to this point data could still be lazy.

return ((dev(Array(data(predictors))), forcings_nt), targets_nt)
end

function prepare_data(hm, data::AbstractDimArray; kwargs...)
Expand All @@ -13,10 +16,10 @@ function prepare_data(hm, data::AbstractDimArray; kwargs...)
end

function prepare_data(hm, data::DataFrame; array_type = :KeyedArray, drop_missing_rows = true)
predictors_forcing, targets = get_prediction_target_names(hm)
predictors, forcings, targets = get_prediction_target_names(hm)

all_predictor_cols = unique(vcat(values(predictors_forcing)...))
col_to_select = unique([all_predictor_cols; targets])
# all_predictor_cols = unique(vcat(values(predictors_forcing)...))
col_to_select = unique([predictors; forcings; targets])

# subset to only the cols we care about
sdf = data[!, col_to_select]
Expand Down Expand Up @@ -84,33 +87,36 @@ Returns a tuple of (predictors_forcing, targets) names.
function get_prediction_target_names(hm)
targets = hm.targets
predictors_forcing = Symbol[]
predictors = Symbol[]
forcings = Symbol[]
for prop in propertynames(hm)
if occursin("predictors", string(prop))
val = getproperty(hm, prop)
if isa(val, AbstractVector)
append!(predictors_forcing, val)
append!(predictors, val)
elseif isa(val, Union{NamedTuple, Tuple})
append!(predictors_forcing, unique(vcat(values(val)...)))
append!(predictors, unique(vcat(values(val)...)))
end
end
end
for prop in propertynames(hm)
if occursin("forcing", string(prop))
val = getproperty(hm, prop)
if isa(val, AbstractVector)
append!(predictors_forcing, val)
append!(forcings, val)
elseif isa(val, Union{Tuple, NamedTuple})
append!(predictors_forcing, unique(vcat(values(val)...)))
append!(forcings, unique(vcat(values(val)...)))
end
end
end
predictors_forcing = unique(predictors_forcing)
# predicto
# predictors_forcing = unique(predictors_forcing)
Comment on lines +112 to +113
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These lines contain an incomplete comment and commented-out code that should be removed to maintain code cleanliness.


if isempty(predictors_forcing)
@warn "Note that you don't have predictors or forcing variables."
end
if isempty(targets)
@warn "Note that you don't have target names."
end
return predictors_forcing, targets
return predictors, forcings, targets
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function now returns predictors and forcings separately, but the warning check at line 115 (visible in context) still references predictors_forcing. Since predictors_forcing is initialized as an empty array at line 89 and never populated in the new logic, this warning will be triggered on every call. The check should be updated to verify if both predictors and forcings are empty.

end
15 changes: 7 additions & 8 deletions src/data/split_data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,16 @@ function split_data(
)

if sequence_kwargs !== nothing
x_all, y_all = data_
(x_all, forcings_all), y_all = data_
sis_default = (; input_window = 10, output_window = 1, output_shift = 1, lead_time = 1)
sis = merge(sis_default, sequence_kwargs)
@info "Using split_into_sequences: $sis"
x_all, y_all = split_into_sequences(x_all, y_all; sis.input_window, sis.output_window, sis.output_shift, sis.lead_time)
x_all, y_all = filter_sequences(x_all, y_all)
else
x_all, y_all = data_
(x_all, forcings_all), y_all = data_
end


if split_by_id !== nothing && folds !== nothing

throw(ArgumentError("split_by_id and folds are not supported together; do the split when constructing folds"))
Expand All @@ -50,9 +49,9 @@ function split_data(
@info "Number of unique $(split_by_id): $(length(unique_ids))"
@info "Train IDs: $(length(train_ids)) | Val IDs: $(length(val_ids))"

x_train, y_train = view_end_dim(x_all, train_idx), view_end_dim(y_all, train_idx)
x_val, y_val = view_end_dim(x_all, val_idx), view_end_dim(y_all, val_idx)
return (x_train, y_train), (x_val, y_val)
x_train, forcings_train, y_train = view_end_dim(x_all, train_idx),view_end_dim(forcings_all, train_idx), view_end_dim(y_all, train_idx)
x_val, forcings_val, y_val = view_end_dim(x_all, val_idx),view_end_dim(forcings_all, val_idx), view_end_dim(y_all, val_idx)
return ((x_train, forcings_train), y_train), ((x_val,forcings_val), y_val)

elseif folds !== nothing || val_fold !== nothing
# --- Option B: external K-fold assignment ---
Expand All @@ -76,8 +75,8 @@ function split_data(

else
# --- Fallback: simple random/chronological split of prepared data ---
(x_train, y_train), (x_val, y_val) = splitobs((x_all, y_all); at = split_data_at, shuffle = shuffleobs)
return (x_train, y_train), (x_val, y_val)
(x_train, forcings_train, y_train), (x_val, forcings_val, y_val) = splitobs((x_all, forcings_all, y_all); at = split_data_at, shuffle = shuffleobs)
return ((x_train, forcings_train), y_train), ((x_val,forcings_val), y_val)
end
end

Expand Down
4 changes: 2 additions & 2 deletions src/data/splits.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
export prepare_splits, maybe_build_sequences

function prepare_splits(data, model, cfg::DataConfig)
(x_train, y_train), (x_val, y_val) = split_data(
((x_train, forcings_train), y_train), ((x_val, forcings_val), y_val) = split_data(
data, model;
array_type = cfg.array_type,
shuffleobs = cfg.shuffleobs,
Expand All @@ -16,7 +16,7 @@ function prepare_splits(data, model, cfg::DataConfig)
@debug "Train size: $(size(x_train)), Val size: $(size(x_val))"
@debug "Data type: $(typeof(x_train))"

return (x_train, y_train), (x_val, y_val)
return ((x_train, forcings_train), y_train), ((x_val, forcings_val), y_val)
end

function maybe_build_sequence_kwargs(cfg::DataConfig)
Expand Down
12 changes: 6 additions & 6 deletions src/io/checkpoints.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
function save_initial_state!(paths::TrainingPaths, model, ps, st, cfg::TrainConfig)
save_ps_st(paths.checkpoint, model, ps, st, cfg.tracked_params)
save_ps_st(paths.checkpoint, model, cfg.cdev(ps), cfg.cdev(st), cfg.tracked_params)
save_train_val_loss!(paths.checkpoint, nothing, "training_loss", 0)
save_train_val_loss!(paths.checkpoint, nothing, "validation_loss", 0)
return nothing
end

function save_epoch!(paths::TrainingPaths, model, ps, st, snapshot::EpochSnapshot, epoch::Int, cfg::TrainConfig)
save_ps_st!(paths.checkpoint, model, ps, st, cfg.tracked_params, epoch)
save_ps_st!(paths.checkpoint, model, cfg.cdev(ps), cfg.cdev(st), cfg.tracked_params, epoch)
save_train_val_loss!(paths.checkpoint, snapshot.l_train, "training_loss", epoch)
save_train_val_loss!(paths.checkpoint, snapshot.l_val, "validation_loss", epoch)
return nothing
end

function save_final!(paths::TrainingPaths, model, ps, st, x_train, y_train, x_val, y_val, stopper::EarlyStopping, cfg::TrainConfig)
function save_final!(paths::TrainingPaths, model, ps, st, x_train, forcings_train, y_train, x_val, forcings_val, y_val, stopper::EarlyStopping, cfg::TrainConfig)
target_names = model.targets
save_epoch = stopper.best_epoch == 0 ? 0 : stopper.best_epoch
save_ps_st!(paths.best_model, model, ps, st, cfg.tracked_params, save_epoch)
save_ps_st!(paths.best_model, model, cfg.cdev(ps), cfg.cdev(st), cfg.tracked_params, save_epoch)

ŷ_train, αst_train = model(x_train, ps, LuxCore.testmode(st))
ŷ_val, αst_val = model(x_val, ps, LuxCore.testmode(st))
ŷ_train, αst_train = model((cfg.cdev(x_train), cfg.cdev(forcings_train)), cfg.cdev(ps), LuxCore.testmode(cfg.cdev(st)))
ŷ_val, αst_val = model((cfg.cdev(x_val), cfg.cdev(forcings_val)), cfg.cdev(ps), LuxCore.testmode(cfg.cdev(st)))

save_predictions!(paths.checkpoint, ŷ_train, αst_train, "training")
save_predictions!(paths.checkpoint, ŷ_val, αst_val, "validation")
Expand Down
2 changes: 1 addition & 1 deletion src/io/save.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ function save_observations!(file_name, target_names, yobs, train_or_val_name)
end

function to_named_tuple(ka, target_names)
arrays = [Array(ka(variable = k)) for k in target_names]
arrays = [Array(ka[k]) for k in target_names]
return NamedTuple{Tuple(target_names)}(arrays)
end

Expand Down
13 changes: 7 additions & 6 deletions src/losses/compute_loss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ Main loss function for hybrid models that handles both training and evaluation m
- `(loss_values, st, ŷ)`: NamedTuple of losses, state and predictions
"""
function compute_loss(
HM::LuxCore.AbstractLuxContainerLayer, ps, st, (x, (y_t, y_nan));
HM::LuxCore.AbstractLuxContainerLayer, ps, st, ((x, forcings), (y_t, y_nan));
logging::LoggingLoss
)

targets = HM.targets
ext_loss = extra_loss(logging)
if logging.train_mode
ŷ, st = HM(x, ps, st)
ŷ, st = HM((x, forcings), ps, st)
loss_value = _compute_loss(ŷ, y_t, y_nan, targets, training_loss(logging), logging.agg)
# Add extra_loss if provided
if ext_loss !== nothing
Expand All @@ -34,7 +34,7 @@ function compute_loss(
end
stats = NamedTuple()
else
ŷ, _ = HM(x, ps, LuxCore.testmode(st))
ŷ, _ = HM((x, forcings), ps, LuxCore.testmode(st))
loss_value = _compute_loss(ŷ, y_t, y_nan, targets, loss_types(logging), logging.agg)
# Add extra_loss entries if provided
if ext_loss !== nothing
Expand Down Expand Up @@ -105,9 +105,10 @@ _get_target_ŷ(ŷ, y_t, target) =
function assemble_loss(ŷ, y, y_nan, targets, loss_spec)
return [
begin
y_t = _get_target_y(y, target)
ŷ_t = _get_target_ŷ(ŷ, y_t, target)
_apply_loss(ŷ_t, y_t, _get_target_nan(y_nan, target), loss_spec)
y_t = y[target]# _get_target_y(y, target)
ŷ_t = ŷ[target]#_get_target_ŷ(ŷ, y_t, target)
Comment on lines +108 to +109
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a typo in line 109: (y with combining circumflex) is used instead of the argument ŷ (U+0177) defined at line 105. While Julia normalizes identifiers to NFC, mixing these characters is confusing and can lead to issues in environments with different normalization rules. Additionally, the commented-out code should be removed.

                y_t = y[target]
                ŷ_t = ŷ[target]

_apply_loss(ŷ_t, y_t, y_nan, loss_spec)
# _apply_loss(ŷ_t, y_t, _get_target_nan(y_nan, target), loss_spec)
end
for target in targets
]
Expand Down
3 changes: 1 addition & 2 deletions src/losses/loss_fn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ function loss_fn(ŷ, y, y_nan, ::Val{:pearson})
return cor(ŷ[y_nan], y[y_nan])
end
function loss_fn(ŷ, y, y_nan, ::Val{:r2})
r = cor(ŷ[y_nan], y[y_nan])
return r * r
return 1 - sum((y[y_nan] .- ŷ[y_nan]).^2) / sum((y[y_nan] .- mean(ŷ[y_nan])).^2)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The R-squared calculation is incorrect. The denominator should use the mean of the observed values (y), not the predicted values (). The standard definition of R² is $1 - SS_{res}/SS_{tot}$, where $SS_{tot}$ is calculated relative to the mean of the observations.

    return 1 - sum((y[y_nan] .- ŷ[y_nan]).^2) / sum((y[y_nan] .- mean(y[y_nan])).^2)

end

function loss_fn(ŷ, y, y_nan, ::Val{:pearsonLoss})
Expand Down
7 changes: 4 additions & 3 deletions src/models/GenericHybridModel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -360,9 +360,9 @@ end

# ───────────────────────────────────────────────────────────────────────────
# Forward pass for SingleNNHybridModel (optimized, no branching)
function (m::SingleNNHybridModel)(ds_k::Union{KeyedArray, AbstractDimArray}, ps, st)
function (m::SingleNNHybridModel)(ds_k, ps, st)
# 1) get features
predictors = toArray(ds_k, m.predictors)
predictors = ds_k[1]#toArray(ds_k, m.predictors)

parameters = m.parameters

Expand Down Expand Up @@ -407,11 +407,12 @@ function (m::SingleNNHybridModel)(ds_k::Union{KeyedArray, AbstractDimArray}, ps,
end

# 5) unpack forcing data
forcing_data = toNamedTuple(ds_k, m.forcing)
forcing_data = ds_k[2]#toNamedTuple(ds_k, m.forcing)

# 6) merge all parameters
all_params = merge(scaled_nn_params, global_params, fixed_params)
all_kwargs = merge(forcing_data, all_params)
# all_kwargs = merge(forcing_data, all_params)

# 7) physics
y_pred = m.mechanistic_model(; all_kwargs...)
Expand Down
22 changes: 11 additions & 11 deletions src/training/early_stopping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@ mutable struct EarlyStopping
done::Bool
end

function EarlyStopping(init_loss, ps, st, patience::Int)
function EarlyStopping(init_loss, ps, st, cfg)
best_loss = extract_agg_loss(init_loss)
return EarlyStopping(best_loss, deepcopy(ps), deepcopy(st), 0, 0, patience, false)
return EarlyStopping(best_loss, deepcopy(cfg.cdev(ps)), deepcopy(cfg.cdev(st)), 0, 0, cfg.patience, false)
end

function update!(es::EarlyStopping, snapshot::EpochSnapshot, ps, st, epoch, cfg::TrainConfig)
current_loss = extract_agg_loss(snapshot.l_val)

if isbetter(current_loss, es.best_loss, first(cfg.loss_types))
es.best_loss = current_loss
es.best_ps = deepcopy(ps)
es.best_st = deepcopy(st)
es.best_ps = deepcopy(cfg.cdev(ps))
es.best_st = deepcopy(cfg.cdev(st))
es.best_epoch = epoch
es.counter = 0
else
Expand Down Expand Up @@ -62,16 +62,16 @@ function best_or_final(stopper::EarlyStopping, ps, st, cfg::TrainConfig)
end
end

function build_results(model, history::TrainingHistory, stopper::EarlyStopping, ps, st, x_train, y_train, x_val, y_val)
function build_results(model, history::TrainingHistory, stopper::EarlyStopping, ps, st, x_train, forcings_train, y_train, x_val, forcings_val, y_val, cfg::TrainConfig)
target_names = model.targets

# final predictions in test mode
ŷ_train, _ = model(x_train, ps, LuxCore.testmode(st))
ŷ_val, _ = model(x_val, ps, LuxCore.testmode(st))
ŷ_train, _ = model((cfg.cdev(x_train), cfg.cdev(forcings_train)), cfg.cdev(ps), LuxCore.testmode(st))
ŷ_val, _ = model((cfg.cdev(x_val), cfg.cdev(forcings_val)), cfg.cdev(ps), LuxCore.testmode(st))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can evaluate this still on the GPU side and just pipe the result of into the cfg.dev function.


# observed vs predicted DataFrames
train_obs_pred = hcat(toDataFrame(y_train), toDataFrame(ŷ_train, target_names))
val_obs_pred = hcat(toDataFrame(y_val), toDataFrame(ŷ_val, target_names))
train_obs_pred = hcat(DataFrame(y_train), toDataFrame(ŷ_train, target_names))
val_obs_pred = hcat(DataFrame(y_val), toDataFrame(ŷ_val, target_names))

# extra predictions without observational counterparts
train_diffs, val_diffs = extract_diffs(ŷ_train, ŷ_val, target_names)
Expand All @@ -84,8 +84,8 @@ function build_results(model, history::TrainingHistory, stopper::EarlyStopping,
val_obs_pred,
train_diffs,
val_diffs,
ps,
st,
cfg.cdev(ps),
cfg.cdev(st),
stopper.best_epoch,
stopper.best_loss,
)
Expand Down
Loading