Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/EasyHybrid.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ using Static: False, True
@reexport begin
import LuxCore
using Lux: Lux, Dense, Chain, Dropout, relu, sigmoid, swish, tanh, Recurrence, LSTMCell
using Lux: gpu_device, cpu_device
using Random
using Statistics
using DataFrames
Expand Down
6 changes: 6 additions & 0 deletions src/config/DataConfig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,10 @@ cross-validation, and sequence construction for time-series training.

"Whether to apply batch normalization to the model inputs. Default: `false`."
input_batchnorm::Bool = false

"Select a gpu_device or default to cpu if none available"
gdev = gpu_device()

"Set the `cpu_device`, useful for sending back to the cpu model parameters"
cdev = cpu_device()
Comment on lines +61 to +64
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The fields gdev and cdev should have explicit type annotations (e.g., Lux.AbstractLuxDevice) to improve code clarity and potentially help with compiler optimizations. Additionally, ensure that gpu_device() is the intended default for all instances of DataConfig, as it may trigger device initialization.

    "Select a gpu_device or default to cpu if none available"
    gdev::Lux.AbstractLuxDevice = gpu_device()

    "Set the `cpu_device`, useful for sending back to the cpu model parameters"
    cdev::Lux.AbstractLuxDevice = cpu_device()

end
6 changes: 6 additions & 0 deletions src/config/TrainingConfig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ loss computation, data handling, output, and visualization.
"Whether to return gradients during training. Default: True()."
return_gradients = True()

"Select a gpu_device or default to cpu if none available"
gdev = gpu_device()

"Set the `cpu_device`, useful for sending back to the cpu model parameters"
cdev = cpu_device()

"Loss type to use during training. Default: `:mse`."
training_loss::Symbol = :mse

Expand Down
5 changes: 3 additions & 2 deletions src/data/loaders.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
function build_loader(x_train, y_train, cfg::TrainConfig)
function build_loader(x_train, forcings_train, y_train, cfg::TrainConfig)
loader = DataLoader(
(x_train, y_train);
((x_train, forcings_train), y_train);
parallel = true,
batchsize = cfg.batchsize,
shuffle = true,
)
Expand Down
30 changes: 18 additions & 12 deletions src/data/prepare_data.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
export prepare_data

function prepare_data(hm, data::KeyedArray; kwargs...)
predictors_forcing, targets = get_prediction_target_names(hm)
function prepare_data(hm, data::KeyedArray; cfg=DataConfig(), kwargs...)
predictors, forcings, targets = get_prediction_target_names(hm)
# KeyedArray: use () syntax for views that are differentiable
return (data(predictors_forcing), data(targets))
dev = cfg.gdev
targets_nt = NamedTuple([target => dev(Array(data(target))) for target in targets])
forcings_nt = NamedTuple([forcing => dev(Array(data(forcing))) for forcing in forcings])
return ((dev(Array(data(predictors))), forcings_nt), targets_nt)
end

function prepare_data(hm, data::AbstractDimArray; kwargs...)
Expand All @@ -13,10 +16,10 @@ function prepare_data(hm, data::AbstractDimArray; kwargs...)
end

function prepare_data(hm, data::DataFrame; array_type = :KeyedArray, drop_missing_rows = true)
predictors_forcing, targets = get_prediction_target_names(hm)
predictors, forcings, targets = get_prediction_target_names(hm)

all_predictor_cols = unique(vcat(values(predictors_forcing)...))
col_to_select = unique([all_predictor_cols; targets])
# all_predictor_cols = unique(vcat(values(predictors_forcing)...))
col_to_select = unique([predictors; forcings; targets])

