Skip to content

Utilities

XRegrid provides several utility functions for creating standard grids, loading ESMF-formatted files, and performing common spatial operations.

Grid Generation

create_global_grid

xregrid.create_global_grid(res_lat, res_lon, add_bounds=True, chunks=None)

Create a global rectilinear grid dataset.

Parameters:

Name Type Description Default
res_lat float

Latitude resolution in degrees.

required
res_lon float

Longitude resolution in degrees.

required
add_bounds bool

Whether to add cell boundary coordinates.

True
chunks int or dict

Chunk sizes for the resulting dask-backed dataset. If None (default), returns an eager NumPy-backed dataset.

None

Returns:

Type Description
Dataset

The global grid dataset containing 'lat' and 'lon'.

Source code in src/xregrid/utils.py
def create_global_grid(
    res_lat: float,
    res_lon: float,
    add_bounds: bool = True,
    chunks: Optional[Union[int, Dict[str, int]]] = None,
) -> xr.Dataset:
    """
    Create a global rectilinear grid dataset.

    Parameters
    ----------
    res_lat : float
        Latitude resolution in degrees.
    res_lon : float
        Longitude resolution in degrees.
    add_bounds : bool, default True
        Whether to add cell boundary coordinates.
    chunks : int or dict, optional
        Chunk sizes for the resulting dask-backed dataset.
        If None (default), returns an eager NumPy-backed dataset.

    Returns
    -------
    xr.Dataset
        The global grid dataset containing 'lat' and 'lon'.
    """
    return _create_rectilinear_grid(
        lat_range=(-90, 90),
        lon_range=(0, 360),
        res_lat=res_lat,
        res_lon=res_lon,
        add_bounds=add_bounds,
        chunks=chunks,
        history_msg=f"Created global grid ({res_lat}x{res_lon}) using xregrid.",
    )

Create a global rectilinear grid dataset with a specified resolution.

from xregrid import create_global_grid

# Create a 1x1 degree global grid with bounds
ds = create_global_grid(res_lat=1.0, res_lon=1.0)

create_regional_grid

xregrid.create_regional_grid(lat_range, lon_range, res_lat, res_lon, add_bounds=True, chunks=None)

Create a regional rectilinear grid dataset.

Parameters:

Name Type Description Default
lat_range tuple of float

(min_lat, max_lat).

required
lon_range tuple of float

(min_lon, max_lon).

required
res_lat float

Latitude resolution in degrees.

required
res_lon float

Longitude resolution in degrees.

required
add_bounds bool

Whether to add cell boundary coordinates.

True
chunks int or dict

Chunk sizes for the resulting dask-backed dataset. If None (default), returns an eager NumPy-backed dataset.

None

Returns:

Type Description
Dataset

The regional grid dataset containing 'lat' and 'lon'.

Source code in src/xregrid/utils.py
def create_regional_grid(
    lat_range: Tuple[float, float],
    lon_range: Tuple[float, float],
    res_lat: float,
    res_lon: float,
    add_bounds: bool = True,
    chunks: Optional[Union[int, Dict[str, int]]] = None,
) -> xr.Dataset:
    """
    Create a regional rectilinear grid dataset.

    Parameters
    ----------
    lat_range : tuple of float
        (min_lat, max_lat).
    lon_range : tuple of float
        (min_lon, max_lon).
    res_lat : float
        Latitude resolution in degrees.
    res_lon : float
        Longitude resolution in degrees.
    add_bounds : bool, default True
        Whether to add cell boundary coordinates.
    chunks : int or dict, optional
        Chunk sizes for the resulting dask-backed dataset.
        If None (default), returns an eager NumPy-backed dataset.

    Returns
    -------
    xr.Dataset
        The regional grid dataset containing 'lat' and 'lon'.
    """
    return _create_rectilinear_grid(
        lat_range=lat_range,
        lon_range=lon_range,
        res_lat=res_lat,
        res_lon=res_lon,
        add_bounds=add_bounds,
        chunks=chunks,
        history_msg=f"Created regional grid ({res_lat}x{res_lon}) using xregrid.",
    )

Create a regional rectilinear grid dataset for a specific geographic bounding box.

from xregrid import create_regional_grid

# Create a regional grid over Europe
ds = create_regional_grid(
    lat_range=(35, 70),
    lon_range=(-10, 40),
    res_lat=0.25,
    res_lon=0.25
)

create_grid_like

xregrid.create_grid_like(obj, res, add_bounds=True, chunks=None, extent=None, crs=None)

Create a new grid dataset with the same extent and CRS as an existing object.

Automatically detects the CRS and spatial extent of the input object. Supports both geographic (lat-lon) and projected coordinate systems.

Parameters:

Name Type Description Default
obj DataArray or Dataset

The input object to use as a template.

required
res float or tuple of float

New grid resolution in the coordinate system units. If tuple, (res_x, res_y) or (res_lon, res_lat).

required
add_bounds bool

Whether to add cell boundary coordinates.

True
chunks int or dict

Chunk sizes for the resulting dask-backed dataset.

None
extent tuple of float

Override the detected extent (min_x, max_x, min_y, max_y). Use this to avoid hidden dask.compute() if you already know the extent.

None
crs str, int, or pyproj.CRS

Override the detected CRS.

None

Returns:

Type Description
Dataset

The new grid dataset.

