202 lines
7.2 KiB
Python
202 lines
7.2 KiB
Python
"""
|
||
Read trajectory NetCDF and return normalized arrays: frames[step][particle] = [lon, lat, u, v, beached].
|
||
Supports common CF / Parcels-style dimensions (time, traj) or (obs, particle).
|
||
Reads actual time axis when present for correct calendar display (e.g. 27/8–3/9 instead of 83 days).
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import numpy as np
|
||
|
||
try:
|
||
import xarray as xr
|
||
except ImportError:
|
||
xr = None
|
||
|
||
from .schema import (
|
||
LON_NAMES,
|
||
LAT_NAMES,
|
||
U_NAMES,
|
||
V_NAMES,
|
||
BEACHED_NAMES,
|
||
TIME_NAMES,
|
||
)
|
||
|
||
|
||
def _find_var(ds, names: tuple[str, ...]):
|
||
for n in names:
|
||
if n in ds:
|
||
return ds[n]
|
||
return None
|
||
|
||
|
||
def load_nc(
|
||
path: str,
|
||
) -> tuple[
|
||
np.ndarray, int, int, list[int], list[str], list[int],
|
||
str | None, str | None, np.ndarray | None,
|
||
]:
|
||
"""
|
||
Load a trajectory NetCDF file.
|
||
|
||
Returns:
|
||
frames: shape (num_steps, num_particles, 5) with [lon, lat, u, v, beached]
|
||
num_particles: int
|
||
num_steps: int
|
||
release_steps: list of length num_particles (step index when particle appears)
|
||
seed_names: list of origin names (may be empty)
|
||
origins: list of length num_particles (index into seed_names)
|
||
time_start_iso: ISO datetime string of first time (or None)
|
||
time_end_iso: ISO datetime string of last time (or None)
|
||
time_ref_sec: 1D array (num_steps,) of epoch seconds per step for subsampling (or None)
|
||
"""
|
||
if xr is None:
|
||
raise ImportError("xarray is required for NetCDF ETL. pip install xarray netCDF4")
|
||
|
||
ds = xr.open_dataset(path)
|
||
|
||
# Resolve dimensions: time/obs and particle/traj
|
||
time_dim = None
|
||
particle_dim = None
|
||
for d in ds.dims:
|
||
dlo = d.lower()
|
||
if dlo in ("time", "obs", "step", "t", "nt"):
|
||
time_dim = d
|
||
elif dlo in ("particle", "traj", "n_particles", "trajectory", "pid"):
|
||
particle_dim = d
|
||
|
||
if time_dim is None:
|
||
time_dim = list(ds.dims)[0]
|
||
if particle_dim is None:
|
||
dims = [d for d in ds.dims if d != time_dim]
|
||
particle_dim = dims[0] if dims else None
|
||
|
||
if particle_dim is None:
|
||
raise ValueError("Could not identify time and particle dimensions in NetCDF")
|
||
|
||
lon_var = _find_var(ds, LON_NAMES)
|
||
lat_var = _find_var(ds, LAT_NAMES)
|
||
if lon_var is None or lat_var is None:
|
||
raise ValueError("NetCDF must contain lon/lat (or longitude/latitude) variables")
|
||
|
||
lon = np.asarray(lon_var).astype(np.float64)
|
||
lat = np.asarray(lat_var).astype(np.float64)
|
||
dims = getattr(lon_var, "dims", None) or ()
|
||
|
||
if lon.ndim == 1 and lat.ndim == 1:
|
||
if len(lon) == len(lat):
|
||
num_particles = 1
|
||
num_steps = len(lon)
|
||
lon = lon.reshape(num_steps, 1)
|
||
lat = lat.reshape(num_steps, 1)
|
||
else:
|
||
raise ValueError("lon/lat 1D but different lengths")
|
||
elif lon.ndim == 2:
|
||
if dims == (particle_dim, time_dim):
|
||
lon = lon.T
|
||
lat = lat.T
|
||
num_steps, num_particles = lon.shape
|
||
else:
|
||
raise ValueError("lon/lat must be 1D or 2D")
|
||
|
||
u_var = _find_var(ds, U_NAMES)
|
||
v_var = _find_var(ds, V_NAMES)
|
||
if u_var is not None and v_var is not None:
|
||
u = np.asarray(u_var).astype(np.float32)
|
||
v = np.asarray(v_var).astype(np.float32)
|
||
if u.ndim == 2 and getattr(u_var, "dims", ()) == (particle_dim, time_dim):
|
||
u, v = u.T, v.T
|
||
if u.shape != (num_steps, num_particles):
|
||
u = np.broadcast_to(u.ravel()[: num_steps * num_particles].reshape(num_steps, num_particles), (num_steps, num_particles))
|
||
if v.shape != (num_steps, num_particles):
|
||
v = np.broadcast_to(v.ravel()[: num_steps * num_particles].reshape(num_steps, num_particles), (num_steps, num_particles))
|
||
else:
|
||
# Derive u, v from lon, lat differences (m/s scale: deg/hour -> rough m/s)
|
||
deg_per_m = 1.0 / 111_320
|
||
u = np.zeros((num_steps, num_particles), dtype=np.float32)
|
||
v = np.zeros((num_steps, num_particles), dtype=np.float32)
|
||
u[1:] = (lon[1:] - lon[:-1]) * deg_per_m * 111320 / 3600
|
||
v[1:] = (lat[1:] - lat[:-1]) * deg_per_m * 111320 / 3600
|
||
|
||
beached_var = _find_var(ds, BEACHED_NAMES)
|
||
if beached_var is not None:
|
||
raw = np.asarray(beached_var)
|
||
if raw.ndim == 2 and getattr(beached_var, "dims", ()) == (particle_dim, time_dim):
|
||
raw = raw.T
|
||
if raw.shape != (num_steps, num_particles):
|
||
beached = np.zeros((num_steps, num_particles), dtype=np.int8)
|
||
else:
|
||
# Normalize: 1 = beached, 0 = not. Ignore fill values (e.g. -2e9, NaN)
|
||
beached = np.where(
|
||
(np.isfinite(raw)) & (raw > 0.5),
|
||
1,
|
||
0,
|
||
).astype(np.int8)
|
||
else:
|
||
beached = np.zeros((num_steps, num_particles), dtype=np.int8)
|
||
|
||
frames = np.stack([lon.astype(np.float32), lat.astype(np.float32), u, v, beached], axis=-1)
|
||
|
||
# Release step: first non-NaN or first step
|
||
release_steps = []
|
||
for p in range(num_particles):
|
||
valid = np.where(np.isfinite(frames[:, p, 0]))[0]
|
||
release_steps.append(int(valid[0]) if len(valid) else 0)
|
||
seed_names = []
|
||
origins = [0] * num_particles
|
||
|
||
# Optional: origin/seed from NetCDF
|
||
if "origin" in ds:
|
||
o = np.asarray(ds["origin"]).ravel()
|
||
if len(o) >= num_particles:
|
||
origins = o[:num_particles].astype(int).tolist()
|
||
if "seed_names" in ds.attrs:
|
||
seed_names = list(ds.attrs["seed_names"])
|
||
elif "seed_name" in ds:
|
||
sn = np.asarray(ds["seed_name"])
|
||
if sn.ndim >= 1:
|
||
try:
|
||
seed_names = [str(x) for x in sn.values.ravel()[:num_particles]]
|
||
except Exception:
|
||
seed_names = [f"Seed {i}" for i in range(max(origins) + 1)]
|
||
|
||
# Optional: actual time axis for calendar display and hourly subsampling
|
||
time_start_iso: str | None = None
|
||
time_end_iso: str | None = None
|
||
time_ref_sec: np.ndarray | None = None
|
||
time_var = _find_var(ds, ("time",) + TIME_NAMES)
|
||
if time_var is not None:
|
||
try:
|
||
t = np.asarray(time_var)
|
||
if time_var.dims == (particle_dim, time_dim):
|
||
t = t.T # (num_steps, num_particles)
|
||
if t.shape == (num_steps, num_particles):
|
||
t_flat = t.ravel()
|
||
valid = t_flat[~np.isnat(t_flat)]
|
||
if len(valid) > 0:
|
||
time_start_iso = str(np.nanmin(valid))
|
||
time_end_iso = str(np.nanmax(valid))
|
||
# Reference time per step (min over particles) in epoch seconds for subsampling
|
||
time_ref_sec = np.full(num_steps, np.nan, dtype=np.float64)
|
||
for s in range(num_steps):
|
||
row = t[s, :]
|
||
row = row[~np.isnat(row)]
|
||
if len(row) > 0:
|
||
secs = row.astype("datetime64[s]").astype(np.float64)
|
||
time_ref_sec[s] = float(np.min(secs))
|
||
except Exception:
|
||
pass
|
||
|
||
ds.close()
|
||
return (
|
||
frames,
|
||
num_particles,
|
||
num_steps,
|
||
release_steps,
|
||
seed_names,
|
||
origins,
|
||
time_start_iso,
|
||
time_end_iso,
|
||
time_ref_sec,
|
||
)
|