""" Read trajectory NetCDF and return normalized arrays: frames[step][particle] = [lon, lat, u, v, beached]. Supports common CF / Parcels-style dimensions (time, traj) or (obs, particle). Reads actual time axis when present for correct calendar display (e.g. 27/8–3/9 instead of 83 days). """ from __future__ import annotations import numpy as np try: import xarray as xr except ImportError: xr = None from .schema import ( LON_NAMES, LAT_NAMES, U_NAMES, V_NAMES, BEACHED_NAMES, TIME_NAMES, ) def _find_var(ds, names: tuple[str, ...]): for n in names: if n in ds: return ds[n] return None def load_nc( path: str, ) -> tuple[ np.ndarray, int, int, list[int], list[str], list[int], str | None, str | None, np.ndarray | None, ]: """ Load a trajectory NetCDF file. Returns: frames: shape (num_steps, num_particles, 5) with [lon, lat, u, v, beached] num_particles: int num_steps: int release_steps: list of length num_particles (step index when particle appears) seed_names: list of origin names (may be empty) origins: list of length num_particles (index into seed_names) time_start_iso: ISO datetime string of first time (or None) time_end_iso: ISO datetime string of last time (or None) time_ref_sec: 1D array (num_steps,) of epoch seconds per step for subsampling (or None) """ if xr is None: raise ImportError("xarray is required for NetCDF ETL. pip install xarray netCDF4") ds = xr.open_dataset(path) # Resolve dimensions: time/obs and particle/traj time_dim = None particle_dim = None for d in ds.dims: dlo = d.lower() if dlo in ("time", "obs", "step", "t", "nt"): time_dim = d elif dlo in ("particle", "traj", "n_particles", "trajectory", "pid"): particle_dim = d if time_dim is None: time_dim = list(ds.dims)[0] if particle_dim is None: dims = [d for d in ds.dims if d != time_dim] particle_dim = dims[0] if dims else None if particle_dim is None: raise ValueError("Could not identify time and particle dimensions in NetCDF") lon_var = _find_var(ds, LON_NAMES) lat_var = _find_var(ds, LAT_NAMES) if lon_var is None or lat_var is None: raise ValueError("NetCDF must contain lon/lat (or longitude/latitude) variables") lon = np.asarray(lon_var).astype(np.float64) lat = np.asarray(lat_var).astype(np.float64) dims = getattr(lon_var, "dims", None) or () if lon.ndim == 1 and lat.ndim == 1: if len(lon) == len(lat): num_particles = 1 num_steps = len(lon) lon = lon.reshape(num_steps, 1) lat = lat.reshape(num_steps, 1) else: raise ValueError("lon/lat 1D but different lengths") elif lon.ndim == 2: if dims == (particle_dim, time_dim): lon = lon.T lat = lat.T num_steps, num_particles = lon.shape else: raise ValueError("lon/lat must be 1D or 2D") u_var = _find_var(ds, U_NAMES) v_var = _find_var(ds, V_NAMES) if u_var is not None and v_var is not None: u = np.asarray(u_var).astype(np.float32) v = np.asarray(v_var).astype(np.float32) if u.ndim == 2 and getattr(u_var, "dims", ()) == (particle_dim, time_dim): u, v = u.T, v.T if u.shape != (num_steps, num_particles): u = np.broadcast_to(u.ravel()[: num_steps * num_particles].reshape(num_steps, num_particles), (num_steps, num_particles)) if v.shape != (num_steps, num_particles): v = np.broadcast_to(v.ravel()[: num_steps * num_particles].reshape(num_steps, num_particles), (num_steps, num_particles)) else: # Derive u, v from lon, lat differences (m/s scale: deg/hour -> rough m/s) deg_per_m = 1.0 / 111_320 u = np.zeros((num_steps, num_particles), dtype=np.float32) v = np.zeros((num_steps, num_particles), dtype=np.float32) u[1:] = (lon[1:] - lon[:-1]) * deg_per_m * 111320 / 3600 v[1:] = (lat[1:] - lat[:-1]) * deg_per_m * 111320 / 3600 beached_var = _find_var(ds, BEACHED_NAMES) if beached_var is not None: raw = np.asarray(beached_var) if raw.ndim == 2 and getattr(beached_var, "dims", ()) == (particle_dim, time_dim): raw = raw.T if raw.shape != (num_steps, num_particles): beached = np.zeros((num_steps, num_particles), dtype=np.int8) else: # Normalize: 1 = beached, 0 = not. Ignore fill values (e.g. -2e9, NaN) beached = np.where( (np.isfinite(raw)) & (raw > 0.5), 1, 0, ).astype(np.int8) else: beached = np.zeros((num_steps, num_particles), dtype=np.int8) frames = np.stack([lon.astype(np.float32), lat.astype(np.float32), u, v, beached], axis=-1) # Release step: first non-NaN or first step release_steps = [] for p in range(num_particles): valid = np.where(np.isfinite(frames[:, p, 0]))[0] release_steps.append(int(valid[0]) if len(valid) else 0) seed_names = [] origins = [0] * num_particles # Optional: origin/seed from NetCDF if "origin" in ds: o = np.asarray(ds["origin"]).ravel() if len(o) >= num_particles: origins = o[:num_particles].astype(int).tolist() if "seed_names" in ds.attrs: seed_names = list(ds.attrs["seed_names"]) elif "seed_name" in ds: sn = np.asarray(ds["seed_name"]) if sn.ndim >= 1: try: seed_names = [str(x) for x in sn.values.ravel()[:num_particles]] except Exception: seed_names = [f"Seed {i}" for i in range(max(origins) + 1)] # Optional: actual time axis for calendar display and hourly subsampling time_start_iso: str | None = None time_end_iso: str | None = None time_ref_sec: np.ndarray | None = None time_var = _find_var(ds, ("time",) + TIME_NAMES) if time_var is not None: try: t = np.asarray(time_var) if time_var.dims == (particle_dim, time_dim): t = t.T # (num_steps, num_particles) if t.shape == (num_steps, num_particles): t_flat = t.ravel() valid = t_flat[~np.isnat(t_flat)] if len(valid) > 0: time_start_iso = str(np.nanmin(valid)) time_end_iso = str(np.nanmax(valid)) # Reference time per step (min over particles) in epoch seconds for subsampling time_ref_sec = np.full(num_steps, np.nan, dtype=np.float64) for s in range(num_steps): row = t[s, :] row = row[~np.isnat(row)] if len(row) > 0: secs = row.astype("datetime64[s]").astype(np.float64) time_ref_sec[s] = float(np.min(secs)) except Exception: pass ds.close() return ( frames, num_particles, num_steps, release_steps, seed_names, origins, time_start_iso, time_end_iso, time_ref_sec, )