Skip to content

PyTorch integration

gamfit.torch exposes gamfit's analytic primitives to PyTorch's autograd. Each function wraps the corresponding NumPy entry point in gamfit._api; no derivative math is reimplemented in Python. For primitives with an analytic Rust backward, gradients flow through loss.backward().

Installation

The torch dependency is optional:

pip install gamfit[torch]

Importing gamfit does not import torch. Importing gamfit.torch without torch installed raises ImportError with an install hint.

Differentiable Gaussian REML

The four closed-form REML primitives have analytic Rust VJPs. Each returns a GaussianRemlOutput (coefficients, fitted, lam, reml_score) and routes upstream gradients into the matching Rust backward:

import torch
import gamfit.torch as gt

x = torch.randn(40, 4, dtype=torch.float64, requires_grad=True)
y = torch.randn(40, 1, dtype=torch.float64)
penalty = torch.eye(4, dtype=torch.float64)

out = gt.gaussian_reml_fit(x, y, penalty)
loss = (out.fitted - y).pow(2).mean() + 0.01 * out.reml_score
loss.backward()

The same pattern applies to gaussian_reml_fit_batched, gaussian_reml_fit_positions, and gaussian_reml_fit_positions_batched (the position variants return GaussianRemlPositionOutput). Saved tensors are version-checked; in-place mutation between forward and backward raises RuntimeError.

Embedding a fitted model

from_fitted wraps a fitted gamfit.Model as a frozen nn.Module:

import gamfit
import gamfit.torch as gt

model = gamfit.fit_array(X_train, y_train, "y ~ s(x0) + x1 + x2")
frozen = gt.from_fitted(model)

X = torch.as_tensor(X_test, dtype=torch.float64)
preds = frozen(X)   # (N, P): typically [eta, mean]

The forward pass crosses the NumPy / Rust boundary. Gradients do not flow back through the inputs. The wrapped model has no trainable parameters (coefficients live in the saved bytes, not in nn.Parameters).

Response geometry

Pure-torch implementations of the simplex and unit-sphere transforms:

closure, clr, alr, inverse_alr, simplex_log_map, simplex_exp_map, simplex_frechet_mean, sphere_log_map, sphere_exp_map, sphere_frechet_mean. Autograd flows through these because they are written directly in torch.

Device and dtype

The gamfit Rust backend operates on f64 CPU buffers. Tensors are moved to CPU f64 for the call and returned on the caller's original device and dtype. For training-loop workloads, prefer the batched primitives (gaussian_reml_fit_batched, gaussian_reml_fit_positions_batched) over per-feature Python iteration.

Public API

Group Symbols
Closed-form REML gaussian_reml_fit, gaussian_reml_fit_batched, gaussian_reml_fit_positions, gaussian_reml_fit_positions_batched, GaussianRemlOutput, GaussianRemlPositionOutput, FrozenPositionPredictor
Basis evaluations bspline_basis, bspline_basis_derivative, duchon_basis_1d, duchon_basis_1d_derivative
Penalty / ridge smoothness_penalty, gaussian_weighted_ridge, gaussian_weighted_ridge_batch
Response geometry closure, clr, alr, inverse_alr, simplex_log_map, simplex_exp_map, simplex_frechet_mean, sphere_log_map, sphere_exp_map, sphere_frechet_mean
Fitted-model loader from_fitted

Value-producing primitives are bit-exact equal to their NumPy counterparts. Primitives with an analytic Rust backward are covered by torch.autograd.gradcheck.