Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ RUN cuda=$(command -v nvcc > /dev/null && echo "true" || echo "false") \
# Switch to bash shell
SHELL ["/bin/bash", "-c"]

# Set ${CONDA_ENV_NAME} to default virutal environment
# Set ${CONDA_ENV_NAME} to default virtual environment
RUN echo "source activate ${CONDA_ENV_NAME}" >> ~/.bashrc

# Cp in the development directory and install
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ models have now been added, as well as general models for assimilation and forec

The models implemented include:

DeepMind's [Functional Generative Network (FGN)](https://storage.googleapis.com/deepmind-media/DeepMind.com/Blog/how-we-re-supporting-better-tropical-cyclone-prediction-with-ai/skillful-joint-probabilistic-weather-forecasting-from-marginals.pdf) for probablistic ensemble forecasting
DeepMind's [Functional Generative Network (FGN)](https://storage.googleapis.com/deepmind-media/DeepMind.com/Blog/how-we-re-supporting-better-tropical-cyclone-prediction-with-ai/skillful-joint-probabilistic-weather-forecasting-from-marginals.pdf) for probabilistic ensemble forecasting

DeepMind's [GenCast](https://www.nature.com/articles/s41586-024-08252-9) for graph diffusion-based forecasting

Expand Down
2 changes: 1 addition & 1 deletion graph_weather/models/fgn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ This is an unofficial implementation of the Functional Generative Network
outlined in [Skillful joint probabilistic weather forecsting from marginals](https://arxiv.org/abs/2506.10772).

This model is heavily based on GenCast, and is designed to make ensemble weather forecasts through a combination of
mutliple trained models, and noise injected into the model parameters during inference.
multiple trained models, and noise injected into the model parameters during inference.

As it does not use diffusion, it is significantly faster to run than GenCast, while outperforming it on nearly all metrics.
2 changes: 1 addition & 1 deletion graph_weather/models/gencast/graph/icosahedral_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def get_icosahedron() -> TriangularMesh:
# / \ |
# / \ YO-----> X
# This results in:
# (adjacent faceis now top plane)
# (adjacent face is now top plane)
# ----------------------O\ (top arist)
# \
# \
Expand Down
8 changes: 4 additions & 4 deletions graph_weather/models/gencast/graph/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def lat_lon_deg_to_spherical(
node_lat: np.ndarray,
node_lon: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray]:
"""Convert lat and lon to spherical coordiantes."""
"""Convert lat and lon to spherical coordinates."""
phi = np.deg2rad(node_lon)
theta = np.deg2rad(90 - node_lat)
return phi, theta
Expand Down Expand Up @@ -221,7 +221,7 @@ def get_relative_position_in_receiver_local_coordinates(

The relative positions will be computed in a rotated space for a local
coordinate system as defined by the receiver. The relative positions are
simply obtained by subtracting sender position minues receiver position in
simply obtained by subtracting sender position minus receiver position in
that local coordinate system after the rotation in R^3.

Args:
Expand Down Expand Up @@ -526,7 +526,7 @@ def get_bipartite_relative_position_in_receiver_local_coordinates(

The relative positions will be computed in a rotated space for a local
coordinate system as defined by the receiver. The relative positions are
simply obtained by subtracting sender position minues receiver position in
simply obtained by subtracting sender position minus receiver position in
that local coordinate system after the rotation in R^3.

Args:
Expand Down Expand Up @@ -641,7 +641,7 @@ def dataset_to_stacked(
) -> xarray.DataArray:
"""Converts an xarray.Dataset to a single stacked array.

This takes each consistuent data_var, converts it into BHWC layout
This takes each constituent data_var, converts it into BHWC layout
using `variable_to_stacked`, then concats them all along the channels axis.

Args:
Expand Down
2 changes: 1 addition & 1 deletion graph_weather/models/gencast/layers/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def __init__(
not be used inside TransformerConv.
concat (bool): if true concatenate the outputs of each head, otherwise average them.
Defaults to True.
beta (bool): if true apply the beta weighting described in the paper. Defauls to True.
beta (bool): if true apply the beta weighting described in the paper. Defaults to True.
activation_layer (torch.nn.Module, optional): activation function applied before
returning the output. If None skip the activation function. Defaults to nn.ReLU.
"""
Expand Down
2 changes: 1 addition & 1 deletion graph_weather/models/gencast/layers/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(
activate_final=False,
)

# Tranformers Blocks
# Transformers Blocks
self.cond_transformers = torch.nn.ModuleList()
if not sparse:
for _ in range(num_blocks - 1):
Expand Down
2 changes: 1 addition & 1 deletion graph_weather/models/gencast/utils/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def batch(senders, edge_index, edge_attr=None, batch_size=1):


def hetero_batch(senders, receivers, edge_index, edge_attr=None, batch_size=1):
"""Build big batched heterogenous graph.
"""Build big batched heterogeneous graph.

Returns nodes and edges of a big graph with batch_size disconnected copies of the original
graph, with features shape [(b n) f].
Expand Down
2 changes: 1 addition & 1 deletion graph_weather/models/gencast/utils/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def generate_isotropic_noise(num_lon: int, num_lat: int, num_samples=1, isotropi
Args:
num_lon (int): number of longitudes in the grid.
num_lat (int): number of latitudes in the grid.
num_samples (int): number of indipendent samples. Defaults to 1.
num_samples (int): number of independent samples. Defaults to 1.
isotropic (bool): if true generates isotropic noise, else flat noise. Defaults to True.

Returns:
Expand Down
4 changes: 2 additions & 2 deletions graph_weather/models/layers/assimilator_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
The Decoder maps back to physical data defined on a latitude/longitude grid. The underlying graph is
again bipartite, this time mapping icosahedron→lat/lon.
The inputs to the Decoder come from the Processor, plus a skip connection back to the original
state of the 78 atmospheric variables onthe latitude/longitude grid.
state of the 78 atmospheric variables on the latitude/longitude grid.
The output of the Decoder is the predicted 6-hour change in the 78 atmospheric variables,
which is then added to the initial state to produce the new state. We found 6 hours to be a good
balance between shorter time steps (simpler dynamics to model but more iterations required during
Expand Down Expand Up @@ -185,7 +185,7 @@ def forward(self, processor_features: torch.Tensor, batch_size: int) -> torch.Te
dim=1,
)

# Readd nodes to match graph node number
# Re-add nodes to match graph node number
features = einops.rearrange(processor_features, "(b n) f -> b n f", b=batch_size)
features = torch.cat(
[features, einops.repeat(self.latlon_nodes, "n f -> b n f", b=batch_size)], dim=1
Expand Down
4 changes: 2 additions & 2 deletions graph_weather/models/layers/assimilator_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
use_checkpointing: bool = False,
):
"""
Encode the lat/lon data inot the isohedron graph
Encode the lat/lon data into the isohedron graph

Args:
resolution: H3 resolution level
Expand Down Expand Up @@ -208,7 +208,7 @@ def create_input_graph(self, features: torch.Tensor, lat_lons_heights: torch.Ten
edge_targets.append(self.h3_mapping[lat_node])
edge_index = torch.tensor([edge_sources, edge_targets], dtype=torch.long)

# Use homogenous graph to make it easier
# Use homogeneous graph to make it easier
return Data(edge_index=edge_index, edge_attr=h3_distances)

def create_latent_graph(self) -> Data:
Expand Down
2 changes: 1 addition & 1 deletion graph_weather/models/layers/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
The Decoder maps back to physical data defined on a latitude/longitude grid. The underlying graph is
again bipartite, this time mapping icosahedron→lat/lon.
The inputs to the Decoder come from the Processor, plus a skip connection back to the original
state of the 78 atmospheric variables onthe latitude/longitude grid.
state of the 78 atmospheric variables on the latitude/longitude grid.
The output of the Decoder is the predicted 6-hour change in the 78 atmospheric variables,
which is then added to the initial state to produce the new state. We found 6 hours to be a good
balance between shorter time steps (simpler dynamics to model but more iterations required during
Expand Down
4 changes: 2 additions & 2 deletions graph_weather/models/layers/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
efficient_batching: bool = False,
):
"""
Encode the lat/lon data inot the isohedron graph
Encode the lat/lon data into the isohedron graph

Args:
lat_lons: List of (lat,lon) points
Expand Down Expand Up @@ -103,7 +103,7 @@ def __init__(
edge_targets.append(self.h3_mapping[lat_node])
edge_index = torch.tensor([edge_sources, edge_targets], dtype=torch.long)

# Use homogenous graph to make it easier
# Use homogeneous graph to make it easier
self.graph = Data(edge_index=edge_index, edge_attr=self.h3_distances)

self.latent_graph = self.create_latent_graph()
Expand Down
4 changes: 2 additions & 2 deletions graph_weather/models/layers/graph_net_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def forward(

Args:
x: Input nodes
edge_index: Edge indicies in COO format
edge_index: Edge indices in COO format
edge_attr: Edge attributes
u: Global attributes, ignored
batch: Batch IDX, ignored
Expand Down Expand Up @@ -284,7 +284,7 @@ def forward(

Args:
x: Input nodes
edge_index: Edge indicies in COO format
edge_index: Edge indices in COO format
edge_attr: Edge attributes

Returns:
Expand Down
2 changes: 1 addition & 1 deletion train/pl_graph_weather.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ def configure_optimizers(self):
)
def run(num_blocks, hidden, batch, gpus):
"""
Trainig process.
Training process.

Args:
num_blocks : Number of blocks.
Expand Down