Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/EasyHybrid.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ using Static: False, True
@reexport begin
import LuxCore
using Lux: Lux, Dense, Chain, Dropout, relu, sigmoid, swish, tanh, Recurrence, LSTMCell
using Lux: gpu_device, cpu_device
using Random
using Statistics
using DataFrames
Expand Down
6 changes: 6 additions & 0 deletions src/config/DataConfig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,10 @@ cross-validation, and sequence construction for time-series training.

"Whether to apply batch normalization to the model inputs. Default: `false`."
input_batchnorm::Bool = false

"Select a gpu_device or default to cpu if none available"
gdev = gpu_device()

"Set the `cpu_device`, useful for sending back to the cpu model parameters"
cdev = cpu_device()
end
6 changes: 6 additions & 0 deletions src/config/TrainingConfig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ loss computation, data handling, output, and visualization.
"Whether to return gradients during training. Default: True()."
return_gradients = True()

"Select a gpu_device or default to cpu if none available"
gdev = gpu_device()

"Set the `cpu_device`, useful for sending back to the cpu model parameters"
cdev = cpu_device()

"Loss type to use during training. Default: `:mse`."
training_loss::Symbol = :mse

Expand Down
5 changes: 3 additions & 2 deletions src/data/loaders.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
function build_loader(x_train, y_train, cfg::TrainConfig)
function build_loader(x_train, forcings_train, y_train, mask, cfg::TrainConfig)
loader = DataLoader(
(x_train, y_train);
((x_train, forcings_train), (y_train, mask));
parallel = true,
batchsize = cfg.batchsize,
shuffle = true,
)
Expand Down
38 changes: 23 additions & 15 deletions src/data/prepare_data.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
export prepare_data

function prepare_data(hm, data::KeyedArray; kwargs...)
predictors_forcing, targets = get_prediction_target_names(hm)
function prepare_data(hm, data::KeyedArray; cfg=DataConfig(), kwargs...)
predictors, forcings, targets = get_prediction_target_names(hm)
# KeyedArray: use () syntax for views that are differentiable
return (data(predictors_forcing), data(targets))
X_arr = Array(data(predictors))
targets_nt = NamedTuple([target => Array(data(target)) for target in targets])
forcings_nt = NamedTuple([forcing => Array(data(forcing)) for forcing in forcings])
return ((X_arr, forcings_nt), targets_nt)
end

function prepare_data(hm, data::AbstractDimArray; kwargs...)
Expand All @@ -13,10 +16,10 @@ function prepare_data(hm, data::AbstractDimArray; kwargs...)
end

function prepare_data(hm, data::DataFrame; array_type = :KeyedArray, drop_missing_rows = true)
predictors_forcing, targets = get_prediction_target_names(hm)
predictors, forcings, targets = get_prediction_target_names(hm)

all_predictor_cols = unique(vcat(values(predictors_forcing)...))
col_to_select = unique([all_predictor_cols; targets])
# all_predictor_cols = unique(vcat(values(predictors_forcing)...))
col_to_select = unique([predictors; forcings; targets])

