Source code for pyetsimul.optics.refractions

"""Light refraction calculation utilities for eye tracking simulation.

Implements Snell's law, ray-surface intersection, and optimization for refraction on spherical and conic surfaces.
"""

from typing import TYPE_CHECKING, cast

import numpy as np
from scipy.optimize import brentq, fsolve

from ..geometry.intersections import (
    conic_surface_normal,
    intersect_ray_conic,
    intersect_ray_sphere,
    point_on_conic_surface,
)
from ..types import Direction3D, IntersectionResult, Point3D, Position3D, Ray, TransformationMatrix

if TYPE_CHECKING:
    from ..core.cornea import Cornea
    from ..core.eye import Eye


def _refraction_objective_sphere(
    a: float,
    camera_pos: Position3D,
    object_pos: Position3D,
    sphere_center: Position3D,
    sphere_radius: float,
    n_outside: float,
    n_sphere: float,
) -> tuple[float, Point3D]:
    """Objective function for refraction finding on sphere.

    Uses interpolation between camera and object directions to find refraction point.
    Returns Snell's law difference for optimization.

    Args:
        a: Interpolation parameter between camera and object directions
        camera_pos: Camera position
        object_pos: Object position
        sphere_center: Sphere center position
        sphere_radius: Sphere radius
        n_outside: Refractive index outside sphere
        n_sphere: Refractive index of sphere

    Returns:
        Tuple of (diff, surface_point) where diff is Snell's law difference and surface_point is on sphere surface

    """
    # Compute vectors from sphere center to camera and object
    to_camera = (camera_pos - sphere_center).normalize()
    to_object = (object_pos - sphere_center).normalize()

    # Interpolate and normalize to get surface normal
    normal_vec = (to_camera * a + to_object * (1 - a)).normalize()

    # Compute point on surface of sphere
    surface_point = sphere_center.to_point3d() + (normal_vec * sphere_radius)

    # Compute angles with surface normal
    camera_to_surface = (camera_pos - surface_point).normalize()
    surface_to_object = (surface_point - object_pos.to_point3d()).normalize()

    cos_angle_c = normal_vec.dot(camera_to_surface)
    cos_angle_o = normal_vec.dot(surface_to_object)

    # Safe sqrt to handle numerical errors
    sin_angle_c = np.sqrt(max(0, 1 - cos_angle_c**2))
    sin_angle_o = np.sqrt(max(0, 1 - cos_angle_o**2))

    # Snell's law difference
    diff = n_outside * sin_angle_c - n_sphere * sin_angle_o

    return diff, surface_point


