"""Base grid generation system for spatial parameter variations."""
from abc import ABC, abstractmethod
from collections.abc import Iterable
import numpy as np
from ..types import Position3D
def _generate_axis_values(center_coord: float, min_offset: float, max_offset: float, num_points: int) -> np.ndarray:
"""Generate values along a single axis."""
if min_offset == max_offset or num_points == 1:
return np.array([center_coord + min_offset])
return np.linspace(center_coord + min_offset, center_coord + max_offset, num_points)
[docs]
class GridGenerator(ABC):
"""Abstract base for 3D grid generation."""
[docs]
@abstractmethod
def generate_positions(self) -> Iterable[Position3D]:
"""Generate list of 3D positions."""
[docs]
class RegularGrid(GridGenerator):
"""Regular 3D grid generation with uniform spacing."""
[docs]
def __init__(
self, center: Position3D, dx: list[float], dy: list[float], dz: list[float], grid_size: list[int]
) -> None:
"""Initialize regular grid.
Args:
center: Grid center position
dx: [min_offset, max_offset] in x direction
dy: [min_offset, max_offset] in y direction
dz: [min_offset, max_offset] in z direction
grid_size: [nx, ny, nz] number of points per dimension
"""
self.center = center
self.dx = dx
self.dy = dy
self.dz = dz
self.grid_size = grid_size
self._validate_parameters()
def _validate_parameters(self) -> None:
"""Validate grid parameters."""
for name, param in [("dx", self.dx), ("dy", self.dy), ("dz", self.dz)]:
if len(param) != 2:
raise ValueError(f"{name} must have exactly 2 elements [min, max], got {len(param)}")
if len(self.grid_size) != 3:
raise ValueError(f"grid_size must have exactly 3 elements [nx, ny, nz], got {len(self.grid_size)}")
if any(n < 1 for n in self.grid_size):
raise ValueError(f"grid_size elements must be >= 1, got {self.grid_size}")
[docs]
def generate_positions(self) -> Iterable[Position3D]:
"""Generate regular grid positions."""
dx_min, dx_max = self.dx
dy_min, dy_max = self.dy
dz_min, dz_max = self.dz
nx, ny, nz = self.grid_size
x_values = _generate_axis_values(self.center.x, dx_min, dx_max, nx)
y_values = _generate_axis_values(self.center.y, dy_min, dy_max, ny)
z_values = _generate_axis_values(self.center.z, dz_min, dz_max, nz)
# For XZ plane: Z outer (slow), Y middle, X inner (fast)
# For XY plane: Y outer (slow), Z middle, X inner (fast)
for z in z_values:
for y in y_values:
for x in x_values:
yield Position3D(x, y, z)
[docs]
class RandomGrid(GridGenerator):
"""Random 3D positions within bounds."""
[docs]
def __init__(
self,
center: Position3D,
dx: list[float],
dy: list[float],
dz: list[float],
num_points: int,
seed: int | None = None,
) -> None:
"""Initialize random grid.
Args:
center: Grid center position
dx: [min_offset, max_offset] in x direction
dy: [min_offset, max_offset] in y direction
dz: [min_offset, max_offset] in z direction
num_points: Number of random points to generate
seed: Random seed for reproducibility
"""
self.center = center
self.dx = dx
self.dy = dy
self.dz = dz
self.num_points = num_points
self.seed = seed
self._validate_parameters()
def _validate_parameters(self) -> None:
"""Validate grid parameters."""
if len(self.dx) != 2:
raise ValueError(f"dx must have exactly 2 elements [min, max], got {len(self.dx)} elements: {self.dx}")
if len(self.dy) != 2:
raise ValueError(f"dy must have exactly 2 elements [min, max], got {len(self.dy)} elements: {self.dy}")
if len(self.dz) != 2:
raise ValueError(f"dz must have exactly 2 elements [min, max], got {len(self.dz)} elements: {self.dz}")
if self.num_points < 1:
raise ValueError(f"num_points must be >= 1, got {self.num_points}")
[docs]
def generate_positions(self) -> Iterable[Position3D]:
"""Generate random positions."""
rng = np.random.default_rng(self.seed)
dx_min, dx_max = self.dx
dy_min, dy_max = self.dy
dz_min, dz_max = self.dz
for _ in range(self.num_points):
x = self.center.x + rng.uniform(dx_min, dx_max)
y = self.center.y + rng.uniform(dy_min, dy_max)
z = self.center.z + rng.uniform(dz_min, dz_max)
yield Position3D(x, y, z)