# subset to only the cols we care about
sdf = data[!, col_to_select]
Expand Down Expand Up @@ -83,34 +86,39 @@ Returns a tuple of (predictors_forcing, targets) names.
"""
function get_prediction_target_names(hm)
targets = hm.targets
predictors_forcing = Symbol[]
predictors = Symbol[]
forcings = Symbol[]
for prop in propertynames(hm)
if occursin("predictors", string(prop))
val = getproperty(hm, prop)
if isa(val, AbstractVector)
append!(predictors_forcing, val)
append!(predictors, val)
elseif isa(val, Union{NamedTuple, Tuple})
append!(predictors_forcing, unique(vcat(values(val)...)))
append!(predictors, unique(vcat(values(val)...)))
end
end
end
for prop in propertynames(hm)
if occursin("forcing", string(prop))
val = getproperty(hm, prop)
if isa(val, AbstractVector)
append!(predictors_forcing, val)
append!(forcings, val)
elseif isa(val, Union{Tuple, NamedTuple})
append!(predictors_forcing, unique(vcat(values(val)...)))
append!(forcings, unique(vcat(values(val)...)))
end
end
end
predictors_forcing = unique(predictors_forcing)
# predicto
# predictors_forcing = unique(predictors_forcing)
Comment on lines +111 to +112
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These lines contain an incomplete comment and commented-out code that should be removed to maintain code cleanliness.


if isempty(predictors_forcing)
@warn "Note that you don't have predictors or forcing variables."
if isempty(predictors)
@warn "Note that you don't have predictors variables."
end
if isempty(forcings)
@warn "Note that you don't have forcing variables."
end
if isempty(targets)
@warn "Note that you don't have target names."
end
return predictors_forcing, targets
return predictors, forcings, targets
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function now returns predictors and forcings separately, but the warning check at line 115 (visible in context) still references predictors_forcing. Since predictors_forcing is initialized as an empty array at line 89 and never populated in the new logic, this warning will be triggered on every call. The check should be updated to verify if both predictors and forcings are empty.

end
88 changes: 75 additions & 13 deletions src/data/split_data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ function split_data(
split_data_at::Real = 0.8,
sequence_kwargs::Union{Nothing, NamedTuple} = nothing,
array_type::Symbol = :KeyedArray,
cfg = DataConfig(),
kwargs...
)
data_ = prepare_data(
Expand All @@ -23,17 +24,16 @@ function split_data(
)

if sequence_kwargs !== nothing
x_all, y_all = data_
(x_all, forcings_all), y_all = data_
sis_default = (; input_window = 10, output_window = 1, output_shift = 1, lead_time = 1)
sis = merge(sis_default, sequence_kwargs)
@info "Using split_into_sequences: $sis"
x_all, y_all = split_into_sequences(x_all, y_all; sis.input_window, sis.output_window, sis.output_shift, sis.lead_time)
x_all, y_all = filter_sequences(x_all, y_all)
else
x_all, y_all = data_
(x_all, forcings_all), y_all = data_
end


if split_by_id !== nothing && folds !== nothing

throw(ArgumentError("split_by_id and folds are not supported together; do the split when constructing folds"))
Expand All @@ -50,9 +50,15 @@ function split_data(
@info "Number of unique $(split_by_id): $(length(unique_ids))"
@info "Train IDs: $(length(train_ids)) | Val IDs: $(length(val_ids))"

x_train, y_train = view_end_dim(x_all, train_idx), view_end_dim(y_all, train_idx)
x_val, y_val = view_end_dim(x_all, val_idx), view_end_dim(y_all, val_idx)
return (x_train, y_train), (x_val, y_val)
x_train, forcings_train, y_train = collect_end_dim(x_all, train_idx),collect_end_dim(forcings_all, train_idx), collect_end_dim(y_all, train_idx)
x_val, forcings_val, y_val = collect_end_dim(x_all, val_idx),collect_end_dim(forcings_all, val_idx), collect_end_dim(y_all, val_idx)
x_train = x_train |> cfg.gdev
forcings_train = forcings_train |> cfg.gdev
y_train = y_train |> cfg.gdev
x_val = x_val |> cfg.gdev
forcings_val = forcings_val |> cfg.gdev
y_val = y_val |> cfg.gdev
return ((x_train, forcings_train), y_train), ((x_val,forcings_val), y_val)

elseif folds !== nothing || val_fold !== nothing
# --- Option B: external K-fold assignment ---
Expand All @@ -70,14 +76,26 @@ function split_data(

@info "K-fold via external assignments: val_fold=$val_fold → train=$(length(train_idx)) val=$(length(val_idx))"

x_train, y_train = view_end_dim(x_all, train_idx), view_end_dim(y_all, train_idx)
x_val, y_val = view_end_dim(x_all, val_idx), view_end_dim(y_all, val_idx)
return (x_train, y_train), (x_val, y_val)
x_train, forcings_train, y_train = collect_end_dim(x_all, train_idx),collect_end_dim(forcings_all, train_idx), collect_end_dim(y_all, train_idx)
x_val, forcings_val, y_val = collect_end_dim(x_all, val_idx),collect_end_dim(forcings_all, val_idx), collect_end_dim(y_all, val_idx)
x_train = x_train |> cfg.gdev
forcings_train = forcings_train |> cfg.gdev
y_train = y_train |> cfg.gdev
x_val = x_val |> cfg.gdev
forcings_val = forcings_val |> cfg.gdev
y_val = y_val |> cfg.gdev
return ((x_train, forcings_train), y_train), ((x_val,forcings_val), y_val)

else
# --- Fallback: simple random/chronological split of prepared data ---
(x_train, y_train), (x_val, y_val) = splitobs((x_all, y_all); at = split_data_at, shuffle = shuffleobs)
return (x_train, y_train), (x_val, y_val)
(x_train, forcings_train, y_train), (x_val, forcings_val, y_val) = splitobs((x_all, forcings_all, y_all); at = split_data_at, shuffle = shuffleobs)
x_train = x_train |> cfg.gdev
forcings_train = forcings_train |> cfg.gdev
y_train = y_train |> cfg.gdev
x_val = x_val |> cfg.gdev
forcings_val = forcings_val |> cfg.gdev
y_val = y_val |> cfg.gdev
return ((x_train, forcings_train), y_train), ((x_val,forcings_val), y_val)
end
end

Expand Down Expand Up @@ -115,8 +133,28 @@ function getbyname(df::DataFrame, name::Symbol)
return df[!, name]
end

function getbyname(ka::Union{KeyedArray, AbstractDimArray}, name::Symbol)
return @view ka[variable = At(name)]
function getbyname(ka::KeyedArray, name::Symbol)
return ka(variable = name)
end

function getbyname(ka::AbstractDimArray, name::Symbol)
return ka[variable = At(name)]
end

function view_end_dim(x_all::AbstractMatrix{T}, idx) where {T}
return view(x_all, :, idx)
end

function view_end_dim(x_all::AbstractVector{T}, idx) where {T}
return view(x_all, idx)
end

function view_end_dim(x_all::NamedTuple, idx)
nt = (;)
for (k,v) in pairs(x_all)
nt = merge(nt, NamedTuple([k => view_end_dim(v, idx)]))
end
return nt
end

function view_end_dim(x_all::Union{KeyedArray{Float32, 2}, AbstractDimArray{Float32, 2}}, idx)
Expand All @@ -126,3 +164,27 @@ end
function view_end_dim(x_all::Union{KeyedArray{Float32, 3}, AbstractDimArray{Float32, 3}}, idx)
return view(x_all, :, :, idx)
end

function collect_end_dim(x_all::AbstractMatrix{T}, idx) where {T}
return collect(getindex(x_all, :, idx))
end

function collect_end_dim(x_all::AbstractVector{T}, idx) where {T}
return collect(getindex(x_all, idx))
end

function collect_end_dim(x_all::NamedTuple, idx)
nt = (;)
for (k,v) in pairs(x_all)
nt = merge(nt, NamedTuple([k => collect_end_dim(v, idx)]))
end
return nt
end

function collect_end_dim(x_all::Union{KeyedArray{Float32, 2}, AbstractDimArray{Float32, 2}}, idx)
return collect(getindex(x_all, :, idx))
end

function collect_end_dim(x_all::Union{KeyedArray{Float32, 3}, AbstractDimArray{Float32, 3}}, idx)
return collect(getindex(x_all, :, :, idx))
end
4 changes: 2 additions & 2 deletions src/data/splits.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
export prepare_splits, maybe_build_sequences

function prepare_splits(data, model, cfg::DataConfig)
(x_train, y_train), (x_val, y_val) = split_data(
((x_train, forcings_train), y_train), ((x_val, forcings_val), y_val) = split_data(
data, model;
array_type = cfg.array_type,
shuffleobs = cfg.shuffleobs,
Expand All @@ -16,7 +16,7 @@ function prepare_splits(data, model, cfg::DataConfig)
@debug "Train size: $(size(x_train)), Val size: $(size(x_val))"
@debug "Data type: $(typeof(x_train))"

return (x_train, y_train), (x_val, y_val)
return ((x_train, forcings_train), y_train), ((x_val, forcings_val), y_val)
end

function maybe_build_sequence_kwargs(cfg::DataConfig)
Expand Down
12 changes: 6 additions & 6 deletions src/io/checkpoints.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
function save_initial_state!(paths::TrainingPaths, model, ps, st, cfg::TrainConfig)
save_ps_st(paths.checkpoint, model, ps, st, cfg.tracked_params)
save_ps_st(paths.checkpoint, model, cfg.cdev(ps), cfg.cdev(st), cfg.tracked_params)
save_train_val_loss!(paths.checkpoint, nothing, "training_loss", 0)
save_train_val_loss!(paths.checkpoint, nothing, "validation_loss", 0)
return nothing
end

function save_epoch!(paths::TrainingPaths, model, ps, st, snapshot::EpochSnapshot, epoch::Int, cfg::TrainConfig)
save_ps_st!(paths.checkpoint, model, ps, st, cfg.tracked_params, epoch)
save_ps_st!(paths.checkpoint, model, cfg.cdev(ps), cfg.cdev(st), cfg.tracked_params, epoch)
save_train_val_loss!(paths.checkpoint, snapshot.l_train, "training_loss", epoch)
save_train_val_loss!(paths.checkpoint, snapshot.l_val, "validation_loss", epoch)
return nothing
end

function save_final!(paths::TrainingPaths, model, ps, st, x_train, y_train, x_val, y_val, stopper::EarlyStopping, cfg::TrainConfig)
function save_final!(paths::TrainingPaths, model, ps, st, x_train, forcings_train, y_train, x_val, forcings_val, y_val, stopper::EarlyStopping, cfg::TrainConfig)
target_names = model.targets
save_epoch = stopper.best_epoch == 0 ? 0 : stopper.best_epoch
save_ps_st!(paths.best_model, model, ps, st, cfg.tracked_params, save_epoch)
save_ps_st!(paths.best_model, model, cfg.cdev(ps), cfg.cdev(st), cfg.tracked_params, save_epoch)

ŷ_train, αst_train = model(x_train, ps, LuxCore.testmode(st))
ŷ_val, αst_val = model(x_val, ps, LuxCore.testmode(st))
ŷ_train, αst_train = model((cfg.cdev(x_train), cfg.cdev(forcings_train)), cfg.cdev(ps), LuxCore.testmode(cfg.cdev(st)))
ŷ_val, αst_val = model((cfg.cdev(x_val), cfg.cdev(forcings_val)), cfg.cdev(ps), LuxCore.testmode(cfg.cdev(st)))

save_predictions!(paths.checkpoint, ŷ_train, αst_train, "training")
save_predictions!(paths.checkpoint, ŷ_val, αst_val, "validation")
Expand Down
2 changes: 1 addition & 1 deletion src/io/save.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ function save_observations!(file_name, target_names, yobs, train_or_val_name)
end

function to_named_tuple(ka, target_names)
arrays = [Array(ka(variable = k)) for k in target_names]
arrays = [Array(ka[k]) for k in target_names]
return NamedTuple{Tuple(target_names)}(arrays)
end

Expand Down
17 changes: 11 additions & 6 deletions src/losses/compute_loss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ Main loss function for hybrid models that handles both training and evaluation m
- `(loss_values, st, ŷ)`: NamedTuple of losses, state and predictions
"""
function compute_loss(
HM::LuxCore.AbstractLuxContainerLayer, ps, st, (x, (y_t, y_nan));
HM::LuxCore.AbstractLuxContainerLayer, ps, st, ((x, forcings), (y_t, y_nan));
logging::LoggingLoss
)