# subset to only the cols we care about
sdf = data[!, col_to_select]
Expand Down Expand Up @@ -84,33 +87,36 @@ Returns a tuple of (predictors_forcing, targets) names.
function get_prediction_target_names(hm)
targets = hm.targets
predictors_forcing = Symbol[]
predictors = Symbol[]
forcings = Symbol[]
for prop in propertynames(hm)
if occursin("predictors", string(prop))
val = getproperty(hm, prop)
if isa(val, AbstractVector)
append!(predictors_forcing, val)
append!(predictors, val)
elseif isa(val, Union{NamedTuple, Tuple})
append!(predictors_forcing, unique(vcat(values(val)...)))
append!(predictors, unique(vcat(values(val)...)))
end
end
end
for prop in propertynames(hm)
if occursin("forcing", string(prop))
val = getproperty(hm, prop)
if isa(val, AbstractVector)
append!(predictors_forcing, val)
append!(forcings, val)
elseif isa(val, Union{Tuple, NamedTuple})
append!(predictors_forcing, unique(vcat(values(val)...)))
append!(forcings, unique(vcat(values(val)...)))
end
end
end
predictors_forcing = unique(predictors_forcing)
# predicto
# predictors_forcing = unique(predictors_forcing)
Comment on lines +112 to +113
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Remove leftover comments and incomplete words.


if isempty(predictors_forcing)
@warn "Note that you don't have predictors or forcing variables."
end
Comment on lines 115 to 117
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The warning check uses predictors_forcing, which is an empty Symbol[] initialized at line 89 and never updated. The logic should check if both predictors and forcings are empty instead.

    if isempty(predictors) && isempty(forcings)
        @warn "Note that you don't have predictors or forcing variables."
    end

if isempty(targets)
@warn "Note that you don't have target names."
end
return predictors_forcing, targets
return predictors, forcings, targets
end
30 changes: 22 additions & 8 deletions src/data/split_data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,16 @@ function split_data(
)

if sequence_kwargs !== nothing
x_all, y_all = data_
(x_all, forcings_all), y_all = data_
sis_default = (; input_window = 10, output_window = 1, output_shift = 1, lead_time = 1)
sis = merge(sis_default, sequence_kwargs)
@info "Using split_into_sequences: $sis"
x_all, y_all = split_into_sequences(x_all, y_all; sis.input_window, sis.output_window, sis.output_shift, sis.lead_time)
x_all, y_all = filter_sequences(x_all, y_all)
else
x_all, y_all = data_
(x_all, forcings_all), y_all = data_
end


if split_by_id !== nothing && folds !== nothing

