"""Gaussian Process error correction for homography normalization.
This module provides a wrapper around scikit-learn's GaussianProcessRegressor
to implement the error correction model described in Hansen et al. (2010).
"""
import warnings
import numpy as np
from sklearn.exceptions import ConvergenceWarning
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, WhiteKernel
[docs]
class GaussianProcessErrorCorrection:
"""Gaussian Process error correction using scikit-learn.
Models systematic errors by wrapping scikit-learn's GaussianProcessRegressor.
Uses a squared exponential kernel (RBF) and a noise kernel (WhiteKernel).
"""
[docs]
def __init__(
self,
length_scale: float = 100.0,
noise_level: float = 1.0,
length_scale_bounds: tuple[float, float] = (10.0, 500.0),
noise_level_bounds: tuple[float, float] = (0.01, 10.0),
) -> None:
"""Initialize the GP error correction model.
Args:
length_scale: Initial length scale of the RBF kernel (mm).
Default 100 mm is suitable for screen-scale coordinates.
noise_level: Initial noise level (mm).
Default 1 mm matches typical calibration error scales.
length_scale_bounds: Min/max bounds for length scale optimization (mm).
Default (10, 500) mm.
noise_level_bounds: Min/max bounds for noise level optimization (mm).
Default (0.01, 10) mm.
"""
kernel = RBF(length_scale=length_scale, length_scale_bounds=length_scale_bounds) + WhiteKernel(
noise_level=noise_level, noise_level_bounds=noise_level_bounds
)
# Multi-output regression is handled by fitting one regressor per output dimension.
self.gp_x = GaussianProcessRegressor(kernel=kernel, random_state=0)
self.gp_y = GaussianProcessRegressor(kernel=kernel, random_state=0)
[docs]
def fit(self, X: "np.ndarray", y: "np.ndarray") -> None:
"""Fit the GP model to the calibration residuals."""
# Suppress ConvergenceWarning, which is common when the optimizer
# hits the bounds of the hyperparameter search space.
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=ConvergenceWarning)
self.gp_x.fit(X, y[:, 0])
self.gp_y.fit(X, y[:, 1])
[docs]
def predict(self, X: "np.ndarray") -> "np.ndarray":
"""Predict the error correction for new gaze points.
Args:
X: Mx2 array of query screen positions.
Returns:
Mx2 array of predicted error vectors.
"""
pred_x = self.gp_x.predict(X)
pred_y = self.gp_y.predict(X)
return np.vstack([pred_x, pred_y]).T