Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/EasyHybrid.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ using Static: False, True
@reexport begin
import LuxCore
using Lux: Lux, Dense, Chain, Dropout, relu, sigmoid, swish, tanh, Recurrence, LSTMCell
using Lux: gpu_device, cpu_device
using Random
using Statistics
using DataFrames
Expand Down
6 changes: 6 additions & 0 deletions src/config/TrainingConfig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ loss computation, data handling, output, and visualization.
"Whether to return gradients during training. Default: True()."
return_gradients = True()

"Select a gpu_device or default to cpu if none available"
gdev = gpu_device()

"Set the `cpu_device`, useful for sending back to the cpu model parameters"
cdev = cpu_device()
Comment on lines +31 to +35
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The fields gdev and cdev are untyped in the TrainConfig struct. In Julia, untyped fields lead to type instability, which can significantly degrade performance because the compiler cannot specialize functions using these fields. Since these are used frequently for device transfers during training, it is highly recommended to provide type annotations, such as Lux.AbstractDevice or using type parameters.


"Loss type to use during training. Default: `:mse`."
training_loss::Symbol = :mse

Expand Down
5 changes: 3 additions & 2 deletions src/data/loaders.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
function build_loader(x_train, y_train, cfg::TrainConfig)
function build_loader(x_train, forcings_train, y_train, mask, cfg::TrainConfig)
loader = DataLoader(
(x_train, y_train);
((x_train, forcings_train), (y_train, mask));
parallel = true,
batchsize = cfg.batchsize,
shuffle = true,
)
Expand Down
46 changes: 29 additions & 17 deletions src/data/prepare_data.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,29 @@
export prepare_data

function prepare_data(hm, data::KeyedArray; kwargs...)
predictors_forcing, targets = get_prediction_target_names(hm)
function prepare_data(hm, data::KeyedArray; cfg=DataConfig(), kwargs...)
predictors, forcings, targets = get_prediction_target_names(hm)
# KeyedArray: use () syntax for views that are differentiable
return (data(predictors_forcing), data(targets))
X_arr = Array(data(predictors))
forcings_nt = NamedTuple([forcing => Array(data(forcing)) for forcing in forcings])
targets_nt = NamedTuple([target => Array(data(target)) for target in targets])
return ((X_arr, forcings_nt), targets_nt)
end

function prepare_data(hm, data::AbstractDimArray; kwargs...)
predictors_forcing, targets = get_prediction_target_names(hm)
predictors, forcings, targets = get_prediction_target_names(hm)
# KeyedArray: use () syntax for views that are differentiable
X_arr = data[variable = At(predictors)]
forcings_nt = NamedTuple([forcing => data[variable = At(forcing)] for forcing in forcings])
targets_nt = NamedTuple([target => data[variable = At(target)] for target in targets])
# DimArray: use [] syntax (copies, but differentiable)
return (data[variable = At(predictors_forcing)], data[variable = At(targets)])
return ((X_arr, forcings_nt), targets_nt)
end

function prepare_data(hm, data::DataFrame; array_type = :KeyedArray, drop_missing_rows = true)
predictors_forcing, targets = get_prediction_target_names(hm)
predictors, forcings, targets = get_prediction_target_names(hm)

all_predictor_cols = unique(vcat(values(predictors_forcing)...))
col_to_select = unique([all_predictor_cols; targets])
# all_predictor_cols = unique(vcat(values(predictors_forcing)...))
col_to_select = unique([predictors; forcings; targets])