Source code in src/xregrid/utils.py
def create_grid_like(
    obj: Union[xr.DataArray, xr.Dataset],
    res: Union[float, Tuple[float, float]],
    add_bounds: bool = True,
    chunks: Optional[Union[int, Dict[str, int]]] = None,
    extent: Optional[Tuple[float, float, float, float]] = None,
    crs: Optional[Union[str, int, Any]] = None,
) -> xr.Dataset:
    """
    Create a new grid dataset with the same extent and CRS as an existing object.

    Automatically detects the CRS and spatial extent of the input object.
    Supports both geographic (lat-lon) and projected coordinate systems.

    Parameters
    ----------
    obj : xr.DataArray or xr.Dataset
        The input object to use as a template.
    res : float or tuple of float
        New grid resolution in the coordinate system units.
        If tuple, (res_x, res_y) or (res_lon, res_lat).
    add_bounds : bool, default True
        Whether to add cell boundary coordinates.
    chunks : int or dict, optional
        Chunk sizes for the resulting dask-backed dataset.
    extent : tuple of float, optional
        Override the detected extent (min_x, max_x, min_y, max_y).
        Use this to avoid hidden dask.compute() if you already know the extent.
    crs : str, int, or pyproj.CRS, optional
        Override the detected CRS.

    Returns
    -------
    xr.Dataset
        The new grid dataset.
    """
    if crs is not None:
        if pyproj is not None:
            crs_obj = pyproj.CRS(crs)
        else:
            crs_obj = crs
    else:
        crs_obj = get_crs_info(obj)

    if isinstance(res, (int, float)):
        res_x = res_y = float(res)
    else:
        res_x, res_y = map(float, res)

    obj_name = getattr(obj, "name", "input")
    history_msg_base = f"Created grid like {obj_name} using xregrid."
    if hasattr(obj, "attrs") and "history" in obj.attrs:
        history_msg_base += f"\nTemplate history:\n{obj.attrs['history']}"

    if extent is not None:
        if crs_obj is None or (
            hasattr(crs_obj, "is_geographic") and crs_obj.is_geographic
        ):
            # Lat-Lon
            return _create_rectilinear_grid(
                (extent[2], extent[3]),  # lat_range
                (extent[0], extent[1]),  # lon_range
                res_y,  # res_lat
                res_x,  # res_lon
                add_bounds=add_bounds,
                chunks=chunks,
                crs=crs_obj.to_wkt() if hasattr(crs_obj, "to_wkt") else "EPSG:4326",
                history_msg=history_msg_base + " (Override Extent).",
            )
        else:
            # Projected
            return create_grid_from_crs(
                crs_obj, extent, (res_x, res_y), add_bounds=add_bounds, chunks=chunks
            )

    # 1. Try to find projected coordinates
    try:
        x_da = obj.cf["projection_x_coordinate"]
        y_da = obj.cf["projection_y_coordinate"]

        try:
            # Use bounds for exact extent if available
            x_b = obj.cf.get_bounds("projection_x_coordinate")
            y_b = obj.cf.get_bounds("projection_y_coordinate")

            # Batch compute if lazy to minimize roundtrips
            if dask is not None and (
                hasattr(x_b.data, "dask") or hasattr(y_b.data, "dask")
            ):
                vals = dask.compute(x_b.min(), x_b.max(), y_b.min(), y_b.max())
                extent = tuple(map(float, vals))
            elif hasattr(x_b.data, "dask") or hasattr(y_b.data, "dask"):
                extent = (
                    float(x_b.min()),
                    float(x_b.max()),
                    float(y_b.min()),
                    float(y_b.max()),
                )
            else:
                extent = (
                    float(x_b.min()),
                    float(x_b.max()),
                    float(y_b.min()),
                    float(y_b.max()),
                )
        except Exception:
            # Fallback to centers
            # Discovery logic: we need min/max and average diff for heuristic
            if dask is not None and (
                hasattr(x_da.data, "dask") or hasattr(y_da.data, "dask")
            ):
                # Batch everything!
                tasks_dict = {
                    "x_min": x_da.min(),
                    "x_max": x_da.max(),
                    "y_min": y_da.min(),
                    "y_max": y_da.max(),
                }
                if x_da.size > 1:
                    tasks_dict["res_x"] = abs(x_da.diff(x_da.dims[0]).mean())
                if y_da.size > 1:
                    tasks_dict["res_y"] = abs(y_da.diff(y_da.dims[0]).mean())

                results = dask.compute(tasks_dict)[0]
                x_min, x_max, y_min, y_max = (
                    float(results["x_min"]),
                    float(results["x_max"]),
                    float(results["y_min"]),
                    float(results["y_max"]),
                )

                res_x_orig = float(results.get("res_x", 0))
                res_y_orig = float(
                    results.get("res_y", res_x_orig if res_x_orig else 0)
                )

                extent = (
                    x_min - res_x_orig / 2,
                    x_max + res_x_orig / 2,
                    y_min - res_y_orig / 2,
                    y_max + res_y_orig / 2,
                )
            elif hasattr(x_da.data, "dask") or hasattr(y_da.data, "dask"):
                # Non-batched fallback
                res_x_orig = (
                    abs(float(x_da.diff(x_da.dims[0]).mean())) if x_da.size > 1 else 0
                )
                res_y_orig = (
                    abs(float(y_da.diff(y_da.dims[0]).mean()))
                    if y_da.size > 1
                    else res_x_orig
                )
                extent = (
                    float(x_da.min()) - res_x_orig / 2,
                    float(x_da.max()) + res_x_orig / 2,
                    float(y_da.min()) - res_y_orig / 2,
                    float(y_da.max()) + res_y_orig / 2,
                )
            else:
                res_x_orig = (
                    abs(float(x_da.diff(x_da.dims[0]).mean())) if x_da.size > 1 else 0
                )
                res_y_orig = (
                    abs(float(y_da.diff(y_da.dims[0]).mean()))
                    if y_da.size > 1
                    else res_x_orig
                )
                extent = (
                    float(x_da.min()) - res_x_orig / 2,
                    float(x_da.max()) + res_x_orig / 2,
                    float(y_da.min()) - res_y_orig / 2,
                    float(y_da.max()) + res_y_orig / 2,
                )

        if crs_obj is None:
            # Fallback to generic geographic if no CRS found
            crs_obj = "EPSG:4326"

        return create_grid_from_crs(
            crs_obj, extent, (res_x, res_y), add_bounds=add_bounds, chunks=chunks
        )

    except (KeyError, AttributeError, ValueError):
        pass

    # 2. Fallback to Geographic (Lat-Lon)
    try:
        lat_da = _find_coord(obj, "latitude")
        lon_da = _find_coord(obj, "longitude")
        if lat_da is None or lon_da is None:
            raise KeyError("Coordinates not found")

        try:
            lat_b = obj.cf.get_bounds("latitude")
            lon_b = obj.cf.get_bounds("longitude")

            if dask is not None and (
                hasattr(lat_b.data, "dask") or hasattr(lon_b.data, "dask")
            ):
                vals = dask.compute(lat_b.min(), lat_b.max(), lon_b.min(), lon_b.max())
                lat_range = (float(vals[0]), float(vals[1]))
                lon_range = (float(vals[2]), float(vals[3]))
            elif hasattr(lat_b.data, "dask") or hasattr(lon_b.data, "dask"):
                lat_range = (float(lat_b.min()), float(lat_b.max()))
                lon_range = (float(lon_b.min()), float(lon_b.max()))
            else:
                lat_range = (float(lat_b.min()), float(lat_b.max()))
                lon_range = (float(lon_b.min()), float(lon_b.max()))
        except Exception:
            # Heuristic for resolution to calculate extent from centers
            if dask is not None and (
                hasattr(lat_da.data, "dask") or hasattr(lon_da.data, "dask")
            ):
                tasks_dict = {
                    "lat_min": lat_da.min(),
                    "lat_max": lat_da.max(),
                    "lon_min": lon_da.min(),
                    "lon_max": lon_da.max(),
                }
                if lat_da.size > 1:
                    tasks_dict["res_lat"] = abs(lat_da.diff(lat_da.dims[0]).mean())
                if lon_da.size > 1:
                    tasks_dict["res_lon"] = abs(lon_da.diff(lon_da.dims[-1]).mean())

                results = dask.compute(tasks_dict)[0]
                lat_min, lat_max, lon_min, lon_max = (
                    float(results["lat_min"]),
                    float(results["lat_max"]),
                    float(results["lon_min"]),
                    float(results["lon_max"]),
                )

                res_lat_orig = float(results.get("res_lat", 0))
                res_lon_orig = float(
                    results.get("res_lon", res_lat_orig if res_lat_orig else 0)
                )

                lat_range = (
                    lat_min - res_lat_orig / 2,
                    lat_max + res_lat_orig / 2,
                )
                lon_range = (
                    lon_min - res_lon_orig / 2,
                    lon_max + res_lon_orig / 2,
                )
            elif hasattr(lat_da.data, "dask") or hasattr(lon_da.data, "dask"):
                res_lat_orig = (
                    abs(float(lat_da.diff(lat_da.dims[0]).mean()))
                    if lat_da.size > 1
                    else 0
                )
                res_lon_orig = (
                    abs(float(lon_da.diff(lon_da.dims[-1]).mean()))
                    if lon_da.size > 1
                    else res_lat_orig
                )
                lat_range = (
                    float(lat_da.min()) - res_lat_orig / 2,
                    float(lat_da.max()) + res_lat_orig / 2,
                )
                lon_range = (
                    float(lon_da.min()) - res_lon_orig / 2,
                    float(lon_da.max()) + res_lon_orig / 2,
                )
            else:
                res_lat_orig = (
                    abs(float(lat_da.diff(lat_da.dims[0]).mean()))
                    if lat_da.size > 1
                    else 0
                )
                res_lon_orig = (
                    abs(float(lon_da.diff(lon_da.dims[-1]).mean()))
                    if lon_da.size > 1
                    else res_lat_orig
                )
                lat_range = (
                    float(lat_da.min()) - res_lat_orig / 2,
                    float(lat_da.max()) + res_lat_orig / 2,
                )
                lon_range = (
                    float(lon_da.min()) - res_lon_orig / 2,
                    float(lon_da.max()) + res_lon_orig / 2,
                )

        return _create_rectilinear_grid(
            lat_range,
            lon_range,
            res_y,  # res_lat
            res_x,  # res_lon
            add_bounds=add_bounds,
            chunks=chunks,
            crs=crs_obj.to_wkt() if crs_obj else "EPSG:4326",
            history_msg=history_msg_base,
        )
    except (KeyError, AttributeError, ValueError):
        raise ValueError(
            "Could not detect spatial coordinates (latitude/longitude or "
            "projection_x/y) in input object."
        )