targets = HM.targets
ext_loss = extra_loss(logging)
if logging.train_mode
ŷ, st = HM(x, ps, st)
ŷ, st = HM((x, forcings), ps, st)
loss_value = _compute_loss(ŷ, y_t, y_nan, targets, training_loss(logging), logging.agg)
# Add extra_loss if provided
if ext_loss !== nothing
Expand All @@ -34,7 +34,7 @@ function compute_loss(
end
stats = NamedTuple()
else
ŷ, _ = HM(x, ps, LuxCore.testmode(st))
ŷ, _ = HM((x, forcings), ps, LuxCore.testmode(st))
loss_value = _compute_loss(ŷ, y_t, y_nan, targets, loss_types(logging), logging.agg)
# Add extra_loss entries if provided
if ext_loss !== nothing
Expand Down Expand Up @@ -105,9 +105,11 @@ _get_target_ŷ(ŷ, y_t, target) =
function assemble_loss(ŷ, y, y_nan, targets, loss_spec)
return [
begin
y_t = _get_target_y(y, target)
ŷ_t = _get_target_ŷ(ŷ, y_t, target)
_apply_loss(ŷ_t, y_t, _get_target_nan(y_nan, target), loss_spec)
y_t = y[target]# _get_target_y(y, target)
ŷ_t = ŷ[target]#_get_target_ŷ(ŷ, y_t, target)
Comment on lines +108 to +109
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a typo in line 109: (y with combining circumflex) is used instead of the argument ŷ (U+0177) defined at line 105. While Julia normalizes identifiers to NFC, mixing these characters is confusing and can lead to issues in environments with different normalization rules. Additionally, the commented-out code should be removed.

                y_t = y[target]
                ŷ_t = ŷ[target]

y_nan_t = y_nan[target]
_apply_loss(ŷ_t, y_t, y_nan_t, loss_spec)
# _apply_loss(ŷ_t, y_t, _get_target_nan(y_nan, target), loss_spec)
end
for target in targets
]
Expand Down Expand Up @@ -163,6 +165,9 @@ Helper function to apply the appropriate loss function based on the specificatio
"""
function _apply_loss end

_get_target_y(y::NamedTuple, target) = y[target]
_get_target_nan(y_nan::NamedTuple, target) = y_nan[target]

_get_target_y(y, target) = y(target)
_get_target_nan(y_nan, target) = y_nan(target)

Expand Down
3 changes: 1 addition & 2 deletions src/losses/loss_fn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ function loss_fn(ŷ, y, y_nan, ::Val{:pearson})
return cor(ŷ[y_nan], y[y_nan])
end
function loss_fn(ŷ, y, y_nan, ::Val{:r2})
r = cor(ŷ[y_nan], y[y_nan])
return r * r
return 1 - sum((y[y_nan] .- ŷ[y_nan]).^2) / sum((y[y_nan] .- mean(y[y_nan])).^2)
end

function loss_fn(ŷ, y, y_nan, ::Val{:pearsonLoss})
Expand Down
Loading
Loading