# subset to only the cols we care about
sdf = data[!, col_to_select]
Expand Down Expand Up @@ -83,34 +90,39 @@ Returns a tuple of (predictors_forcing, targets) names.
"""
function get_prediction_target_names(hm)
targets = hm.targets
predictors_forcing = Symbol[]
predictors = Symbol[]
forcings = Symbol[]
for prop in propertynames(hm)
if occursin("predictors", string(prop))
val = getproperty(hm, prop)
if isa(val, AbstractVector)
append!(predictors_forcing, val)
append!(predictors, val)
elseif isa(val, Union{NamedTuple, Tuple})
append!(predictors_forcing, unique(vcat(values(val)...)))
append!(predictors, unique(vcat(values(val)...)))
end
end
end
for prop in propertynames(hm)
if occursin("forcing", string(prop))
val = getproperty(hm, prop)
if isa(val, AbstractVector)
append!(predictors_forcing, val)
append!(forcings, val)
elseif isa(val, Union{Tuple, NamedTuple})
append!(predictors_forcing, unique(vcat(values(val)...)))
append!(forcings, unique(vcat(values(val)...)))
end
end
end
predictors_forcing = unique(predictors_forcing)
# predicto
# predictors_forcing = unique(predictors_forcing)

if isempty(predictors_forcing)
@warn "Note that you don't have predictors or forcing variables."
if isempty(predictors)
@warn "Note that you don't have predictors variables."
end
if isempty(forcings)
@warn "Note that you don't have forcing variables."
end
if isempty(targets)
@warn "Note that you don't have target names."
end
return predictors_forcing, targets
return predictors, forcings, targets
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function now returns predictors and forcings separately, but the warning check at line 115 (visible in context) still references predictors_forcing. Since predictors_forcing is initialized as an empty array at line 89 and never populated in the new logic, this warning will be triggered on every call. The check should be updated to verify if both predictors and forcings are empty.

end
70 changes: 57 additions & 13 deletions src/data/split_data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ function split_data(
split_data_at::Real = 0.8,
sequence_kwargs::Union{Nothing, NamedTuple} = nothing,
array_type::Symbol = :KeyedArray,
cfg = DataConfig(),
kwargs...
)
data_ = prepare_data(
Expand All @@ -23,17 +24,16 @@ function split_data(
)

if sequence_kwargs !== nothing
x_all, y_all = data_
(x_all, forcings_all), y_all = data_
sis_default = (; input_window = 10, output_window = 1, output_shift = 1, lead_time = 1)
sis = merge(sis_default, sequence_kwargs)
@info "Using split_into_sequences: $sis"
x_all, y_all = split_into_sequences(x_all, y_all; sis.input_window, sis.output_window, sis.output_shift, sis.lead_time)
x_all, y_all = filter_sequences(x_all, y_all)
else
x_all, y_all = data_
(x_all, forcings_all), y_all = data_
end


if split_by_id !== nothing && folds !== nothing

throw(ArgumentError("split_by_id and folds are not supported together; do the split when constructing folds"))
Expand All @@ -50,9 +50,9 @@ function split_data(
@info "Number of unique $(split_by_id): $(length(unique_ids))"
@info "Train IDs: $(length(train_ids)) | Val IDs: $(length(val_ids))"

x_train, y_train = view_end_dim(x_all, train_idx), view_end_dim(y_all, train_idx)
x_val, y_val = view_end_dim(x_all, val_idx), view_end_dim(y_all, val_idx)
return (x_train, y_train), (x_val, y_val)
x_train, forcings_train, y_train = collect_end_dim(x_all, train_idx),collect_end_dim(forcings_all, train_idx), collect_end_dim(y_all, train_idx)
x_val, forcings_val, y_val = collect_end_dim(x_all, val_idx),collect_end_dim(forcings_all, val_idx), collect_end_dim(y_all, val_idx)
return ((x_train, forcings_train), y_train), ((x_val,forcings_val), y_val)

elseif folds !== nothing || val_fold !== nothing
# --- Option B: external K-fold assignment ---
Expand All @@ -70,14 +70,14 @@ function split_data(

@info "K-fold via external assignments: val_fold=$val_fold → train=$(length(train_idx)) val=$(length(val_idx))"

x_train, y_train = view_end_dim(x_all, train_idx), view_end_dim(y_all, train_idx)
x_val, y_val = view_end_dim(x_all, val_idx), view_end_dim(y_all, val_idx)
return (x_train, y_train), (x_val, y_val)
x_train, forcings_train, y_train = collect_end_dim(x_all, train_idx),collect_end_dim(forcings_all, train_idx), collect_end_dim(y_all, train_idx)
x_val, forcings_val, y_val = collect_end_dim(x_all, val_idx),collect_end_dim(forcings_all, val_idx), collect_end_dim(y_all, val_idx)
return ((x_train, forcings_train), y_train), ((x_val,forcings_val), y_val)

else
# --- Fallback: simple random/chronological split of prepared data ---
(x_train, y_train), (x_val, y_val) = splitobs((x_all, y_all); at = split_data_at, shuffle = shuffleobs)
return (x_train, y_train), (x_val, y_val)
(x_train, forcings_train, y_train), (x_val, forcings_val, y_val) = splitobs((x_all, forcings_all, y_all); at = split_data_at, shuffle = shuffleobs)
return ((x_train, forcings_train), y_train), ((x_val,forcings_val), y_val)
end
end

Expand Down Expand Up @@ -115,8 +115,28 @@ function getbyname(df::DataFrame, name::Symbol)
return df[!, name]
end

function getbyname(ka::Union{KeyedArray, AbstractDimArray}, name::Symbol)
return @view ka[variable = At(name)]
function getbyname(ka::KeyedArray, name::Symbol)
return ka(variable = name)
end

function getbyname(ka::AbstractDimArray, name::Symbol)
return ka[variable = At(name)]
end

function view_end_dim(x_all::AbstractMatrix{T}, idx) where {T}
return view(x_all, :, idx)
end

function view_end_dim(x_all::AbstractVector{T}, idx) where {T}
return view(x_all, idx)
end

function view_end_dim(x_all::NamedTuple, idx)
nt = (;)
for (k,v) in pairs(x_all)
nt = merge(nt, NamedTuple([k => view_end_dim(v, idx)]))
end
return nt
end

function view_end_dim(x_all::Union{KeyedArray{Float32, 2}, AbstractDimArray{Float32, 2}}, idx)
Expand All @@ -126,3 +146,27 @@ end
function view_end_dim(x_all::Union{KeyedArray{Float32, 3}, AbstractDimArray{Float32, 3}}, idx)
return view(x_all, :, :, idx)
end

function collect_end_dim(x_all::AbstractMatrix{T}, idx) where {T}
return collect(getindex(x_all, :, idx))
end

function collect_end_dim(x_all::AbstractVector{T}, idx) where {T}
return collect(getindex(x_all, idx))
end

function collect_end_dim(x_all::NamedTuple, idx)
nt = (;)
for (k,v) in pairs(x_all)
nt = merge(nt, NamedTuple([k => collect_end_dim(v, idx)]))
end
return nt
end

function collect_end_dim(x_all::Union{KeyedArray{Float32, 2}, AbstractDimArray{Float32, 2}}, idx)
return collect(getindex(x_all, :, idx))
end

function collect_end_dim(x_all::Union{KeyedArray{Float32, 3}, AbstractDimArray{Float32, 3}}, idx)
return collect(getindex(x_all, :, :, idx))
end
4 changes: 2 additions & 2 deletions src/data/splits.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
export prepare_splits, maybe_build_sequences

function prepare_splits(data, model, cfg::DataConfig)
(x_train, y_train), (x_val, y_val) = split_data(
((x_train, forcings_train), y_train), ((x_val, forcings_val), y_val) = split_data(
data, model;
array_type = cfg.array_type,
shuffleobs = cfg.shuffleobs,
Expand All @@ -16,7 +16,7 @@ function prepare_splits(data, model, cfg::DataConfig)
@debug "Train size: $(size(x_train)), Val size: $(size(x_val))"
@debug "Data type: $(typeof(x_train))"

return (x_train, y_train), (x_val, y_val)
return ((x_train, forcings_train), y_train), ((x_val, forcings_val), y_val)
end

function maybe_build_sequence_kwargs(cfg::DataConfig)
Expand Down
12 changes: 6 additions & 6 deletions src/io/checkpoints.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
function save_initial_state!(paths::TrainingPaths, model, ps, st, cfg::TrainConfig)
save_ps_st(paths.checkpoint, model, ps, st, cfg.tracked_params)
save_ps_st(paths.checkpoint, model, cfg.cdev(ps), cfg.cdev(st), cfg.tracked_params)
save_train_val_loss!(paths.checkpoint, nothing, "training_loss", 0)
save_train_val_loss!(paths.checkpoint, nothing, "validation_loss", 0)
return nothing
end

function save_epoch!(paths::TrainingPaths, model, ps, st, snapshot::EpochSnapshot, epoch::Int, cfg::TrainConfig)
save_ps_st!(paths.checkpoint, model, ps, st, cfg.tracked_params, epoch)
save_ps_st!(paths.checkpoint, model, cfg.cdev(ps), cfg.cdev(st), cfg.tracked_params, epoch)
save_train_val_loss!(paths.checkpoint, snapshot.l_train, "training_loss", epoch)
save_train_val_loss!(paths.checkpoint, snapshot.l_val, "validation_loss", epoch)
return nothing
end

function save_final!(paths::TrainingPaths, model, ps, st, x_train, y_train, x_val, y_val, stopper::EarlyStopping, cfg::TrainConfig)
function save_final!(paths::TrainingPaths, model, ps, st, x_train, forcings_train, y_train, x_val, forcings_val, y_val, stopper::EarlyStopping, cfg::TrainConfig)
target_names = model.targets
save_epoch = stopper.best_epoch == 0 ? 0 : stopper.best_epoch
save_ps_st!(paths.best_model, model, ps, st, cfg.tracked_params, save_epoch)
save_ps_st!(paths.best_model, model, cfg.cdev(ps), cfg.cdev(st), cfg.tracked_params, save_epoch)

ŷ_train, αst_train = model(x_train, ps, LuxCore.testmode(st))
ŷ_val, αst_val = model(x_val, ps, LuxCore.testmode(st))
ŷ_train, αst_train = model((cfg.cdev(x_train), cfg.cdev(forcings_train)), cfg.cdev(ps), LuxCore.testmode(cfg.cdev(st)))
ŷ_val, αst_val = model((cfg.cdev(x_val), cfg.cdev(forcings_val)), cfg.cdev(ps), LuxCore.testmode(cfg.cdev(st)))

save_predictions!(paths.checkpoint, ŷ_train, αst_train, "training")
save_predictions!(paths.checkpoint, ŷ_val, αst_val, "validation")
Expand Down
2 changes: 1 addition & 1 deletion src/io/save.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ function save_observations!(file_name, target_names, yobs, train_or_val_name)
end

function to_named_tuple(ka, target_names)
arrays = [Array(ka(variable = k)) for k in target_names]
arrays = [Array(ka[k]) for k in target_names]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The change from ka(variable = k) to ka[k] will break if ka is a KeyedArray, as KeyedArray indexing typically requires dimension names or positional indices. While this might work if ka is now a NamedTuple due to changes in prepare_data, it makes the function less robust if it's still intended to handle KeyedArray inputs.

return NamedTuple{Tuple(target_names)}(arrays)
end

Expand Down
17 changes: 11 additions & 6 deletions src/losses/compute_loss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ Main loss function for hybrid models that handles both training and evaluation m
- `(loss_values, st, ŷ)`: NamedTuple of losses, state and predictions
"""
function compute_loss(
HM::LuxCore.AbstractLuxContainerLayer, ps, st, (x, (y_t, y_nan));
HM::LuxCore.AbstractLuxContainerLayer, ps, st, ((x, forcings), (y_t, y_nan));
logging::LoggingLoss
)

