Source code for pyetsimul.visualization.eye_anatomy

"""Eye anatomy visualization using structured types and vector arithmetic.

Provides 3D visualization functions for eye anatomy using structured types.
Supports anatomical accuracy and vector-based transformations.
"""

from typing import TYPE_CHECKING

import matplotlib.pyplot as plt

if TYPE_CHECKING:
    from matplotlib.axes import Axes
import numpy as np

from ..core import Eye
from ..geometry.intersections import intersect_ray_conic, intersect_ray_sphere
from ..types import Direction3D, Position3D
from ..utils.eye_surface_points import generate_corneal_surface_points, get_transformed_corneal_landmarks
from ..utils.eyelid_surface_points import generate_eyelid_opening_edge_local, transform_eyelid_points_to_world
from .plot_config import create_plot_config
from .transforms import transform_surface


def _filter_points_by_eyelid_occlusion(points_world: np.ndarray, eye: Eye) -> np.ndarray:
    """Return only points not occluded by eyelid opening."""
    if eye.eyelid is None or len(points_world) == 0:
        return points_world

    eyelid_trans_inv = np.linalg.inv(eye.eyelid_trans)
    points_world_h = np.column_stack([points_world, np.ones(len(points_world))])
    points_eyelid_local_h = (eyelid_trans_inv @ points_world_h.T).T
    points_eyelid_local = points_eyelid_local_h[:, :3]

    visible_mask = []
    for point in points_eyelid_local:
        is_visible = not eye.eyelid.point_within_eyelid(Position3D(point[0], point[1], point[2]))
        visible_mask.append(is_visible)
    visible_mask = np.array(visible_mask)
    return points_world[visible_mask]