throw(ArgumentError("split_by_id and folds are not supported together; do the split when constructing folds"))
Expand All @@ -50,9 +49,9 @@ function split_data(
@info "Number of unique $(split_by_id): $(length(unique_ids))"
@info "Train IDs: $(length(train_ids)) | Val IDs: $(length(val_ids))"

x_train, y_train = view_end_dim(x_all, train_idx), view_end_dim(y_all, train_idx)
x_val, y_val = view_end_dim(x_all, val_idx), view_end_dim(y_all, val_idx)
return (x_train, y_train), (x_val, y_val)
x_train, forcings_train, y_train = view_end_dim(x_all, train_idx),view_end_dim(forcings_all, train_idx), view_end_dim(y_all, train_idx)
x_val, forcings_val, y_val = view_end_dim(x_all, val_idx),view_end_dim(forcings_all, val_idx), view_end_dim(y_all, val_idx)
return ((x_train, forcings_train), y_train), ((x_val,forcings_val), y_val)

elseif folds !== nothing || val_fold !== nothing
# --- Option B: external K-fold assignment ---
Expand All @@ -76,8 +75,8 @@ function split_data(

else
# --- Fallback: simple random/chronological split of prepared data ---
(x_train, y_train), (x_val, y_val) = splitobs((x_all, y_all); at = split_data_at, shuffle = shuffleobs)
return (x_train, y_train), (x_val, y_val)
(x_train, forcings_train, y_train), (x_val, forcings_val, y_val) = splitobs((x_all, forcings_all, y_all); at = split_data_at, shuffle = shuffleobs)
return ((x_train, forcings_train), y_train), ((x_val,forcings_val), y_val)
end
end

Expand Down Expand Up @@ -118,6 +117,21 @@ end
function getbyname(ka::Union{KeyedArray, AbstractDimArray}, name::Symbol)
return @view ka[variable = At(name)]
end
function view_end_dim(x_all::AbstractArray{Float32}, idx)
return view(x_all, ntuple(_ -> :, ndims(x_all) - 1)..., idx)
end

# function view_end_dim(x_all::AbstractMatrix{Float32}, idx)
# return view(x_all, :, idx)
# end

# function view_end_dim(x_all::AbstractArray{Float32, 3}, idx)
# return view(x_all, :, :, idx)
# end

function view_end_dim(x_all::NamedTuple, idx)
return map(x -> view_end_dim(x, idx), x_all)
end

function view_end_dim(x_all::Union{KeyedArray{Float32, 2}, AbstractDimArray{Float32, 2}}, idx)
return view(x_all, :, idx)
Expand Down
4 changes: 2 additions & 2 deletions src/data/splits.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
export prepare_splits, maybe_build_sequences

function prepare_splits(data, model, cfg::DataConfig)
(x_train, y_train), (x_val, y_val) = split_data(
((x_train, forcings_train), y_train), ((x_val, forcings_val), y_val) = split_data(
data, model;
array_type = cfg.array_type,
shuffleobs = cfg.shuffleobs,
Expand All @@ -16,7 +16,7 @@ function prepare_splits(data, model, cfg::DataConfig)
@debug "Train size: $(size(x_train)), Val size: $(size(x_val))"
@debug "Data type: $(typeof(x_train))"

return (x_train, y_train), (x_val, y_val)
return ((x_train, forcings_train), y_train), ((x_val, forcings_val), y_val)
end

function maybe_build_sequence_kwargs(cfg::DataConfig)
Expand Down
12 changes: 6 additions & 6 deletions src/io/checkpoints.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
function save_initial_state!(paths::TrainingPaths, model, ps, st, cfg::TrainConfig)
save_ps_st(paths.checkpoint, model, ps, st, cfg.tracked_params)
save_ps_st(paths.checkpoint, model, cfg.cdev(ps), cfg.cdev(st), cfg.tracked_params)
save_train_val_loss!(paths.checkpoint, nothing, "training_loss", 0)
save_train_val_loss!(paths.checkpoint, nothing, "validation_loss", 0)
return nothing
end

function save_epoch!(paths::TrainingPaths, model, ps, st, snapshot::EpochSnapshot, epoch::Int, cfg::TrainConfig)
save_ps_st!(paths.checkpoint, model, ps, st, cfg.tracked_params, epoch)
save_ps_st!(paths.checkpoint, model, cfg.cdev(ps), cfg.cdev(st), cfg.tracked_params, epoch)
save_train_val_loss!(paths.checkpoint, snapshot.l_train, "training_loss", epoch)
save_train_val_loss!(paths.checkpoint, snapshot.l_val, "validation_loss", epoch)
return nothing
end

function save_final!(paths::TrainingPaths, model, ps, st, x_train, y_train, x_val, y_val, stopper::EarlyStopping, cfg::TrainConfig)
function save_final!(paths::TrainingPaths, model, ps, st, x_train, forcings_train, y_train, x_val, forcings_val, y_val, stopper::EarlyStopping, cfg::TrainConfig)
target_names = model.targets
save_epoch = stopper.best_epoch == 0 ? 0 : stopper.best_epoch
save_ps_st!(paths.best_model, model, ps, st, cfg.tracked_params, save_epoch)
save_ps_st!(paths.best_model, model, cfg.cdev(ps), cfg.cdev(st), cfg.tracked_params, save_epoch)

ŷ_train, αst_train = model(x_train, ps, LuxCore.testmode(st))
ŷ_val, αst_val = model(x_val, ps, LuxCore.testmode(st))
ŷ_train, αst_train = model((cfg.cdev(x_train), cfg.cdev(forcings_train)), cfg.cdev(ps), LuxCore.testmode(cfg.cdev(st)))
ŷ_val, αst_val = model((cfg.cdev(x_val), cfg.cdev(forcings_val)), cfg.cdev(ps), LuxCore.testmode(cfg.cdev(st)))

save_predictions!(paths.checkpoint, ŷ_train, αst_train, "training")
save_predictions!(paths.checkpoint, ŷ_val, αst_val, "validation")
Expand Down
2 changes: 1 addition & 1 deletion src/io/save.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ function save_observations!(file_name, target_names, yobs, train_or_val_name)
end

function to_named_tuple(ka, target_names)
arrays = [Array(ka(variable = k)) for k in target_names]
arrays = [Array(ka[k]) for k in target_names]
return NamedTuple{Tuple(target_names)}(arrays)
end

Expand Down
13 changes: 7 additions & 6 deletions src/losses/compute_loss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ Main loss function for hybrid models that handles both training and evaluation m
- `(loss_values, st, ŷ)`: NamedTuple of losses, state and predictions
"""
function compute_loss(
HM::LuxCore.AbstractLuxContainerLayer, ps, st, (x, (y_t, y_nan));
HM::LuxCore.AbstractLuxContainerLayer, ps, st, ((x, forcings), (y_t, y_nan));
logging::LoggingLoss
)

targets = HM.targets
ext_loss = extra_loss(logging)
if logging.train_mode
ŷ, st = HM(x, ps, st)
ŷ, st = HM((x, forcings), ps, st)
loss_value = _compute_loss(ŷ, y_t, y_nan, targets, training_loss(logging), logging.agg)
# Add extra_loss if provided
if ext_loss !== nothing
Expand All @@ -34,7 +34,7 @@ function compute_loss(
end
stats = NamedTuple()
else
ŷ, _ = HM(x, ps, LuxCore.testmode(st))
ŷ, _ = HM((x, forcings), ps, LuxCore.testmode(st))
loss_value = _compute_loss(ŷ, y_t, y_nan, targets, loss_types(logging), logging.agg)
# Add extra_loss entries if provided
if ext_loss !== nothing
Expand Down Expand Up @@ -105,9 +105,10 @@ _get_target_ŷ(ŷ, y_t, target) =
function assemble_loss(ŷ, y, y_nan, targets, loss_spec)
return [
begin
y_t = _get_target_y(y, target)
ŷ_t = _get_target_ŷ(ŷ, y_t, target)
_apply_loss(ŷ_t, y_t, _get_target_nan(y_nan, target), loss_spec)
y_t = y[target]# _get_target_y(y, target)
ŷ_t = ŷ[target]#_get_target_ŷ(ŷ, y_t, target)
_apply_loss(ŷ_t, y_t, y_nan, loss_spec)
# _apply_loss(ŷ_t, y_t, _get_target_nan(y_nan, target), loss_spec)
Comment on lines +108 to +111
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a typo on line 109: (y + U+0302) is used instead of the argument ŷ (U+0177), which will cause an UndefVarError. Additionally, y_nan should be target-specific (e.g., y_nan[target]) to ensure that NaNs in one target do not invalidate observations for other targets during loss calculation.

                y_t = y[target]
                ŷ_t = ŷ[target]
                _apply_loss(ŷ_t, y_t, y_nan[target], loss_spec)

end
for target in targets
]
Expand Down
3 changes: 1 addition & 2 deletions src/losses/loss_fn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ function loss_fn(ŷ, y, y_nan, ::Val{:pearson})
return cor(ŷ[y_nan], y[y_nan])
end
function loss_fn(ŷ, y, y_nan, ::Val{:r2})
r = cor(ŷ[y_nan], y[y_nan])
return r * r
return 1 - sum((y[y_nan] .- ŷ[y_nan]).^2) / sum((y[y_nan] .- mean(ŷ[y_nan])).^2)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The $R^2$ formula is statistically incorrect. The denominator (Total Sum of Squares) should be calculated using the mean of the observed values y, not the predicted values .

    return 1 - sum((y[y_nan] .- ŷ[y_nan]).^2) / sum((y[y_nan] .- mean(y[y_nan])).^2)

end

function loss_fn(ŷ, y, y_nan, ::Val{:pearsonLoss})
Expand Down
7 changes: 4 additions & 3 deletions src/models/GenericHybridModel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -360,9 +360,9 @@ end

# ───────────────────────────────────────────────────────────────────────────
# Forward pass for SingleNNHybridModel (optimized, no branching)
function (m::SingleNNHybridModel)(ds_k::Union{KeyedArray, AbstractDimArray}, ps, st)
function (m::SingleNNHybridModel)(ds_k, ps, st)
# 1) get features
predictors = toArray(ds_k, m.predictors)
predictors = ds_k[1]#toArray(ds_k, m.predictors)

parameters = m.parameters

Expand Down Expand Up @@ -407,11 +407,12 @@ function (m::SingleNNHybridModel)(ds_k::Union{KeyedArray, AbstractDimArray}, ps,
end

# 5) unpack forcing data
forcing_data = toNamedTuple(ds_k, m.forcing)
forcing_data = ds_k[2]#toNamedTuple(ds_k, m.forcing)

# 6) merge all parameters
all_params = merge(scaled_nn_params, global_params, fixed_params)
all_kwargs = merge(forcing_data, all_params)
# all_kwargs = merge(forcing_data, all_params)

# 7) physics
y_pred = m.mechanistic_model(; all_kwargs...)
Expand Down
22 changes: 11 additions & 11 deletions src/training/early_stopping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ mutable struct EarlyStopping
done::Bool
end

function EarlyStopping(init_loss, ps, st, patience::Int)
function EarlyStopping(init_loss, ps, st, cfg)
best_loss = extract_agg_loss(init_loss)
return EarlyStopping(best_loss, deepcopy(ps), deepcopy(st), 0, 0, patience, false)
return EarlyStopping(best_loss, deepcopy(cfg.cdev(ps)), deepcopy(cfg.cdev(st)), 0, 0, cfg.patience, false)
end

function update!(es::EarlyStopping, history::TrainingHistory, snapshot::EpochSnapshot, ps, st, epoch, cfg::TrainConfig)
Expand All @@ -23,8 +23,8 @@ function update!(es::EarlyStopping, history::TrainingHistory, snapshot::EpochSna

if isbetter(current_loss, es.best_loss, first(cfg.loss_types))
es.best_loss = current_loss
es.best_ps = deepcopy(ps)
es.best_st = deepcopy(st)
es.best_ps = deepcopy(cfg.cdev(ps))
es.best_st = deepcopy(cfg.cdev(st))
es.best_epoch = epoch
es.counter = 0
if !cfg.keep_history
Expand Down Expand Up @@ -70,16 +70,16 @@ function best_or_final(stopper::EarlyStopping, ps, st, cfg::TrainConfig)
end
end

function build_results(model, history::TrainingHistory, stopper::EarlyStopping, ps, st, x_train, y_train, x_val, y_val)
function build_results(model, history::TrainingHistory, stopper::EarlyStopping, ps, st, x_train, forcings_train, y_train, x_val, forcings_val, y_val, cfg::TrainConfig)
target_names = model.targets

# final predictions in test mode
ŷ_train, _ = model(x_train, ps, LuxCore.testmode(st))
ŷ_val, _ = model(x_val, ps, LuxCore.testmode(st))
ŷ_train, _ = model((cfg.cdev(x_train), cfg.cdev(forcings_train)), cfg.cdev(ps), LuxCore.testmode(st))
ŷ_val, _ = model((cfg.cdev(x_val), cfg.cdev(forcings_val)), cfg.cdev(ps), LuxCore.testmode(st))

# observed vs predicted DataFrames
train_obs_pred = hcat(toDataFrame(y_train), toDataFrame(ŷ_train, target_names))
val_obs_pred = hcat(toDataFrame(y_val), toDataFrame(ŷ_val, target_names))
train_obs_pred = hcat(DataFrame(y_train), toDataFrame(ŷ_train, target_names))
val_obs_pred = hcat(DataFrame(y_val), toDataFrame(ŷ_val, target_names))

# extra predictions without observational counterparts
train_diffs, val_diffs = extract_diffs(ŷ_train, ŷ_val, target_names)
Expand All @@ -92,8 +92,8 @@ function build_results(model, history::TrainingHistory, stopper::EarlyStopping,
val_obs_pred,
train_diffs,
val_diffs,
ps,
st,
cfg.cdev(ps),
cfg.cdev(st),
stopper.best_epoch,
stopper.best_loss,
)
Expand Down
Loading
Loading