Create a new grid dataset with the same extent and CRS as an existing object.

from xregrid.utils import create_grid_like

# Create a 0.5 degree grid matching the extent of an existing dataset
new_grid = create_grid_like(ds, res=0.5)

create_grid_from_crs

xregrid.create_grid_from_crs(crs, extent, res, add_bounds=True, chunks=None)

Create a structured grid dataset from a CRS and extent.

Parameters:

Name Type Description Default
crs str, int, or pyproj.CRS

The CRS of the grid (Proj4 string, EPSG code, WKT, or CRS object).

required
extent Tuple[float, float, float, float]

Grid extent in CRS units: (min_x, max_x, min_y, max_y).

required
res float or Tuple[float, float]

Grid resolution in CRS units. If float, same resolution in x and y. If tuple, (res_x, res_y).

required
add_bounds bool

Whether to add cell boundary coordinates.

True
chunks int or Dict[str, int]

Chunk sizes for the resulting dask-backed dataset.

None

Returns:

Type Description
Dataset

The grid dataset containing 'lat', 'lon' and projected coordinates 'x', 'y'.

Source code in src/xregrid/utils.py
def create_grid_from_crs(
    crs: Union[str, int, Any],
    extent: Tuple[float, float, float, float],
    res: Union[float, Tuple[float, float]],
    add_bounds: bool = True,
    chunks: Optional[Union[int, Dict[str, int]]] = None,
) -> xr.Dataset:
    """
    Create a structured grid dataset from a CRS and extent.

    Parameters
    ----------
    crs : str, int, or pyproj.CRS
        The CRS of the grid (Proj4 string, EPSG code, WKT, or CRS object).
    extent : Tuple[float, float, float, float]
        Grid extent in CRS units: (min_x, max_x, min_y, max_y).
    res : float or Tuple[float, float]
        Grid resolution in CRS units. If float, same resolution in x and y.
        If tuple, (res_x, res_y).
    add_bounds : bool, default True
        Whether to add cell boundary coordinates.
    chunks : int or Dict[str, int], optional
        Chunk sizes for the resulting dask-backed dataset.

    Returns
    -------
    xr.Dataset
        The grid dataset containing 'lat', 'lon' and projected coordinates 'x', 'y'.
    """
    if isinstance(res, (int, float)):
        res_x = res_y = float(res)
    else:
        res_x, res_y = map(float, res)

    x_chunks = chunks.get("x", -1) if isinstance(chunks, dict) else chunks
    y_chunks = chunks.get("y", -1) if isinstance(chunks, dict) else chunks

    # Generate 1D coordinates in projected space
    x = _lazy_arange(extent[0] + res_x / 2, extent[1], res_x, chunks=x_chunks)
    y = _lazy_arange(extent[2] + res_y / 2, extent[3], res_y, chunks=y_chunks)

    x_da = xr.DataArray(x, dims=["x"], name="x")
    y_da = xr.DataArray(y, dims=["y"], name="y")

    # Use xr.broadcast for lazy 2D arrays
    yy_da, xx_da = xr.broadcast(y_da, x_da)

    # Ensure (y, x) order
    yy_da = yy_da.transpose("y", "x")
    xx_da = xx_da.transpose("y", "x")

    if pyproj is None:
        raise ImportError("pyproj is required for create_grid_from_crs.")
    crs_obj = pyproj.CRS(crs)

    lon, lat = xr.apply_ufunc(
        _transform_coords,
        xx_da,
        yy_da,
        kwargs={"crs_in": crs_obj},
        dask="parallelized",
        output_dtypes=[float, float],
        input_core_dims=[[], []],
        output_core_dims=[[], []],
    )

    try:
        units = crs_obj.axis_info[0].unit_name or "m"
    except (IndexError, AttributeError):
        units = "m"

    ds = xr.Dataset(
        coords={
            "y": (
                ["y"],
                y,
                {"units": units, "standard_name": "projection_y_coordinate"},
            ),
            "x": (
                ["x"],
                x,
                {"units": units, "standard_name": "projection_x_coordinate"},
            ),
            "lat": (
                ["y", "x"],
                lat.data,
                {"units": "degrees_north", "standard_name": "latitude"},
            ),
            "lon": (
                ["y", "x"],
                lon.data,
                {"units": "degrees_east", "standard_name": "longitude"},
            ),
        }
    )

    ds.attrs["crs"] = crs_obj.to_wkt()

    if add_bounds:
        if chunks is not None and da is not None:
            x_b_raw = da.stack(
                [x - res_x / 2, x + res_x / 2, x + res_x / 2, x - res_x / 2]
            )
            y_b_raw = da.stack(
                [y - res_y / 2, y - res_y / 2, y + res_y / 2, y + res_y / 2]
            )
        else:
            x_b_raw = np.stack(
                [x - res_x / 2, x + res_x / 2, x + res_x / 2, x - res_x / 2]
            )
            y_b_raw = np.stack(
                [y - res_y / 2, y - res_y / 2, y + res_y / 2, y + res_y / 2]
            )

        x_b_da = xr.DataArray(x_b_raw, dims=["nv", "x"])
        y_b_da = xr.DataArray(y_b_raw, dims=["nv", "y"])

        yy_b_da, xx_b_da = xr.broadcast(y_b_da, x_b_da)

        lon_b, lat_b = xr.apply_ufunc(
            _transform_coords,
            xx_b_da,
            yy_b_da,
            kwargs={"crs_in": crs_obj},
            dask="parallelized",
            output_dtypes=[float, float],
            input_core_dims=[[], []],
            output_core_dims=[[], []],
        )

        ds.coords["lat_b"] = (
            ["y", "x", "nv"],
            lat_b.data.transpose(1, 2, 0),
            {"units": "degrees_north"},
        )
        ds.coords["lon_b"] = (
            ["y", "x", "nv"],
            lon_b.data.transpose(1, 2, 0),
            {"units": "degrees_east"},
        )
        ds["lat"].attrs["bounds"] = "lat_b"
        ds["lon"].attrs["bounds"] = "lon_b"

        # Add 1D projected bounds using backend-agnostic xarray operations
        x_da_1d = xr.DataArray(x, dims=["x"])
        y_da_1d = xr.DataArray(y, dims=["y"])

        # Create (N, 2) bounds
        x_b_1d = xr.concat(
            [x_da_1d - res_x / 2, x_da_1d + res_x / 2], dim="nbounds"
        ).transpose("x", "nbounds")
        y_b_1d = xr.concat(
            [y_da_1d - res_y / 2, y_da_1d + res_y / 2], dim="nbounds"
        ).transpose("y", "nbounds")

        ds.coords["x_b"] = (["x", "nbounds"], x_b_1d.data, {"units": units})
        ds.coords["y_b"] = (["y", "nbounds"], y_b_1d.data, {"units": units})
        ds["x"].attrs["bounds"] = "x_b"
        ds["y"].attrs["bounds"] = "y_b"

    # Backend detection for provenance
    is_lazy = chunks is not None
    backend = "Lazy" if is_lazy else "Eager"

    update_history(ds, f"Created grid from CRS {crs} using xregrid ({backend}).")
    if chunks is not None:
        ds = ds.chunk(chunks)
    return ds

