2026-05-04 09:44:56 +02:00

202 lines
7.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Read trajectory NetCDF and return normalized arrays: frames[step][particle] = [lon, lat, u, v, beached].
Supports common CF / Parcels-style dimensions (time, traj) or (obs, particle).
Reads actual time axis when present for correct calendar display (e.g. 27/83/9 instead of 83 days).
"""
from __future__ import annotations
import numpy as np
try:
import xarray as xr
except ImportError:
xr = None
from .schema import (
LON_NAMES,
LAT_NAMES,
U_NAMES,
V_NAMES,
BEACHED_NAMES,
TIME_NAMES,
)
def _find_var(ds, names: tuple[str, ...]):
for n in names:
if n in ds:
return ds[n]
return None
def load_nc(
path: str,
) -> tuple[
np.ndarray, int, int, list[int], list[str], list[int],
str | None, str | None, np.ndarray | None,
]:
"""
Load a trajectory NetCDF file.
Returns:
frames: shape (num_steps, num_particles, 5) with [lon, lat, u, v, beached]
num_particles: int
num_steps: int
release_steps: list of length num_particles (step index when particle appears)
seed_names: list of origin names (may be empty)
origins: list of length num_particles (index into seed_names)
time_start_iso: ISO datetime string of first time (or None)
time_end_iso: ISO datetime string of last time (or None)
time_ref_sec: 1D array (num_steps,) of epoch seconds per step for subsampling (or None)
"""
if xr is None:
raise ImportError("xarray is required for NetCDF ETL. pip install xarray netCDF4")
ds = xr.open_dataset(path)
# Resolve dimensions: time/obs and particle/traj
time_dim = None
particle_dim = None
for d in ds.dims:
dlo = d.lower()
if dlo in ("time", "obs", "step", "t", "nt"):
time_dim = d
elif dlo in ("particle", "traj", "n_particles", "trajectory", "pid"):
particle_dim = d
if time_dim is None:
time_dim = list(ds.dims)[0]
if particle_dim is None:
dims = [d for d in ds.dims if d != time_dim]
particle_dim = dims[0] if dims else None
if particle_dim is None:
raise ValueError("Could not identify time and particle dimensions in NetCDF")
lon_var = _find_var(ds, LON_NAMES)
lat_var = _find_var(ds, LAT_NAMES)
if lon_var is None or lat_var is None:
raise ValueError("NetCDF must contain lon/lat (or longitude/latitude) variables")
lon = np.asarray(lon_var).astype(np.float64)
lat = np.asarray(lat_var).astype(np.float64)
dims = getattr(lon_var, "dims", None) or ()
if lon.ndim == 1 and lat.ndim == 1:
if len(lon) == len(lat):
num_particles = 1
num_steps = len(lon)
lon = lon.reshape(num_steps, 1)
lat = lat.reshape(num_steps, 1)
else:
raise ValueError("lon/lat 1D but different lengths")
elif lon.ndim == 2:
if dims == (particle_dim, time_dim):
lon = lon.T
lat = lat.T
num_steps, num_particles = lon.shape
else:
raise ValueError("lon/lat must be 1D or 2D")
u_var = _find_var(ds, U_NAMES)
v_var = _find_var(ds, V_NAMES)
if u_var is not None and v_var is not None:
u = np.asarray(u_var).astype(np.float32)
v = np.asarray(v_var).astype(np.float32)
if u.ndim == 2 and getattr(u_var, "dims", ()) == (particle_dim, time_dim):
u, v = u.T, v.T
if u.shape != (num_steps, num_particles):
u = np.broadcast_to(u.ravel()[: num_steps * num_particles].reshape(num_steps, num_particles), (num_steps, num_particles))
if v.shape != (num_steps, num_particles):
v = np.broadcast_to(v.ravel()[: num_steps * num_particles].reshape(num_steps, num_particles), (num_steps, num_particles))
else:
# Derive u, v from lon, lat differences (m/s scale: deg/hour -> rough m/s)
deg_per_m = 1.0 / 111_320
u = np.zeros((num_steps, num_particles), dtype=np.float32)
v = np.zeros((num_steps, num_particles), dtype=np.float32)
u[1:] = (lon[1:] - lon[:-1]) * deg_per_m * 111320 / 3600
v[1:] = (lat[1:] - lat[:-1]) * deg_per_m * 111320 / 3600
beached_var = _find_var(ds, BEACHED_NAMES)
if beached_var is not None:
raw = np.asarray(beached_var)
if raw.ndim == 2 and getattr(beached_var, "dims", ()) == (particle_dim, time_dim):
raw = raw.T
if raw.shape != (num_steps, num_particles):
beached = np.zeros((num_steps, num_particles), dtype=np.int8)
else:
# Normalize: 1 = beached, 0 = not. Ignore fill values (e.g. -2e9, NaN)
beached = np.where(
(np.isfinite(raw)) & (raw > 0.5),
1,
0,
).astype(np.int8)
else:
beached = np.zeros((num_steps, num_particles), dtype=np.int8)
frames = np.stack([lon.astype(np.float32), lat.astype(np.float32), u, v, beached], axis=-1)
# Release step: first non-NaN or first step
release_steps = []
for p in range(num_particles):
valid = np.where(np.isfinite(frames[:, p, 0]))[0]
release_steps.append(int(valid[0]) if len(valid) else 0)
seed_names = []
origins = [0] * num_particles
# Optional: origin/seed from NetCDF
if "origin" in ds:
o = np.asarray(ds["origin"]).ravel()
if len(o) >= num_particles:
origins = o[:num_particles].astype(int).tolist()
if "seed_names" in ds.attrs:
seed_names = list(ds.attrs["seed_names"])
elif "seed_name" in ds:
sn = np.asarray(ds["seed_name"])
if sn.ndim >= 1:
try:
seed_names = [str(x) for x in sn.values.ravel()[:num_particles]]
except Exception:
seed_names = [f"Seed {i}" for i in range(max(origins) + 1)]
# Optional: actual time axis for calendar display and hourly subsampling
time_start_iso: str | None = None
time_end_iso: str | None = None
time_ref_sec: np.ndarray | None = None
time_var = _find_var(ds, ("time",) + TIME_NAMES)
if time_var is not None:
try:
t = np.asarray(time_var)
if time_var.dims == (particle_dim, time_dim):
t = t.T # (num_steps, num_particles)
if t.shape == (num_steps, num_particles):
t_flat = t.ravel()
valid = t_flat[~np.isnat(t_flat)]
if len(valid) > 0:
time_start_iso = str(np.nanmin(valid))
time_end_iso = str(np.nanmax(valid))
# Reference time per step (min over particles) in epoch seconds for subsampling
time_ref_sec = np.full(num_steps, np.nan, dtype=np.float64)
for s in range(num_steps):
row = t[s, :]
row = row[~np.isnat(row)]
if len(row) > 0:
secs = row.astype("datetime64[s]").astype(np.float64)
time_ref_sec[s] = float(np.min(secs))
except Exception:
pass
ds.close()
return (
frames,
num_particles,
num_steps,
release_steps,
seed_names,
origins,
time_start_iso,
time_end_iso,
time_ref_sec,
)