169 lines
5.6 KiB
Python
169 lines
5.6 KiB
Python
"""FastAPI server – serves simulation blocks (JSON or Arrow) with gzip compression."""
|
||
|
||
import json
|
||
from pathlib import Path
|
||
|
||
import pyarrow as pa
|
||
from fastapi import FastAPI, HTTPException, Query
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.middleware.gzip import GZipMiddleware
|
||
from fastapi.requests import Request
|
||
from fastapi.responses import Response
|
||
|
||
DATA_DIR = Path(__file__).parent / "data"
|
||
|
||
app = FastAPI(title="Trajectory Viewer API")
|
||
|
||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["http://localhost:3000"],
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
# Cache by (sim_id, name) for metadata/json. Arrow not cached so ETL regeneration is visible without restart.
|
||
_meta_cache: dict[tuple[str, str], dict] = {}
|
||
|
||
|
||
def _sim_dir(sim_id: str) -> Path:
|
||
"""Resolve data directory for a simulation. Backward compat: default can be data/ if no simulations/."""
|
||
sim_path = DATA_DIR / "simulations" / sim_id
|
||
if sim_path.exists():
|
||
return sim_path
|
||
if sim_id == "default" and (DATA_DIR / "metadata.json").exists():
|
||
return DATA_DIR
|
||
return sim_path
|
||
|
||
|
||
def _load(sim_id: str, name: str) -> dict | None:
|
||
p = _sim_dir(sim_id) / name
|
||
if not p.exists():
|
||
return None
|
||
# Never cache metadata so regenerated ETL (e.g. time_start/time_end) is always served
|
||
if name == "metadata.json":
|
||
with open(p) as f:
|
||
return json.load(f)
|
||
key = (sim_id, name)
|
||
if key not in _meta_cache:
|
||
with open(p) as f:
|
||
_meta_cache[key] = json.load(f)
|
||
return _meta_cache[key]
|
||
|
||
|
||
def _block_to_arrow(sim_id: str, block_id: int) -> bytes:
|
||
sim_path = _sim_dir(sim_id)
|
||
arrow_file = sim_path / f"block_{block_id}.arrow"
|
||
if arrow_file.exists():
|
||
with open(arrow_file, "rb") as f:
|
||
return f.read()
|
||
data = _load(sim_id, f"block_{block_id}.json")
|
||
if not data:
|
||
raise HTTPException(404, f"Block {block_id} not found")
|
||
frames = data["frames"]
|
||
step_start = data["step_start"]
|
||
n_particles = len(frames[0]) if frames else 0
|
||
steps = []
|
||
particle_ids = []
|
||
lons, lats, us, vs, beached = [], [], [], [], []
|
||
for step_idx, frame in enumerate(frames):
|
||
step_global = step_start + step_idx
|
||
for pid, row in enumerate(frame):
|
||
steps.append(step_global)
|
||
particle_ids.append(pid)
|
||
lons.append(row[0])
|
||
lats.append(row[1])
|
||
us.append(row[2])
|
||
vs.append(row[3])
|
||
beached.append(row[4])
|
||
table = pa.table({
|
||
"step": pa.array(steps, type=pa.int32()),
|
||
"particle_id": pa.array(particle_ids, type=pa.int32()),
|
||
"lon": pa.array(lons, type=pa.float32()),
|
||
"lat": pa.array(lats, type=pa.float32()),
|
||
"u": pa.array(us, type=pa.float32()),
|
||
"v": pa.array(vs, type=pa.float32()),
|
||
"beached": pa.array(beached, type=pa.int8()),
|
||
})
|
||
sink = pa.BufferOutputStream()
|
||
with pa.ipc.new_stream(sink, table.schema) as writer:
|
||
writer.write_table(table)
|
||
return sink.getvalue()
|
||
|
||
|
||
@app.get("/api/simulations")
|
||
def list_simulations():
|
||
"""List available simulation IDs (for future N simulations)."""
|
||
out = []
|
||
if (DATA_DIR / "metadata.json").exists():
|
||
out.append({"id": "default", "name": "Default"})
|
||
sims_dir = DATA_DIR / "simulations"
|
||
if sims_dir.exists():
|
||
for d in sorted(sims_dir.iterdir()):
|
||
if d.is_dir() and (d / "metadata.json").exists():
|
||
out.append({"id": d.name, "name": d.name})
|
||
return out
|
||
|
||
|
||
@app.get("/api/metadata")
|
||
def metadata(sim: str = Query("default", description="Simulation ID")):
|
||
data = _load(sim, "metadata.json")
|
||
if not data:
|
||
raise HTTPException(
|
||
404,
|
||
"metadata.json not found – run generate_dummy_data.py or scripts/run_etl for this simulation",
|
||
)
|
||
return data
|
||
|
||
|
||
@app.get("/api/block/{block_id}")
|
||
def block(
|
||
block_id: int,
|
||
request: Request,
|
||
sim: str = Query("default", description="Simulation ID"),
|
||
):
|
||
data = _load(sim, f"block_{block_id}.json")
|
||
if not data:
|
||
arrow_file = _sim_dir(sim) / f"block_{block_id}.arrow"
|
||
if not arrow_file.exists():
|
||
raise HTTPException(404, f"Block {block_id} not found")
|
||
if request.query_params.get("frames") == "false":
|
||
meta = _load(sim, "metadata.json")
|
||
if not meta:
|
||
raise HTTPException(404, "metadata.json not found")
|
||
num_blocks = meta["num_blocks"]
|
||
block_size = meta["block_size"]
|
||
step_start = block_id * block_size
|
||
step_end = min(step_start + block_size, meta["num_steps"])
|
||
return {
|
||
"block": block_id,
|
||
"step_start": step_start,
|
||
"step_end": step_end,
|
||
"accumulation": [],
|
||
}
|
||
raise HTTPException(
|
||
400,
|
||
"Block available only as Arrow; use /api/block/{id}/arrow (JSON not generated for this simulation)",
|
||
)
|
||
if request.query_params.get("frames") == "false":
|
||
return {
|
||
"block": data["block"],
|
||
"step_start": data["step_start"],
|
||
"step_end": data["step_end"],
|
||
"accumulation": data["accumulation"],
|
||
}
|
||
return data
|
||
|
||
|
||
@app.get("/api/block/{block_id}/arrow")
|
||
def block_arrow(
|
||
block_id: int,
|
||
sim: str = Query("default", description="Simulation ID"),
|
||
):
|
||
buf = _block_to_arrow(sim, block_id)
|
||
return Response(
|
||
content=bytes(buf),
|
||
media_type="application/vnd.apache.arrow.stream",
|
||
headers={"Cache-Control": "public, max-age=3600"},
|
||
)
|