Create a structured grid dataset from a Coordinate Reference System (CRS) and extent.

from xregrid import create_grid_from_crs

# Create a Lambert Conformal Conic grid over North America
extent = (-2500000, 2500000, -2000000, 2000000)
res = (12000, 12000) # 12km
crs = "+proj=lcc +lat_1=33 +lat_2=45 +lat_0=40 +lon_0=-97 +x_0=0 +y_0=0 +ellps=WGS84 +units=m +no_defs"

ds = create_grid_from_crs(crs, extent, res)

create_grid_from_ioapi

xregrid.create_grid_from_ioapi(metadata, add_bounds=True, chunks=None)

Create a structured grid dataset from IOAPI-compliant metadata.

Supports GDTYP: - 1: Lat-Lon - 2: Lambert Conformal - 3: Mercator - 4: Stereographic - 5: UTM - 6: Polar Stereographic - 7: Equatorial Mercator - 8: Transverse Mercator - 9: Albers Equal Area - 10: Lambert Azimuthal Equal Area - 13: Sinusoidal

Parameters:

Name Type Description Default
metadata dict

IOAPI metadata containing GDTYP, P_ALP, P_BET, P_GAM, XCENT, YCENT, XORIG, YORIG, XCELL, YCELL, NCOLS, NROWS.

required
add_bounds bool

Whether to add cell boundary coordinates.

True
chunks int or dict

Chunk sizes for the resulting dask-backed dataset.

None

Returns:

Type Description
Dataset

The grid dataset.