targets = HM.targets
ext_loss = extra_loss(logging)
if logging.train_mode
ŷ, st = HM(x, ps, st)
ŷ, st = HM((x, forcings), ps, st)
loss_value = _compute_loss(ŷ, y_t, y_nan, targets, training_loss(logging), logging.agg)
# Add extra_loss if provided
if ext_loss !== nothing
Expand All @@ -34,7 +34,7 @@ function compute_loss(
end
stats = NamedTuple()
else
ŷ, _ = HM(x, ps, LuxCore.testmode(st))
ŷ, _ = HM((x, forcings), ps, LuxCore.testmode(st))
loss_value = _compute_loss(ŷ, y_t, y_nan, targets, loss_types(logging), logging.agg)
# Add extra_loss entries if provided
if ext_loss !== nothing
Expand Down Expand Up @@ -105,9 +105,11 @@ _get_target_ŷ(ŷ, y_t, target) =
function assemble_loss(ŷ, y, y_nan, targets, loss_spec)
return [
begin
y_t = _get_target_y(y, target)
ŷ_t = _get_target_ŷ(ŷ, y_t, target)
_apply_loss(ŷ_t, y_t, _get_target_nan(y_nan, target), loss_spec)
y_t = y[target]# _get_target_y(y, target)
ŷ_t = ŷ[target]#_get_target_ŷ(ŷ, y_t, target)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a typo in line 109: (y with combining circumflex) is used instead of the argument ŷ (U+0177) defined at line 105. While Julia normalizes identifiers to NFC, mixing these characters is confusing and can lead to issues in environments with different normalization rules. Additionally, the commented-out code should be removed.

                y_t = y[target]
                ŷ_t = ŷ[target]

y_nan_t = y_nan[target]
_apply_loss(ŷ_t, y_t, y_nan_t, loss_spec)
# _apply_loss(ŷ_t, y_t, _get_target_nan(y_nan, target), loss_spec)
Comment on lines +110 to +112
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of _get_target_y for a variable named y_nan is semantically confusing. It is better to use _get_target_nan, which was specifically defined for this purpose. Also, please remove the commented-out code to keep the codebase clean.

                y_nan_t = _get_target_nan(y_nan, target)
                _apply_loss(ŷ_t, y_t, y_nan_t, loss_spec)

end
for target in targets
]
Expand Down Expand Up @@ -163,6 +165,9 @@ Helper function to apply the appropriate loss function based on the specificatio
"""
function _apply_loss end

_get_target_y(y::NamedTuple, target) = y[target]
_get_target_nan(y_nan::NamedTuple, target) = y_nan[target]

_get_target_y(y, target) = y(target)
_get_target_nan(y_nan, target) = y_nan(target)

Expand Down
3 changes: 1 addition & 2 deletions src/losses/loss_fn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ function loss_fn(ŷ, y, y_nan, ::Val{:pearson})
return cor(ŷ[y_nan], y[y_nan])
end
function loss_fn(ŷ, y, y_nan, ::Val{:r2})
r = cor(ŷ[y_nan], y[y_nan])
return r * r
return 1 - sum((y[y_nan] .- ŷ[y_nan]).^2) / sum((y[y_nan] .- mean(y[y_nan])).^2)
end

function loss_fn(ŷ, y, y_nan, ::Val{:pearsonLoss})
Expand Down
7 changes: 4 additions & 3 deletions src/models/GenericHybridModel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -360,9 +360,9 @@ end

# ───────────────────────────────────────────────────────────────────────────
# Forward pass for SingleNNHybridModel (optimized, no branching)
function (m::SingleNNHybridModel)(ds_k::Union{KeyedArray, AbstractDimArray}, ps, st)
function (m::SingleNNHybridModel)(ds_k, ps, st)
# 1) get features
predictors = toArray(ds_k, m.predictors)
predictors = ds_k[1]#toArray(ds_k, m.predictors)

parameters = m.parameters

Expand Down Expand Up @@ -407,11 +407,12 @@ function (m::SingleNNHybridModel)(ds_k::Union{KeyedArray, AbstractDimArray}, ps,
end

# 5) unpack forcing data
forcing_data = toNamedTuple(ds_k, m.forcing)
forcing_data = ds_k[2]#toNamedTuple(ds_k, m.forcing)

# 6) merge all parameters
all_params = merge(scaled_nn_params, global_params, fixed_params)
all_kwargs = merge(forcing_data, all_params)
# all_kwargs = merge(forcing_data, all_params)

# 7) physics
y_pred = m.mechanistic_model(; all_kwargs...)
Expand Down
Loading
Loading