Source code for pyetsimul.core.eye_tracker

"""Eye tracker module.

This module provides the EyeTracker class that represents a complete eye tracking
system with cameras, lights, calibration points, and algorithm functions.
"""

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any

import numpy as np

from pyetsimul.log import info, table, warning

from ..types import EyeMeasurement, GazePrediction, Point3D, Position3D, PupilData
from ..utils import validate_eye_camera_setup
from .camera import Camera
from .eye import Eye
from .light import Light


[docs] @dataclass class EyeTracker(ABC): """Abstract base class for eye tracking systems. Provides unified interface for different eye tracking algorithms. Manages cameras, lights, calibration points, and measurement collection. Implements common workflow for calibration and gaze estimation. """ # Physical components cameras: list[Camera] = field(default_factory=list) lights: list[Light] = field(default_factory=list) # Calibration points for gaze tracking calib_points: list[Position3D] = field(default_factory=list) # Algorithm state/parameters state: dict[str, Any] = field(default_factory=dict) # Refraction setting use_refraction: bool = True # Pupil center calculation method: "ellipse" (default) or "center_of_mass" pupil_center_method: str = "ellipse" # Calibration diagnostics: list of (point_number, position, reason) for failed points failed_calibration_points: list[tuple[int, Position3D, str]] = field(default_factory=list) # MATLAB compatibility mode for eye rotation use_legacy_look_at: bool = False @property @abstractmethod def algorithm_name(self) -> str: """Algorithm name identifier - must be implemented by subclasses."""
[docs] def add_camera(self, camera: Camera) -> None: """Add a camera to the eye tracker. Manages camera collection for multi-camera eye tracking setups. """ self.cameras.append(camera)
[docs] def add_light(self, light: Light) -> None: """Add a light to the eye tracker. Manages light collection for corneal reflection detection. """ self.lights.append(light)
[docs] def add_calibration_point(self, point: Position3D) -> None: """Add a calibration point to the eye tracker. Builds calibration grid for gaze tracking accuracy. """ self.calib_points.append(point)
[docs] def set_calibration_points(self, points: list[Position3D]) -> None: """Set all calibration points at once. Replaces entire calibration grid with new point collection. """ self.calib_points = points.copy()
[docs] def run_calibration(self, eye: Eye) -> "EyeTracker": """Run the complete calibration workflow. Manages data collection and algorithm-specific calibration. Collects measurements at all calibration points and calls algorithm calibration. Args: eye: Eye object to calibrate with Returns: Self for method chaining """ # Validate eye-camera setup if cameras are present if self.cameras: for camera in self.cameras: validate_eye_camera_setup(eye.rest_orientation, camera.trans.get_rotation()) calibration_measurements = self._collect_calibration_measurements(eye) self.calibrate(calibration_measurements) return self
def _collect_calibration_measurements(self, eye: Eye) -> list[EyeMeasurement]: """Helper to collect measurements for each calibration point. Gathers calibration data and reports detection failures. """ measurements = [] n_points = len(self.calib_points) failed_points = [] info(f"Collecting calibration data at {n_points} points...") if len(self.cameras) > 1: warning( f"calibrate() uses only the first of {len(self.cameras)} configured cameras; " "multi-camera calibration is not implemented." ) for i, calib_point in enumerate(self.calib_points): # Make eye look at calibration point eye.look_at(calib_point, legacy=self.use_legacy_look_at) if not self.cameras: raise ValueError("No cameras available for calibration") camera_image = self.cameras[0].take_image( eye, self.lights, use_refraction=self.use_refraction, center_method=self.pupil_center_method ) # Create pupil data from camera image pupil_data = PupilData(boundary_points=camera_image.pupil_boundary, center=camera_image.pupil_center) # Create eye measurement measurement = EyeMeasurement( camera_image=camera_image, pupil_data=pupil_data, gaze_direction=Point3D(calib_point.x, calib_point.y, calib_point.z), ) measurements.append(measurement) # Check for detection failures if camera_image.pupil_center is None: failed_points.append((i + 1, calib_point, "PUPIL CENTER not detected")) elif not camera_image.corneal_reflections or camera_image.corneal_reflections[0] is None: failed_points.append((i + 1, calib_point, "CR not detected")) # Store and report failed points self.failed_calibration_points = failed_points if failed_points: warning(f"\n{len(failed_points)}/{n_points} calibration points failed:") for point_num, point, reason in failed_points: warning(f" Point {point_num} ({point.x:.0f}mm, {point.z:.0f}mm): {reason}") warning(f" Calibration will proceed with {n_points - len(failed_points)} valid points.\n") return measurements
[docs] def estimate_gaze_at(self, eye: Eye, look_at_pos: Point3D) -> GazePrediction | None: """Estimate gaze position when eye looks at a target. Implements complete gaze estimation pipeline: eye movement → camera → prediction. Delegates to algorithm-specific prediction method. Args: eye: Eye object look_at_pos: 3D position where eye should look Returns: GazePrediction with estimated gaze and intermediate values """ # Make eye look at target position target = Position3D(look_at_pos.x, look_at_pos.y, look_at_pos.z) eye.look_at(target, legacy=self.use_legacy_look_at) # Use first camera (TODO: multi-camera support is algorithm-dependent) camera_image = self.cameras[0].take_image( eye, self.lights, use_refraction=self.use_refraction, center_method=self.pupil_center_method ) # Create EyeMeasurement from camera image pupil_data = PupilData(boundary_points=camera_image.pupil_boundary, center=camera_image.pupil_center) measurement = EyeMeasurement( camera_image=camera_image, pupil_data=pupil_data, timestamp=None, # Could add timestamp if needed ) # Get gaze prediction - returns None if prediction fails return self.predict_gaze(measurement)
[docs] def calculate_gaze_error(self, eye: Eye, look_at_pos: Point3D) -> tuple[float, float]: """Calculate gaze estimation error. Evaluates gaze tracking accuracy by comparing prediction to known target. Returns error in mm or NaN if estimation fails. Args: eye: Eye object look_at_pos: 3D position where eye should look Returns: Tuple of (u, v) gaze error in mm, or (NaN, NaN) if estimation fails """ gaze_prediction = self.estimate_gaze_at(eye, look_at_pos) if gaze_prediction is not None and gaze_prediction.gaze_point is not None: u = gaze_prediction.gaze_point.x - look_at_pos.x v = gaze_prediction.gaze_point.y - look_at_pos.y return u, v return np.nan, np.nan
[docs] @abstractmethod def calibrate(self, calibration_measurements: list[EyeMeasurement]) -> None: """Calibrate the eye tracker using collected data. Abstract interface for algorithm-specific calibration implementation. Each eye tracker type must implement its specific calibration algorithm. Args: calibration_measurements: List of eye measurements collected at each calibration point """
[docs] def test_calibration_fit(self, eye: Eye) -> list[tuple[Position3D, GazePrediction | None]]: """Test calibrated polynomial by predicting each calibration point. Validates calibration quality by testing full pipeline on known targets. Tests: target → eye movement → camera → polynomial → prediction. Args: eye: Eye object to use for measurements Returns: List of (target_position, prediction) tuples for each calibration point """ results = [] for target_position in self.calib_points: # Make eye look at calibration point eye.look_at(target_position, legacy=self.use_legacy_look_at) # Take fresh camera measurement camera_image = self.cameras[0].take_image( eye, self.lights, use_refraction=self.use_refraction, center_method=self.pupil_center_method ) # Create measurement from camera image pupil_data = PupilData(boundary_points=camera_image.pupil_boundary, center=camera_image.pupil_center) measurement = EyeMeasurement( camera_image=camera_image, pupil_data=pupil_data, timestamp=None, ) # Use calibrated polynomial to predict gaze prediction = self.predict_gaze(measurement) results.append((target_position, prediction)) return results
[docs] @abstractmethod def predict_gaze(self, measurement: EyeMeasurement) -> GazePrediction | None: """Predict gaze position from eye measurement. Abstract interface for algorithm-specific gaze prediction implementation. Each eye tracker type must implement its specific gaze prediction algorithm. Args: measurement: EyeMeasurement containing pupil and corneal reflection data Returns: GazePrediction with estimated gaze position or None if prediction fails """
def __str__(self) -> str: """Basic string representation of the eye tracker.""" try: calibrated = self.algorithm_state.is_calibrated except AttributeError: calibrated = False return f"{self.__class__.__name__}(algorithm={self.algorithm_name}, cameras={len(self.cameras)}, lights={len(self.lights)}, calibrated={calibrated})"
[docs] def pprint(self, eye: "Eye | None" = None) -> None: """Print detailed eye tracker parameters in a formatted table. Args: eye: Optional Eye instance to include eye position in the summary. """ # Check calibration status try: calibrated = self.algorithm_state.is_calibrated except AttributeError: calibrated = False calib_points = len(self.calib_points) if self.calib_points else 0 data = [] # Add eye position if provided if eye is not None: pos = eye.position data.append(["Eye position (mm)", f"({pos.x:.1f}, {pos.y:.1f}, {pos.z:.1f})"]) data.extend([ ["Algorithm", self.algorithm_name], ["Cameras", str(len(self.cameras))], ["Lights", str(len(self.lights))], ["Calibration points", str(calib_points)], ["Calibration status", "Calibrated" if calibrated else "Not calibrated"], ["Use refraction", "Yes" if self.use_refraction else "No"], ["Legacy look_at mode", "Yes" if self.use_legacy_look_at else "No"], ]) # Add algorithm-specific configuration if hasattr(self, "algorithm_state") and hasattr(self.algorithm_state, "config"): config = self.algorithm_state.config if hasattr(config, "method"): data.append(["Algorithm method", config.method]) if hasattr(config, "degree"): data.append(["Polynomial degree", str(config.degree)]) # Add camera details if self.cameras: for i, cam in enumerate(self.cameras): pos = cam.position data.append([ f"Camera {i + 1} position (mm)", f"({pos.x:.1f}, {pos.y:.1f}, {pos.z:.1f})", ]) # Add light details if self.lights: for i, light in enumerate(self.lights): pos = light.position data.append([ f"Light {i + 1} position (mm)", f"({pos.x:.1f}, {pos.y:.1f}, {pos.z:.1f})", ]) headers = ["Parameter", "Value"] info(f"{self.__class__.__name__} Configuration:") table(data, headers=headers, tablefmt="grid")