Source code in src/xregrid/utils.py
def create_grid_from_ioapi(
    metadata: Dict[str, Any],
    add_bounds: bool = True,
    chunks: Optional[Union[int, Dict[str, int]]] = None,
) -> xr.Dataset:
    """
    Create a structured grid dataset from IOAPI-compliant metadata.

    Supports GDTYP:
    - 1: Lat-Lon
    - 2: Lambert Conformal
    - 3: Mercator
    - 4: Stereographic
    - 5: UTM
    - 6: Polar Stereographic
    - 7: Equatorial Mercator
    - 8: Transverse Mercator
    - 9: Albers Equal Area
    - 10: Lambert Azimuthal Equal Area
    - 13: Sinusoidal

    Parameters
    ----------
    metadata : dict
        IOAPI metadata containing GDTYP, P_ALP, P_BET, P_GAM, XCENT, YCENT,
        XORIG, YORIG, XCELL, YCELL, NCOLS, NROWS.
    add_bounds : bool, default True
        Whether to add cell boundary coordinates.
    chunks : int or dict, optional
        Chunk sizes for the resulting dask-backed dataset.

    Returns
    -------
    xr.Dataset
        The grid dataset.
    """
    gdtyp = metadata["GDTYP"]
    p_alp = metadata["P_ALP"]
    p_bet = metadata["P_BET"]
    xcent = metadata["XCENT"]
    ycent = metadata["YCENT"]
    xorig = metadata["XORIG"]
    yorig = metadata["YORIG"]
    xcell = metadata["XCELL"]
    ycell = metadata["YCELL"]
    ncols = metadata["NCOLS"]
    nrows = metadata["NROWS"]

    if gdtyp == 1:  # Lat-Lon
        crs = "EPSG:4326"
    elif gdtyp == 2:  # Lambert Conformal
        crs = (
            f"+proj=lcc +lat_1={p_alp} +lat_2={p_bet} +lat_0={ycent} "
            f"+lon_0={xcent} +x_0=0 +y_0=0 +datum=WGS84 +units=m +no_defs"
        )
    elif gdtyp == 3:  # Mercator
        crs = (
            f"+proj=merc +lat_ts={p_alp} +lon_0={xcent} +lat_0={ycent} "
            f"+x_0=0 +y_0=0 +datum=WGS84 +units=m +no_defs"
        )
    elif gdtyp == 4:  # Stereographic
        crs = (
            f"+proj=stere +lat_ts={p_alp} +lat_0={ycent} +lon_0={xcent} "
            f"+x_0=0 +y_0=0 +datum=WGS84 +units=m +no_defs"
        )
    elif gdtyp == 5:  # UTM
        crs = f"+proj=utm +zone={int(p_alp)} +datum=WGS84 +units=m +no_defs"
    elif gdtyp == 6:  # Polar Stereographic
        # lat_0 determined by p_alp (1.0 for North, -1.0 for South)
        lat_0 = 90.0 if p_alp > 0 else -90.0
        crs = (
            f"+proj=stere +lat_0={lat_0} +lat_ts={p_bet} +lon_0={xcent} "
            f"+x_0=0 +y_0=0 +datum=WGS84 +units=m +no_defs"
        )
    elif gdtyp == 7:  # Equatorial Mercator
        crs = (
            f"+proj=merc +lat_ts={p_alp} +lon_0={xcent} +lat_0=0 "
            f"+x_0=0 +y_0=0 +datum=WGS84 +units=m +no_defs"
        )
    elif gdtyp == 8:  # Transverse Mercator
        crs = (
            f"+proj=tmerc +lat_0={ycent} +k={p_bet} +lon_0={xcent} "
            f"+x_0=0 +y_0=0 +datum=WGS84 +units=m +no_defs"
        )
    elif gdtyp == 9:  # Albers Equal Area
        crs = (
            f"+proj=aea +lat_1={p_alp} +lat_2={p_bet} +lat_0={ycent} "
            f"+lon_0={xcent} +x_0=0 +y_0=0 +datum=WGS84 +units=m +no_defs"
        )
    elif gdtyp == 10:  # Lambert Azimuthal Equal Area
        crs = (
            f"+proj=laea +lat_0={ycent} +lon_0={xcent} "
            f"+x_0=0 +y_0=0 +datum=WGS84 +units=m +no_defs"
        )
    elif gdtyp == 13:  # Sinusoidal
        crs = f"+proj=sinu +lon_0={xcent} +x_0=0 +y_0=0 +datum=WGS84 +units=m +no_defs"
    else:
        raise ValueError(f"Unsupported IOAPI GDTYP: {gdtyp}")

    extent = (xorig, xorig + ncols * xcell, yorig, yorig + nrows * ycell)
    res = (xcell, ycell)

    ds = create_grid_from_crs(crs, extent, res, add_bounds=add_bounds, chunks=chunks)

    # Attach IOAPI metadata for provenance
    for k, v in metadata.items():
        ds.attrs[f"ioapi_{k}"] = v

    update_history(ds, f"Created grid from IOAPI metadata (GDTYP={gdtyp})")

    return ds

Create a structured grid dataset from IOAPI-compliant metadata.

from xregrid.utils import create_grid_from_ioapi

metadata = {
    "GDTYP": 2,
    "P_ALP": 30.0,
    "P_BET": 60.0,
    "XCENT": -97.0,
    "YCENT": 40.0,
    "XORIG": -1000.0,
    "YORIG": -1000.0,
    "XCELL": 500.0,
    "YCELL": 500.0,
    "NCOLS": 100,
    "NROWS": 100,
}

ds = create_grid_from_ioapi(metadata)

create_mesh_from_coords

xregrid.utils.create_mesh_from_coords(x, y, crs, chunks=None)

Create an unstructured mesh dataset from coordinates and a CRS.

Parameters:

Name Type Description Default
x ndarray or DataArray

1D array of x coordinates in CRS units.

required
y ndarray or DataArray

1D array of y coordinates in CRS units.

required
crs str, int, or pyproj.CRS

The CRS of the coordinates.

required
chunks int or dict

Chunk sizes for the resulting dask-backed dataset. If None (default), returns an eager NumPy-backed dataset.

None

Returns:

Type Description
Dataset

The mesh dataset containing 'lat', 'lon' and 'x', 'y' as 1D arrays.

Source code in src/xregrid/utils.py
def create_mesh_from_coords(
    x: Union[np.ndarray, xr.DataArray],
    y: Union[np.ndarray, xr.DataArray],
    crs: Union[str, int, Any],
    chunks: Optional[Union[int, Dict[str, int]]] = None,
) -> xr.Dataset:
    """
    Create an unstructured mesh dataset from coordinates and a CRS.

    Parameters
    ----------
    x : np.ndarray or xr.DataArray
        1D array of x coordinates in CRS units.
    y : np.ndarray or xr.DataArray
        1D array of y coordinates in CRS units.
    crs : str, int, or pyproj.CRS
        The CRS of the coordinates.
    chunks : int or dict, optional
        Chunk sizes for the resulting dask-backed dataset.
        If None (default), returns an eager NumPy-backed dataset.

    Returns
    -------
    xr.Dataset
        The mesh dataset containing 'lat', 'lon' and 'x', 'y' as 1D arrays.
    """
    if pyproj is None:
        raise ImportError(
            "pyproj is required for create_mesh_from_coords. "
            "Install it with `pip install pyproj`."
        )
    crs_obj = pyproj.CRS(crs)

    # Force n_pts dimension to avoid alignment/broadcasting issues in apply_ufunc
    # Preserve backend (data) while dropping stale coordinates
    if isinstance(x, xr.DataArray):
        x_da = xr.DataArray(x.data, dims=["n_pts"], name="x", attrs=x.attrs)
    else:
        x_da = xr.DataArray(x, dims=["n_pts"], name="x")

    if isinstance(y, xr.DataArray):
        y_da = xr.DataArray(y.data, dims=["n_pts"], name="y", attrs=y.attrs)
    else:
        y_da = xr.DataArray(y, dims=["n_pts"], name="y")

    if chunks is not None:
        # If chunks is a dict, ensure we use the canonical dimension name
        if isinstance(chunks, dict):
            # Map any single-dimension chunking to n_pts
            n_pts_chunks = next(iter(chunks.values()))
            chunks = {"n_pts": n_pts_chunks}
        x_da = x_da.chunk(chunks)
        y_da = y_da.chunk(chunks)

    # Backend detection for provenance
    is_lazy = chunks is not None or hasattr(x_da.data, "dask")
    backend = "Lazy" if is_lazy else "Eager"

    # Use apply_ufunc with dask='parallelized'
    lon, lat = xr.apply_ufunc(
        _transform_coords,
        x_da,
        y_da,
        kwargs={"crs_in": crs_obj},
        dask="parallelized",
        output_dtypes=[float, float],
        input_core_dims=[[], []],
        output_core_dims=[[], []],
    )

    # Conditional metadata based on CRS type
    if crs_obj.is_geographic:
        x_std, y_std = "longitude", "latitude"
        x_units, y_units = "degrees_east", "degrees_north"
    else:
        x_std, y_std = "projection_x_coordinate", "projection_y_coordinate"
        try:
            x_units = y_units = crs_obj.axis_info[0].unit_name or "m"
        except (IndexError, AttributeError):
            x_units = y_units = "m"

    ds = xr.Dataset(
        coords={
            "n_pts": (["n_pts"], np.arange(x_da.size)),
            "x": (
                ["n_pts"],
                x_da.data,
                {"units": x_units, "standard_name": x_std},
            ),
            "y": (
                ["n_pts"],
                y_da.data,
                {"units": y_units, "standard_name": y_std},
            ),
            "lat": (
                ["n_pts"],
                lat.data,
                {"units": "degrees_north", "standard_name": "latitude"},
            ),
            "lon": (
                ["n_pts"],
                lon.data,
                {"units": "degrees_east", "standard_name": "longitude"},
            ),
        }
    )

    # Add grid_mapping variable for CF-compliance
    if crs_obj.is_projected:
        gm_name = "spatial_ref"
        ds[gm_name] = ([], 0, crs_obj.to_cf())
        ds.attrs["grid_mapping"] = gm_name
        for var in ["x", "y", "lat", "lon"]:
            ds[var].attrs["grid_mapping"] = gm_name

    ds.attrs["crs"] = crs_obj.to_wkt()

    # Capture extents for provenance (careful with lazy arrays)
    if not is_lazy:
        extent = (
            float(x_da.min()),
            float(x_da.max()),
            float(y_da.min()),
            float(y_da.max()),
        )
        extent_msg = f" Extent: {extent}."
    else:
        extent_msg = ""

    update_history(
        ds,
        f"Created mesh from coordinates and CRS {crs} using xregrid ({backend}).{extent_msg}",
    )

    return ds