[docs] def find_refraction_sphere( camera_pos: Position3D, object_pos: Position3D, sphere_center: Position3D, sphere_radius: float, n_outside: float, n_sphere: float, ) -> Point3D | None: """Find refraction point on sphere surface. Uses optimization to find point where object ray refracts to camera position. Implements Snell's law using numerical root finding. Args: camera_pos: Camera/observer position object_pos: Object position inside sphere sphere_center: Sphere center position sphere_radius: Sphere radius n_outside: Refractive index outside sphere n_sphere: Refractive index of sphere Returns: Position on sphere surface where refraction occurs, or None if not found. """ try: alpha = brentq( lambda x: _refraction_objective_sphere( x, camera_pos, object_pos, sphere_center, sphere_radius, n_outside, n_sphere )[0], 0, 1, ) _, result = _refraction_objective_sphere( cast("float", alpha), camera_pos, object_pos, sphere_center, sphere_radius, n_outside, n_sphere ) return result except (ValueError, RuntimeError): return None
def _refraction_snell_conic_1d( alpha: float, to_camera: np.ndarray, to_object: np.ndarray, camera_pos: Position3D, object_pos: Position3D, conic_center: Position3D, radius: float, conic_constant: float, n_outside: float, n_conic: float, ) -> float: """1D Snell's law residual on conic surface for in-plane search (beta=0). Evaluates n_outside * sin(theta_cam) - n_conic * sin(theta_obj) at the conic surface point determined by alpha. Used with brentq on [0, 1] to find the in-plane refraction point directly on the conic (not spherical approximation). """ direction = to_camera * alpha + to_object * (1 - alpha) norm = np.linalg.norm(direction) if norm < 1e-15: return 1e10 direction /= norm n_vec = Direction3D(direction[0], direction[1], direction[2]) intersection = point_on_conic_surface(conic_center, n_vec, radius, conic_constant) if intersection is None: return 1e10 surface_normal = conic_surface_normal(intersection, conic_center, radius, conic_constant) n = np.array([surface_normal.x, surface_normal.y, surface_normal.z]) d_obj = np.array([intersection.x - object_pos.x, intersection.y - object_pos.y, intersection.z - object_pos.z]) d_obj /= np.linalg.norm(d_obj) d_cam = np.array([camera_pos.x - intersection.x, camera_pos.y - intersection.y, camera_pos.z - intersection.z]) d_cam /= np.linalg.norm(d_cam) cos_cam = np.dot(n, d_cam) cos_obj = -np.dot(n, d_obj) sin_cam = np.sqrt(max(0, 1 - cos_cam**2)) sin_obj = np.sqrt(max(0, 1 - cos_obj**2)) return n_outside * sin_cam - n_conic * sin_obj def _refraction_residuals_conic( params: np.ndarray, to_camera: np.ndarray, to_object: np.ndarray, perp: np.ndarray, camera_pos: Position3D, object_pos: Position3D, conic_center: Position3D, radius: float, conic_constant: float, n_outside: float, n_conic: float, ) -> np.ndarray: """Residual function for 2D refraction finding on conic surface. Uses (alpha, beta) parameterization where alpha interpolates between camera and object directions, and beta adds an out-of-plane component. This 2D search is necessary because aspherical conic surfaces (k != 0) can have refraction points outside the camera-center-object plane. Returns a 2-element residual vector: [0]: Snell's law condition (n_outside * sin_camera - n_conic * sin_object) [1]: Coplanarity condition (N dot (d_object x d_camera)) """ alpha, beta = params # Search direction: in-plane interpolation + out-of-plane offset direction = to_camera * alpha + to_object * (1 - alpha) + perp * beta norm = np.linalg.norm(direction) if norm < 1e-15: return np.array([1e10, 1e10]) direction /= norm n_vec = Direction3D(direction[0], direction[1], direction[2]) intersection = point_on_conic_surface(conic_center, n_vec, radius, conic_constant) if intersection is None: return np.array([1e10, 1e10]) surface_normal = conic_surface_normal(intersection, conic_center, radius, conic_constant) n = np.array([surface_normal.x, surface_normal.y, surface_normal.z]) # Direction from object to refraction point (incident ray inside conic) d_obj = np.array([intersection.x - object_pos.x, intersection.y - object_pos.y, intersection.z - object_pos.z]) d_obj /= np.linalg.norm(d_obj) # Direction from refraction point to camera (outgoing ray) d_cam = np.array([camera_pos.x - intersection.x, camera_pos.y - intersection.y, camera_pos.z - intersection.z]) d_cam /= np.linalg.norm(d_cam) # Condition 1: Snell's law — n_outside * sin(θ_cam) = n_conic * sin(θ_obj) cos_cam = np.dot(n, d_cam) cos_obj = -np.dot(n, d_obj) # negate because d_obj points away from interior sin_cam = np.sqrt(max(0, 1 - cos_cam**2)) sin_obj = np.sqrt(max(0, 1 - cos_obj**2)) residual_snell = n_outside * sin_cam - n_conic * sin_obj # Condition 2: coplanarity — N · (d_obj x d_cam) = 0 residual_coplanar = np.dot(n, np.cross(d_obj, d_cam)) return np.array([residual_snell, residual_coplanar])
[docs] def find_refraction_conic( camera_pos: Position3D, object_pos: Position3D, conic_center: Position3D, radius: float, conic_constant: float, n_outside: float, n_conic: float, ) -> Point3D | None: """Find refraction point on conic surface. Uses a two-stage approach to find the point where an object ray refracts toward the camera through the conic surface: Stage 1 (brentq): Solve the 1D in-plane Snell's law residual with beta=0. The Snell residual is monotonic and sign-changing in alpha ∈ [0, 1], so brentq is guaranteed to find the unique correct root. This works directly on the conic surface — no spherical approximation needed. Stage 2 (fsolve): Starting from (alpha_brentq, 0), refine with the full 2D system (Snell + coplanarity) to find the small out-of-plane beta correction needed for aspherical surfaces (k != 0). This approach works for any conic constant k because Stage 1 is bounded (no wrong roots) and Stage 2 starts very close to the solution (won't drift). Args: camera_pos: Camera/observer position object_pos: Object position inside conic conic_center: Conic center position (typically corneal apex) radius: Radius of curvature at apex (mm) conic_constant: Conic constant (k < 0 for prolate, k = 0 for sphere, k > 0 for oblate) n_outside: Refractive index outside conic n_conic: Refractive index of conic Returns: Position on conic surface where refraction occurs, or None if not found. """ try: to_camera_dir = (camera_pos - conic_center).normalize() to_object_dir = (object_pos - conic_center).normalize() to_camera = np.array([to_camera_dir.x, to_camera_dir.y, to_camera_dir.z]) to_object = np.array([to_object_dir.x, to_object_dir.y, to_object_dir.z]) # Stage 1: Find in-plane solution using brentq on the conic Snell residual. # alpha ∈ [0, 1] interpolates between object and camera directions — # the residual changes sign across this interval, guaranteeing a unique root. args_1d = ( to_camera, to_object, camera_pos, object_pos, conic_center, radius, conic_constant, n_outside, n_conic, ) alpha_0 = brentq(_refraction_snell_conic_1d, 0, 1, args=args_1d) # Stage 2: Refine with 2D fsolve for the small out-of-plane (beta) correction. # For aspherical surfaces the refraction point may deviate slightly from # the camera-center-object plane, captured by the beta parameter. perp = np.cross(to_camera, to_object) perp_norm = np.linalg.norm(perp) if perp_norm < 1e-15: arb = np.array([1.0, 0.0, 0.0]) if abs(to_camera[0]) < 0.9 else np.array([0.0, 1.0, 0.0]) perp = np.cross(to_camera, arb) perp /= np.linalg.norm(perp) else: perp /= perp_norm args_2d = ( to_camera, to_object, perp, camera_pos, object_pos, conic_center, radius, conic_constant, n_outside, n_conic, ) solution, info, ier, _msg = fsolve(_refraction_residuals_conic, [alpha_0, 0.0], args=args_2d, full_output=True) if ier != 1 and np.max(np.abs(info["fvec"])) > 1e-8: return None alpha, beta = solution direction = to_camera * alpha + to_object * (1 - alpha) + perp * beta direction /= np.linalg.norm(direction) n_vec = Direction3D(direction[0], direction[1], direction[2]) return point_on_conic_surface(conic_center, n_vec, radius, conic_constant) except (ValueError, TypeError): return None
[docs] def refract_ray_sphere( ray: Ray, sphere_center: Position3D, sphere_radius: float, n_outside: float, n_sphere: float ) -> tuple[IntersectionResult | None, Ray | None]: """Refract ray through sphere surface. Finds intersection point and computes refracted ray direction using Snell's law. Handles total internal reflection when critical angle is exceeded. Args: ray: Input ray with origin and direction sphere_center: Sphere center position sphere_radius: Sphere radius n_outside: Refractive index outside sphere n_sphere: Refractive index of sphere Returns: Tuple of (intersection_result, refracted_ray) where intersection_result contains the intersection point and refracted_ray is the refracted ray. Returns (None, None) if no intersection or total internal reflection. """ # Find point of intersection intersection_result, _ = intersect_ray_sphere(ray, sphere_center, sphere_radius) if intersection_result is None or not intersection_result.intersects: return None, None intersection_point = cast("Point3D", intersection_result.point) # Find surface normal at point of intersection (pointing inwards) normal_vec = (sphere_center.to_point3d() - intersection_point).to_direction3d().normalize() # Calculate angles incident_normalized = ray.direction.normalize() costh1 = incident_normalized.dot(normal_vec) costh2_squared = 1 - (n_outside / n_sphere) ** 2 * (1 - costh1**2) # Check for total internal reflection if costh2_squared < 0: return intersection_result, None costh2 = np.sqrt(costh2_squared) # Snell's law refraction formula n_ratio = n_outside / n_sphere refracted_direction = incident_normalized * n_ratio + normal_vec * (costh2 - n_ratio * costh1) refracted_ray = Ray(origin=intersection_point, direction=refracted_direction) return intersection_result, refracted_ray
[docs] def refract_ray_conic( ray: Ray, conic_center: Position3D, radius: float, conic_constant: float, n_outside: float, n_conic: float ) -> tuple[IntersectionResult | None, Ray | None]: """Refract ray through conic surface. Finds intersection point and computes refracted ray direction using Snell's law. Uses proper conic surface normal calculation for accurate refraction. Handles total internal reflection when critical angle is exceeded. Args: ray: Input ray with origin and direction conic_center: Conic center position (typically corneal apex) radius: Radius parameter (R in the formula, mm) conic_constant: Conic constant (k < 0 for prolate, k = 0 for sphere, k > 0 for oblate) n_outside: Refractive index outside conic (e.g., air = 1.0) n_conic: Refractive index of conic (e.g., cornea = 1.376) Returns: Tuple of (intersection_result, refracted_ray) where: - intersection_result: Contains intersection point on conic surface - refracted_ray: Refracted ray Returns (None, None) if no intersection or total internal reflection. """ # Find intersection point intersection_result, _ = intersect_ray_conic(ray, conic_center, radius, conic_constant) if intersection_result is None or not intersection_result.intersects: return None, None intersection_point = cast("Point3D", intersection_result.point) # Calculate surface normal at intersection point surface_normal = conic_surface_normal(intersection_point, conic_center, radius, conic_constant) # For refraction, we need inward-pointing normal (toward conic interior) center_to_point = intersection_point - conic_center.to_point3d() if surface_normal.dot(center_to_point) > 0: # Normal points outward surface_normal *= -1 # Flip to point inward # Apply Snell's law incident_normalized = ray.direction.normalize() costh1 = incident_normalized.dot(surface_normal) costh2_squared = 1 - (n_outside / n_conic) ** 2 * (1 - costh1**2) # Check for total internal reflection if costh2_squared < 0: return intersection_result, None costh2 = np.sqrt(costh2_squared) # Snell's law refraction formula n_ratio = n_outside / n_conic refracted_direction = incident_normalized * n_ratio + surface_normal * (costh2 - n_ratio * costh1) refracted_ray = Ray(origin=intersection_point, direction=refracted_direction) return intersection_result, refracted_ray
[docs] def refract_ray_dual_surface( eye: "Eye", ray_origin: Point3D, ray_direction: Direction3D ) -> tuple[Point3D | None, Point3D | None, Direction3D | None]: """Computes refraction through both anterior and posterior corneal surfaces. Models complete corneal optical path by calculating refraction at both: 1. Anterior surface: air (n=1.0) → cornea (n=1.376) 2. Posterior surface: cornea (n=1.376) → aqueous humor (n=1.336) This provides more accurate modeling of light rays passing through the cornea compared to single-surface refraction which only considers the anterior surface. Args: eye: Eye object containing corneal geometry and refractive indices ray_origin: Ray origin (Position3D) ray_direction: Ray direction (3D vector) Returns: Tuple of (anterior_point, posterior_point, final_direction) where: - anterior_point: Point where ray strikes anterior corneal surface - posterior_point: Point where ray strikes posterior corneal surface - final_direction: Direction of ray after exiting posterior surface Returns (None, None, None) if ray doesn't intersect with cornea. """ # Get corneal center in world coordinates cornea_center_homogeneous = eye.trans @ np.array(eye.cornea.center) cornea_center = Position3D.from_array(cornea_center_homogeneous) # Refraction at outer surface of cornea ray = Ray(origin=ray_origin, direction=ray_direction) intersection_result, refracted_ray = refract_ray_sphere( ray, cornea_center, eye.cornea.anterior_radius, 1.0, # Air refractive index eye.cornea.refractive_index, ) if intersection_result is None or refracted_ray is None: return None, None, None outer_point = intersection_result.point intermediate_direction = refracted_ray.direction if outer_point is None or intermediate_direction is None: return None, None, None # Refraction at inner surface of cornea posterior_center_homogeneous = eye.trans @ np.array(eye.cornea.get_posterior_center()) posterior_center = Position3D.from_array(posterior_center_homogeneous) ray2 = Ray(origin=outer_point, direction=intermediate_direction) intersection_result2, refracted_ray2 = refract_ray_sphere( ray2, posterior_center, eye.cornea.posterior_radius, eye.cornea.refractive_index, eye.n_aqueous_humor, ) if intersection_result2 is None or refracted_ray2 is None: return outer_point, None, None inner_point = intersection_result2.point final_direction = refracted_ray2.direction return outer_point, inner_point, final_direction
[docs] def find_refraction_point( cornea: "Cornea", eye_transform: TransformationMatrix, camera_position: Position3D, object_position: Position3D ) -> Position3D | None: """Computes observed position of intraocular objects through corneal refraction. Pure function that calculates where camera observes intraocular object through corneal refraction. Determines corneal surface point where object ray refracts to camera. Note: This function does not check corneal boundaries - that should be done by the caller if needed (e.g., using Eye.point_within_cornea()). Args: cornea: Cornea object with find_refraction method eye_transform: Eye transformation matrix camera_position: Camera position (Position3D) object_position: Object position inside eye (Position3D) Returns: Position3D on corneal surface where refraction occurs, or None if no solution exists """ # Find refraction point on corneal surface using cornea's refraction method refraction_point = cornea.find_refraction( camera_position, object_position, 1.0, # Air refractive index cornea.refractive_index, eye_transform, ) return None if refraction_point is None else refraction_point.to_position3d()