Skip to content

Ac/gpu support#256

Open
Qfl3x wants to merge 4 commits intomainfrom
ac/gpu_support
Open

Ac/gpu support#256
Qfl3x wants to merge 4 commits intomainfrom
ac/gpu_support

Conversation

@Qfl3x
Copy link
Copy Markdown
Collaborator

@Qfl3x Qfl3x commented Apr 2, 2026

No description provided.

lazarusA and others added 3 commits March 20, 2026 12:32
…at needs to be done in the outer loop, refactoring genericHybrid is needed for that
Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces GPU support and refactors data handling to separate predictors and forcings throughout the training pipeline. Key changes include adding CUDA dependencies, updating configuration objects with device selectors, and modifying data loaders, splitters, and model forward passes to accommodate a new nested tuple input structure. Feedback highlights a mathematical error in the R-squared calculation, potential shape mismatches and incorrect NaN masking in the epoch loop, and several instances of dead code or typos. Additionally, a logic error was identified in a warning check within the data preparation module.

function loss_fn(ŷ, y, y_nan, ::Val{:r2})
r = cor(ŷ[y_nan], y[y_nan])
return r * r
return 1 - sum((y[y_nan] .- ŷ[y_nan]).^2) / sum((y[y_nan] .- mean(ŷ[y_nan])).^2)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The R-squared calculation is incorrect. The denominator should use the mean of the observed values (y), not the predicted values (). The standard definition of R² is $1 - SS_{res}/SS_{tot}$, where $SS_{tot}$ is calculated relative to the mean of the observations.

    return 1 - sum((y[y_nan] .- ŷ[y_nan]).^2) / sum((y[y_nan] .- mean(y[y_nan])).^2)

Comment on lines +5 to +8
is_no_nan = falses(length(first(y))) |> cfg.gdev
for vec in y
is_no_nan = is_no_nan.|| .!isnan.(vec)
end
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This logic has two significant issues:

  1. Shape Mismatch: falses(length(first(y))) creates a 1D array. If the targets are multi-dimensional (e.g., (time, batch)), the bitwise OR operation .|| will fail. Use size(first(y)) instead of length.
  2. Incorrect Masking: Computing a single global is_no_nan mask by ORing all targets is problematic. If target A has a NaN at an index where target B is valid, the global mask will be true at that index. Consequently, the loss for target A will be computed using the NaN value, resulting in a NaN total loss. Masks should be computed and applied per-target.

Comment on lines +112 to +113
# predicto
# predictors_forcing = unique(predictors_forcing)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These lines contain an incomplete comment and commented-out code that should be removed to maintain code cleanliness.

@warn "Note that you don't have target names."
end
return predictors_forcing, targets
return predictors, forcings, targets
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function now returns predictors and forcings separately, but the warning check at line 115 (visible in context) still references predictors_forcing. Since predictors_forcing is initialized as an empty array at line 89 and never populated in the new logic, this warning will be triggered on every call. The check should be updated to verify if both predictors and forcings are empty.

Comment on lines +108 to +109
y_t = y[target]# _get_target_y(y, target)
ŷ_t = ŷ[target]#_get_target_ŷ(ŷ, y_t, target)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a typo in line 109: (y with combining circumflex) is used instead of the argument ŷ (U+0177) defined at line 105. While Julia normalizes identifiers to NFC, mixing these characters is confusing and can lead to issues in environments with different normalization rules. Additionally, the commented-out code should be removed.

                y_t = y[target]
                ŷ_t = ŷ[target]

Comment on lines +19 to +20
ps, st = LuxCore.setup(Random.default_rng(), model) |> cfg.gdev
ps = ps |> cfg.gdev
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The line ps = ps |> cfg.gdev is redundant. The parameters ps and state st are already moved to the device as part of the piped operation in line 19.

        ps, st = LuxCore.setup(Random.default_rng(), model) |> cfg.gdev

@lazarusA lazarusA mentioned this pull request Apr 3, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants