"""Cornea model definitions for eye tracking simulation.
Defines abstract and concrete cornea models (spherical, conic) for anatomical and optical simulation.
"""
from abc import ABC, abstractmethod
from dataclasses import InitVar, dataclass, field
from typing import TYPE_CHECKING
import numpy as np
from pyetsimul.log import info, table, warning
from ..geometry import intersections
from ..optics import reflections, refractions
from ..types import Direction3D, IntersectionResult, Point3D, Position3D, Ray, TransformationMatrix, Vector3D
from .default_configs import CorneaDefaults
if TYPE_CHECKING:
from .eye import Eye
[docs]
@dataclass
class Cornea(ABC):
"""Abstract base class for different corneal models.
Defines common interface for corneal models to ensure interchangeability.
Provides unified interface for intersection, reflection, and refraction calculations.
"""
center_init: InitVar[Position3D | None] = None
_center: Position3D | None = field(default=None, init=False)
_cornea_depth_default: float = CorneaDefaults.CORNEA_DEPTH
_cornea_center_to_rotation_center_default: float = CorneaDefaults.CENTER_TO_ROTATION
def __post_init__(self, center_init: Position3D | None) -> None:
"""Initialize cornea center if provided."""
if center_init is not None:
self._center = center_init
@property
def center(self) -> Position3D:
"""Get the cornea center position."""
if self._center is None:
raise ValueError("Center has not been initialized. Did you call setup_eye_geometry() already?")
return self._center
@center.setter
def center(self, value: Position3D) -> None:
self._center = value
@property
def cornea_center_to_rotation_center_default(self) -> float:
"""Get the default cornea center to rotation center distance."""
return self._cornea_center_to_rotation_center_default
@property
@abstractmethod
def cornea_type(self) -> str:
"""Return the type name of this cornea model."""
[docs]
@abstractmethod
def intersect(self, ray: Ray) -> IntersectionResult | None:
"""Calculates the intersection point of a light ray with the cornea."""
[docs]
@abstractmethod
def normal_at(self, point: Point3D) -> Direction3D:
"""Calculates the normal vector at a given point on the cornea's surface."""
[docs]
def point_within_cornea(self, p: Position3D, eye: "Eye") -> bool:
"""Tests whether a point lies within the cornea boundaries.
Uses projection distance calculation to determine if point is within corneal depth.
Args:
p: Point to test in local eye coordinates
eye: Eye object containing apex position and cornea depth
Returns:
bool: True if point lies within cornea boundaries, False otherwise
"""
# Calculate direction from apex to cornea center
apex_pos = eye.cornea.get_apex_position()
direction = Vector3D(self.center.x - apex_pos.x, self.center.y - apex_pos.y, self.center.z - apex_pos.z)
diff = Vector3D(p.x - apex_pos.x, p.y - apex_pos.y, p.z - apex_pos.z)
# Use dot product for projection distance calculation
projection_distance = diff.dot(direction) / direction.magnitude()
return projection_distance < eye.cornea.get_corneal_depth()
[docs]
@abstractmethod
def find_reflection(
self, light_pos: Position3D, camera_pos: Position3D, eye_transform: TransformationMatrix
) -> Point3D | None:
"""Finds position of a glint on the corneal surface."""
[docs]
@abstractmethod
def find_refraction(
self,
camera_pos: Position3D,
object_pos: Position3D,
n_outside: float,
n_cornea: float,
eye_transform: TransformationMatrix,
) -> Point3D | None:
"""Finds position where refraction occurs on the corneal surface."""
def __str__(self) -> str:
"""Basic string representation of the cornea."""
center_str = f"({self.center.x:.1f}, {self.center.y:.1f}, {self.center.z:.1f})mm" if self.center else "unset"
return f"{self.__class__.__name__}(center={center_str}, type={self.cornea_type})"
[docs]
def pprint(self) -> None:
"""Print detailed cornea parameters in a formatted table."""
# Base parameters
data = [
["Cornea type", self.cornea_type],
[
"Center (x,y,z) mm",
f"({self.center.x:.3f}, {self.center.y:.3f}, {self.center.z:.3f})" if self.center else "unset",
],
]
# SphericalCornea and ConicCornea parameters
if isinstance(self, (SphericalCornea, ConicCornea)):
data.extend([
["Anterior radius (mm)", f"{self.anterior_radius:.3f}"],
["Posterior radius (mm)", f"{self.posterior_radius:.3f}"],
["Refractive index", f"{self.refractive_index:.3f}"],
["Thickness offset (mm)", f"{self.thickness_offset:.3f}"],
["Corneal depth (mm)", f"{self.get_corneal_depth():.3f}"],
])
if self.center:
apex = self.get_apex_position()
data.append([
"Apex position (x,y,z) mm",
f"({apex.x:.3f}, {apex.y:.3f}, {apex.z:.3f})",
])
# ConicCornea-specific parameters
if isinstance(self, ConicCornea):
data.extend([
["Anterior k (conic)", f"{self.anterior_k:.3f}"],
["Posterior k (conic)", f"{self.posterior_k:.3f}"],
])
headers = ["Parameter", "Value"]
info(f"{self.__class__.__name__} Parameters:")
table(data, headers=headers, tablefmt="grid")
[docs]
@dataclass
class SphericalCornea(Cornea):
"""Represents a cornea with dual spherical surfaces.
Uses anatomical scaling based on Boff and Lincoln [1988] parameters.
Implements proportional scaling of all eye dimensions based on corneal radius.
Attributes:
anterior_radius (float): The radius of the anterior (outer) corneal surface.
posterior_radius (float): The radius of the posterior (inner) corneal surface.
thickness (float): Central corneal thickness.
center (Point4D): Inherited from parent. If None, will be calculated by Eye based on anatomical scaling.
"""
anterior_radius: float = CorneaDefaults.ANTERIOR_RADIUS
refractive_index: float = CorneaDefaults.REFRACTIVE_INDEX
# Reference values for scaling (from Boff and Lincoln [1988])
_posterior_radius_default: float = CorneaDefaults.POSTERIOR_RADIUS
_thickness_offset_default: float = CorneaDefaults.THICKNESS_OFFSET
@property
def cornea_type(self) -> str:
"""Return the type name of this cornea model."""
return "spherical"
# Sphere-specific constants from Boff and Lincoln [1988, Section 1.210]
_r_cornea_default: float = CorneaDefaults.ANTERIOR_RADIUS
@property
def posterior_radius(self) -> float:
"""Calculate the scaled posterior corneal radius.
Returns:
Posterior corneal radius in mm (scaled from reference)
"""
scale = self.get_scale_factor()
return scale * self._posterior_radius_default
@property
def thickness_offset(self) -> float:
"""Calculate the scaled thickness offset.
Returns:
Thickness offset in mm (scaled from reference)
"""
scale = self.get_scale_factor()
return scale * self._thickness_offset_default
@property
def thickness(self) -> float:
"""Calculate the central corneal thickness based on radii and offset.
For a dual-surface spherical cornea, the central thickness is the distance
between the anterior and posterior surfaces along the optical axis.
Returns:
Central corneal thickness in mm
"""
# Distance between surface centers along optical axis
center_distance = abs(self.anterior_radius - self.posterior_radius - self.thickness_offset)
return center_distance
[docs]
def get_posterior_center(self) -> Position3D:
"""Calculate the center of the posterior surface based on thickness."""
thickness_term = self.anterior_radius - self.posterior_radius - self.thickness_offset
return Position3D(self.center.x, self.center.y, self.center.z - thickness_term)
[docs]
@staticmethod
def calculate_center_position(
scale: float, axial_length: float, cornea_center_to_rotation_center: float
) -> Position3D:
"""Calculate the center position for spherical cornea based on anatomical parameters.
This implements the original MATLAB/Eye logic for positioning the spherical cornea center.
Args:
scale: Scaling factor based on corneal radius
axial_length: Total axial length of eye (mm)
cornea_center_to_rotation_center: Distance from corneal center to rotation center (mm)
Returns:
Cornea center position
"""
cornea_z_offset = axial_length - 2 * cornea_center_to_rotation_center
return Position3D(0, 0, -scale * cornea_z_offset)
[docs]
def get_apex_position(self) -> Position3D:
"""Calculate the apex position for spherical cornea.
For spherical cornea, apex is at center + [0, 0, -radius, 0]
Returns:
Corneal apex position
"""
return Position3D(self.center.x, self.center.y, self.center.z - self.anterior_radius)
[docs]
def get_scale_factor(self) -> float:
"""Calculate the scaling factor for this spherical cornea.
The scale factor is used to proportionally scale all eye dimensions
based on how this cornea's radius differs from the reference radius.
Returns:
Scale factor (dimensionless)
"""
return self.anterior_radius / self._r_cornea_default
[docs]
def get_corneal_depth(self) -> float:
"""Calculate the scaled corneal depth for this spherical cornea.
Returns:
Corneal depth in mm (scaled from reference depth)
"""
scale = self.get_scale_factor()
return scale * self._cornea_depth_default
[docs]
def setup_eye_geometry(self, axial_length: float) -> dict:
"""Setup all sphere-specific eye geometry parameters.
This method encapsulates all the sphere-specific scaling logic
that was previously scattered in the Eye class.
Args:
axial_length: Total axial length of the eye (general eye parameter)
Returns:
Dictionary containing all calculated geometry parameters
"""
scale = self.get_scale_factor()
# Calculate center position if not already set
if self._center is None:
self.center = self.calculate_center_position(
scale, axial_length, self._cornea_center_to_rotation_center_default
)
return {
"scale": scale,
"corneal_depth": self.get_corneal_depth(),
"apex_position": self.get_apex_position(),
"cornea_center_to_rotation_center": self._cornea_center_to_rotation_center_default,
}
[docs]
def intersect(self, ray: Ray) -> IntersectionResult | None:
"""Calculates intersection for a spherical cornea.
Returns the intersection result closer to the ray origin.
"""
intersection_result, _ = intersections.intersect_ray_sphere(ray, self.center, self.anterior_radius)
return intersection_result
[docs]
def normal_at(self, point: Point3D) -> Direction3D:
"""Calculates the normal vector for a spherical surface."""
# Calculate normal vector from center to point
normal_vec = Direction3D(point.x - self.center.x, point.y - self.center.y, point.z - self.center.z)
return normal_vec.normalize()
[docs]
def find_reflection(
self, light_pos: Position3D, camera_pos: Position3D, eye_transform: TransformationMatrix
) -> Point3D | None:
"""Finds position of a glint on the spherical corneal surface."""
world_center_homogeneous = eye_transform @ np.array(self.center)
world_center = Position3D.from_array(world_center_homogeneous)
return reflections.find_reflection_sphere(light_pos, camera_pos, world_center, self.anterior_radius)
[docs]
def find_refraction(
self,
camera_pos: Position3D,
object_pos: Position3D,
n_outside: float,
n_cornea: float,
eye_transform: TransformationMatrix,
) -> Point3D | None:
"""Finds position where refraction occurs on the spherical corneal surface."""
world_center_homogeneous = eye_transform @ np.array(self.center)
world_center = Position3D.from_array(world_center_homogeneous)
return refractions.find_refraction_sphere(
camera_pos, object_pos, world_center, self.anterior_radius, n_outside, n_cornea
)
[docs]
def serialize(self) -> dict:
"""Serialize to dictionary representation."""
return {
"cornea_type": self.cornea_type,
"center": self.center.serialize() if self.center else None,
"anterior_radius": float(self.anterior_radius),
"refractive_index": float(self.refractive_index),
}
[docs]
@classmethod
def deserialize(cls, data: dict) -> "SphericalCornea":
"""Deserialize from dictionary representation."""
cornea = cls(anterior_radius=data["anterior_radius"], refractive_index=data["refractive_index"])
if data["center"]:
cornea.center = Position3D.deserialize(data["center"])
return cornea
[docs]
@dataclass
class ConicCornea(Cornea):
"""Represents a cornea with dual conic surfaces.
Uses conic section geometry with formula: (x-cx)² + (y-cy)² + (1+k)(z-cz)² - 2*R*(z-cz) = 0
Implements absolute dimensions without scaling for mathematical consistency.
Default parameters are 30-year-old values from Goncharov & Dainty (2007).
Attributes:
center (Point4D): The 4D homogeneous coordinate of the conic center.
anterior_radius (float): Anterior surface radius of curvature at apex in mm.
anterior_k (float): Anterior surface conic constant.
posterior_radius (float): Posterior surface radius of curvature at apex in mm.
posterior_k (float): Posterior surface conic constant.
thickness (float): Central corneal thickness.
refractive_index (float): Refractive index of cornea.
thickness_offset (float): Corneal thickness offset.
k-value meanings:
- k = 0: Perfect sphere
- k < 0: Prolate ellipsoid (typical cornea, flattens toward periphery)
- k > 0: Oblate ellipsoid (steepens toward periphery)
"""
# Anterior surface (30-year defaults from Goncharov & Dainty 2007)
anterior_radius: float = CorneaDefaults.CONIC_ANTERIOR_RADIUS
anterior_k: float = CorneaDefaults.CONIC_ANTERIOR_K
# Posterior surface (30-year defaults from Goncharov & Dainty 2007)
posterior_radius: float = CorneaDefaults.CONIC_POSTERIOR_RADIUS
posterior_k: float = CorneaDefaults.CONIC_POSTERIOR_K
@property
def cornea_type(self) -> str:
"""Return the type name of this cornea model."""
return "conic"
# Corneal properties
thickness_offset: float = CorneaDefaults.CONIC_THICKNESS_OFFSET
refractive_index: float = CorneaDefaults.REFRACTIVE_INDEX
[docs]
def get_posterior_center(self) -> Position3D:
"""Calculate the center of the posterior surface based on conic geometry and thickness."""
# Calculate anterior apex position relative to anterior center
anterior_apex_offset = -self.anterior_radius / (1 + self.anterior_k)
# Calculate posterior apex position relative to posterior center
posterior_apex_offset = -self.posterior_radius / (1 + self.posterior_k)
# For cornea facing -z direction:
# - Anterior apex is the foremost point (closest to +z)
# - Posterior apex should be thickness_offset BEHIND anterior apex (more negative z)
#
# anterior_apex_z = anterior_center_z + anterior_apex_offset
# posterior_apex_z = anterior_apex_z - thickness_offset (more negative = behind)
# posterior_center_z = posterior_apex_z - posterior_apex_offset
anterior_apex_z = self.center.z + anterior_apex_offset
posterior_apex_z = anterior_apex_z + self.thickness_offset # Behind anterior apex
posterior_center_z = posterior_apex_z - posterior_apex_offset # Account for posterior apex offset
return Position3D(self.center.x, self.center.y, posterior_center_z)
[docs]
def get_apex_position(self) -> Position3D:
"""Calculate the apex position for conic cornea.
For conic cornea, apex is mathematically at z = -R/(1+k) from center.
This is the foremost point along the -Z axis.
Returns:
Corneal apex position
"""
# Mathematical apex position: z = -R/(1+k) from center
apex_z = -self.anterior_radius / (1 + self.anterior_k)
return Position3D(self.center.x, self.center.y, self.center.z + apex_z)
[docs]
def get_corneal_depth(self) -> float:
"""Calculate the corneal depth for conic cornea.
For consistency with spherical cornea, we use the same reference depth.
This ensures that point_within_cornea behaves consistently between models.
Returns:
Corneal depth in mm
"""
return self._cornea_depth_default
[docs]
def get_scale_factor(self) -> float: # noqa: PLR6301
"""Get scale factor for conic cornea.
Conic cornea uses absolute dimensions, so scale factor is always 1.0.
Returns:
Scale factor of 1.0 (no scaling)
"""
return 1.0
[docs]
@staticmethod
def calculate_center_position(axial_length: float, cornea_center_to_rotation_center: float) -> Position3D:
"""Calculate the center position for conic cornea based on anatomical parameters (no scaling).
Args:
axial_length: Total axial length of eye (mm)
cornea_center_to_rotation_center: Distance from corneal center to rotation center (mm)
Returns:
Cornea center position
"""
cornea_z_offset = axial_length - 2 * cornea_center_to_rotation_center
return Position3D(0, 0, -cornea_z_offset)
[docs]
def setup_eye_geometry(self, axial_length: float) -> dict:
"""Setup conic cornea geometry parameters.
Unlike spherical cornea, conic cornea does not use scaling - it uses absolute dimensions.
Args:
axial_length: Total axial length of the eye (not used for conic cornea)
Returns:
Dictionary containing geometry parameters
"""
# Set default center if not already set (no scaling applied)
# For conic: position center at origin for mathematical consistency
if self._center is None:
# Use anatomical offset if desired, otherwise origin
self.center = self.calculate_center_position(axial_length, self._cornea_center_to_rotation_center_default)
return {
"scale": 1.0, # No scaling for conic cornea
"corneal_depth": self.get_corneal_depth(),
"apex_position": self.get_apex_position(),
"cornea_center_to_rotation_center": 0.0, # Not applicable for conic #CHECK
}
def __post_init__(self, center_init: Position3D | None) -> None:
"""Initialize conic cornea with validation of k parameters."""
super().__post_init__(center_init)
# Validate k parameter ranges for both surfaces
if self.anterior_k < -1:
warning(f"anterior_k = {self.anterior_k} < -1 may represent unusual corneal geometry")
if self.posterior_k < -1:
warning(f"posterior_k = {self.posterior_k} < -1 may represent unusual corneal geometry")
# Calculate (1+k) values for reference
anterior_1_plus_k = 1 + self.anterior_k
posterior_1_plus_k = 1 + self.posterior_k
if anterior_1_plus_k <= 0:
warning(f"anterior (1+k) = {anterior_1_plus_k} ≤ 0 may cause numerical issues in conic calculations")
if posterior_1_plus_k <= 0:
warning(f"posterior (1+k) = {posterior_1_plus_k} ≤ 0 may cause numerical issues in conic calculations")
[docs]
def intersect(self, ray: Ray) -> IntersectionResult | None:
"""Calculates intersection for the anterior conic surface.
Returns the intersection result closer to the ray origin.
"""
intersection_result, _ = intersections.intersect_ray_conic(
ray, self.center, self.anterior_radius, self.anterior_k
)
return intersection_result
[docs]
def normal_at(self, point: Point3D) -> Direction3D:
"""Calculates the normal vector for the anterior conic surface."""
return intersections.conic_surface_normal(point, self.center, self.anterior_radius, self.anterior_k)
[docs]
def find_reflection(
self, light_pos: Position3D, camera_pos: Position3D, eye_transform: TransformationMatrix
) -> Point3D | None:
"""Finds position of a glint on the anterior conic surface.
Transforms light and camera positions into eye-local coordinates where
the conic axis is aligned with the Z-axis, runs the reflection solver,
then transforms the result back to world coordinates.
"""
inv_transform = np.linalg.inv(eye_transform)
local_light = Position3D.from_array(inv_transform @ np.array([light_pos.x, light_pos.y, light_pos.z, 1.0]))
local_camera = Position3D.from_array(inv_transform @ np.array([camera_pos.x, camera_pos.y, camera_pos.z, 1.0]))
local_glint = reflections.find_reflection_conic(
local_light, local_camera, self.center, self.anterior_radius, self.anterior_k
)
if local_glint is None:
return None
world_glint = eye_transform @ np.array([local_glint.x, local_glint.y, local_glint.z, 1.0])
return Point3D(world_glint[0], world_glint[1], world_glint[2])
[docs]
def find_refraction(
self,
camera_pos: Position3D,
object_pos: Position3D,
n_outside: float,
n_cornea: float,
eye_transform: TransformationMatrix,
) -> Point3D | None:
"""Finds position where refraction occurs on the anterior conic surface.
Transforms positions into eye-local coordinates where the conic axis
is aligned with the Z-axis, runs the refraction solver, then
transforms the result back to world coordinates.
"""
inv_transform = np.linalg.inv(eye_transform)
local_camera = Position3D.from_array(inv_transform @ np.array([camera_pos.x, camera_pos.y, camera_pos.z, 1.0]))
local_object = Position3D.from_array(inv_transform @ np.array([object_pos.x, object_pos.y, object_pos.z, 1.0]))
local_refraction = refractions.find_refraction_conic(
local_camera, local_object, self.center, self.anterior_radius, self.anterior_k, n_outside, n_cornea
)
if local_refraction is None:
return None
world_refraction = eye_transform @ np.array([local_refraction.x, local_refraction.y, local_refraction.z, 1.0])
return Point3D(world_refraction[0], world_refraction[1], world_refraction[2])
[docs]
def serialize(self) -> dict:
"""Serialize to dictionary representation."""
return {
"cornea_type": self.cornea_type,
"center": self.center.serialize() if self.center else None,
"anterior_radius": float(self.anterior_radius),
"anterior_k": float(self.anterior_k),
"posterior_radius": float(self.posterior_radius),
"posterior_k": float(self.posterior_k),
"refractive_index": float(self.refractive_index),
"thickness_offset": float(self.thickness_offset),
}
[docs]
@classmethod
def deserialize(cls, data: dict) -> "ConicCornea":
"""Deserialize from dictionary representation."""
cornea = cls(
anterior_radius=data["anterior_radius"],
anterior_k=data["anterior_k"],
posterior_radius=data["posterior_radius"],
posterior_k=data["posterior_k"],
refractive_index=data["refractive_index"],
thickness_offset=data["thickness_offset"],
)
if data["center"]:
cornea.center = Position3D.deserialize(data["center"])
return cornea
[docs]
def create_cornea(cornea_model_type: str, center: Position3D, **kwargs: float) -> Cornea:
"""Factory function to create a cornea object of specified type.
Provides unified interface for creating different corneal models.
Supports both spherical and conic corneal geometries.
Args:
cornea_model_type (str): The type of cornea model to create.
Supported types: "spherical", "conic".
center (Point4D): The center of the cornea.
**kwargs: Additional parameters required for the specific cornea model.
For "spherical": anterior_radius, refractive_index
For "conic": anterior_radius, anterior_k, posterior_radius, posterior_k, refractive_index, thickness_offset
Returns:
Cornea: An instance of the specified Cornea subclass.
Raises:
ValueError: If an unsupported cornea_model_type is provided.
"""
if cornea_model_type == "spherical":
return SphericalCornea(center_init=center, **kwargs)
if cornea_model_type == "conic":
return ConicCornea(center_init=center, **kwargs)
raise ValueError(f"Unknown cornea model type: '{cornea_model_type}'")