2026-05-04 09:44:56 +02:00

138 lines
5.0 KiB
Python

"""
Write trajectory frames to block-wise Arrow files and metadata.json.
Frontend expects: metadata.json, block_0.arrow, block_1.arrow, ...
Arrow table per block: columns step, particle_id, lon, lat, u, v, beached; rows ordered by step then particle_id.
"""
from __future__ import annotations
import json
from pathlib import Path
import numpy as np
import pyarrow as pa
from .schema import BLOCK_SIZE
def build_accumulation(lons: np.ndarray, lats: np.ndarray, res: float = 0.1) -> list[list[float]]:
"""Build accumulation grid cells for beached particles [lon_cell, lat_cell, count]."""
if lons.size == 0:
return []
cells: dict[tuple[int, int], int] = {}
for lo, la in zip(lons.ravel().tolist(), lats.ravel().tolist()):
if not np.isfinite(lo) or not np.isfinite(la):
continue
key = (round(lo / res), round(la / res))
cells[key] = cells.get(key, 0) + 1
return [[k[0] * res, k[1] * res, v] for k, v in cells.items()]
def write_blocks_arrow(
out_dir: Path,
frames: np.ndarray,
num_particles: int,
num_steps: int,
release_steps: list[int],
seed_names: list[str],
origins: list[int],
write_json: bool = False,
time_start_iso: str | None = None,
time_end_iso: str | None = None,
) -> dict:
"""
Write blocks as Arrow IPC and metadata.json.
frames: (num_steps, num_particles, 5) with [lon, lat, u, v, beached].
Returns metadata dict for API.
"""
out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
# Remove any existing block files from a previous run (e.g. after re-ETL with fewer steps)
for p in out_dir.iterdir():
if p.is_file() and p.name.startswith("block_") and (p.suffix == ".arrow" or p.suffix == ".json"):
p.unlink()
num_blocks = (num_steps + BLOCK_SIZE - 1) // BLOCK_SIZE
metadata: dict = {
"num_particles": num_particles,
"num_steps": num_steps,
"num_blocks": num_blocks,
"block_size": BLOCK_SIZE,
"dt_hours": 1,
"seed_names": seed_names or [f"Origin {i}" for i in range(max(origins) + 1)] if origins else ["Default"],
"origins": origins,
"release_steps": release_steps,
}
if time_start_iso and time_end_iso:
metadata["time_start"] = time_start_iso
metadata["time_end"] = time_end_iso
for block_id in range(num_blocks):
step_start = block_id * BLOCK_SIZE
step_end = min(step_start + BLOCK_SIZE, num_steps)
chunk = frames[step_start:step_end]
n_steps_chunk = chunk.shape[0]
steps_list = []
particle_ids_list = []
lons_list = []
lats_list = []
u_list = []
v_list = []
beached_list = []
for step_idx in range(n_steps_chunk):
step_global = step_start + step_idx
for pid in range(num_particles):
row = chunk[step_idx, pid]
steps_list.append(step_global)
particle_ids_list.append(pid)
lons_list.append(float(row[0]))
lats_list.append(float(row[1]))
u_list.append(float(row[2]))
v_list.append(float(row[3]))
# beached: index 4; ensure 0/1 (source can be float)
b_val = row[4]
beached_list.append(1 if (float(b_val) > 0.5) else 0)
# Build columns; beached from numpy to ensure int8 layout for IPC
beached_np = np.array(beached_list, dtype=np.int8)
table = pa.table({
"step": pa.array(steps_list, type=pa.int32()),
"particle_id": pa.array(particle_ids_list, type=pa.int32()),
"lon": pa.array(lons_list, type=pa.float32()),
"lat": pa.array(lats_list, type=pa.float32()),
"u": pa.array(u_list, type=pa.float32()),
"v": pa.array(v_list, type=pa.float32()),
"beached": pa.array(beached_np),
})
arrow_path = out_dir / f"block_{block_id}.arrow"
with open(arrow_path, "wb") as f:
with pa.ipc.new_stream(f, table.schema) as writer:
writer.write_table(table)
if write_json:
frames_json = chunk.astype(float).tolist()
beached_mask = chunk[:, :, 4].astype(bool)
accumulation = []
for step_idx in range(n_steps_chunk):
b_lons = chunk[step_idx, beached_mask[step_idx], 0]
b_lats = chunk[step_idx, beached_mask[step_idx], 1]
accumulation.append(build_accumulation(b_lons, b_lats))
block_json = {
"block": block_id,
"step_start": step_start,
"step_end": step_end,
"frames": frames_json,
"accumulation": accumulation,
}
with open(out_dir / f"block_{block_id}.json", "w") as f:
json.dump(block_json, f, separators=(",", ":"))
with open(out_dir / "metadata.json", "w") as f:
json.dump(metadata, f, separators=(",", ":"))
return metadata