[docs] def plot_eye_anatomy(eye: Eye, ax: "Axes | None" = None) -> "Axes": """Plot 3D eye anatomy using structured types and vector arithmetic. Visualizes anatomical structures of the eye in 3D using vector-based transformations. Useful for understanding eye geometry and verifying anatomical accuracy. Assumes the eye is already oriented as desired. Args: eye: Eye object to plot (required) - should already be oriented as desired ax: Matplotlib 3D axis (optional, creates new if None) Raises: ValueError: If eye has no current target point (call eye.look_at() first) """ # Get target point from eye target_point = eye.current_target_point if target_point is None: raise ValueError( "Eye has no current target point. Call eye.look_at(target_position) first " "to orient the eye and set a target for visualization." ) # Calculate all key points in WORLD coordinates using structured types eye_rotation_center = eye.position cornea_center = eye.cornea.center cornea_inner_center = eye.cornea.get_posterior_center() pupil_position = eye.pupil.pos_pupil fovea_position = eye.fovea_position # Transform positions to world coordinates using vector arithmetic cornea_center_world = Position3D.from_array(eye.trans @ np.array(cornea_center)) cornea_inner_center_world = Position3D.from_array(eye.trans @ np.array(cornea_inner_center)) fovea_world = Position3D.from_array(eye.trans @ np.array(fovea_position)) # Eye sphere parameters main_eye_radius = eye.axial_length / 2 apex_pos = eye.cornea.get_apex_position() limbus_z_local = apex_pos.z + eye.cornea.get_corneal_depth() # Generate corneal surface points using proper transformation handling # Choose intersection function based on cornea type if eye.cornea.cornea_type == "conic": intersection_func = intersect_ray_conic elif eye.cornea.cornea_type == "spherical": intersection_func = intersect_ray_sphere else: raise ValueError(f"Unknown cornea type: {eye.cornea.cornea_type}") anterior_points = generate_corneal_surface_points(eye, intersection_func, "anterior", n_points=50) posterior_points = generate_corneal_surface_points(eye, intersection_func, "posterior", n_points=50) # Get transformed corneal landmarks corneal_landmarks = get_transformed_corneal_landmarks(eye) cornea_center_world = corneal_landmarks["anterior_center"] cornea_inner_center_world = corneal_landmarks["posterior_center"] # Filter surfaces based on corneal depth limits (same as spherical cornea visualization) anterior_mask = np.array([eye.point_within_cornea(Position3D(p[0], p[1], p[2])) for p in anterior_points]) posterior_mask = np.array([eye.point_within_cornea(Position3D(p[0], p[1], p[2])) for p in posterior_points]) anterior_limited = anterior_points[anterior_mask] posterior_limited = posterior_points[posterior_mask] # Apply eyelid occlusion filtering to cornea points anterior_limited = _filter_points_by_eyelid_occlusion(anterior_limited, eye) posterior_limited = _filter_points_by_eyelid_occlusion(posterior_limited, eye) # Create eye sphere coordinates using structured types phi_eye = np.linspace(0, np.pi, 30) theta_eye = np.linspace(0, 2 * np.pi, 50) phi_eye_grid, theta_eye_grid = np.meshgrid(phi_eye, theta_eye) x_eye_local = main_eye_radius * np.sin(phi_eye_grid) * np.cos(theta_eye_grid) y_eye_local = main_eye_radius * np.sin(phi_eye_grid) * np.sin(theta_eye_grid) z_eye_local = main_eye_radius * np.cos(phi_eye_grid) x_eye_world, y_eye_world, z_eye_world = transform_surface(x_eye_local, y_eye_local, z_eye_local, eye.trans) # Mask out the front part where cornea is using vector arithmetic optical_axis_world = eye.trans @ np.array([0, 0, -1, 0]) optical_axis_unit = optical_axis_world / np.linalg.norm(optical_axis_world) limbus_point_world = eye.trans @ np.array([0, 0, limbus_z_local, 1]) limbus_projection = np.dot(limbus_point_world[:3] - np.array(eye_rotation_center)[:3], optical_axis_unit[:3]) eye_vectors = np.stack([x_eye_world, y_eye_world, z_eye_world]) - np.array(eye_rotation_center)[:3].reshape( 3, 1, 1 ) projections = np.einsum("i,ijk->jk", optical_axis_unit[:3], eye_vectors) mask = projections <= limbus_projection x_eye_world[~mask] = np.nan y_eye_world[~mask] = np.nan z_eye_world[~mask] = np.nan # Calculate axes using vector arithmetic axis_length = 20 # 20mm axis length # Optical axis using structured types optical_axis_vec = Direction3D.from_array(optical_axis_unit[:3]) optical_axis_end = eye_rotation_center + optical_axis_vec * axis_length # Visual axis using structured types and vector arithmetic visual_axis_direction = (eye_rotation_center - fovea_world).normalize() visual_axis_end = eye_rotation_center + visual_axis_direction * axis_length config = create_plot_config() # Create figure if not provided if ax is None: fig = plt.figure(figsize=config.layout.anatomy_detail) ax = fig.add_subplot(111, projection="3d") # Plot eye components using structured type coordinates ax.plot_surface( x_eye_world, y_eye_world, z_eye_world, alpha=config.lines.grid_alpha, color=config.colors.eye_globe, label="Eye Globe", ) # Plot corneal surfaces using filtered surface points if len(anterior_limited) > 0: ax.scatter( anterior_limited[:, 0], anterior_limited[:, 1], anterior_limited[:, 2], alpha=config.lines.primary_alpha, color=config.colors.cornea_outer, s=config.markers.cornea_surface_anterior, label="Cornea outer surface", ) if len(posterior_limited) > 0: ax.scatter( posterior_limited[:, 0], posterior_limited[:, 1], posterior_limited[:, 2], alpha=config.lines.primary_alpha, color=config.colors.cornea_inner, s=config.markers.cornea_surface_posterior, label="Cornea inner surface", ) # Plot key points using structured types ax.scatter( eye_rotation_center.x, eye_rotation_center.y, eye_rotation_center.z, color=config.colors.rotation_center, s=config.markers.small_details, marker="o", label="Rotation Center", ) ax.scatter( cornea_center_world.x, cornea_center_world.y, cornea_center_world.z, color=config.colors.cornea_outer, s=config.markers.cornea_center_outer, marker="^", label="Cornea center (outer)", ) ax.scatter( cornea_inner_center_world.x, cornea_inner_center_world.y, cornea_inner_center_world.z, color=config.colors.cornea_inner, s=config.markers.cornea_center_inner, marker="^", label="Cornea center (inner)", ) ax.scatter( fovea_world.x, fovea_world.y, fovea_world.z, color=config.colors.fovea, s=config.markers.small_details + 30, marker="*", label="Fovea", ) # Plot pupil as filled dark circle using structured types n_pupil_points = 120 t = np.linspace(0, 2 * np.pi, n_pupil_points) # Create filled pupil with radial points from center to boundary n_radial = 10 radial_factors = np.linspace(0, 1, n_radial) pupil_points_local = [] for r_factor in radial_factors: for theta in t: cos_theta = np.cos(theta) sin_theta = np.sin(theta) # Point on pupil surface at radius factor r_factor point_local = np.array(pupil_position)[:3].reshape(-1, 1) + r_factor * ( np.array(eye.pupil.x_pupil)[:3].reshape(-1, 1) * cos_theta + np.array(eye.pupil.y_pupil)[:3].reshape(-1, 1) * sin_theta ) pupil_points_local.append(point_local.flatten()) # Transform all pupil points to world coordinates pupil_points_local = np.array(pupil_points_local) # Nx3 format # Convert to homogeneous coordinates for transformation pupil_points_local_h = np.column_stack([pupil_points_local, np.ones(len(pupil_points_local))]) pupil_points_world_h = (eye.trans @ pupil_points_local_h.T).T # Transform and back to Nx4 pupil_points_world = pupil_points_world_h[:, :3] # Extract Nx3 world coordinates # Apply eyelid occlusion filtering to pupil points pupil_points_world_filtered = _filter_points_by_eyelid_occlusion(pupil_points_world, eye) # Plot filtered pupil as scatter points if len(pupil_points_world_filtered) > 0: ax.scatter( pupil_points_world_filtered[:, 0], pupil_points_world_filtered[:, 1], pupil_points_world_filtered[:, 2], c=config.colors.pupil, s=config.markers.surface_points, alpha=config.lines.primary_alpha, label="Pupil Opening", ) # Plot axes using structured types ax.plot( [eye_rotation_center.x, optical_axis_end.x], [eye_rotation_center.y, optical_axis_end.y], [eye_rotation_center.z, optical_axis_end.z], color=config.colors.optical_axis, linestyle=config.lines.dashed, linewidth=config.lines.standard_lines, label="Optical Axis", ) ax.plot( [fovea_world.x, visual_axis_end.x], [fovea_world.y, visual_axis_end.y], [fovea_world.z, visual_axis_end.z], color=config.colors.visual_axis, linestyle=config.lines.dashed, linewidth=config.lines.standard_lines, label="Visual Axis", ) # Plot eyelid opening edge if enabled if eye.eyelid is not None: opening_edge_local = generate_eyelid_opening_edge_local(eye.eyelid, n_edge_points=160) if len(opening_edge_local) > 0: edge_local_closed = np.vstack([opening_edge_local, opening_edge_local[0]]) edge_world = transform_eyelid_points_to_world(edge_local_closed, eye.eyelid_trans) ax.plot( edge_world[:, 0], edge_world[:, 1], edge_world[:, 2], color=config.colors.eyelid, linewidth=config.elements.eyelid_width, label="Eyelid Opening", ) # Plot target point ax.scatter( target_point.x, target_point.y, target_point.z, color=config.colors.target, s=config.markers.landmarks, marker="x", label="Target", ) # Set labels and title ax.set_xlabel("X (mm)") ax.set_ylabel("Y (mm)") ax.set_zlabel("Z (mm)") ax.set_title("Eye Anatomy") # Overlay eye openness percentage (top-left) if eye.eyelid is not None: openness_pct = 100.0 * float(eye.eyelid.openness) ax.text2D(0.02, 0.96, f"Eye openness: {openness_pct:.0f}%", transform=ax.transAxes) ax.legend() # Set equal aspect ratio ax.set_box_aspect([1, 1, 1]) return ax