diff --git a/Project.toml b/Project.toml index 6a8e075b..349749af 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ authors = ["Lazaro Alonso", "Bernhard Ahrens", "Markus Reichstein"] [deps] AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5" CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Chain = "8be319e6-bccf-4806-a6f7-6fae938471bc" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" @@ -44,6 +45,7 @@ EasyHybridMakie = "Makie" [compat] AxisKeys = "0.2" CSV = "0.10.15" +CUDA = "5.11.0" Chain = "0.6, 1" ChainRulesCore = "1.25.1" ComponentArrays = "0.15.28" diff --git a/src/EasyHybrid.jl b/src/EasyHybrid.jl index 9de434eb..99f43a3f 100644 --- a/src/EasyHybrid.jl +++ b/src/EasyHybrid.jl @@ -48,6 +48,7 @@ using Static: False, True @reexport begin import LuxCore using Lux: Lux, Dense, Chain, Dropout, relu, sigmoid, swish, tanh, Recurrence, LSTMCell + using Lux: gpu_device, cpu_device using Random using Statistics using DataFrames diff --git a/src/config/DataConfig.jl b/src/config/DataConfig.jl index 6df0b236..71c3bfea 100644 --- a/src/config/DataConfig.jl +++ b/src/config/DataConfig.jl @@ -56,4 +56,10 @@ cross-validation, and sequence construction for time-series training. "Whether to apply batch normalization to the model inputs. Default: `false`." input_batchnorm::Bool = false + + "Select a gpu_device or default to cpu if none available" + gdev = gpu_device() + + "Set the `cpu_device`, useful for sending back to the cpu model parameters" + cdev = cpu_device() end diff --git a/src/config/TrainingConfig.jl b/src/config/TrainingConfig.jl index 5c0e24e0..452701da 100644 --- a/src/config/TrainingConfig.jl +++ b/src/config/TrainingConfig.jl @@ -28,6 +28,12 @@ loss computation, data handling, output, and visualization. "Whether to return gradients during training. Default: True()." return_gradients = True() + "Select a gpu_device or default to cpu if none available" + gdev = gpu_device() + + "Set the `cpu_device`, useful for sending back to the cpu model parameters" + cdev = cpu_device() + "Loss type to use during training. Default: `:mse`." training_loss::Symbol = :mse diff --git a/src/data/loaders.jl b/src/data/loaders.jl index 5a675cce..8362bf81 100644 --- a/src/data/loaders.jl +++ b/src/data/loaders.jl @@ -1,6 +1,7 @@ -function build_loader(x_train, y_train, cfg::TrainConfig) +function build_loader(x_train, forcings_train, y_train, cfg::TrainConfig) loader = DataLoader( - (x_train, y_train); + ((x_train, forcings_train), y_train); + parallel = true, batchsize = cfg.batchsize, shuffle = true, ) diff --git a/src/data/prepare_data.jl b/src/data/prepare_data.jl index aa8e00fb..b2cde8ee 100644 --- a/src/data/prepare_data.jl +++ b/src/data/prepare_data.jl @@ -1,9 +1,12 @@ export prepare_data -function prepare_data(hm, data::KeyedArray; kwargs...) - predictors_forcing, targets = get_prediction_target_names(hm) +function prepare_data(hm, data::KeyedArray; cfg=DataConfig(), kwargs...) + predictors, forcings, targets = get_prediction_target_names(hm) # KeyedArray: use () syntax for views that are differentiable - return (data(predictors_forcing), data(targets)) + dev = cfg.gdev + targets_nt = NamedTuple([target => dev(Array(data(target))) for target in targets]) + forcings_nt = NamedTuple([forcing => dev(Array(data(forcing))) for forcing in forcings]) + return ((dev(Array(data(predictors))), forcings_nt), targets_nt) end function prepare_data(hm, data::AbstractDimArray; kwargs...) @@ -13,10 +16,10 @@ function prepare_data(hm, data::AbstractDimArray; kwargs...) end function prepare_data(hm, data::DataFrame; array_type = :KeyedArray, drop_missing_rows = true) - predictors_forcing, targets = get_prediction_target_names(hm) + predictors, forcings, targets = get_prediction_target_names(hm) - all_predictor_cols = unique(vcat(values(predictors_forcing)...)) - col_to_select = unique([all_predictor_cols; targets]) + # all_predictor_cols = unique(vcat(values(predictors_forcing)...)) + col_to_select = unique([predictors; forcings; targets]) # subset to only the cols we care about sdf = data[!, col_to_select] @@ -84,13 +87,15 @@ Returns a tuple of (predictors_forcing, targets) names. function get_prediction_target_names(hm) targets = hm.targets predictors_forcing = Symbol[] + predictors = Symbol[] + forcings = Symbol[] for prop in propertynames(hm) if occursin("predictors", string(prop)) val = getproperty(hm, prop) if isa(val, AbstractVector) - append!(predictors_forcing, val) + append!(predictors, val) elseif isa(val, Union{NamedTuple, Tuple}) - append!(predictors_forcing, unique(vcat(values(val)...))) + append!(predictors, unique(vcat(values(val)...))) end end end @@ -98,13 +103,14 @@ function get_prediction_target_names(hm) if occursin("forcing", string(prop)) val = getproperty(hm, prop) if isa(val, AbstractVector) - append!(predictors_forcing, val) + append!(forcings, val) elseif isa(val, Union{Tuple, NamedTuple}) - append!(predictors_forcing, unique(vcat(values(val)...))) + append!(forcings, unique(vcat(values(val)...))) end end end - predictors_forcing = unique(predictors_forcing) + # predicto + # predictors_forcing = unique(predictors_forcing) if isempty(predictors_forcing) @warn "Note that you don't have predictors or forcing variables." @@ -112,5 +118,5 @@ function get_prediction_target_names(hm) if isempty(targets) @warn "Note that you don't have target names." end - return predictors_forcing, targets + return predictors, forcings, targets end diff --git a/src/data/split_data.jl b/src/data/split_data.jl index eb50ff2d..347fc6ae 100644 --- a/src/data/split_data.jl +++ b/src/data/split_data.jl @@ -23,17 +23,16 @@ function split_data( ) if sequence_kwargs !== nothing - x_all, y_all = data_ + (x_all, forcings_all), y_all = data_ sis_default = (; input_window = 10, output_window = 1, output_shift = 1, lead_time = 1) sis = merge(sis_default, sequence_kwargs) @info "Using split_into_sequences: $sis" x_all, y_all = split_into_sequences(x_all, y_all; sis.input_window, sis.output_window, sis.output_shift, sis.lead_time) x_all, y_all = filter_sequences(x_all, y_all) else - x_all, y_all = data_ + (x_all, forcings_all), y_all = data_ end - if split_by_id !== nothing && folds !== nothing throw(ArgumentError("split_by_id and folds are not supported together; do the split when constructing folds")) @@ -50,9 +49,9 @@ function split_data( @info "Number of unique $(split_by_id): $(length(unique_ids))" @info "Train IDs: $(length(train_ids)) | Val IDs: $(length(val_ids))" - x_train, y_train = view_end_dim(x_all, train_idx), view_end_dim(y_all, train_idx) - x_val, y_val = view_end_dim(x_all, val_idx), view_end_dim(y_all, val_idx) - return (x_train, y_train), (x_val, y_val) + x_train, forcings_train, y_train = view_end_dim(x_all, train_idx),view_end_dim(forcings_all, train_idx), view_end_dim(y_all, train_idx) + x_val, forcings_val, y_val = view_end_dim(x_all, val_idx),view_end_dim(forcings_all, val_idx), view_end_dim(y_all, val_idx) + return ((x_train, forcings_train), y_train), ((x_val,forcings_val), y_val) elseif folds !== nothing || val_fold !== nothing # --- Option B: external K-fold assignment --- @@ -76,8 +75,8 @@ function split_data( else # --- Fallback: simple random/chronological split of prepared data --- - (x_train, y_train), (x_val, y_val) = splitobs((x_all, y_all); at = split_data_at, shuffle = shuffleobs) - return (x_train, y_train), (x_val, y_val) + (x_train, forcings_train, y_train), (x_val, forcings_val, y_val) = splitobs((x_all, forcings_all, y_all); at = split_data_at, shuffle = shuffleobs) + return ((x_train, forcings_train), y_train), ((x_val,forcings_val), y_val) end end diff --git a/src/data/splits.jl b/src/data/splits.jl index e0676de8..6c84d397 100644 --- a/src/data/splits.jl +++ b/src/data/splits.jl @@ -1,7 +1,7 @@ export prepare_splits, maybe_build_sequences function prepare_splits(data, model, cfg::DataConfig) - (x_train, y_train), (x_val, y_val) = split_data( + ((x_train, forcings_train), y_train), ((x_val, forcings_val), y_val) = split_data( data, model; array_type = cfg.array_type, shuffleobs = cfg.shuffleobs, @@ -16,7 +16,7 @@ function prepare_splits(data, model, cfg::DataConfig) @debug "Train size: $(size(x_train)), Val size: $(size(x_val))" @debug "Data type: $(typeof(x_train))" - return (x_train, y_train), (x_val, y_val) + return ((x_train, forcings_train), y_train), ((x_val, forcings_val), y_val) end function maybe_build_sequence_kwargs(cfg::DataConfig) diff --git a/src/io/checkpoints.jl b/src/io/checkpoints.jl index 9c696abb..17326908 100644 --- a/src/io/checkpoints.jl +++ b/src/io/checkpoints.jl @@ -1,24 +1,24 @@ function save_initial_state!(paths::TrainingPaths, model, ps, st, cfg::TrainConfig) - save_ps_st(paths.checkpoint, model, ps, st, cfg.tracked_params) + save_ps_st(paths.checkpoint, model, cfg.cdev(ps), cfg.cdev(st), cfg.tracked_params) save_train_val_loss!(paths.checkpoint, nothing, "training_loss", 0) save_train_val_loss!(paths.checkpoint, nothing, "validation_loss", 0) return nothing end function save_epoch!(paths::TrainingPaths, model, ps, st, snapshot::EpochSnapshot, epoch::Int, cfg::TrainConfig) - save_ps_st!(paths.checkpoint, model, ps, st, cfg.tracked_params, epoch) + save_ps_st!(paths.checkpoint, model, cfg.cdev(ps), cfg.cdev(st), cfg.tracked_params, epoch) save_train_val_loss!(paths.checkpoint, snapshot.l_train, "training_loss", epoch) save_train_val_loss!(paths.checkpoint, snapshot.l_val, "validation_loss", epoch) return nothing end -function save_final!(paths::TrainingPaths, model, ps, st, x_train, y_train, x_val, y_val, stopper::EarlyStopping, cfg::TrainConfig) +function save_final!(paths::TrainingPaths, model, ps, st, x_train, forcings_train, y_train, x_val, forcings_val, y_val, stopper::EarlyStopping, cfg::TrainConfig) target_names = model.targets save_epoch = stopper.best_epoch == 0 ? 0 : stopper.best_epoch - save_ps_st!(paths.best_model, model, ps, st, cfg.tracked_params, save_epoch) + save_ps_st!(paths.best_model, model, cfg.cdev(ps), cfg.cdev(st), cfg.tracked_params, save_epoch) - ŷ_train, αst_train = model(x_train, ps, LuxCore.testmode(st)) - ŷ_val, αst_val = model(x_val, ps, LuxCore.testmode(st)) + ŷ_train, αst_train = model((cfg.cdev(x_train), cfg.cdev(forcings_train)), cfg.cdev(ps), LuxCore.testmode(cfg.cdev(st))) + ŷ_val, αst_val = model((cfg.cdev(x_val), cfg.cdev(forcings_val)), cfg.cdev(ps), LuxCore.testmode(cfg.cdev(st))) save_predictions!(paths.checkpoint, ŷ_train, αst_train, "training") save_predictions!(paths.checkpoint, ŷ_val, αst_val, "validation") diff --git a/src/io/save.jl b/src/io/save.jl index 4ec03aef..c8a08292 100644 --- a/src/io/save.jl +++ b/src/io/save.jl @@ -53,7 +53,7 @@ function save_observations!(file_name, target_names, yobs, train_or_val_name) end function to_named_tuple(ka, target_names) - arrays = [Array(ka(variable = k)) for k in target_names] + arrays = [Array(ka[k]) for k in target_names] return NamedTuple{Tuple(target_names)}(arrays) end diff --git a/src/losses/compute_loss.jl b/src/losses/compute_loss.jl index 05adb8eb..2fb5e541 100644 --- a/src/losses/compute_loss.jl +++ b/src/losses/compute_loss.jl @@ -18,14 +18,14 @@ Main loss function for hybrid models that handles both training and evaluation m - `(loss_values, st, ŷ)`: NamedTuple of losses, state and predictions """ function compute_loss( - HM::LuxCore.AbstractLuxContainerLayer, ps, st, (x, (y_t, y_nan)); + HM::LuxCore.AbstractLuxContainerLayer, ps, st, ((x, forcings), (y_t, y_nan)); logging::LoggingLoss ) targets = HM.targets ext_loss = extra_loss(logging) if logging.train_mode - ŷ, st = HM(x, ps, st) + ŷ, st = HM((x, forcings), ps, st) loss_value = _compute_loss(ŷ, y_t, y_nan, targets, training_loss(logging), logging.agg) # Add extra_loss if provided if ext_loss !== nothing @@ -34,7 +34,7 @@ function compute_loss( end stats = NamedTuple() else - ŷ, _ = HM(x, ps, LuxCore.testmode(st)) + ŷ, _ = HM((x, forcings), ps, LuxCore.testmode(st)) loss_value = _compute_loss(ŷ, y_t, y_nan, targets, loss_types(logging), logging.agg) # Add extra_loss entries if provided if ext_loss !== nothing @@ -105,9 +105,10 @@ _get_target_ŷ(ŷ, y_t, target) = function assemble_loss(ŷ, y, y_nan, targets, loss_spec) return [ begin - y_t = _get_target_y(y, target) - ŷ_t = _get_target_ŷ(ŷ, y_t, target) - _apply_loss(ŷ_t, y_t, _get_target_nan(y_nan, target), loss_spec) + y_t = y[target]# _get_target_y(y, target) + ŷ_t = ŷ[target]#_get_target_ŷ(ŷ, y_t, target) + _apply_loss(ŷ_t, y_t, y_nan, loss_spec) + # _apply_loss(ŷ_t, y_t, _get_target_nan(y_nan, target), loss_spec) end for target in targets ] diff --git a/src/losses/loss_fn.jl b/src/losses/loss_fn.jl index f369baa2..9b10e025 100644 --- a/src/losses/loss_fn.jl +++ b/src/losses/loss_fn.jl @@ -69,8 +69,7 @@ function loss_fn(ŷ, y, y_nan, ::Val{:pearson}) return cor(ŷ[y_nan], y[y_nan]) end function loss_fn(ŷ, y, y_nan, ::Val{:r2}) - r = cor(ŷ[y_nan], y[y_nan]) - return r * r + return 1 - sum((y[y_nan] .- ŷ[y_nan]).^2) / sum((y[y_nan] .- mean(ŷ[y_nan])).^2) end function loss_fn(ŷ, y, y_nan, ::Val{:pearsonLoss}) diff --git a/src/models/GenericHybridModel.jl b/src/models/GenericHybridModel.jl index 8037fd52..13c6af86 100644 --- a/src/models/GenericHybridModel.jl +++ b/src/models/GenericHybridModel.jl @@ -360,9 +360,9 @@ end # ─────────────────────────────────────────────────────────────────────────── # Forward pass for SingleNNHybridModel (optimized, no branching) -function (m::SingleNNHybridModel)(ds_k::Union{KeyedArray, AbstractDimArray}, ps, st) +function (m::SingleNNHybridModel)(ds_k, ps, st) # 1) get features - predictors = toArray(ds_k, m.predictors) + predictors = ds_k[1]#toArray(ds_k, m.predictors) parameters = m.parameters @@ -407,11 +407,12 @@ function (m::SingleNNHybridModel)(ds_k::Union{KeyedArray, AbstractDimArray}, ps, end # 5) unpack forcing data - forcing_data = toNamedTuple(ds_k, m.forcing) + forcing_data = ds_k[2]#toNamedTuple(ds_k, m.forcing) # 6) merge all parameters all_params = merge(scaled_nn_params, global_params, fixed_params) all_kwargs = merge(forcing_data, all_params) + # all_kwargs = merge(forcing_data, all_params) # 7) physics y_pred = m.mechanistic_model(; all_kwargs...) diff --git a/src/training/early_stopping.jl b/src/training/early_stopping.jl index 109dab5f..52f5a17a 100644 --- a/src/training/early_stopping.jl +++ b/src/training/early_stopping.jl @@ -8,9 +8,9 @@ mutable struct EarlyStopping done::Bool end -function EarlyStopping(init_loss, ps, st, patience::Int) +function EarlyStopping(init_loss, ps, st, cfg) best_loss = extract_agg_loss(init_loss) - return EarlyStopping(best_loss, deepcopy(ps), deepcopy(st), 0, 0, patience, false) + return EarlyStopping(best_loss, deepcopy(cfg.cdev(ps)), deepcopy(cfg.cdev(st)), 0, 0, cfg.patience, false) end function update!(es::EarlyStopping, history::TrainingHistory, snapshot::EpochSnapshot, ps, st, epoch, cfg::TrainConfig) @@ -23,8 +23,8 @@ function update!(es::EarlyStopping, history::TrainingHistory, snapshot::EpochSna if isbetter(current_loss, es.best_loss, first(cfg.loss_types)) es.best_loss = current_loss - es.best_ps = deepcopy(ps) - es.best_st = deepcopy(st) + es.best_ps = deepcopy(cfg.cdev(ps)) + es.best_st = deepcopy(cfg.cdev(st)) es.best_epoch = epoch es.counter = 0 if !cfg.keep_history @@ -70,16 +70,16 @@ function best_or_final(stopper::EarlyStopping, ps, st, cfg::TrainConfig) end end -function build_results(model, history::TrainingHistory, stopper::EarlyStopping, ps, st, x_train, y_train, x_val, y_val) +function build_results(model, history::TrainingHistory, stopper::EarlyStopping, ps, st, x_train, forcings_train, y_train, x_val, forcings_val, y_val, cfg::TrainConfig) target_names = model.targets # final predictions in test mode - ŷ_train, _ = model(x_train, ps, LuxCore.testmode(st)) - ŷ_val, _ = model(x_val, ps, LuxCore.testmode(st)) + ŷ_train, _ = model((cfg.cdev(x_train), cfg.cdev(forcings_train)), cfg.cdev(ps), LuxCore.testmode(st)) + ŷ_val, _ = model((cfg.cdev(x_val), cfg.cdev(forcings_val)), cfg.cdev(ps), LuxCore.testmode(st)) # observed vs predicted DataFrames - train_obs_pred = hcat(toDataFrame(y_train), toDataFrame(ŷ_train, target_names)) - val_obs_pred = hcat(toDataFrame(y_val), toDataFrame(ŷ_val, target_names)) + train_obs_pred = hcat(DataFrame(y_train), toDataFrame(ŷ_train, target_names)) + val_obs_pred = hcat(DataFrame(y_val), toDataFrame(ŷ_val, target_names)) # extra predictions without observational counterparts train_diffs, val_diffs = extract_diffs(ŷ_train, ŷ_val, target_names) @@ -92,8 +92,8 @@ function build_results(model, history::TrainingHistory, stopper::EarlyStopping, val_obs_pred, train_diffs, val_diffs, - ps, - st, + cfg.cdev(ps), + cfg.cdev(st), stopper.best_epoch, stopper.best_loss, ) diff --git a/src/training/epoch.jl b/src/training/epoch.jl index c0127d16..ce1cd406 100644 --- a/src/training/epoch.jl +++ b/src/training/epoch.jl @@ -1,23 +1,26 @@ -function run_epoch!(loader, model, ps, st, opt_state, cfg::TrainConfig) +function run_epoch!(loader, model, ps, st, train_state, cfg::TrainConfig) loss_fn = build_loss_fn(model, cfg) - for (x, y) in loader - is_no_nan = valid_mask(y) - isnothing(is_no_nan) && continue + for (x, y) in cfg.gdev(loader) + is_no_nan = falses(length(first(y))) |> cfg.gdev + for vec in y + is_no_nan = is_no_nan.|| .!isnan.(vec) + end + isempty(is_no_nan) && continue - _, _, _, opt_state = Lux.Training.single_train_step!( + _, _, _, train_state = Lux.Training.single_train_step!( cfg.autodiff_backend, loss_fn, (x, (y, is_no_nan)), - opt_state; + train_state; return_gradients = cfg.return_gradients ) end - ps = opt_state.parameters - st = opt_state.states + ps = train_state.parameters + st = train_state.states - return ps, st, opt_state + return ps, st, train_state end function valid_mask(y) is_no_nan = .!isnan.(y) @@ -39,16 +42,22 @@ function build_loss_fn(model, cfg::TrainConfig) ) end -function evaluate_epoch(model, x_train, y_train, x_val, y_val, ps, st, init::EpochSnapshot, cfg::TrainConfig) - is_no_nan_t = .!isnan.(y_train) - is_no_nan_v = .!isnan.(y_val) +function evaluate_epoch(model, x_train, forcings_train, y_train, x_val, forcings_val, y_val, ps, st, init::EpochSnapshot, cfg::TrainConfig) + is_no_nan_t = falses(length(first(y_train))) |> cfg.gdev + for vec in y_train + is_no_nan_t = is_no_nan_t .|| .!isnan.(vec) + end + is_no_nan_v = falses(length(first(y_val))) |> cfg.gdev + for vec in y_val + is_no_nan_v = is_no_nan_v .|| .!isnan.(vec) + end l_train, _, ŷ_train = evaluate_acc( - model, x_train, y_train, is_no_nan_t, + model, x_train, forcings_train, y_train, is_no_nan_t, ps, st, cfg.loss_types, cfg.training_loss, cfg.extra_loss, cfg.agg ) l_val, _, ŷ_val = evaluate_acc( - model, x_val, y_val, is_no_nan_v, + model, x_val, forcings_val, y_val, is_no_nan_v, ps, st, cfg.loss_types, cfg.training_loss, cfg.extra_loss, cfg.agg ) diff --git a/src/training/initialization.jl b/src/training/initialization.jl index b25e9601..756208f2 100644 --- a/src/training/initialization.jl +++ b/src/training/initialization.jl @@ -16,15 +16,16 @@ end function init_model_state(model, cfg::TrainConfig) if isnothing(cfg.train_from) - ps, st = LuxCore.setup(Random.default_rng(), model) + ps, st = LuxCore.setup(Random.default_rng(), model) |> cfg.gdev + ps = ps |> cfg.gdev else - ps, st = get_ps_st(cfg.train_from) + ps, st = get_ps_st(cfg.train_from) |> cfg.gdev end - ps = ComponentArray(ps) - opt_state = Lux.Training.TrainState(model, ps, st, cfg.opt) + # ps = ComponentArray(ps) + train_state = Lux.Training.TrainState(model, ps, st, cfg.opt) - return ps, st, opt_state + return ps, st, train_state end struct EpochSnapshot @@ -35,16 +36,22 @@ struct EpochSnapshot end -function compute_initial_state(model, x_train, y_train, x_val, y_val, ps, st, cfg::TrainConfig) - is_no_nan_t = .!isnan.(y_train) - is_no_nan_v = .!isnan.(y_val) +function compute_initial_state(model, x_train, forcings_train, y_train, x_val, forcings_val, y_val, ps, st, cfg::TrainConfig) + is_no_nan_t = falses(length(first(y_train))) |> cfg.gdev + for vec in y_train + is_no_nan_t = is_no_nan_t .|| .!isnan.(vec) + end + is_no_nan_v = falses(length(first(y_val))) |> cfg.gdev + for vec in y_val + is_no_nan_v = is_no_nan_v .|| .!isnan.(vec) + end l_train, _, ŷ_train = evaluate_acc( - model, x_train, y_train, is_no_nan_t, + model, x_train, forcings_train, y_train, is_no_nan_t, ps, st, cfg.loss_types, cfg.training_loss, cfg.extra_loss, cfg.agg ) l_val, _, ŷ_val = evaluate_acc( - model, x_val, y_val, is_no_nan_v, + model, x_val, forcings_val, y_val, is_no_nan_v, ps, st, cfg.loss_types, cfg.training_loss, cfg.extra_loss, cfg.agg ) diff --git a/src/training/train.jl b/src/training/train.jl index a6295418..f8b5b71a 100644 --- a/src/training/train.jl +++ b/src/training/train.jl @@ -40,13 +40,13 @@ function train(model, data; train_cfg::TrainConfig = TrainConfig(), data_cfg::Da ext = load_makie_extension(train_cfg) seed!(train_cfg.random_seed) - (x_train, y_train), (x_val, y_val) = prepare_splits(data, model, data_cfg) - loader = build_loader(x_train, y_train, train_cfg) - ps, st, opt_state = init_model_state(model, train_cfg) - - init = compute_initial_state(model, x_train, y_train, x_val, y_val, ps, st, train_cfg) + ((x_train, forcings_train), y_train), ((x_val, forcings_val), y_val) = prepare_splits(data, model, data_cfg) + loader = build_loader(x_train, forcings_train, y_train, train_cfg) + ps, st, train_state = init_model_state(model, train_cfg) + + init = compute_initial_state(model, x_train, forcings_train, y_train, x_val, forcings_val, y_val, ps, st, train_cfg) history = TrainingHistory(init) - stopper = EarlyStopping(init.l_val, ps, st, train_cfg.patience) + stopper = EarlyStopping(init.l_val, ps, st, train_cfg) paths = resolve_paths(train_cfg) prog = build_progress(train_cfg) dashboard = init_dashboard(ext, init, train_cfg, y_train, y_val, model.targets) @@ -55,8 +55,8 @@ function train(model, data; train_cfg::TrainConfig = TrainConfig(), data_cfg::Da record_or_run(ext, paths, train_cfg) do io for epoch in 1:train_cfg.nepochs - ps, st, opt_state = run_epoch!(loader, model, ps, st, opt_state, train_cfg) - snapshot = evaluate_epoch(model, x_train, y_train, x_val, y_val, ps, st, init, train_cfg) + ps, st, train_state = run_epoch!(loader, model, ps, st, train_state, train_cfg) + snapshot = evaluate_epoch(model, x_train, forcings_train, y_train, x_val, forcings_val, y_val, ps, st, init, train_cfg) update!(stopper, history, snapshot, ps, st, epoch, train_cfg) save_epoch!(paths, model, ps, st, snapshot, epoch, train_cfg) @@ -69,9 +69,9 @@ function train(model, data; train_cfg::TrainConfig = TrainConfig(), data_cfg::Da save_dashboard_img!(dashboard, ext, paths, stopper.best_epoch) ps, st = best_or_final(stopper, ps, st, train_cfg) - save_final!(paths, model, ps, st, x_train, y_train, x_val, y_val, stopper, train_cfg) + save_final!(paths, model, ps, st, x_train, forcings_train, y_train, x_val, forcings_val, y_val, stopper, train_cfg) - return build_results(model, history, stopper, ps, st, x_train, y_train, x_val, y_val) + return build_results(model, history, stopper, ps, st, x_train, forcings_train, y_train, x_val, forcings_val, y_val, train_cfg) end function train(model, data, save_ps; kwargs...) @@ -159,8 +159,8 @@ function rename_deprecated_kwargs(kwargs) return NamedTuple(pairs) end -function evaluate_acc(ghm, x, y, y_no_nan, ps, st, loss_types, training_loss, extra_loss, agg) - loss_val, sts, ŷ = compute_loss(ghm, ps, st, (x, (y, y_no_nan)), logging = LoggingLoss(train_mode = false, loss_types = loss_types, training_loss = training_loss, extra_loss = extra_loss, agg = agg)) +function evaluate_acc(ghm, x, forcings, y, y_no_nan, ps, st, loss_types, training_loss, extra_loss, agg) + loss_val, sts, ŷ = compute_loss(ghm, ps, st, ((x, forcings), (y, y_no_nan)), logging = LoggingLoss(train_mode = false, loss_types = loss_types, training_loss = training_loss, extra_loss = extra_loss, agg = agg)) return loss_val, sts, ŷ end diff --git a/test/Project.toml b/test/Project.toml index f2051aa7..a244f8c2 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,7 +6,9 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0" EasyHybrid = "61bb816a-e6af-4913-ab9e-91bff2e122e3" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +Metal = "dde4c033-4e86-420c-a63e-0dd931031962" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/test_split_data_train.jl b/test/test_split_data_train.jl index ceb5a415..6a4ff0e6 100644 --- a/test/test_split_data_train.jl +++ b/test/test_split_data_train.jl @@ -5,6 +5,9 @@ using DataFrames using Statistics using DimensionalData using ChainRulesCore +# using GPUArraysCore + +# GPUArraysCore.allowscalar(false) # ------------------------------------------------------------------------------ # Synthetic data similar to the example's columns (no network calls)