Create an unstructured mesh dataset from 1D coordinates and a CRS.

from xregrid.utils import create_mesh_from_coords
import numpy as np

lons = np.random.uniform(0, 360, 1000)
lats = np.random.uniform(-90, 90, 1000)
ds_mesh = create_mesh_from_coords(lons, lats, crs="EPSG:4326")

Spatial Operations

spatial_slice

xregrid.utils.spatial_slice(obj, extent, crs=None, buffer=0.0)

Slice an xarray object to a spatial extent, handling longitude wrapping.

This function identifies spatial dimensions via cf-xarray and performs a backend-agnostic slice. For geographic coordinates, it robustly handles longitude wrapping (e.g., slicing a 0-360 grid with a -20 to 20 extent).

Parameters:

Name Type Description Default
obj DataArray or Dataset

The input object to slice.

required
extent tuple of float

Spatial extent as (min_x, max_x, min_y, max_y).

required
crs str, int, or pyproj.CRS

The CRS of the provided extent. If None, assumes the same CRS as obj.

None
buffer float

Extra buffer to add around the extent in coordinate units.

0.0

Returns:

Type Description
DataArray or Dataset

The spatially sliced object.

Notes

For longitude wrapping, if the requested extent crosses the grid's discontinuity, the result will be concatenated along the longitude dimension.

Source code in src/xregrid/utils.py
def spatial_slice(
    obj: Union[xr.DataArray, xr.Dataset],
    extent: Tuple[float, float, float, float],
    crs: Optional[Union[str, int, Any]] = None,
    buffer: float = 0.0,
) -> Union[xr.DataArray, xr.Dataset]:
    """
    Slice an xarray object to a spatial extent, handling longitude wrapping.

    This function identifies spatial dimensions via cf-xarray and performs
    a backend-agnostic slice. For geographic coordinates, it robustly
    handles longitude wrapping (e.g., slicing a 0-360 grid with a -20 to 20 extent).

    Parameters
    ----------
    obj : xr.DataArray or xr.Dataset
        The input object to slice.
    extent : tuple of float
        Spatial extent as (min_x, max_x, min_y, max_y).
    crs : str, int, or pyproj.CRS, optional
        The CRS of the provided extent. If None, assumes the same CRS as obj.
    buffer : float, default 0.0
        Extra buffer to add around the extent in coordinate units.

    Returns
    -------
    xr.DataArray or xr.Dataset
        The spatially sliced object.

    Notes
    -----
    For longitude wrapping, if the requested extent crosses the grid's
    discontinuity, the result will be concatenated along the longitude dimension.
    """
    # 1. Coordinate and Dimension Discovery
    lat_da = _find_coord(obj, "latitude")
    lon_da = _find_coord(obj, "longitude")

    if lat_da is None or lon_da is None:
        try:
            x_da = obj.cf["projection_x_coordinate"]
            y_da = obj.cf["projection_y_coordinate"]
            is_geographic = False
        except (KeyError, AttributeError):
            raise ValueError(
                "Could not detect spatial coordinates (lat/lon or x/y) for slicing. "
                "Ensure your data has CF-compliant coordinates."
            )
    else:
        x_da, y_da = lon_da, lat_da
        is_geographic = True

    # 2. CRS Transformation
    if crs is not None:
        if pyproj is None:
            raise ImportError(
                "pyproj is required for CRS-aware slicing. "
                "Install it with `pip install pyproj`."
            )
        target_crs = get_crs_info(obj) or pyproj.CRS("EPSG:4326")
        transformer = pyproj.Transformer.from_crs(crs, target_crs, always_xy=True)

        # Transform bbox by checking 4 corners
        x_pts = [extent[0], extent[1], extent[1], extent[0]]
        y_pts = [extent[2], extent[2], extent[3], extent[3]]
        xx, yy = transformer.transform(x_pts, y_pts)
        extent = (min(xx), max(xx), min(yy), max(yy))

    min_x, max_x, min_y, max_y = extent
    min_x -= buffer
    max_x += buffer
    min_y -= buffer
    max_y += buffer

    # 3. Y-Slicing (Latitude or Projection Y)
    y_dim = y_da.dims[0]
    if obj.indexes[y_dim].is_monotonic_increasing:
        obj = obj.sel({y_dim: slice(min_y, max_y)})
    else:
        obj = obj.sel({y_dim: slice(max_y, min_y)})

    # 4. X-Slicing (Longitude or Projection X)
    x_dim = x_da.dims[0]
    if not is_geographic:
        # Standard slice for projected coordinates
        if obj.indexes[x_dim].is_monotonic_increasing:
            obj = obj.sel({x_dim: slice(min_x, max_x)})
        else:
            obj = obj.sel({x_dim: slice(max_x, min_x)})
        return obj

    # 5. Longitude Wrapping Logic
    # Get grid convention from eager indexes
    lon_grid = obj.indexes[x_dim]
    g_min = lon_grid.min()

    # Normalize extent to [g_min, g_min + 360]
    norm_min_x = (min_x - g_min) % 360 + g_min
    norm_max_x = (max_x - g_min) % 360 + g_min

    # Detect if we need a wrapped slice
    if norm_min_x > norm_max_x:
        # Crosses the grid boundary
        if lon_grid.is_monotonic_increasing:
            part1 = obj.sel({x_dim: slice(norm_min_x, g_min + 360)})
            part2 = obj.sel({x_dim: slice(g_min, norm_max_x)})
        else:
            part1 = obj.sel({x_dim: slice(g_min + 360, norm_min_x)})
            part2 = obj.sel({x_dim: slice(norm_max_x, g_min)})

        # Concatenate parts
        res = xr.concat([part1, part2], dim=x_dim)
    else:
        # Simple non-wrapped slice
        if lon_grid.is_monotonic_increasing:
            res = obj.sel({x_dim: slice(norm_min_x, norm_max_x)})
        else:
            res = obj.sel({x_dim: slice(norm_max_x, norm_min_x)})

    # Metadata update
    msg = f"Spatially sliced to extent {extent} (wrapped={norm_min_x > norm_max_x})"
    update_history(res, msg)

    return res

Slice an xarray object to a spatial extent, robustly handling longitude wrapping.

from xregrid.utils import spatial_slice

# Slice a 0-360 grid to a region crossing the dateline (-20 to 20 lon)
subset = spatial_slice(ds, extent=(-20, 20, 30, 50))

unstructured_to_scrip

xregrid.utils.unstructured_to_scrip(ds)

Canonicalize an unstructured dataset (UGRID or MPAS) to SCRIP format.

Extracts connectivity information to build explicit boundary coordinates (lat_b, lon_b) on a flat 'grid_size' dimension. This enables conservative and bilinear regridding for unstructured grids that only provide connectivity.

Parameters:

Name Type Description Default
ds Dataset

The input unstructured dataset.

required

Returns:

Type Description
Dataset

A CF-compliant SCRIP-style dataset.

Source code in src/xregrid/utils.py
def unstructured_to_scrip(ds: xr.Dataset) -> xr.Dataset:
    """
    Canonicalize an unstructured dataset (UGRID or MPAS) to SCRIP format.

    Extracts connectivity information to build explicit boundary coordinates
    (lat_b, lon_b) on a flat 'grid_size' dimension. This enables conservative
    and bilinear regridding for unstructured grids that only provide connectivity.

    Parameters
    ----------
    ds : xr.Dataset
        The input unstructured dataset.

    Returns
    -------
    xr.Dataset
        A CF-compliant SCRIP-style dataset.
    """
    from .grid import _get_unstructured_mesh_info

    # 1. Get centers via _find_coord (robust)
    lat_c = _find_coord(ds, "latitude")
    lon_c = _find_coord(ds, "longitude")

    if lat_c is None or lon_c is None:
        raise ValueError("Could not find latitude/longitude centers in dataset.")

    # 2. Extract connectivity and vertices
    try:
        (
            node_lon,
            node_lat,
            element_conn,
            element_types,
            element_ids,
            orig_cell_index,
        ) = _get_unstructured_mesh_info(ds, method="conservative")
    except Exception as e:
        raise ValueError(f"Failed to extract unstructured connectivity: {e}")

    # 3. Reshape connectivity to SCRIP-style (N, 3 for triangles)
    # _get_unstructured_mesh_info always triangulates.
    n_tris = len(element_conn) // 3
    conn_2d = element_conn.reshape(n_tris, 3)

    # 4. Map nodes to corner coordinates
    lat_b = node_lat[conn_2d]
    lon_b = node_lon[conn_2d]

    # 5. Handle mapping back to original cell centers if we triangulated a polygon grid
    # If the original grid was polygons (MPAS, UGRID faces), we now have n_tris.
    # We should probably map the original centers to the triangles if possible,
    # or just use the triangle centers.
    # For now, we return the triangulated mesh as the primary representation.

    # 3. Ensure attributes are CF-compliant for centers
    lat_attrs = lat_c.attrs.copy()
    lon_attrs = lon_c.attrs.copy()
    if "standard_name" not in lat_attrs:
        lat_attrs["standard_name"] = "latitude"
    if "standard_name" not in lon_attrs:
        lon_attrs["standard_name"] = "longitude"
    if "units" not in lat_attrs:
        lat_attrs["units"] = "degrees_north"
    if "units" not in lon_attrs:
        lon_attrs["units"] = "degrees_east"

    scrip_ds = xr.Dataset(
        coords={
            "lat": (
                ["grid_size"],
                lat_c.data[orig_cell_index]
                if orig_cell_index is not None
                else lat_c.data,
                lat_attrs,
            ),
            "lon": (
                ["grid_size"],
                lon_c.data[orig_cell_index]
                if orig_cell_index is not None
                else lon_c.data,
                lon_attrs,
            ),
            "lat_b": (
                ["grid_size", "nv"],
                lat_b,
                {"units": "degrees_north", "standard_name": "latitude_bounds"},
            ),
            "lon_b": (
                ["grid_size", "nv"],
                lon_b,
                {"units": "degrees_east", "standard_name": "longitude_bounds"},
            ),
        },
        attrs=ds.attrs,
    )

    scrip_ds["lat"].attrs["bounds"] = "lat_b"
    scrip_ds["lon"].attrs["bounds"] = "lon_b"

    update_history(scrip_ds, "Canonicalized unstructured grid to SCRIP-style format.")

    # Scientific Hygiene: add attributes that help regridder identify it as unstructured
    scrip_ds["lat"].attrs["location"] = "face"
    scrip_ds["lon"].attrs["location"] = "face"

    return scrip_ds

Canonicalize an unstructured dataset (UGRID or MPAS) to SCRIP format.

from xregrid.utils import unstructured_to_scrip

scrip_ds = unstructured_to_scrip(ds_ugrid)

mpas_to_scrip

xregrid.utils.mpas_to_scrip(ds)

Convert an MPAS-native dataset to a CF-compliant SCRIP-style format.

Alias for unstructured_to_scrip with MPAS-specific validation.

Parameters:

Name Type Description Default
ds Dataset

The MPAS dataset.

required

Returns:

Type Description
Dataset

SCRIP-style dataset.

Source code in src/xregrid/utils.py
def mpas_to_scrip(ds: xr.Dataset) -> xr.Dataset:
    """
    Convert an MPAS-native dataset to a CF-compliant SCRIP-style format.

    Alias for unstructured_to_scrip with MPAS-specific validation.

    Parameters
    ----------
    ds : xr.Dataset
        The MPAS dataset.

    Returns
    -------
    xr.Dataset
        SCRIP-style dataset.
    """
    if "nCells" not in ds.dims:
        raise ValueError("Dataset does not appear to be an MPAS grid (missing nCells).")
    return unstructured_to_scrip(ds)

Convert an MPAS-native dataset to a CF-compliant SCRIP-style format.

from xregrid.utils import mpas_to_scrip

scrip_ds = mpas_to_scrip(ds_mpas)

ESMF File Support

load_esmf_file

xregrid.load_esmf_file(filepath)

Load an ESMF mesh, mosaic, or grid file into an xarray Dataset.

Automatically recognizes SCRIP/ESMF standard variable names and renames them to 'lat', 'lon', 'lat_b', 'lon_b' while adding CF attributes.

Parameters:

Name Type Description Default
filepath str

Path to the ESMF file.

required

Returns:

Type Description
Dataset

The dataset representation of the ESMF file.

Source code in src/xregrid/utils.py
def load_esmf_file(filepath: str) -> xr.Dataset:
    """
    Load an ESMF mesh, mosaic, or grid file into an xarray Dataset.

    Automatically recognizes SCRIP/ESMF standard variable names and renames
    them to 'lat', 'lon', 'lat_b', 'lon_b' while adding CF attributes.

    Parameters
    ----------
    filepath : str
        Path to the ESMF file.

    Returns
    -------
    xr.Dataset
        The dataset representation of the ESMF file.
    """
    ds = xr.open_dataset(filepath)

    # Recognize SCRIP/ESMF standard names
    rename_map = {
        "grid_center_lat": "lat",
        "grid_center_lon": "lon",
        "grid_corner_lat": "lat_b",
        "grid_corner_lon": "lon_b",
        "grid_imask": "mask",
    }

    found_renames = {k: v for k, v in rename_map.items() if k in ds}

    if found_renames:
        ds = ds.rename(found_renames)
        message = f"Loaded ESMF file and renamed standard variables: {found_renames}"
    else:
        message = f"Loaded ESMF file from {filepath}."

    # Add CF attributes if missing for better cf-xarray discovery
    if "lat" in ds:
        if "units" not in ds["lat"].attrs:
            ds["lat"].attrs["units"] = "degrees_north"
        if "standard_name" not in ds["lat"].attrs:
            ds["lat"].attrs["standard_name"] = "latitude"

    if "lon" in ds:
        if "units" not in ds["lon"].attrs:
            ds["lon"].attrs["units"] = "degrees_east"
        if "standard_name" not in ds["lon"].attrs:
            ds["lon"].attrs["standard_name"] = "longitude"

    # Link bounds if present
    if "lat" in ds and "lat_b" in ds:
        ds["lat"].attrs["bounds"] = "lat_b"
    if "lon" in ds and "lon_b" in ds:
        ds["lon"].attrs["bounds"] = "lon_b"

    update_history(ds, message)

    return ds

Load an ESMF mesh, mosaic, or grid file into an xarray Dataset.

from xregrid import load_esmf_file

# Load an ESMF mesh file
ds = load_esmf_file("path/to/mesh.nc")

High-Performance Computing

get_rdhpcs_cluster

xregrid.utils.get_rdhpcs_cluster(machine=None, account=None, **kwargs)

Create a dask-jobqueue SLURMCluster for NOAA RDHPCS systems.

This helper automatically detects the machine if not provided and sets up reasonable defaults for Hera, Jet, and Gaea.

Parameters:

Name Type Description Default
machine str

Machine name ('hera', 'jet', 'gaea-c5', 'gaea-c6', 'ursa'). If None, attempts to detect based on hostname.

None
account str

SLURM account/project for charging.

None
**kwargs Any

Additional keyword arguments passed to SLURMCluster.

{}

Returns:

Type Description
SLURMCluster

The configured cluster object.

Source code in src/xregrid/utils.py
def get_rdhpcs_cluster(
    machine: Optional[str] = None,
    account: Optional[str] = None,
    **kwargs: Any,
) -> Any:
    """
    Create a dask-jobqueue SLURMCluster for NOAA RDHPCS systems.

    This helper automatically detects the machine if not provided and sets up
    reasonable defaults for Hera, Jet, and Gaea.

    Parameters
    ----------
    machine : str, optional
        Machine name ('hera', 'jet', 'gaea-c5', 'gaea-c6', 'ursa').
        If None, attempts to detect based on hostname.
    account : str, optional
        SLURM account/project for charging.
    **kwargs
        Additional keyword arguments passed to SLURMCluster.

    Returns
    -------
    dask_jobqueue.SLURMCluster
        The configured cluster object.
    """
    try:
        from dask_jobqueue import SLURMCluster
    except ImportError:
        raise ImportError(
            "dask-jobqueue is required for get_rdhpcs_cluster. "
            "Install it with `pip install dask-jobqueue`."
        )

    hostname = socket.gethostname()
    if machine is None:
        if "ufe" in hostname or "ursa" in hostname:
            machine = "ursa"
        elif "hfe" in hostname or "heralogin" in hostname:
            machine = "hera"
        elif "fe" in hostname and "jet" in hostname:
            machine = "jet"
        elif "gaea" in hostname:
            # Hard to distinguish c5/c6 from hostname alone usually
            machine = "gaea-c5"
        else:
            raise ValueError(
                f"Could not detect NOAA RDHPCS machine from hostname '{hostname}'. "
                "Please specify 'machine' explicitly."
            )

    defaults = {
        "account": account or os.environ.get("SACCOUNT"),
        "walltime": "01:00:00",
    }

    if machine == "hera":
        defaults.update(
            {
                "queue": "hera",
                "cores": 40,
                "processes": 40,
                "memory": "160GB",
                "job_extra_directives": ["--exclusive"],
            }
        )
    elif machine == "jet":
        defaults.update(
            {
                "queue": "batch",
                "cores": 24,
                "processes": 12,
                "memory": "120GB",
            }
        )
    elif machine.startswith("gaea"):
        cluster_ver = machine.split("-")[-1] if "-" in machine else "c5"
        cores = 128 if cluster_ver == "c5" else 192
        defaults.update(
            {
                "queue": "batch",
                "cores": cores,
                "processes": 16,
                "memory": "256GB" if cluster_ver == "c5" else "384GB",
                "job_extra_directives": [f"-M {cluster_ver}"],
            }
        )
    elif machine == "ursa":
        defaults.update(
            {
                "queue": "u1-compute",
                "cores": 192,
                "processes": 32,
                "memory": "384GB",
                "job_extra_directives": ["--exclusive"],
            }
        )

    # Override defaults with user kwargs
    defaults.update(kwargs)

    if defaults["account"] is None:
        import warnings

        warnings.warn(
            "No SLURM account specified. Please provide 'account' or set SACCOUNT environment variable."
        )

    return SLURMCluster(**defaults)

Create a dask-jobqueue SLURMCluster for NOAA RDHPCS systems (Hera, Jet, Gaea, Ursa).

from xregrid.utils import get_rdhpcs_cluster
from distributed import Client

# Automatically detect machine and setup cluster
cluster = get_rdhpcs_cluster(account="your_account")
cluster.scale(jobs=4)
client = Client(cluster)