Skip to content

Analysis — Dynamic Tools

btorch.analysis.dynamic_tools.attractor_dynamics

Functions

calculate_kaplan_yorke_dimension(lyapunov_spectrum)

Calculate the Kaplan-Yorke Dimension (D_KY), also known as the Lyapunov Dimension.

Formula: D_KY = k + sum(lambda_i for i=1 to k) / |lambda_{k+1}| where k is the max index such that the sum of the first k exponents is non-negative.

Parameters:

Name Type Description Default
lyapunov_spectrum ndarray

Array of Lyapunov exponents, sorted in

required

Returns:

Name Type Description
float

The Kaplan-Yorke dimension. Returns 0 if the system is stable (all lambda < 0). Returns the number of exponents if the sum of all is positive (unbounded/hyperchaos).

Source code in btorch/analysis/dynamic_tools/attractor_dynamics.py
def calculate_kaplan_yorke_dimension(lyapunov_spectrum: np.ndarray):
    """Calculate the Kaplan-Yorke Dimension (D_KY), also known as the Lyapunov
    Dimension.

    Formula: D_KY = k + sum(lambda_i for i=1 to k) / |lambda_{k+1}|
    where k is the max index such that the sum of the first k exponents is non-negative.

    Args:
        lyapunov_spectrum (np.ndarray): Array of Lyapunov exponents, sorted in
        descending order.

    Returns:
        float: The Kaplan-Yorke dimension. Returns 0 if the system is stable
            (all lambda < 0). Returns the number of exponents if the sum of all
            is positive (unbounded/hyperchaos).
    """
    # Ensure sorted descending
    ls = np.sort(lyapunov_spectrum)[::-1]

    n = len(ls)

    # Calculate cumulative sums
    cum_sum = np.cumsum(ls)

    # Find k: max index such that sum >= 0
    # We look for the last index where cum_sum >= 0
    positive_sums = np.where(cum_sum >= 0)[0]

    if len(positive_sums) == 0:
        # All cumulative sums are negative.
        # This usually means the first exponent is negative (stable fixed point).
        return 0.0

    k = positive_sums[-1]

    # Check if k is the last element (sum of all is positive)
    if k == n - 1:
        return float(n)

    # Apply formula
    # Note: indices are 0-based in Python, so k corresponds to the (k+1)-th
    # element in 1-based math notation.
    # The formula uses 1-based k.
    # Let's map carefully:
    # Python index k is the index of the last element included in the sum.
    # So we have summed ls[0]...ls[k].
    # The next element is ls[k+1].
    # The integer part of dimension is (k + 1).

    sum_lambda = cum_sum[k]
    lambda_next = ls[k + 1]

    if lambda_next == 0:
        # Avoid division by zero, though theoretically lambda_{k+1} should be
        # negative here.
        return float(k + 1)

    d_ky = (k + 1) + sum_lambda / abs(lambda_next)

    return d_ky

calculate_structural_eigenvalue_outliers(weight_matrix, spectral_radius=None)

Analyze the eigenvalues of the weight matrix to identify structural outliers.

According to the circular law, eigenvalues of a random matrix are distributed within a disk of radius R. Outliers outside this radius indicate structural enforcement of specific oscillatory modes (stable dynamics) rather than random chaos.

Parameters:

Name Type Description Default
weight_matrix ndarray

The connectivity weight matrix (N x N).

required
spectral_radius float

The theoretical spectral radius of the random component. If None, it is estimated as std(W) * sqrt(N).

None

Returns:

Name Type Description
dict

Dictionary containing: - 'eigenvalues': All eigenvalues. - 'outliers': Eigenvalues outside the spectral radius. - 'outlier_count': Number of outliers. - 'spectral_radius': The radius used for thresholding.

Source code in btorch/analysis/dynamic_tools/attractor_dynamics.py
def calculate_structural_eigenvalue_outliers(
    weight_matrix: np.ndarray, spectral_radius: float = None
):
    """Analyze the eigenvalues of the weight matrix to identify structural
    outliers.

    According to the circular law, eigenvalues of a random matrix are distributed
    within a disk of radius R. Outliers outside this radius indicate structural
    enforcement of specific oscillatory modes (stable dynamics) rather than
    random chaos.

    Args:
        weight_matrix (np.ndarray): The connectivity weight matrix (N x N).
        spectral_radius (float, optional): The theoretical spectral radius of the
            random component. If None, it is estimated as std(W) * sqrt(N).

    Returns:
        dict: Dictionary containing:
            - 'eigenvalues': All eigenvalues.
            - 'outliers': Eigenvalues outside the spectral radius.
            - 'outlier_count': Number of outliers.
            - 'spectral_radius': The radius used for thresholding.
    """
    # Ensure numpy array
    W = np.array(weight_matrix)
    N = W.shape[0]

    if W.shape[0] != W.shape[1]:
        raise ValueError("Weight matrix must be square.")

    # Compute eigenvalues
    eigenvalues = np.linalg.eigvals(W)

    # Determine spectral radius if not provided
    if spectral_radius is None:
        # Estimate radius based on random matrix theory
        # For entries with variance sigma^2/N, radius is sigma.
        # Here we have entries with variance var(W).
        # If W_ij ~ N(0, sigma^2), then radius R = sigma * sqrt(N).
        # std(W) corresponds to sigma.
        sigma = np.std(W)
        spectral_radius = sigma * np.sqrt(N)

    # Identify outliers
    magnitudes = np.abs(eigenvalues)
    outlier_indices = np.where(magnitudes > spectral_radius)[0]
    outliers = eigenvalues[outlier_indices]

    return {
        "eigenvalues": eigenvalues,
        # True Spectral Radius
        "max_eigenvalue": np.max(magnitudes) if len(magnitudes) > 0 else 0.0,
        "outliers": outliers,
        "outlier_count": len(outliers),
        "spectral_radius": spectral_radius,  # Bulk Radius (Threshold)
    }

btorch.analysis.dynamic_tools.complexity

Functions

calculate_gain_stability_sensitivity(model, dataloader, g_values=None, dt=1.0, device='cuda')

Calculate the Gain-Stability Sensitivity (Susceptibility) slope.

Definition: The slope of the curve of the Maximum Lyapunov Exponent (lambda_max) as a function of global synaptic gain scaling (g).

Parameters:

Name Type Description Default
model

The Brain model.

required
dataloader

DataLoader providing input.

required
g_values

List of gain scaling factors. Default np.linspace(0.5, 5.0, 10).

None
dt

Simulation time step.

1.0
device

Device to run on.

'cuda'

Returns:

Name Type Description
tuple float

(slope, intercept, g_values, lambda_values) - slope: The slope of lambda_max vs g. - intercept: The intercept of the fit. - g_values: The gain values used. - lambda_values: The computed max Lyapunov exponents.

Source code in btorch/analysis/dynamic_tools/complexity.py
def calculate_gain_stability_sensitivity(
    model, dataloader, g_values=None, dt=1.0, device="cuda"
) -> float:
    """Calculate the Gain-Stability Sensitivity (Susceptibility) slope.

    Definition: The slope of the curve of the Maximum Lyapunov Exponent (lambda_max)
    as a function of global synaptic gain scaling (g).

    Args:
        model: The Brain model.
        dataloader: DataLoader providing input.
        g_values: List of gain scaling factors. Default np.linspace(0.5, 5.0, 10).
        dt: Simulation time step.
        device: Device to run on.

    Returns:
        tuple: (slope, intercept, g_values, lambda_values)
            - slope: The slope of lambda_max vs g.
            - intercept: The intercept of the fit.
            - g_values: The gain values used.
            - lambda_values: The computed max Lyapunov exponents.
    """
    from model import functional, init

    if g_values is None:
        g_values = np.linspace(0.5, 5.0, 10)

    # Access linear layer
    # Assuming model is Brain, model.brain is RecurrentNN, model.brain.synapse
    # is Synapse model.brain.synapse.linear is the layer
    try:
        linear_layer = model.brain.synapse.linear
    except AttributeError:
        print("Could not find linear layer at model.brain.synapse.linear")
        return 0.0

    original_magnitude = linear_layer.magnitude.data.clone()

    lambda_values = []

    model.eval()
    model.to(device)

    # Get one batch of input
    try:
        batch = next(iter(dataloader))
    except StopIteration:
        print("Dataloader is empty.")
        return 0.0

    inputs = batch["input"]
    # inputs: (Batch, Time, ...) -> (Time, Batch, ...)
    inputs = inputs.transpose(0, 1).to(device)

    # We can use the first sample in the batch.
    input_sample = inputs[:, 0:1, ...]  # Keep batch dim 1

    for g in g_values:
        # Scale weights
        linear_layer.magnitude.data = original_magnitude * g

        # Reset state
        functional.reset_net(model, device=device)
        init.uniform_v_(model.brain.neuron, set_reset_value=True, batch_size=1)

        # Run
        with torch.no_grad():
            _, brain_out = model(input_sample)
            spikes = brain_out["neuron"]["spike"]  # (Time, Batch, Neurons)

        # Convert to rate
        # spikes: (Time, 1, Neurons) -> (Time, Neurons)
        spikes_sq = spikes.squeeze(1)

        # Continuous rate
        rates = get_continuous_spiking_rate(spikes_sq, dt=dt)

        # Mean population rate for LE calculation
        mean_rate = rates.mean(axis=1)

        # Compute LE
        try:
            le = compute_max_lyapunov_exponent(mean_rate)
        except Exception as e:
            print(f"Error computing LE for g={g}: {e}")
            le = 0.0  # Or NaN?

        lambda_values.append(le)

    # Restore weights
    linear_layer.magnitude.data = original_magnitude

    # Calculate slope
    # Fit line: lambda = slope * g + intercept
    # Handle potential NaNs or Infs
    valid_indices = np.isfinite(lambda_values)
    if np.sum(valid_indices) < 2:
        return 0.0

    g_valid = g_values[valid_indices]
    lambda_valid = np.array(lambda_values)[valid_indices]

    slope, intercept = np.polyfit(g_valid, lambda_valid, 1)

    return slope, intercept, g_values, np.array(lambda_values)

calculate_lyapunov_exponent(spike_train, dt=0.1)

Calculate the maximum Lyapunov exponent for a given spike train.

Parameters:

Name Type Description Default
spike_train Tensor

The spike train data. Shape (time_steps,

required
dt float

Time bin size in milliseconds. Default is 0.1 ms.

0.1

Returns:

Name Type Description
float float

The maximum Lyapunov exponent.

Source code in btorch/analysis/dynamic_tools/complexity.py
def calculate_lyapunov_exponent(spike_train: torch.Tensor, dt: float = 0.1) -> float:
    """Calculate the maximum Lyapunov exponent for a given spike train.

    Args:
        spike_train (torch.Tensor): The spike train data. Shape (time_steps,
        num_neurons).
        dt (float): Time bin size in milliseconds. Default is 0.1 ms.

    Returns:
        float: The maximum Lyapunov exponent.
    """
    # Ensure spike_train is a 2D tensor
    if spike_train.ndim != 2:
        raise ValueError(
            "spike_train must be a 2D tensor with shape (time_steps, num_neurons)"
        )

    # 1. Calculate the continuous spiking rate using a Gaussian kernel (smooth
    # the spike train) This is effectively a form of kernel density estimation.
    # We use a small bandwidth, as the original dynamics should be captured at a
    # fine timescale.
    bandwidth = 5.0  # in ms, this may need adjustment
    continuous_rate = get_continuous_spiking_rate(spike_train, dt, bandwidth)

    # 2. Calculate the Lyapunov exponent using the continuous rate
    # We use the largest Lyapunov exponent as the measure of chaos/complexity.
    lyapunov_exponent = compute_max_lyapunov_exponent(continuous_rate, dt)

    return lyapunov_exponent

calculate_pcist(response, baseline, threshold_factor=3.0)

Calculate the Perturbational Complexity Index based on State Transitions (PCIst).

Definition: A measure of the spatiotemporal complexity of the network's response to a specific perturbation.

Steps: 1. Perturb (assumed done, input is response). 2. Measure (input is response matrix). 3. Decompose: Perform PCA on the response matrix. 4. Recurrence: Calculate state transitions on the principal components. 5. Sum significant state transitions weighted by the component's Signal-to-Noise ratio.

Parameters:

Name Type Description Default
response Tensor

The network response to perturbation. Shape

required
baseline Tensor

The baseline

required
threshold_factor float

Factor of baseline std dev to define

3.0

Returns:

Name Type Description
float float

The PCIst score.

Source code in btorch/analysis/dynamic_tools/complexity.py
def calculate_pcist(
    response: torch.Tensor, baseline: torch.Tensor, threshold_factor: float = 3.0
) -> float:
    """Calculate the Perturbational Complexity Index based on State Transitions
    (PCIst).

    Definition: A measure of the spatiotemporal complexity of the network's
    response to a specific perturbation.

    Steps:
    1. Perturb (assumed done, input is response).
    2. Measure (input is response matrix).
    3. Decompose: Perform PCA on the response matrix.
    4. Recurrence: Calculate state transitions on the principal components.
    5. Sum significant state transitions weighted by the component's
    Signal-to-Noise ratio.

    Args:
        response (torch.Tensor): The network response to perturbation. Shape
        (time_steps, num_neurons).
        baseline (torch.Tensor): The baseline
        activity before perturbation. Shape (time_steps_base, num_neurons).
        threshold_factor (float): Factor of baseline std dev to define
        significant state excursion. Default 3.0.

    Returns:
        float: The PCIst score.
    """
    # Ensure inputs are float tensors
    if not isinstance(response, torch.Tensor):
        response = torch.tensor(response, dtype=torch.float32)
    if not isinstance(baseline, torch.Tensor):
        baseline = torch.tensor(baseline, dtype=torch.float32)

    response = response.float()
    baseline = baseline.float()

    # Handle batch dimension: if 3D, calculate mean PCIst over batch or raise
    # error? For simplicity, if 3D, we assume (batch, time, neurons) and
    # calculate average PCIst.
    if response.ndim == 3:
        batch_size = response.shape[0]
        pcist_values = []
        for i in range(batch_size):
            # Handle corresponding baseline
            b_sample = baseline[i] if baseline.ndim == 3 else baseline
            pcist_values.append(
                calculate_pcist(response[i], b_sample, threshold_factor)
            )
        return sum(pcist_values) / len(pcist_values)

    # 1. Center data based on baseline mean
    mean_base = baseline.mean(dim=0)
    response_centered = response - mean_base
    baseline_centered = baseline - mean_base

    # 2. PCA on Response
    # We use SVD for PCA: X = U S V^T. Principal components (scores) are X V = U S.
    # response_centered shape: (T, N)
    try:
        # full_matrices=False ensures we get min(T, N) components
        U, S, Vh = torch.linalg.svd(response_centered, full_matrices=False)
    except RuntimeError:
        # Fallback for singular matrices or convergence issues
        return 0.0

    V = Vh.T  # (N, K)

    # Project data onto Principal Components
    # Scores shape: (T, K)
    scores_response = torch.matmul(response_centered, V)
    scores_baseline = torch.matmul(baseline_centered, V)

    # 3. Calculate SNR for each component
    # SNR = Variance(Response) / Variance(Baseline)
    # Add epsilon to avoid division by zero
    var_response = scores_response.var(dim=0)
    var_baseline = scores_baseline.var(dim=0)

    epsilon = 1e-9
    snr = var_response / (var_baseline + epsilon)

    # 4. Calculate State Transitions
    # A state transition is defined as crossing a threshold defined by baseline noise.
    # Threshold for component k: threshold_factor * std(baseline_k)

    std_baseline = scores_baseline.std(dim=0)
    thresholds = threshold_factor * std_baseline  # (K,)

    # Binarize response: 1 if |score| > threshold, else 0
    # We are looking for "significant excursions"
    # Shape: (T, K)
    active_states = (torch.abs(scores_response) > thresholds.unsqueeze(0)).float()

    # Count transitions: change from 0 to 1 or 1 to 0
    # diff along time dimension
    transitions = torch.abs(active_states[1:] - active_states[:-1])

    # Sum transitions for each component
    num_transitions = transitions.sum(dim=0)  # (K,)

    # 5. Weighted Sum "Sum significant state transitions weighted by the
    # component's Signal-to-Noise ratio." We might want to filter components
    # with SNR < 1? The text doesn't strictly say so, but "weighted by SNR"
    # implies low SNR components contribute little. However, if SNR is very
    # high, it dominates. Let's follow the instruction literally.

    pcist_score = (num_transitions * snr).sum()

    return pcist_score.item()

calculate_ra(spike_initial, spike_final)

Calculate Representation Alignment (RA) using spike data.

RA = Trace(G_final * G_initial) / (||G_final|| * ||G_initial||) where G = S * S^T (Gram matrix of spike activity)

If inputs are 3D tensors, they are assumed to be (batch_size, time_steps, num_neurons) and will be averaged over the time dimension (dim=1) to obtain firing rates.

Parameters:

Name Type Description Default
spike_initial Tensor

Initial spike activity. Shape (batch_size,

required
Tensor

Final spike activity. Shape (batch_size, num_neurons) or

required

Returns:

Name Type Description
float float

The Representation Alignment (RA) score. Low RA -> Rich Regime (Radical restructuring) High RA -> Lazy Regime (Little change in internal structure)

Source code in btorch/analysis/dynamic_tools/complexity.py
def calculate_ra(spike_initial: torch.Tensor, spike_final: torch.Tensor) -> float:
    """Calculate Representation Alignment (RA) using spike data.

    RA = Trace(G_final * G_initial) / (||G_final|| * ||G_initial||)
    where G = S * S^T (Gram matrix of spike activity)

    If inputs are 3D tensors, they are assumed to be (batch_size, time_steps,
    num_neurons) and will be averaged over the time dimension (dim=1) to obtain
    firing rates.

    Args:
        spike_initial (torch.Tensor): Initial spike activity. Shape (batch_size,
        num_neurons) or (batch_size, time_steps, num_neurons). spike_final
        (torch.Tensor): Final spike activity. Shape (batch_size, num_neurons) or
        (batch_size, time_steps, num_neurons).

    Returns:
        float: The Representation Alignment (RA) score.
               Low RA -> Rich Regime (Radical restructuring)
               High RA -> Lazy Regime (Little change in internal structure)
    """
    # Ensure inputs are float tensors
    if not isinstance(spike_initial, torch.Tensor):
        spike_initial = torch.tensor(spike_initial, dtype=torch.float32)
    if not isinstance(spike_final, torch.Tensor):
        spike_final = torch.tensor(spike_final, dtype=torch.float32)

    spike_initial = spike_initial.float()
    spike_final = spike_final.float()

    # Handle 3D input: (batch, time, neurons) -> (batch, neurons)
    if spike_initial.ndim == 3:
        spike_initial = spike_initial.mean(dim=1)
    if spike_final.ndim == 3:
        spike_final = spike_final.mean(dim=1)

    # 1. Compute Gram matrix G = S * S^T
    # Shape: (batch_size, batch_size)
    g_initial = torch.matmul(spike_initial, spike_initial.T)
    g_final = torch.matmul(spike_final, spike_final.T)

    # 2. Calculate Trace(G_final * G_initial)
    # Trace(A @ B) = sum(element-wise product of A and B^T)
    # Since G is symmetric, G^T = G, so this is sum(G_final * G_initial)
    product = torch.matmul(g_final, g_initial)
    numerator = torch.trace(product)

    # 3. Calculate norms ||G||
    # Assuming Frobenius norm as is standard for matrix alignment
    norm_initial = torch.norm(g_initial, p="fro")
    norm_final = torch.norm(g_final, p="fro")

    # 4. Calculate RA
    if norm_initial == 0 or norm_final == 0:
        return 0.0  # Avoid division by zero

    ra = numerator / (norm_final * norm_initial)

    return ra.item()

compute_max_lyapunov_exponent(time_series, emb_dim=6, lag=1, tau=1)

Compute the largest Lyapunov exponent of a given time series using the nolds library.

Parameters: - time_series: A 1D numpy array representing the time series data.

Returns: - lyapunov_exponent: The estimated largest Lyapunov exponent.

Source code in btorch/analysis/dynamic_tools/lyapunov_dynamics.py
def compute_max_lyapunov_exponent(time_series, emb_dim=6, lag=1, tau=1):
    """Compute the largest Lyapunov exponent of a given time series using the
    nolds library.

    Parameters:
    - time_series: A 1D numpy array representing the time series data.

    Returns:
    - lyapunov_exponent: The estimated largest Lyapunov exponent.
    """
    lyapunov_exponent = nolds.lyap_r(time_series, emb_dim=emb_dim, lag=lag, tau=tau)
    return lyapunov_exponent

get_continuous_spiking_rate(spikes, dt, sigma=20.0)

Convert discrete spike trains into continuous firing rates using Gaussian smoothing.

Parameters:

Name Type Description Default
spikes ndarray or Tensor

Spike matrix of shape (time_steps,

required
dt float

Simulation time step in ms.

required
sigma float

Standard deviation of the Gaussian kernel in ms. Default 20ms.

20.0

Returns:

Type Description

np.ndarray: Continuous firing rate traces of shape (time_steps, n_neurons).

Source code in btorch/analysis/dynamic_tools/lyapunov_dynamics.py
def get_continuous_spiking_rate(spikes, dt, sigma=20.0):
    """Convert discrete spike trains into continuous firing rates using
    Gaussian smoothing.

    Args:
        spikes (np.ndarray or torch.Tensor): Spike matrix of shape (time_steps,
        n_neurons).
        dt (float): Simulation time step in ms.
        sigma (float): Standard deviation of the Gaussian kernel in ms. Default 20ms.

    Returns:
        np.ndarray: Continuous firing rate traces of shape (time_steps, n_neurons).
    """
    if isinstance(spikes, torch.Tensor):
        spikes = spikes.detach().cpu().numpy()

    # Convert sigma from ms to bins
    sigma_bins = sigma / dt

    # Apply Gaussian filter along the time axis (axis 0)
    rates = gaussian_filter1d(spikes.astype(float), sigma=sigma_bins, axis=0)

    return rates

btorch.analysis.dynamic_tools.criticality

Attributes

HAS_NOLDS = True module-attribute

Functions

_fit_distribution(data)

Helper to fit power law distribution using powerlaw package.

Source code in btorch/analysis/dynamic_tools/criticality.py
def _fit_distribution(data):
    """Helper to fit power law distribution using powerlaw package."""
    if len(data) < 10:
        return np.nan, None
    try:
        # discrete=True because sizes/durations are counts (integers)
        fit = powerlaw.Fit(data, discrete=True, verbose=False)
        return fit.alpha, fit
    except Exception as e:
        warnings.warn(f"Failed to fit power law distribution: {e}")
        return np.nan, None

_fit_scaling(x, y)

Helper to fit power law scaling y = a * x^gamma using curve_fit.

Source code in btorch/analysis/dynamic_tools/criticality.py
def _fit_scaling(x, y):
    """Helper to fit power law scaling y = a * x^gamma using curve_fit."""
    if len(x) < 3:
        return np.nan, None

    try:
        # Initial guess: a=1, gamma=1.5
        popt, pcov = curve_fit(_power_law_func, x, y, p0=[1, 1.5], maxfev=2000)
        gamma = popt[1]

        # Calculate R^2
        residuals = y - _power_law_func(x, *popt)
        ss_res = np.sum(residuals**2)
        ss_tot = np.sum((y - np.mean(y)) ** 2)
        r_squared = 1 - (ss_res / ss_tot)

        stats = {"r_squared": r_squared, "popt": popt, "pcov": pcov}
        return gamma, stats
    except Exception as e:
        warnings.warn(f"Failed to fit scaling relation: {e}")
        return np.nan, None

_power_law_func(x, a, gamma)

Source code in btorch/analysis/dynamic_tools/criticality.py
def _power_law_func(x, a, gamma):
    return a * np.power(x, gamma)

calculate_dfa(spike_train, bin_size=1)

Calculate Detrended Fluctuation Analysis (DFA) exponent alpha.

Meaning of alpha: - 0.5: White noise (no memory) - 0.5 < alpha < 1.0: Long-range memory (fractal structure) - 1.0: 1/f noise (Pink noise) - 1.5: Brownian motion (Random walk)

Parameters:

Name Type Description Default
spike_train ndarray

Binary spike matrix of shape (time_steps, n_neurons).

required
bin_size int

Width of time bin in number of time steps.

1

Returns:

Name Type Description
float

The DFA exponent alpha.

Source code in btorch/analysis/dynamic_tools/criticality.py
def calculate_dfa(spike_train: np.ndarray, bin_size: int = 1):
    """Calculate Detrended Fluctuation Analysis (DFA) exponent alpha.

    Meaning of alpha:
    - 0.5: White noise (no memory)
    - 0.5 < alpha < 1.0: Long-range memory (fractal structure)
    - 1.0: 1/f noise (Pink noise)
    - 1.5: Brownian motion (Random walk)

    Args:
        spike_train (np.ndarray): Binary spike matrix of shape (time_steps, n_neurons).
        bin_size (int): Width of time bin in number of time steps.

    Returns:
        float: The DFA exponent alpha.
    """
    if not HAS_NOLDS:
        raise ImportError(
            "nolds package is required for DFA analysis. "
            "Install with: pip install nolds"
        )

    # Ensure input is numpy array
    spike_train = np.array(spike_train)

    # 1. Calculate population activity (sum spikes across neurons)
    if spike_train.ndim == 2:
        population_activity = np.sum(spike_train, axis=1)
    else:
        population_activity = spike_train

    # 2. Binning
    if bin_size > 1:
        n_bins = len(population_activity) // bin_size
        population_activity = population_activity[: n_bins * bin_size]
        population_activity = population_activity.reshape(-1, bin_size).sum(axis=1)

    # 3. Calculate DFA using nolds
    # nolds.dfa expects the time series (it performs integration internally)
    try:
        alpha = nolds.dfa(population_activity)
        return alpha
    except Exception as e:
        warnings.warn(f"Failed to calculate DFA: {e}")
        return np.nan

compute_avalanche_statistics(spike_train, bin_size=1)

Calculate avalanche size (S) and duration (T) distributions and their power-law exponents.

Definition: An avalanche is defined as a continuous sequence of time bins (width bin_size) containing at least one spike, flanked by empty bins.

Parameters:

Name Type Description Default
spike_train ndarray

Binary spike matrix of shape (time_steps, n_neurons).

required
bin_size int

Width of time bin in number of time steps.

1

Returns:

Name Type Description
dict

Dictionary containing: - 'tau': Power-law exponent for avalanche size distribution P(S) ~ S^-tau - 'alpha': Power-law exponent for avalanche duration distribution P(T) ~ T^-alpha - 'gamma': Power-law exponent for average size vs duration (T) ~ T^gamma - 'gamma_pred': Predicted gamma based on tau and alpha: (alpha-1)/(tau-1) - 'CCC': Criticality Consistency Coefficient: 1 - |gamma - gamma_pred| / gamma - 'sizes': List of avalanche sizes - 'durations': List of avalanche durations - 'avg_size_by_duration': Tuple (unique_durations, mean_sizes) - 'fit_S': powerlaw.Fit object for sizes - 'fit_T': powerlaw.Fit object for durations

Source code in btorch/analysis/dynamic_tools/criticality.py
def compute_avalanche_statistics(spike_train: np.ndarray, bin_size: int = 1):
    """Calculate avalanche size (S) and duration (T) distributions and their
    power-law exponents.

    Definition: An avalanche is defined as a continuous sequence of time bins
    (width bin_size) containing at least one spike, flanked by empty bins.

    Args:
        spike_train (np.ndarray): Binary spike matrix of shape (time_steps, n_neurons).
        bin_size (int): Width of time bin in number of time steps.

    Returns:
        dict: Dictionary containing:
            - 'tau': Power-law exponent for avalanche size distribution P(S) ~
              S^-tau
            - 'alpha': Power-law exponent for avalanche duration distribution
              P(T) ~ T^-alpha
            - 'gamma': Power-law exponent for average size vs duration <S>(T) ~
              T^gamma
            - 'gamma_pred': Predicted gamma based on tau and alpha:
              (alpha-1)/(tau-1)
            - 'CCC': Criticality Consistency Coefficient: 1 - |gamma -
              gamma_pred| / gamma
            - 'sizes': List of avalanche sizes
            - 'durations': List of avalanche durations
            - 'avg_size_by_duration': Tuple (unique_durations, mean_sizes)
            - 'fit_S': powerlaw.Fit object for sizes
            - 'fit_T': powerlaw.Fit object for durations
    """
    # Ensure input is numpy array
    spike_train = np.array(spike_train)

    # Check dimensions. We expect (Time, Neurons).
    if spike_train.ndim != 2:
        raise ValueError("spike_train must be a 2D matrix (time_steps, n_neurons)")

    # 1. Calculate population activity (sum spikes across neurons)
    population_activity = np.sum(spike_train, axis=1)  # Shape: (T,)

    # 2. Binning
    if bin_size > 1:
        n_bins = len(population_activity) // bin_size
        # Truncate to multiple of bin_size
        population_activity = population_activity[: n_bins * bin_size]
        # Reshape and sum
        population_activity = population_activity.reshape(-1, bin_size).sum(axis=1)

    # 3. Identify avalanches
    # Active bins are those with > 0 spikes
    is_active = population_activity > 0

    # Find continuous sequences of active bins
    # Pad with False to detect start/end at boundaries
    padded_active = np.concatenate(([False], is_active, [False]))
    diff = np.diff(padded_active.astype(int))

    starts = np.where(diff == 1)[0]
    ends = np.where(diff == -1)[0]

    sizes = []
    durations = []

    for start, end in zip(starts, ends):
        # segment from start to end (exclusive)
        segment = population_activity[start:end]

        # Size (S): Total number of spikes in the avalanche
        s = np.sum(segment)

        # Duration (T): Number of time bins the avalanche lasts
        t = len(segment)  # equivalent to end - start

        sizes.append(s)
        durations.append(t)

    sizes = np.array(sizes)
    durations = np.array(durations)

    results = {
        "sizes": sizes,
        "durations": durations,
        "tau": np.nan,
        "alpha": np.nan,
        "gamma": np.nan,
        "gamma_pred": np.nan,
        "CCC": np.nan,
        "fit_S": None,
        "fit_T": None,
    }

    if len(sizes) < 10:
        warnings.warn(
            f"Not enough avalanches to fit power law. Found {len(sizes)} avalanches."
        )
        return results

    # 4. Fit power laws using MLE (powerlaw package)
    results["tau"], results["fit_S"] = _fit_distribution(sizes)
    results["alpha"], results["fit_T"] = _fit_distribution(durations)

    # 5. Average Size vs. Duration Scaling (<S>(T) ~ T^gamma)
    if len(durations) > 0:
        # Use bincount for fast grouping by integer duration
        counts = np.bincount(durations)
        sum_sizes = np.bincount(durations, weights=sizes)

        # Filter out durations that didn't occur
        mask = counts > 0
        unique_durations = np.arange(len(counts))[mask]
        mean_sizes = sum_sizes[mask] / counts[mask]

        results["avg_size_by_duration"] = (unique_durations, mean_sizes)

        # Fit scaling relation using curve_fit (non-linear least squares)
        results["gamma"], results["gamma_stats"] = _fit_scaling(
            unique_durations, mean_sizes
        )

    # 6. Calculate Criticality Consistency Coefficient (CCC)
    # gamma_pred = (alpha - 1) / (tau - 1)
    # CCC = 1 - |gamma_obs - gamma_pred| / gamma_obs
    if (
        not np.isnan(results["tau"])
        and not np.isnan(results["alpha"])
        and not np.isnan(results["gamma"])
    ):
        try:
            if results["tau"] != 1:
                gamma_pred = (results["alpha"] - 1) / (results["tau"] - 1)
                results["gamma_pred"] = gamma_pred

                if results["gamma"] != 0:
                    ccc = 1 - abs(results["gamma"] - gamma_pred) / results["gamma"]
                    results["CCC"] = ccc
        except Exception as e:
            warnings.warn(f"Failed to calculate CCC: {e}")

    return results

btorch.analysis.dynamic_tools.ei_balance

E/I balance analysis tools for spiking neural networks.

Functions

_compute_eci(I_e, I_i, *, I_ext=None, batch_axis=None, dtype=None)

Source code in btorch/analysis/dynamic_tools/ei_balance.py
def _compute_eci(
    I_e: torch.Tensor | np.ndarray,
    I_i: torch.Tensor | np.ndarray,
    *,
    I_ext: torch.Tensor | np.ndarray | None = None,
    batch_axis: tuple[int, ...] | int | None = None,
    dtype: torch.dtype | np.dtype | None = None,
) -> torch.Tensor | np.ndarray:
    if isinstance(batch_axis, int):
        batch_axis = (batch_axis,)

    # Check if all inputs are zero - return ones in that case
    if (I_e == 0).all() and (I_i == 0).all() and (I_ext is None or (I_ext == 0).all()):
        # Determine which axes are aggregated to compute output shape
        if batch_axis is not None:
            agg_axes = {0} | set(batch_axis)  # time (0) + batch axes
        else:
            agg_axes = set(range(I_e.ndim - 1))  # all except last (neurons)
        output_shape = tuple(I_e.shape[i] for i in range(I_e.ndim) if i not in agg_axes)
        if len(output_shape) == 0:
            output_shape = (1,)
        if isinstance(I_e, torch.Tensor):
            return torch.ones(output_shape, dtype=I_e.dtype, device=I_e.device)
        return np.ones(output_shape, dtype=I_e.dtype)

    # Handle I_ext by splitting into excitatory and inhibitory parts
    if I_ext is not None:
        if isinstance(I_e, torch.Tensor):
            I_ext_pos = torch.clamp(I_ext, min=0)
            I_ext_neg = torch.clamp(I_ext, max=0)
        else:
            I_ext_pos = np.clip(I_ext, 0, None)
            I_ext_neg = np.clip(I_ext, None, 0)
        I_e_eff = I_e + I_ext_pos
        I_i_eff = I_i + I_ext_neg
    else:
        I_e_eff = I_e
        I_i_eff = I_i

    # Compute recurrent current
    I_rec = I_e_eff + I_i_eff

    # Determine axes/dims for aggregation
    if batch_axis is not None:
        agg_dims = (0,) + tuple(batch_axis)
    else:
        agg_dims = tuple(range(I_e.ndim - 1))

    if isinstance(I_e, torch.Tensor):
        numer = torch.abs(I_rec).mean(dim=agg_dims, dtype=dtype)
        denom = (torch.abs(I_e_eff) + torch.abs(I_i_eff)).mean(
            dim=agg_dims, dtype=dtype
        )
        denom = denom + torch.finfo(I_e.dtype).eps
        return numer / denom
    else:
        numer = np.abs(I_rec).mean(axis=agg_dims, dtype=dtype)
        denom = (np.abs(I_e_eff) + np.abs(I_i_eff)).mean(axis=agg_dims, dtype=dtype)
        denom = denom + np.finfo(I_e.dtype).eps
        return numer / denom

_compute_lag_correlation(x, y, *, dt=1.0, max_lag_ms=30.0, batch_axis=None, use_fft=True, dtype=None)

Source code in btorch/analysis/dynamic_tools/ei_balance.py
def _compute_lag_correlation(
    x: torch.Tensor | np.ndarray,
    y: torch.Tensor | np.ndarray,
    *,
    dt: float = 1.0,
    max_lag_ms: float = 30.0,
    batch_axis: tuple[int, ...] | int | None = None,
    use_fft: bool = True,
    dtype: torch.dtype | np.dtype | None = None,
):
    assert max_lag_ms >= dt
    if isinstance(x, torch.Tensor):
        y = torch.as_tensor(y, dtype=x.dtype, device=x.device)
    else:
        y = y.numpy(force=True) if isinstance(y, torch.Tensor) else np.asarray(y)

    is_torch = isinstance(x, torch.Tensor)
    T = x.shape[0]
    n_neurons = x.shape[-1]
    lag_bins = int(max_lag_ms / dt)
    max_lag = min(lag_bins, T - 1)

    if isinstance(batch_axis, int):
        batch_axis = (batch_axis,)

    # Common reshape logic (works for both torch and numpy)
    x_3d = x.reshape(T, -1, n_neurons)
    y_3d = y.reshape(T, -1, n_neurons)
    flat_x = x_3d.reshape(T, -1)
    flat_y = y_3d.reshape(T, -1)

    if use_fft or T > 100:
        corr_flat = _cross_correlation_fft(flat_x, flat_y, max_lag, dtype)
    else:
        corr_flat = _cross_correlation_direct(flat_x, flat_y, max_lag, dtype)

    n_lags = corr_flat.shape[0]
    corr_3d = corr_flat.reshape(n_lags, -1, n_neurons)

    if is_torch:
        if batch_axis is not None:
            corr_agg = corr_3d.mean(dim=1, dtype=dtype)
        else:
            corr_agg = corr_3d.reshape(n_lags, -1)

        peak_corr, best_lag_idx = torch.max(corr_agg, dim=0)
        max_lag_actual = min(lag_bins, T - 1)
        best_lags = best_lag_idx - max_lag_actual
        best_lag_ms = best_lags * dt

        info = {
            "corr_over_lags": corr_agg,
            "lag_values_ms": torch.arange(
                -max_lag_actual, max_lag_actual + 1, device=x.device
            )
            * dt,
        }
        return peak_corr, best_lag_ms, info
    else:
        if batch_axis is not None:
            corr_agg = corr_3d.mean(axis=1)
        else:
            corr_agg = corr_3d.reshape(n_lags, -1)

        best_lag_idx = np.argmax(corr_agg, axis=0)
        peak_corr = corr_agg[best_lag_idx, np.arange(corr_agg.shape[1])]
        max_lag_actual = min(lag_bins, T - 1)
        best_lags = best_lag_idx - max_lag_actual
        best_lag_ms = best_lags * dt

        info = {
            "corr_over_lags": corr_agg,
            "lag_values_ms": np.arange(-max_lag_actual, max_lag_actual + 1) * dt,
        }
        return peak_corr, best_lag_ms, info

_cross_correlation_direct(x, y, max_lag, dtype=None)

Direct cross-correlation (simpler for short signals).

Source code in btorch/analysis/dynamic_tools/ei_balance.py
def _cross_correlation_direct(
    x: torch.Tensor | np.ndarray,
    y: torch.Tensor | np.ndarray,
    max_lag: int,
    dtype: torch.dtype | np.dtype | None = None,
) -> torch.Tensor | np.ndarray:
    """Direct cross-correlation (simpler for short signals)."""
    is_torch = isinstance(x, torch.Tensor)
    T, N = x.shape

    if is_torch:
        device = x.device
        # Normalize
        x_mean = x.mean(dim=0, keepdim=True)
        y_mean = y.mean(dim=0, keepdim=True)
        x_std = x.std(dim=0, keepdim=True) + torch.finfo(x.dtype).eps
        y_std = y.std(dim=0, keepdim=True) + torch.finfo(x.dtype).eps
        x_norm = (x - x_mean) / x_std
        y_norm = (y - y_mean) / y_std
        # Compute correlations for each lag
        max_lag_actual = min(max_lag, T - 1)
        n_lags = 2 * max_lag_actual + 1
        corr = torch.zeros(n_lags, N, device=device, dtype=x.dtype)
        for i, lag in enumerate(range(-max_lag_actual, max_lag_actual + 1)):
            if lag < 0:
                c = (x_norm[:lag] * y_norm[-lag:]).mean(dim=0, dtype=dtype)
            elif lag > 0:
                c = (x_norm[lag:] * y_norm[:-lag]).mean(dim=0, dtype=dtype)
            else:
                c = (x_norm * y_norm).mean(dim=0, dtype=dtype)
            corr[i, :] = c
        return corr
    else:
        # Normalize
        x_norm = (x - x.mean(axis=0)) / (x.std(axis=0) + np.finfo(x.dtype).eps)
        y_norm = (y - y.mean(axis=0)) / (y.std(axis=0) + np.finfo(x.dtype).eps)
        # Compute correlations for each lag
        n_lags = 2 * max_lag + 1
        corr = np.zeros((n_lags, N))
        for i, lag in enumerate(range(-max_lag, max_lag + 1)):
            if lag < 0:
                c = (x_norm[:lag] * y_norm[-lag:]).mean(axis=0, dtype=dtype)
            elif lag > 0:
                c = (x_norm[lag:] * y_norm[:-lag]).mean(axis=0, dtype=dtype)
            else:
                c = (x_norm * y_norm).mean(axis=0, dtype=dtype)
            corr[i, :] = c
        return corr

_cross_correlation_fft(x, y, max_lag, dtype=None)

FFT-based cross-correlation.

Source code in btorch/analysis/dynamic_tools/ei_balance.py
def _cross_correlation_fft(
    x: torch.Tensor | np.ndarray,
    y: torch.Tensor | np.ndarray,
    max_lag: int,
    dtype: torch.dtype | np.dtype | None = None,
) -> torch.Tensor | np.ndarray:
    """FFT-based cross-correlation."""
    is_torch = isinstance(x, torch.Tensor)
    T, N = x.shape
    n_fft = 2 * T

    if is_torch:
        # Demean and compute FFT
        x_demean = x - x.mean(dim=0, keepdim=True)
        y_demean = y - y.mean(dim=0, keepdim=True)
        X = torch.fft.rfft(x_demean.float(), n=n_fft, dim=0)
        Y = torch.fft.rfft(y_demean.float(), n=n_fft, dim=0)
        cross_spec = X * Y.conj()
        corr_full = torch.fft.irfft(cross_spec, n=n_fft, dim=0)
        # Normalize
        x_std = x.std(dim=0, keepdim=True) + torch.finfo(x.dtype).eps
        y_std = y.std(dim=0, keepdim=True) + torch.finfo(x.dtype).eps
        corr_norm = corr_full / (x_std * y_std * T)
        # Extract valid lags
        max_lag_actual = min(max_lag, T - 1)
        neg_lags = corr_norm[-max_lag_actual:, :]
        pos_lags = corr_norm[: max_lag_actual + 1, :]
        return torch.cat([neg_lags, pos_lags], dim=0)
    else:
        # Demean and compute FFT
        x_demean = x - x.mean(axis=0, keepdims=True, dtype=dtype)
        y_demean = y - y.mean(axis=0, keepdims=True, dtype=dtype)
        X = np.fft.rfft(x_demean, n=n_fft, axis=0)
        Y = np.fft.rfft(y_demean, n=n_fft, axis=0)
        cross_spec = X * np.conj(Y)
        corr_full = np.fft.irfft(cross_spec, n=n_fft, axis=0)
        # Normalize
        x_std = x.std(axis=0, keepdims=True, dtype=dtype) + np.finfo(x.dtype).eps
        y_std = y.std(axis=0, keepdims=True, dtype=dtype) + np.finfo(x.dtype).eps
        corr_norm = corr_full / (x_std * y_std * T)
        # Extract valid lags
        neg_lags = corr_norm[-max_lag:, :]
        pos_lags = corr_norm[: max_lag + 1, :]
        return np.concatenate([neg_lags, pos_lags], axis=0)

compute_eci(I_e, I_i, *, I_ext=None, batch_axis=None, dtype=None)

Compute Excitatory-Inhibitory Cancellation Index (ECI).

ECI measures the degree of cancellation between excitatory and inhibitory currents. ECI = 0 indicates perfect cancellation, ECI = 1 indicates no cancellation.

Formula: ECI = |I_rec + I_ext| / (|I_e| + |I_i|) where I_rec = I_e + I_i

This function is decorated with @use_stats and @use_percentiles. See use_stats() and use_percentiles() for detailed usage.

Parameters:

Name Type Description Default
I_e Tensor | ndarray

Excitatory current [T,..., N]

required
I_i Tensor | ndarray

Inhibitory current [T,..., N]. Note: assumed to be negative (inhibitory).

required
I_ext Tensor | ndarray | None

External current [T,..., N] (optional)

None
batch_axis tuple[int, ...] | int | None

Axes to aggregate over (e.g., trials) in addition to the time axis. If None, averages over all non-neuron dimensions.

None
dtype dtype | dtype | None

Data type for aggregation.

None
stat

Aggregation statistic to return instead of per-neuron values. Options: "mean", "median", "max", "min", "std", "var", "argmax", "argmin", "cv". See use_stats().

required
stat_info

Additional statistics to compute and store in info dict. See use_stats().

required
nan_policy

How to handle NaN values ("skip", "warn", "assert"). See use_stats().

required
inf_policy

How to handle Inf values ("propagate", "skip", "warn", "assert"). See use_stats().

required
percentiles

Percentile level(s) in [0, 100] to compute. See use_percentiles().

required

Returns:

Name Type Description
eci Tensor | ndarray

ECI values per neuron (shape depends on batch_axis). If stat is provided, returns the aggregated statistic instead.

info Tensor | ndarray

Dictionary with additional statistics and optional percentile data.

Source code in btorch/analysis/dynamic_tools/ei_balance.py
@use_percentiles(value_key="eci")
@use_stats(value_key="eci")
def compute_eci(
    I_e: torch.Tensor | np.ndarray,
    I_i: torch.Tensor | np.ndarray,
    *,
    I_ext: torch.Tensor | np.ndarray | None = None,
    batch_axis: tuple[int, ...] | int | None = None,
    dtype: torch.dtype | np.dtype | None = None,
) -> torch.Tensor | np.ndarray:
    """Compute Excitatory-Inhibitory Cancellation Index (ECI).

    ECI measures the degree of cancellation between excitatory and inhibitory
    currents. ECI = 0 indicates perfect cancellation, ECI = 1 indicates no
    cancellation.

    Formula: ECI = |I_rec + I_ext| / (|I_e| + |I_i|)
    where I_rec = I_e + I_i

    This function is decorated with `@use_stats` and `@use_percentiles`.
    See [`use_stats()`](btorch/analysis/statistics.py:483) and
    [`use_percentiles()`](btorch/analysis/statistics.py:777) for detailed usage.

    Args:
        I_e: Excitatory current [T,..., N]
        I_i: Inhibitory current [T,..., N]. Note: assumed to be
            negative (inhibitory).
        I_ext: External current [T,..., N] (optional)
        batch_axis: Axes to aggregate over (e.g., trials) in addition to the time axis.
            If None, averages over all non-neuron dimensions.
        dtype: Data type for aggregation.
        stat: Aggregation statistic to return instead of per-neuron values.
            Options: "mean", "median", "max", "min", "std", "var", "argmax",
            "argmin", "cv". See [`use_stats()`](btorch/analysis/statistics.py:483).
        stat_info: Additional statistics to compute and store in info dict.
            See [`use_stats()`](btorch/analysis/statistics.py:483).
        nan_policy: How to handle NaN values ("skip", "warn", "assert").
            See [`use_stats()`](btorch/analysis/statistics.py:483).
        inf_policy: How to handle Inf values ("propagate", "skip", "warn",
            "assert"). See [`use_stats()`](btorch/analysis/statistics.py:483).
        percentiles: Percentile level(s) in [0, 100] to compute.
            See [`use_percentiles()`](btorch/analysis/statistics.py:777).

    Returns:
        eci: ECI values per neuron (shape depends on batch_axis).
            If `stat` is provided, returns the aggregated statistic instead.
        info: Dictionary with additional statistics and optional percentile data.
    """
    return _compute_eci(I_e, I_i, I_ext=I_ext, batch_axis=batch_axis, dtype=dtype)

compute_ei_balance(I_e, I_i, *, I_ext=None, dt=1.0, max_lag_ms=30.0, batch_axis=None)

Compute E/I balance metrics including ECI and lag correlation.

This function is decorated with @use_stats and @use_percentiles. See use_stats() and use_percentiles() for detailed usage.

Parameters:

Name Type Description Default
I_e Tensor | ndarray

Excitatory current [T, ..., N]

required
I_i Tensor | ndarray

Inhibitory current [T, ..., N]

required
I_ext Tensor | ndarray | None

External current [T, ..., N] (optional)

None
dt float

Time step in ms

1.0
max_lag_ms float

Maximum lag for correlation analysis

30.0
batch_axis tuple[int, ...] | int | None

Axes to aggregate over (e.g., trials). If None, averages over all non-time dimensions.

None
stat

Aggregation statistic per return position. See use_stats().

required
stat_info

Additional statistics per position. See use_stats().

required
nan_policy

How to handle NaN values ("skip", "warn", "assert"). See use_stats().

required
inf_policy

How to handle Inf values ("propagate", "skip", "warn", "assert"). See use_stats().

required
percentiles

Percentile level(s) in [0, 100] to compute per position. See use_percentiles().

required

Returns:

Name Type Description
eci

ECI values per neuron

peak_corr

Peak correlation between E and I currents

best_lag_ms

Best lag in ms (positive = I lags E)

info

Dictionary with detailed analysis results

Source code in btorch/analysis/dynamic_tools/ei_balance.py
@use_percentiles(
    value_key={0: "eci", 1: "peak_corr", 2: "best_lag_ms"},
    default_percentiles=(10, 50, 90),
)
@use_stats(
    value_key={0: "eci", 1: "peak_corr", 2: "best_lag_ms"},
    default_stat_info={0: "mean", 1: "mean", 2: "mean"},
)
def compute_ei_balance(
    I_e: torch.Tensor | np.ndarray,
    I_i: torch.Tensor | np.ndarray,
    *,
    I_ext: torch.Tensor | np.ndarray | None = None,
    dt: float = 1.0,
    max_lag_ms: float = 30.0,
    batch_axis: tuple[int, ...] | int | None = None,
):
    """Compute E/I balance metrics including ECI and lag correlation.

    This function is decorated with `@use_stats` and `@use_percentiles`.
    See [`use_stats()`](btorch/analysis/statistics.py:483) and
    [`use_percentiles()`](btorch/analysis/statistics.py:777) for detailed usage.

    Args:
        I_e: Excitatory current [T, ..., N]
        I_i: Inhibitory current [T, ..., N]
        I_ext: External current [T, ..., N] (optional)
        dt: Time step in ms
        max_lag_ms: Maximum lag for correlation analysis
        batch_axis: Axes to aggregate over (e.g., trials). If None, averages
            over all non-time dimensions.
        stat: Aggregation statistic per return position.
            See [`use_stats()`](btorch/analysis/statistics.py:483).
        stat_info: Additional statistics per position.
            See [`use_stats()`](btorch/analysis/statistics.py:483).
        nan_policy: How to handle NaN values ("skip", "warn", "assert").
            See [`use_stats()`](btorch/analysis/statistics.py:483).
        inf_policy: How to handle Inf values ("propagate", "skip", "warn",
            "assert"). See [`use_stats()`](btorch/analysis/statistics.py:483).
        percentiles: Percentile level(s) in [0, 100] to compute per position.
            See [`use_percentiles()`](btorch/analysis/statistics.py:777).

    Returns:
        eci: ECI values per neuron
        peak_corr: Peak correlation between E and I currents
        best_lag_ms: Best lag in ms (positive = I lags E)
        info: Dictionary with detailed analysis results
    """
    # Compute ECI
    eci, eci_info = compute_eci(I_e, I_i, I_ext=I_ext, batch_axis=batch_axis, stat=None)

    # Compute lag correlation between E and I
    peak_corr, best_lag_ms, lag_info = compute_lag_correlation(
        I_e, -I_i, dt=dt, max_lag_ms=max_lag_ms, batch_axis=batch_axis, stat=None
    )

    info = {
        "eci_info": eci_info,
        "lag_info": lag_info,
    }

    return eci, peak_corr, best_lag_ms, info

compute_lag_correlation(x, y, *, dt=1.0, max_lag_ms=30.0, batch_axis=None, use_fft=True)

Compute lagged cross-correlation between two signals.

Uses FFT-based correlation for efficiency. Returns correlation values and best lag per neuron.

This function is decorated with @use_stats and @use_percentiles. See use_stats() and use_percentiles() for detailed usage.

Parameters:

Name Type Description Default
x Tensor | ndarray

First signal [T, ...] or [T, B, ...]

required
y Tensor | ndarray

Second signal [T, ...] or [T, B, ...]

required
dt float

Time step in ms

1.0
max_lag_ms float

Maximum lag for correlation in ms

30.0
batch_axis tuple[int, ...] | int | None

Axes to aggregate over (e.g., trials). If None, averages over all non-time dimensions.

None
use_fft bool

If True, use FFT-based correlation (faster for long signals)

True
stat

Aggregation statistic per return position. Can be a single stat or dict mapping position to stat (e.g., {0: "mean", 1: "median"}). See use_stats().

required
stat_info

Additional statistics to compute and store in info dict. Can be a single stat, iterable, or dict mapping position to stat(s). See use_stats().

required
nan_policy

How to handle NaN values ("skip", "warn", "assert"). See use_stats().

required
inf_policy

How to handle Inf values ("propagate", "skip", "warn", "assert"). See use_stats().

required
percentiles

Percentile level(s) in [0, 100] to compute per position. Can be a single value or dict mapping position to percentile(s). See use_percentiles().

required

Returns:

Name Type Description
peak_corr

Correlation values at best lag per neuron. If stat is provided, returns aggregated value(s) instead.

best_lag_ms

Best lag in ms per neuron. If stat is provided, returns aggregated value(s) instead.

info

Dictionary with correlation over lags, best lags, etc.

Example
Get per-neuron values

peak, lag, info = compute_lag_correlation(x, y)

Aggregate: max peak correlation, mean best lag

peak_max, lag_mean, info = compute_lag_correlation( x, y, stat={0: "max", 1: "mean"} )

Source code in btorch/analysis/dynamic_tools/ei_balance.py
@use_stats(value_key={0: "peak_corr", 1: "best_lag"})
@use_percentiles(value_key={0: "peak_corr", 1: "best_lag"})
def compute_lag_correlation(
    x: torch.Tensor | np.ndarray,
    y: torch.Tensor | np.ndarray,
    *,
    dt: float = 1.0,
    max_lag_ms: float = 30.0,
    batch_axis: tuple[int, ...] | int | None = None,
    use_fft: bool = True,
):
    """Compute lagged cross-correlation between two signals.

    Uses FFT-based correlation for efficiency. Returns correlation values
    and best lag per neuron.

    This function is decorated with `@use_stats` and `@use_percentiles`.
    See [`use_stats()`](btorch/analysis/statistics.py:483) and
    [`use_percentiles()`](btorch/analysis/statistics.py:777) for detailed usage.

    Args:
        x: First signal [T, ...] or [T, B, ...]
        y: Second signal [T, ...] or [T, B, ...]
        dt: Time step in ms
        max_lag_ms: Maximum lag for correlation in ms
        batch_axis: Axes to aggregate over (e.g., trials). If None, averages
            over all non-time dimensions.
        use_fft: If True, use FFT-based correlation (faster for long signals)
        stat: Aggregation statistic per return position. Can be a single stat
            or dict mapping position to stat (e.g., {0: "mean", 1: "median"}).
            See [`use_stats()`](btorch/analysis/statistics.py:483).
        stat_info: Additional statistics to compute and store in info dict.
            Can be a single stat, iterable, or dict mapping position to stat(s).
            See [`use_stats()`](btorch/analysis/statistics.py:483).
        nan_policy: How to handle NaN values ("skip", "warn", "assert").
            See [`use_stats()`](btorch/analysis/statistics.py:483).
        inf_policy: How to handle Inf values ("propagate", "skip", "warn",
            "assert"). See [`use_stats()`](btorch/analysis/statistics.py:483).
        percentiles: Percentile level(s) in [0, 100] to compute per position.
            Can be a single value or dict mapping position to percentile(s).
            See [`use_percentiles()`](btorch/analysis/statistics.py:777).

    Returns:
        peak_corr: Correlation values at best lag per neuron.
            If `stat` is provided, returns aggregated value(s) instead.
        best_lag_ms: Best lag in ms per neuron.
            If `stat` is provided, returns aggregated value(s) instead.
        info: Dictionary with correlation over lags, best lags, etc.

    Example:
        # Get per-neuron values
        peak, lag, info = compute_lag_correlation(x, y)

        # Aggregate: max peak correlation, mean best lag
        peak_max, lag_mean, info = compute_lag_correlation(
            x, y, stat={0: "max", 1: "mean"}
        )
    """
    return _compute_lag_correlation(
        x, y, dt=dt, max_lag_ms=max_lag_ms, batch_axis=batch_axis, use_fft=use_fft
    )

use_percentiles(func=None, *, value_key='values', default_percentiles=None)

Decorator to add percentiles arg and optionally compute percentiles.

This decorator adds a percentiles parameter to a function that returns per-neuron values. Percentiles are only computed if percentiles is not None. Results are stored in info[f"{value_key}_percentile"].

Can also accept a dict mapping return positions to labels for functions returning multiple values (e.g., {1: "eci", 3: "lag"}).

The decorated function should return either: - A tuple of (values, info_dict) where values are per-neuron metrics - Just the per-neuron values (will be wrapped in a tuple with empty dict) - A tuple of multiple values with info as the last element

Parameters:

Name Type Description Default
func Callable | None

The function to decorate (or None if using with parentheses)

None
value_key str | dict[int, str]

Key to use in info dict for the percentile result

'values'

Returns:

Type Description
Callable

Decorated function with added percentiles parameter

Example
@use_percentiles
def compute_metric(data, *, percentiles=None):
    values = some_computation(data)  # per-neuron values
    return values, {"raw": values}

# Usage:
values, info = compute_metric(data)  # no percentiles computed
values, info = compute_metric(data, percentiles=0.5)  # compute median
values, info = compute_metric(
    data, percentiles=(0.25, 0.5, 0.75)
)  # compute quartiles
Source code in btorch/analysis/statistics.py
def use_percentiles(
    func: Callable | None = None,
    *,
    value_key: str | dict[int, str] = "values",
    default_percentiles: float | tuple[float, ...] | None = None,
) -> Callable:
    """Decorator to add percentiles arg and optionally compute percentiles.

    This decorator adds a `percentiles` parameter to a function that returns
    per-neuron values. Percentiles are only computed if percentiles is not None.
    Results are stored in info[f"{value_key}_percentile"].

    Can also accept a dict mapping return positions to labels for functions
    returning multiple values (e.g., {1: "eci", 3: "lag"}).

    The decorated function should return either:
    - A tuple of (values, info_dict) where values are per-neuron metrics
    - Just the per-neuron values (will be wrapped in a tuple with empty dict)
    - A tuple of multiple values with info as the last element

    Args:
        func: The function to decorate (or None if using with parentheses)
        value_key: Key to use in info dict for the percentile result

    Returns:
        Decorated function with added percentiles parameter

    Example:
        ```python
        @use_percentiles
        def compute_metric(data, *, percentiles=None):
            values = some_computation(data)  # per-neuron values
            return values, {"raw": values}

        # Usage:
        values, info = compute_metric(data)  # no percentiles computed
        values, info = compute_metric(data, percentiles=0.5)  # compute median
        values, info = compute_metric(
            data, percentiles=(0.25, 0.5, 0.75)
        )  # compute quartiles
        ```
    """

    def decorator(f: Callable) -> Callable:
        @wraps(f)
        def wrapper(
            *args,
            percentiles: float
            | tuple[float, ...]
            | dict[int, float | tuple[float, ...]]
            | None = default_percentiles,
            **kwargs,
        ) -> tuple[Any, ...]:
            # Call the original function
            result = f(*args, **kwargs)
            if percentiles is None:
                return result

            # Unpack result using shared helper
            values_tuple, info = _unpack_result(result, value_key)

            # Ensure info is a dict
            if info is None:
                info = {}

            updated_info = dict(info)

            # Helper to get value key name for a position
            def _get_value_key_name(pos: int) -> str:
                if isinstance(value_key, dict):
                    return value_key.get(pos, f"values{pos}")
                return f"{value_key}{pos}"

            # Helper to get values at a position
            def _get_values(pos: int) -> Any:
                if pos < 0 or pos >= len(values_tuple):
                    raise IndexError(
                        f"Position {pos} out of range for return tuple "
                        f"of length {len(values_tuple)}"
                    )
                return values_tuple[pos]

            # Compute percentiles only if requested
            if isinstance(percentiles, dict):
                # Dict format: {position: percentile_value(s)}
                # Allows different percentiles for different return values
                for pos, perc_value in percentiles.items():
                    values = _get_values(pos)
                    key_name = _get_value_key_name(pos)
                    perc_result = compute_percentiles(values, perc_value)
                    updated_info[f"{key_name}_percentiles"] = perc_result["percentiles"]
                    updated_info[f"{key_name}_levels"] = perc_result["levels"]
            elif isinstance(value_key, dict):
                # Dict value_key with single percentiles value:
                # Apply same percentiles to all positions in value_key
                for pos, label in value_key.items():
                    values = _get_values(pos)
                    perc_result = compute_percentiles(values, percentiles)
                    updated_info[f"{label}_percentiles"] = perc_result["percentiles"]
                    updated_info[f"{label}_levels"] = perc_result["levels"]
            else:
                # Single percentiles format - apply to position 0
                values = _get_values(0)
                perc_result = compute_percentiles(values, percentiles)
                updated_info[f"{value_key}_percentiles"] = perc_result["percentiles"]
                updated_info[f"{value_key}_levels"] = perc_result["levels"]

            if len(values_tuple) > 1:
                return values_tuple + (updated_info,)
            else:
                return values_tuple[0], updated_info

        return wrapper

    if func is None:
        return decorator
    return decorator(func)

use_stats(func=None, *, value_key='values', dim=None, default_stat=None, default_stat_info=None, default_nan_policy='skip', default_inf_policy='propagate')

Decorator to add stat and stat_info args for aggregation.

This decorator adds stat, stat_info, nan_policy, and inf_policy parameters to a function that returns per-neuron values.

  • stat: If not None, returns the aggregated value instead of per-neuron values. The aggregation is stored in info[f"{value_key}_stat"]. Can be a StatChoice, or a dict mapping return position to label (e.g., {1: "eci", 3: "lag"}) for functions returning multiple values.
  • stat_info: Additional stats to compute and store in info dict without affecting the return value. Can be a single StatChoice, Iterable of StatChoice, a dict mapping position to label(s), or None. If dict, format is {position: stat_or_stats} where stat_or_stats can be a single StatChoice or Iterable of StatChoice.
  • dim: Dimension(s) to aggregate over. Can be:
    • None: Flatten all dimensions (default)
    • int: Aggregate over this dimension for all outputs
    • tuple[int, ...]: Aggregate over these dimensions for all outputs
    • dict[int, int | tuple[int, ...] | None]: Different dim for each output position (e.g., {0: 1, 1: 2, 2: None, 3: (1, 3, 4)})
  • nan_policy: How to handle NaN values:
    • "skip": Ignore NaN values (default)
    • "warn": Warn if NaN values found but continue
    • "assert": Raise error if NaN values found
  • inf_policy: How to handle Inf values:
    • "propagate": Keep Inf values (default)
    • "skip": Ignore Inf values
    • "warn": Warn if Inf values found but continue
    • "assert": Raise error if Inf values found

The decorated function should return either: - A tuple of (values, info_dict) where values are per-neuron metrics - Just the per-neuron values (will be wrapped in a tuple with empty dict) - A tuple of multiple values with info as the last element

Parameters:

Name Type Description Default
func Callable | None

The function to decorate (or None if using with parentheses)

None
value_key str | dict[int, str]

Key prefix to use in info dict for stat results

'values'
dim int | tuple[int, ...] | dict[int, int | tuple[int, ...] | None] | None

Dimension(s) to aggregate over for each output

None
default_nan_policy Literal['skip', 'warn', 'assert']

Default nan_policy for this decorated function

'skip'
default_inf_policy Literal['propagate', 'skip', 'warn', 'assert']

Default inf_policy for this decorated function

'propagate'
default_stat StatChoice | dict[int, StatChoice] | None

Default stat for this decorated function

None

Returns:

Type Description
Callable

Decorated function with added stat, stat_info, nan_policy, and

Callable

inf_policy parameters

Example
@use_stat
def compute_metric(
    data,
    *,
    stat=None,
    stat_info=None,
    nan_policy="skip",
    inf_policy="propagate",
):
    values = some_computation(data)  # per-neuron values
    return values, {"raw": values}

# Usage:
values, info = compute_metric(data)  # returns per-neuron values
mean_val, info = compute_metric(data, stat="mean")  # returns aggregated
values, info = compute_metric(
    data, stat_info=["mean", "max"]
)  # extra stats in info

# Multi-value return with dict stat:
@use_stat
def compute_multiple(data, *, stat=None, stat_info=None):
    eci = compute_eci(data)  # per-neuron
    lag = compute_lag(data)  # per-neuron
    return eci, lag, {}  # multiple values

# Aggregate specific positions with dict stat:
eci_mean, lag_mean, info = compute_multiple(
    data, stat={0: "eci", 1: "lag"}
)
Source code in btorch/analysis/statistics.py
def use_stats(
    func: Callable | None = None,
    *,
    value_key: str | dict[int, str] = "values",
    dim: int | tuple[int, ...] | dict[int, int | tuple[int, ...] | None] | None = None,
    default_stat: StatChoice | dict[int, StatChoice] | None = None,
    default_stat_info: (
        StatChoice
        | Iterable[StatChoice]
        | dict[int, StatChoice | Iterable[StatChoice]]
        | None
    ) = None,
    default_nan_policy: Literal["skip", "warn", "assert"] = "skip",
    default_inf_policy: Literal["propagate", "skip", "warn", "assert"] = "propagate",
) -> Callable:
    """Decorator to add stat and stat_info args for aggregation.

    This decorator adds `stat`, `stat_info`, `nan_policy`, and `inf_policy`
    parameters to a function that returns per-neuron values.

    - `stat`: If not None, returns the aggregated value instead of per-neuron
      values. The aggregation is stored in info[f"{value_key}_stat"].
      Can be a StatChoice, or a dict mapping return position to label
      (e.g., {1: "eci", 3: "lag"}) for functions returning multiple values.
    - `stat_info`: Additional stats to compute and store in info dict without
      affecting the return value. Can be a single StatChoice, Iterable of
      StatChoice, a dict mapping position to label(s), or None.
      If dict, format is {position: stat_or_stats} where stat_or_stats can be
      a single StatChoice or Iterable of StatChoice.
    - `dim`: Dimension(s) to aggregate over. Can be:
        - None: Flatten all dimensions (default)
        - int: Aggregate over this dimension for all outputs
        - tuple[int, ...]: Aggregate over these dimensions for all outputs
        - dict[int, int | tuple[int, ...] | None]: Different dim for each
          output position (e.g., {0: 1, 1: 2, 2: None, 3: (1, 3, 4)})
    - `nan_policy`: How to handle NaN values:
        - "skip": Ignore NaN values (default)
        - "warn": Warn if NaN values found but continue
        - "assert": Raise error if NaN values found
    - `inf_policy`: How to handle Inf values:
        - "propagate": Keep Inf values (default)
        - "skip": Ignore Inf values
        - "warn": Warn if Inf values found but continue
        - "assert": Raise error if Inf values found

    The decorated function should return either:
    - A tuple of (values, info_dict) where values are per-neuron metrics
    - Just the per-neuron values (will be wrapped in a tuple with empty dict)
    - A tuple of multiple values with info as the last element

    Args:
        func: The function to decorate (or None if using with parentheses)
        value_key: Key prefix to use in info dict for stat results
        dim: Dimension(s) to aggregate over for each output
        default_nan_policy: Default nan_policy for this decorated function
        default_inf_policy: Default inf_policy for this decorated function
        default_stat: Default stat for this decorated function

    Returns:
        Decorated function with added stat, stat_info, nan_policy, and
        inf_policy parameters

    Example:
        ```python
        @use_stat
        def compute_metric(
            data,
            *,
            stat=None,
            stat_info=None,
            nan_policy="skip",
            inf_policy="propagate",
        ):
            values = some_computation(data)  # per-neuron values
            return values, {"raw": values}

        # Usage:
        values, info = compute_metric(data)  # returns per-neuron values
        mean_val, info = compute_metric(data, stat="mean")  # returns aggregated
        values, info = compute_metric(
            data, stat_info=["mean", "max"]
        )  # extra stats in info

        # Multi-value return with dict stat:
        @use_stat
        def compute_multiple(data, *, stat=None, stat_info=None):
            eci = compute_eci(data)  # per-neuron
            lag = compute_lag(data)  # per-neuron
            return eci, lag, {}  # multiple values

        # Aggregate specific positions with dict stat:
        eci_mean, lag_mean, info = compute_multiple(
            data, stat={0: "eci", 1: "lag"}
        )
        ```
    """

    def decorator(f: Callable) -> Callable:
        # Inspect the wrapped function to determine what arguments it accepts
        sig = inspect.signature(f)
        f_accepts_nan_policy = "nan_policy" in sig.parameters
        f_accepts_inf_policy = "inf_policy" in sig.parameters

        @wraps(f)
        def wrapper(
            *args,
            stat: StatChoice | dict[int, StatChoice] | None = default_stat,
            stat_info: StatChoice
            | Iterable[StatChoice]
            | dict[int, StatChoice | Iterable[StatChoice]]
            | None = default_stat_info,
            nan_policy: Literal["skip", "warn", "assert"] | None = None,
            inf_policy: Literal["propagate", "skip", "warn", "assert"] | None = None,
            **kwargs,
        ) -> tuple[Any, ...]:
            # Use effective policies (passed value > decorator default > "skip")
            effective_nan_policy = (
                nan_policy if nan_policy is not None else default_nan_policy
            )
            effective_inf_policy = (
                inf_policy if inf_policy is not None else default_inf_policy
            )

            # Pass policies to the wrapped function if it accepts them
            if f_accepts_nan_policy:
                kwargs["nan_policy"] = effective_nan_policy
            if f_accepts_inf_policy:
                kwargs["inf_policy"] = effective_inf_policy

            # Call the original function
            result = f(*args, **kwargs)

            # Unpack result using shared helper
            values_tuple, info = _unpack_result(result, value_key)

            # Ensure info is a dict
            if info is None:
                info = {}

            updated_info = dict(info)

            # Helper to get value key name for a position
            def _get_value_key_name(pos: int) -> str:
                if isinstance(value_key, dict):
                    return value_key.get(pos, f"values{pos}")
                if len(values_tuple) > 1:
                    return f"{value_key}{pos}"
                return value_key

            # Helper to get values at a position
            def _get_values(pos: int) -> Any:
                if pos < 0 or pos >= len(values_tuple):
                    raise IndexError(
                        f"Position {pos} out of range for return tuple "
                        f"of length {len(values_tuple)}"
                    )
                return values_tuple[pos]

            # Helper to get effective dim for a position
            def _get_dim_for_pos(pos: int) -> int | tuple[int, ...] | None:
                if dim is None:
                    return None
                if isinstance(dim, dict):
                    return dim.get(pos, None)
                return dim

            # Handle stat parameter
            if stat is not None:
                # Check if stat is a dict mapping positions to stats
                if isinstance(stat, dict):
                    # Multiple position aggregation with dict stat
                    results = []
                    for pos, stat_choice in stat.items():
                        values = _get_values(pos)
                        key_name = _get_value_key_name(pos)
                        effective_dim = _get_dim_for_pos(pos)
                        stat_value = _compute_stat(
                            values,
                            stat_choice,
                            effective_nan_policy,
                            effective_inf_policy,
                            effective_dim,  # type: ignore
                        )
                        results.append(stat_value)
                        updated_info[key_name] = values
                        updated_info[f"{key_name}_{stat_choice}"] = stat_value
                    return tuple(results) + (updated_info,)
                else:
                    # Single stat - apply to position 0
                    values = _get_values(0)
                    key_name = _get_value_key_name(0)
                    effective_dim = _get_dim_for_pos(0)
                    stat_value = _compute_stat(
                        values,
                        stat,
                        effective_nan_policy,
                        effective_inf_policy,
                        effective_dim,  # type: ignore
                    )
                    updated_info[key_name] = values
                    updated_info[f"{key_name}_{stat}"] = stat_value
                    return stat_value, updated_info

            # Handle stat_info parameter
            if stat_info is not None:
                # Check if stat_info is a dict mapping positions to stats
                if isinstance(stat_info, dict):
                    # Dict format: {position: stat_or_stats}
                    for pos, stats in stat_info.items():
                        # Normalize to iterable
                        if isinstance(stats, str):
                            stats_list = [stats]
                        else:
                            stats_list = list(stats)

                        # Use batch computation for efficiency
                        values = _get_values(pos)
                        key_name = _get_value_key_name(pos)
                        effective_dim = _get_dim_for_pos(pos)
                        if len(stats_list) > 1:
                            batch_results = _compute_stats_batch(
                                values,
                                [str(s) for s in stats_list],
                                effective_nan_policy,
                                effective_inf_policy,
                                effective_dim,
                            )
                            for s in stats_list:
                                updated_info[f"{key_name}_{s}"] = batch_results[str(s)]
                        else:
                            # Single stat - no need for batch optimization
                            stat_value = _compute_stat(
                                values,
                                stats_list[0],
                                effective_nan_policy,
                                effective_inf_policy,
                                effective_dim,  # type: ignore
                            )
                            updated_info[f"{key_name}_{stats_list[0]}"] = stat_value
                else:
                    # Original format: apply to position 0
                    # Normalize to iterable
                    if isinstance(stat_info, str):
                        stat_info_list = [stat_info]
                    else:
                        stat_info_list = list(stat_info)

                    # Use batch computation for efficiency (reuses mean/std for cv)
                    values = _get_values(0)
                    key_name = _get_value_key_name(0)
                    effective_dim = _get_dim_for_pos(0)
                    if len(stat_info_list) > 1:
                        batch_results = _compute_stats_batch(
                            values,
                            [str(s) for s in stat_info_list],
                            effective_nan_policy,
                            effective_inf_policy,
                            effective_dim,
                        )
                        for s in stat_info_list:
                            updated_info[f"{key_name}_{s}"] = batch_results[str(s)]
                    else:
                        # Single stat - no need for batch optimization
                        stat_value = _compute_stat(
                            values,
                            stat_info_list[0],
                            effective_nan_policy,
                            effective_inf_policy,
                            effective_dim,  # type: ignore
                        )
                        updated_info[f"{key_name}_{stat_info_list[0]}"] = stat_value

                # Return original values with updated info
                if len(values_tuple) > 1:
                    return values_tuple + (updated_info,)
                else:
                    return (values_tuple[0], updated_info)

            # No stat or stat_info - return original values with info
            if len(values_tuple) > 1:
                return values_tuple + (updated_info,)
            else:
                return (values_tuple[0], updated_info)

        return wrapper

    if func is None:
        return decorator
    return decorator(func)

btorch.analysis.dynamic_tools.lyapunov_dynamics

Functions

compute_expansion_to_contraction_ratio(lyapunov_spectrum)

Compute the ratio of expansion to contraction from the Lyapunov spectrum.

Parameters: - lyapunov_spectrum: A list or numpy array of Lyapunov exponents.

Returns: - ratio: The ratio of the sum of positive exponents to the absolute sum of negative exponents.

Source code in btorch/analysis/dynamic_tools/lyapunov_dynamics.py
def compute_expansion_to_contraction_ratio(lyapunov_spectrum):
    """Compute the ratio of expansion to contraction from the Lyapunov
    spectrum.

    Parameters:
    - lyapunov_spectrum: A list or numpy array of Lyapunov exponents.

    Returns:
    - ratio: The ratio of the sum of positive exponents to the absolute sum of
    negative exponents.
    """
    lyapunov_spectrum = np.array(lyapunov_spectrum)
    positive_sum = np.sum(lyapunov_spectrum[lyapunov_spectrum > 0])
    negative_sum = np.sum(np.abs(lyapunov_spectrum[lyapunov_spectrum < 0]))

    if negative_sum == 0:
        return np.inf  # Avoid division by zero; indicates pure expansion

    ratio = positive_sum / negative_sum
    return ratio

compute_ks_entropy(time_series, emb_dim=6, lag=1)

Compute the Kolmogorov-Sinai (KS) entropy of a given time series using the nolds library.

Parameters: - time_series: A 1D numpy array representing the time series data.

Returns: - ks_entropy: The estimated KS entropy.

Source code in btorch/analysis/dynamic_tools/lyapunov_dynamics.py
def compute_ks_entropy(time_series, emb_dim=6, lag=1):
    """Compute the Kolmogorov-Sinai (KS) entropy of a given time series using
    the nolds library.

    Parameters:
    - time_series: A 1D numpy array representing the time series data.

    Returns:
    - ks_entropy: The estimated KS entropy.
    """
    ks_entropy = nolds.sampen(time_series, emb_dim=emb_dim, lag=lag)
    return ks_entropy

compute_lyapunov_exponent_spectrum(time_series, emb_dim=6, matrix_dim=4, tau=1)

Compute the full Lyapunov spectrum of a given time series using the nolds library.

Parameters: - time_series: A 1D numpy array representing the time series data.

Returns: - lyapunov_spectrum: A list of estimated Lyapunov exponents.

Source code in btorch/analysis/dynamic_tools/lyapunov_dynamics.py
def compute_lyapunov_exponent_spectrum(time_series, emb_dim=6, matrix_dim=4, tau=1):
    """Compute the full Lyapunov spectrum of a given time series using the
    nolds library.

    Parameters:
    - time_series: A 1D numpy array representing the time series data.

    Returns:
    - lyapunov_spectrum: A list of estimated Lyapunov exponents.
    """
    lyapunov_spectrum = nolds.lyap_e(
        time_series, emb_dim=emb_dim, matrix_dim=matrix_dim, tau=tau
    )
    return lyapunov_spectrum

compute_max_lyapunov_exponent(time_series, emb_dim=6, lag=1, tau=1)

Compute the largest Lyapunov exponent of a given time series using the nolds library.

Parameters: - time_series: A 1D numpy array representing the time series data.

Returns: - lyapunov_exponent: The estimated largest Lyapunov exponent.

Source code in btorch/analysis/dynamic_tools/lyapunov_dynamics.py
def compute_max_lyapunov_exponent(time_series, emb_dim=6, lag=1, tau=1):
    """Compute the largest Lyapunov exponent of a given time series using the
    nolds library.

    Parameters:
    - time_series: A 1D numpy array representing the time series data.

    Returns:
    - lyapunov_exponent: The estimated largest Lyapunov exponent.
    """
    lyapunov_exponent = nolds.lyap_r(time_series, emb_dim=emb_dim, lag=lag, tau=tau)
    return lyapunov_exponent

get_continuous_spiking_rate(spikes, dt, sigma=20.0)

Convert discrete spike trains into continuous firing rates using Gaussian smoothing.

Parameters:

Name Type Description Default
spikes ndarray or Tensor

Spike matrix of shape (time_steps,

required
dt float

Simulation time step in ms.

required
sigma float

Standard deviation of the Gaussian kernel in ms. Default 20ms.

20.0

Returns:

Type Description

np.ndarray: Continuous firing rate traces of shape (time_steps, n_neurons).

Source code in btorch/analysis/dynamic_tools/lyapunov_dynamics.py
def get_continuous_spiking_rate(spikes, dt, sigma=20.0):
    """Convert discrete spike trains into continuous firing rates using
    Gaussian smoothing.

    Args:
        spikes (np.ndarray or torch.Tensor): Spike matrix of shape (time_steps,
        n_neurons).
        dt (float): Simulation time step in ms.
        sigma (float): Standard deviation of the Gaussian kernel in ms. Default 20ms.

    Returns:
        np.ndarray: Continuous firing rate traces of shape (time_steps, n_neurons).
    """
    if isinstance(spikes, torch.Tensor):
        spikes = spikes.detach().cpu().numpy()

    # Convert sigma from ms to bins
    sigma_bins = sigma / dt

    # Apply Gaussian filter along the time axis (axis 0)
    rates = gaussian_filter1d(spikes.astype(float), sigma=sigma_bins, axis=0)

    return rates

btorch.analysis.dynamic_tools.micro_scale

Functions

calculate_cv_isi(spikes, dt=1.0)

计算群体中每个神经元的CV_ISI,并统计其分布特征。

Parameters:

Name Type Description Default
spikes

(Time, Neurons) 脉冲矩阵

required
dt

仿真步长(ms)

1.0

Returns: dict: {'cv_isi': array, 'mean': float}

Source code in btorch/analysis/dynamic_tools/micro_scale.py
def calculate_cv_isi(spikes, dt=1.0):
    """计算群体中每个神经元的CV_ISI,并统计其分布特征。

    Args:
        spikes: (Time, Neurons) 脉冲矩阵
        dt: 仿真步长(ms)
    Returns:
        dict: {'cv_isi': array, 'mean': float}
    """
    if isinstance(spikes, torch.Tensor):
        spikes = spikes.detach().cpu().numpy()

    num_neurons = spikes.shape[1]
    cv_isi_list = []

    for n in range(num_neurons):
        spike_times = np.where(spikes[:, n] > 0)[0] * dt  # 转换为ms
        if len(spike_times) < 2:
            cv_isi_list.append(np.nan)  # 不足两个脉冲,无法计算ISI
            continue

        isis = np.diff(spike_times)  # 计算ISI
        if np.mean(isis) == 0:
            cv_isi_list.append(np.nan)
            continue

        cv_isi = np.std(isis) / np.mean(isis)
        cv_isi_list.append(cv_isi)

    cv_isi_array = np.array(cv_isi_list)
    mean_cv_isi = np.nanmean(cv_isi_array)  # 忽略NaN值计算均值

    return {
        "cv_isi": cv_isi_array,  # 每个神经元的CV_ISI分布
        "mean": mean_cv_isi,  # 均值
    }

calculate_fr_distribution(spikes, dt=1.0)

计算群体中每个时刻的平均发放率,并统计其分布特征。

Parameters:

Name Type Description Default
spikes

(Time, Neurons) 脉冲矩阵

required
dt

仿真步长(ms)

1.0

Returns: dict: {'rates': array, 'mean': float, 'skew': float, 'kurt': float}

Source code in btorch/analysis/dynamic_tools/micro_scale.py
def calculate_fr_distribution(spikes, dt=1.0):
    """计算群体中每个时刻的平均发放率,并统计其分布特征。

    Args:
        spikes: (Time, Neurons) 脉冲矩阵
        dt: 仿真步长(ms)
    Returns:
        dict: {'rates': array, 'mean': float, 'skew': float, 'kurt': float}
    """
    if isinstance(spikes, torch.Tensor):
        spikes = spikes.detach().cpu().numpy()

    # 计算每个时间点的平均发放率 (Hz)
    window_size = 5
    kernel = np.ones(window_size) / (window_size * dt / 1000.0)  # 转换为Hz
    pop_spikes = spikes.mean(axis=1)  # (T,)
    rates = np.convolve(pop_spikes, kernel, mode="same")

    return {
        "rates": rates,  # 每个神经元的发放率分布
        "mean": np.mean(rates),  # 均值
        "skew": skew(rates),  # 偏度
        "kurt": kurtosis(rates),  # 峰度
    }

calculate_spike_distance(spikes, dt=1.0, subset_size=100, seed=None)

计算 SPIKE-distance (Kreuz et al., 2013)。

衡量脉冲序列之间的不同步程度。0表示完全同步。

Parameters:

Name Type Description Default
spikes

(Time, Neurons) 脉冲矩阵

required
dt

仿真步长(ms)

1.0
subset_size

随机抽样的神经元数量,用于计算成对距离

100

Returns: float: 平均 SPIKE-distance

Source code in btorch/analysis/dynamic_tools/micro_scale.py
def calculate_spike_distance(spikes, dt=1.0, subset_size=100, seed=None):
    """计算 SPIKE-distance (Kreuz et al., 2013)。

    衡量脉冲序列之间的不同步程度。0表示完全同步。

    Args:
        spikes: (Time, Neurons) 脉冲矩阵
        dt: 仿真步长(ms)
        subset_size: 随机抽样的神经元数量,用于计算成对距离
    Returns:
        float: 平均 SPIKE-distance
    """
    if isinstance(spikes, torch.Tensor):
        spikes = spikes.detach().cpu().numpy()

    T_steps, N = spikes.shape
    times = np.arange(T_steps) * dt

    # 随机抽样
    if N > subset_size:
        if seed is not None:
            np.random.seed(seed)
        indices = np.random.choice(N, subset_size, replace=False)
        selected_spikes = spikes[:, indices]
        N_subset = subset_size
    else:
        selected_spikes = spikes
        N_subset = N

    # 预计算每个神经元的 t_prev, t_next, isi
    # shape: (N_subset, T_steps)
    t_prev = np.zeros((N_subset, T_steps))
    t_next = np.zeros((N_subset, T_steps))
    isi = np.zeros((N_subset, T_steps))

    for n in range(N_subset):
        spike_indices = np.where(selected_spikes[:, n] > 0)[0]
        spike_times = spike_indices * dt

        if len(spike_times) == 0:
            # 处理无脉冲情况:设为无穷大或整个区间
            t_prev[n, :] = 0
            t_next[n, :] = times[-1]
            isi[n, :] = times[-1]
            continue

        # 使用 searchsorted 找到每个时间点的前后脉冲
        # indices_next 指向 times 中每个 t 之后的第一个脉冲在 spike_times 中的位置
        indices_next = np.searchsorted(spike_times, times)

        # 处理边界
        indices_next = np.clip(indices_next, 0, len(spike_times) - 1)
        # indices_prev = np.clip(indices_next - 1, 0, len(spike_times) - 1)

        # 修正 searchsorted 的结果,确保 t_prev <= t <= t_next
        # 对于 t 正好在 spike_time 上的情况,searchsorted 可能返回当前或下一个
        # 这里我们简单处理:
        # t_next[t] 是 >= t 的第一个脉冲
        # t_prev[t] 是 <= t 的最后一个脉冲

        # 更精确的做法:
        # t_prev: max(s | s <= t)
        # t_next: min(s | s > t)  (SPIKE-distance 定义通常要求严格大于,或者 >=)

        # 重新实现简单的循环填充(虽然慢一点但准确)或者利用 searchsorted 的性质
        # 实际上,对于 step function,可以用 diff 填充

        # 快速填充法:
        # t_prev
        curr_spike = 0.0
        spike_idx = 0
        for t_idx, t in enumerate(times):
            if spike_idx < len(spike_times) and t >= spike_times[spike_idx]:
                curr_spike = spike_times[spike_idx]
                # 如果不是最后一个脉冲,检查是否到了下一个
                if spike_idx < len(spike_times) - 1 and t >= spike_times[spike_idx + 1]:
                    spike_idx += 1
                    curr_spike = spike_times[spike_idx]
            t_prev[n, t_idx] = curr_spike

        # t_next
        curr_spike = times[-1]
        spike_idx = len(spike_times) - 1
        for t_idx in range(T_steps - 1, -1, -1):
            t = times[t_idx]
            if spike_idx >= 0 and t <= spike_times[spike_idx]:
                curr_spike = spike_times[spike_idx]
                if spike_idx > 0 and t <= spike_times[spike_idx - 1]:
                    spike_idx -= 1
                    curr_spike = spike_times[spike_idx]
            t_next[n, t_idx] = curr_spike

        isi[n, :] = t_next[n, :] - t_prev[n, :]

        isi[n, isi[n, :] == 0] = dt

    # 计算成对 SPIKE-distance S(t) = ( |dt_p1 - dt_p2| * isi2 + |dt_f1 - dt_f2|
    # * isi1 ) / ( 0.5 * (isi1 + isi2)**2 )

    dt_p = times[None, :] - t_prev  # (N, T)
    dt_f = t_next - times[None, :]  # (N, T)

    pairwise_distances = []

    # 随机选取若干对进行计算,或者计算所有对如果 N_subset 很大,计算所有对可能较
    # 慢。这里 N_subset 默认为 50, 50*49/2 = 1225,可以接受。
    for i in range(N_subset):
        for j in range(i + 1, N_subset):
            isi1 = isi[i]
            isi2 = isi[j]

            avg_isi_sq = 0.5 * (isi1 + isi2) ** 2
            # 避免除以零
            avg_isi_sq[avg_isi_sq == 0] = 1.0

            term1 = np.abs(dt_p[i] - dt_p[j]) * isi2
            term2 = np.abs(dt_f[i] - dt_f[j]) * isi1

            s_t = (term1 + term2) / avg_isi_sq

            # 时间积分(平均)
            dist = np.mean(s_t)
            pairwise_distances.append(dist)

    if not pairwise_distances:
        return 0.0

    return np.mean(pairwise_distances)

btorch.analysis.dynamic_tools.spiking

Advanced Fano factor methods for spike train analysis.

This module provides methods to compensate for firing rate effects on the Fano factor, implementing:

  1. Operational Time Method: Transforms spike trains to a rate-independent reference frame by evaluating Fano factor at normalized rate (λ=1). Reference: Rajdl et al. (2020), Front. Comput. Neurosci.

  2. Mean Matching Method: Selects data points where mean spike counts are matched across conditions before computing FF. Reference: Churchland et al. (2010), Nature Neurosci.

  3. Model-Based Approaches:

  4. Modulated Poisson with multiplicative noise (Goris et al., 2014)
  5. Flexible overdispersion model (Charles et al., 2018)

All methods support both NumPy and PyTorch inputs following the btorch conventions, with GPU acceleration where applicable.

Functions

_compute_mean_matching_weights_numpy(means, n_bins=10)

Compute mean-matching weights for each condition/time point.

The mean matching method ensures that the distribution of mean counts
is matched across conditions/times by computing weights that equalize

the histograms.

Args:
    means: Array of mean spike counts [n_conditions, n_neurons]
    n_bins: Number of bins for mean count histogram

Returns:
    Dictionary with weights and matching info
Source code in btorch/analysis/dynamic_tools/spiking.py
def _compute_mean_matching_weights_numpy(
    means: np.ndarray,
    n_bins: int = 10,
) -> dict[str, np.ndarray]:
    """Compute mean-matching weights for each condition/time point.

        The mean matching method ensures that the distribution of mean counts
        is matched across conditions/times by computing weights that equalize
    the histograms.

        Args:
            means: Array of mean spike counts [n_conditions, n_neurons]
            n_bins: Number of bins for mean count histogram

        Returns:
            Dictionary with weights and matching info
    """
    n_conditions, n_neurons = means.shape

    # Compute histogram for each condition
    mean_min, mean_max = np.nanmin(means), np.nanmax(means)
    if mean_min == mean_max:
        # All means are the same
        return {
            "weights": np.ones_like(means),
            "bin_edges": np.array([mean_min, mean_max + 1]),
            "counts_per_condition": np.full((n_conditions, 1), n_neurons),
            "min_counts_per_bin": np.full(1, n_neurons),
        }

    bin_edges = np.linspace(mean_min, mean_max, n_bins + 1)

    # Count per bin for each condition
    counts_per_condition = np.zeros((n_conditions, n_bins))
    for c in range(n_conditions):
        valid_means = means[c, ~np.isnan(means[c])]
        if len(valid_means) > 0:
            counts_per_condition[c], _ = np.histogram(valid_means, bins=bin_edges)

    # Find minimum count per bin (greatest common distribution)
    min_counts_per_bin = np.min(counts_per_condition, axis=0)

    # Compute weights for each data point
    weights = np.ones_like(means)

    for c in range(n_conditions):
        for b in range(n_bins):
            mask = (means[c] >= bin_edges[b]) & (means[c] < bin_edges[b + 1])
            if b == n_bins - 1:  # Include right edge for last bin
                mask = mask | (means[c] == bin_edges[b + 1])

            n_in_bin = np.sum(mask)
            target_n = min_counts_per_bin[b]

            if n_in_bin > 0 and target_n < n_in_bin:
                weights[c, mask] = target_n / n_in_bin
            elif n_in_bin > 0:
                weights[c, mask] = 1.0
            else:
                weights[c, mask] = 0.0

    return {
        "weights": weights,
        "bin_edges": bin_edges,
        "counts_per_condition": counts_per_condition,
        "min_counts_per_bin": min_counts_per_bin,
    }

_compute_mean_matching_weights_torch(means, n_bins=10)

Torch implementation of mean matching weights.

Source code in btorch/analysis/dynamic_tools/spiking.py
def _compute_mean_matching_weights_torch(
    means: torch.Tensor,
    n_bins: int = 10,
) -> dict[str, torch.Tensor]:
    """Torch implementation of mean matching weights."""
    device = means.device
    n_conditions, n_neurons = means.shape

    mean_min, mean_max = torch.min(means), torch.max(means)
    if mean_min.item() == mean_max.item():
        return {
            "weights": torch.ones_like(means),
            "bin_edges": torch.tensor(
                [mean_min.item(), mean_max.item() + 1], device=device
            ),
            "counts_per_condition": torch.full(
                (n_conditions, 1), n_neurons, device=device
            ),
            "min_counts_per_bin": torch.full((1,), n_neurons, device=device),
        }

    bin_edges = torch.linspace(
        mean_min.item(), mean_max.item(), n_bins + 1, device=device
    )

    # Compute histogram counts
    counts_per_condition = torch.zeros((n_conditions, n_bins), device=device)
    for c in range(n_conditions):
        for b in range(n_bins):
            mask = (means[c] >= bin_edges[b]) & (means[c] < bin_edges[b + 1])
            if b == n_bins - 1:
                mask = mask | (means[c] == bin_edges[b + 1])
            counts_per_condition[c, b] = torch.sum(mask).float()

    min_counts_per_bin = torch.min(counts_per_condition, dim=0).values

    # Compute weights
    weights = torch.ones_like(means)
    for c in range(n_conditions):
        for b in range(n_bins):
            mask = (means[c] >= bin_edges[b]) & (means[c] < bin_edges[b + 1])
            if b == n_bins - 1:
                mask = mask | (means[c] == bin_edges[b + 1])

            n_in_bin = torch.sum(mask).float()
            target_n = min_counts_per_bin[b]

            if n_in_bin > 0 and target_n < n_in_bin:
                weights[c, mask] = target_n / n_in_bin
            elif n_in_bin > 0:
                weights[c, mask] = 1.0
            else:
                weights[c, mask] = 0.0

    return {
        "weights": weights,
        "bin_edges": bin_edges,
        "counts_per_condition": counts_per_condition,
        "min_counts_per_bin": min_counts_per_bin,
    }

_compute_operational_fano_numpy(spike_data, rate_hz, window_op, dt_ms, batch_axis)

Compute Fano factor in operational time for NumPy arrays.

The operational time approach rescales time by the firing rate: w = λ * t, where λ is the firing rate in Hz.

For a renewal process: F⁽ᵒ⁾ = CV² (independent of rate).

Implementation: We normalize spike counts by rate, effectively computing statistics as if rate = 1 everywhere.

Parameters:

Name Type Description Default
spike_data ndarray

Spike train [T, ...]

required
rate_hz ndarray

Estimated rate in Hz [T, ...] or scalar

required
window_op float

Window size in operational time units

required
dt_ms float

Time step in milliseconds

required
batch_axis tuple[int, ...] | None

Axes to aggregate across

required

Returns:

Name Type Description
fano_op ndarray

Operational time Fano factor

info dict

Computation info

Source code in btorch/analysis/dynamic_tools/spiking.py
def _compute_operational_fano_numpy(
    spike_data: np.ndarray,
    rate_hz: np.ndarray,
    window_op: float,
    dt_ms: float,
    batch_axis: tuple[int, ...] | None,
) -> tuple[np.ndarray, dict]:
    """Compute Fano factor in operational time for NumPy arrays.

    The operational time approach rescales time by the firing rate:
    w = λ * t, where λ is the firing rate in Hz.

    For a renewal process: F⁽ᵒ⁾ = CV² (independent of rate).

    Implementation: We normalize spike counts by rate, effectively
    computing statistics as if rate = 1 everywhere.

    Args:
        spike_data: Spike train [T, ...]
        rate_hz: Estimated rate in Hz [T, ...] or scalar
        window_op: Window size in operational time units
        dt_ms: Time step in milliseconds
        batch_axis: Axes to aggregate across

    Returns:
        fano_op: Operational time Fano factor
        info: Computation info
    """
    T = spike_data.shape[0]
    rest_shape = spike_data.shape[1:]
    n_elements = np.prod(rest_shape) if rest_shape else 1

    # Flatten for processing
    flat_spikes = (
        spike_data.reshape(T, -1) if n_elements > 0 else spike_data.reshape(T, 1)
    )
    flat_rate = (
        rate_hz.reshape(T, -1)
        if isinstance(rate_hz, np.ndarray) and rate_hz.ndim > 0
        else rate_hz
    )

    n_neurons = flat_spikes.shape[1]

    # Compute operational time window in original time bins
    # w = λ * t, so for window W in operational time:
    # number of original bins ≈ W / (λ * dt)
    # Use median rate for window calculation
    if isinstance(flat_rate, np.ndarray) and flat_rate.ndim > 1:
        median_rate = np.median(flat_rate, axis=0)
    else:
        median_rate = np.full(
            n_neurons, flat_rate if np.isscalar(flat_rate) else np.median(flat_rate)
        )

    # Compute normalized spike counts (operational time counts)
    # The key insight: for rate λ, the expected count in window w is λ*w
    # Normalizing by λ gives counts as if rate = 1

    # Method: Compute counts in windows, normalize by expected count
    # Then compute variance/mean of normalized counts

    # For simplicity, use fixed window in original time and normalize
    window_bins = max(5, int(window_op / np.median(median_rate) / (dt_ms / 1000.0)))
    window_bins = min(window_bins, T // 3)  # Ensure reasonable size

    if window_bins < 2:
        window_bins = 2

    step = max(1, window_bins // 2)
    n_windows = (T - window_bins) // step + 1

    if n_windows < 2:
        # Not enough windows, return NaN
        return np.full(rest_shape if rest_shape else (1,), np.nan), {
            "error": "insufficient data"
        }

    # Compute counts per window for each neuron
    counts = np.zeros((n_windows, n_neurons))
    for w in range(n_windows):
        start = w * step
        end = start + window_bins
        counts[w] = flat_spikes[start:end].sum(axis=0)

    # Compute expected counts (rate * window duration)
    # For operational time: normalize counts by rate
    if isinstance(flat_rate, np.ndarray) and flat_rate.ndim > 1:
        mean_rate = np.mean(flat_rate, axis=0)  # [n_neurons]
    else:
        mean_rate = np.full(
            n_neurons, flat_rate if np.isscalar(flat_rate) else np.mean(flat_rate)
        )

    window_duration_s = window_bins * dt_ms / 1000.0
    expected_counts = mean_rate * window_duration_s  # [n_neurons]

    # Normalized counts (as if rate = 1)
    # For Poisson: normalized counts should have mean = window_duration
    # and variance = window_duration, giving Fano = 1
    normalized_counts = counts / (expected_counts[None, :] + 1e-12)

    # Compute mean and variance of normalized counts
    mean_norm = np.mean(normalized_counts, axis=0)
    var_norm = np.var(normalized_counts, axis=0, ddof=1)

    # Operational Fano factor
    fano_op = np.zeros(n_neurons)
    valid = (mean_norm > 0) & np.isfinite(var_norm)
    fano_op[valid] = var_norm[valid] / mean_norm[valid]

    # Reshape to original structure
    if rest_shape:
        fano_op = fano_op.reshape(rest_shape)
    else:
        fano_op = fano_op[0] if n_neurons == 1 else fano_op

    # Apply batch axis aggregation if requested
    if batch_axis is not None:
        fano_op = np.mean(fano_op, axis=tuple(batch_axis))

    info = {
        "method": "operational_time",
        "window_bins": window_bins,
        "n_windows": n_windows,
        "mean_rate_hz": np.mean(mean_rate),
        "window_duration_s": window_duration_s,
    }

    return fano_op, info

_compute_operational_fano_torch(spike_data, rate_hz, window_op, dt_ms, batch_axis)

Compute Fano factor in operational time for Torch tensors.

Source code in btorch/analysis/dynamic_tools/spiking.py
def _compute_operational_fano_torch(
    spike_data: torch.Tensor,
    rate_hz: torch.Tensor | float,
    window_op: float,
    dt_ms: float,
    batch_axis: tuple[int, ...] | None,
) -> tuple[torch.Tensor, dict]:
    """Compute Fano factor in operational time for Torch tensors."""
    device = spike_data.device
    T = spike_data.shape[0]
    rest_shape = spike_data.shape[1:]
    n_elements = int(np.prod(rest_shape)) if rest_shape else 1

    # Flatten for processing
    flat_spikes = (
        spike_data.reshape(T, -1) if n_elements > 0 else spike_data.reshape(T, 1)
    )

    if isinstance(rate_hz, torch.Tensor):
        flat_rate = rate_hz.reshape(T, -1) if rate_hz.numel() > 1 else rate_hz.item()
    else:
        flat_rate = rate_hz

    n_neurons = flat_spikes.shape[1]

    # Compute median rate
    if isinstance(flat_rate, torch.Tensor) and flat_rate.ndim > 1:
        median_rate = torch.median(flat_rate, dim=0).values
    else:
        scalar_rate = (
            flat_rate if isinstance(flat_rate, (int, float)) else flat_rate.item()
        )
        median_rate = torch.full((n_neurons,), scalar_rate, device=device)

    # Compute window size
    median_rate_val = torch.median(median_rate).item()
    window_bins = max(5, int(window_op / median_rate_val / (dt_ms / 1000.0)))
    window_bins = min(window_bins, T // 3)

    if window_bins < 2:
        window_bins = 2

    step = max(1, window_bins // 2)
    n_windows = (T - window_bins) // step + 1

    if n_windows < 2:
        return torch.full(
            rest_shape if rest_shape else (1,), float("nan"), device=device
        ), {"error": "insufficient data"}

    # Compute counts per window
    counts = torch.zeros((n_windows, n_neurons), device=device)
    for w in range(n_windows):
        start = w * step
        end = start + window_bins
        counts[w] = flat_spikes[start:end].sum(dim=0)

    # Compute expected counts
    if isinstance(flat_rate, torch.Tensor) and flat_rate.ndim > 1:
        mean_rate = torch.mean(flat_rate, dim=0)
    else:
        scalar_rate = (
            flat_rate if isinstance(flat_rate, (int, float)) else flat_rate.item()
        )
        mean_rate = torch.full((n_neurons,), scalar_rate, device=device)

    window_duration_s = window_bins * dt_ms / 1000.0
    expected_counts = mean_rate * window_duration_s

    # Normalized counts
    normalized_counts = counts / (expected_counts.unsqueeze(0) + 1e-12)

    # Compute statistics
    mean_norm = torch.mean(normalized_counts, dim=0)
    var_norm = torch.var(normalized_counts, dim=0, unbiased=True)

    # Operational Fano factor
    fano_op = torch.zeros(n_neurons, device=device)
    valid = (mean_norm > 0) & torch.isfinite(var_norm)
    fano_op[valid] = var_norm[valid] / mean_norm[valid]

    # Reshape
    if rest_shape:
        fano_op = fano_op.reshape(rest_shape)
    elif n_neurons == 1:
        fano_op = fano_op[0]

    # Batch aggregation
    if batch_axis is not None:
        fano_op = torch.mean(fano_op, dim=tuple(batch_axis))

    info = {
        "method": "operational_time",
        "window_bins": window_bins,
        "n_windows": n_windows,
        "mean_rate_hz": torch.mean(mean_rate).item(),
        "window_duration_s": window_duration_s,
    }

    return fano_op, info

_compute_weighted_fano_numpy(spike_counts, weights)

Compute weighted Fano factor.

Parameters:

Name Type Description Default
spike_counts ndarray

Spike counts [n_trials, n_neurons] or [n_conditions, n_neurons]

required
weights ndarray

Weights for each observation

required

Returns:

Type Description
ndarray

Weighted Fano factor per neuron

Source code in btorch/analysis/dynamic_tools/spiking.py
def _compute_weighted_fano_numpy(
    spike_counts: np.ndarray,
    weights: np.ndarray,
) -> np.ndarray:
    """Compute weighted Fano factor.

    Args:
        spike_counts: Spike counts [n_trials, n_neurons] or [n_conditions, n_neurons]
        weights: Weights for each observation

    Returns:
        Weighted Fano factor per neuron
    """
    # Handle NaN values
    valid_mask = ~np.isnan(spike_counts) & ~np.isnan(weights)

    # Compute weighted mean
    weights_normalized = np.where(
        valid_mask, weights / (np.nansum(weights, axis=0, keepdims=True) + 1e-12), 0
    )

    weighted_mean = np.nansum(spike_counts * weights_normalized, axis=0)

    # Weighted variance
    V1 = np.nansum(weights, axis=0)
    V2 = np.nansum(weights**2, axis=0)

    # Avoid division by zero
    denom = V1 - V2 / (V1 + 1e-12) + 1e-12

    weighted_var = (
        np.nansum(weights * (spike_counts - weighted_mean[None, :]) ** 2, axis=0)
        / denom
    )

    # Fano factor
    fano = np.full_like(weighted_mean, np.nan)
    valid = (weighted_mean > 0) & (weighted_var >= 0)
    fano[valid] = weighted_var[valid] / weighted_mean[valid]

    return fano

_compute_weighted_fano_torch(spike_counts, weights)

Torch implementation of weighted Fano factor.

Source code in btorch/analysis/dynamic_tools/spiking.py
def _compute_weighted_fano_torch(
    spike_counts: torch.Tensor,
    weights: torch.Tensor,
) -> torch.Tensor:
    """Torch implementation of weighted Fano factor."""
    valid_mask = ~torch.isnan(spike_counts) & ~torch.isnan(weights)

    # Normalize weights
    weight_sums = torch.sum(weights * valid_mask.float(), dim=0, keepdim=True)
    weights_normalized = torch.where(
        valid_mask, weights / (weight_sums + 1e-12), torch.zeros_like(weights)
    )

    # Weighted mean
    weighted_mean = torch.sum(spike_counts * weights_normalized, dim=0)

    # Weighted variance
    V1 = torch.sum(weights * valid_mask.float(), dim=0)
    V2 = torch.sum((weights**2) * valid_mask.float(), dim=0)

    denom = V1 - V2 / (V1 + 1e-12) + 1e-12

    weighted_var = (
        torch.sum(weights * (spike_counts - weighted_mean.unsqueeze(0)) ** 2, dim=0)
        / denom
    )

    # Fano factor
    fano = torch.full_like(weighted_mean, float("nan"))
    valid = (weighted_mean > 0) & (weighted_var >= 0)
    fano[valid] = weighted_var[valid] / weighted_mean[valid]

    return fano

_estimate_rate_numpy(spike_data, dt_ms, window_ms=None)

Estimate instantaneous firing rate using sliding window.

Parameters:

Name Type Description Default
spike_data ndarray

Spike train of shape [T, ...]. First dimension is time.

required
dt_ms float

Time step in milliseconds.

required
window_ms float | None

Window size for rate estimation. If None, uses T//20 * dt_ms.

None

Returns:

Type Description
ndarray

Estimated rate in Hz [T, ...].

Source code in btorch/analysis/dynamic_tools/spiking.py
def _estimate_rate_numpy(
    spike_data: np.ndarray,
    dt_ms: float,
    window_ms: float | None = None,
) -> np.ndarray:
    """Estimate instantaneous firing rate using sliding window.

    Args:
        spike_data: Spike train of shape [T, ...]. First dimension is time.
        dt_ms: Time step in milliseconds.
        window_ms: Window size for rate estimation. If None, uses T//20 * dt_ms.

    Returns:
        Estimated rate in Hz [T, ...].
    """
    T = spike_data.shape[0]

    if window_ms is None:
        window_bins = max(10, T // 20)
    else:
        window_bins = max(1, int(window_ms / dt_ms))

    # Simple moving average for rate estimation
    kernel = np.ones(window_bins) / (window_bins * dt_ms / 1000.0)

    # Convolve along time axis for each neuron independently
    rate = np.apply_along_axis(
        lambda x: np.convolve(x, kernel, mode="same"),
        axis=0,
        arr=spike_data,
    )

    # Ensure minimum rate to avoid division by zero
    return np.maximum(rate, 1e-6)

_estimate_rate_torch(spike_data, dt_ms, window_ms=None)

Torch implementation of rate estimation.

Source code in btorch/analysis/dynamic_tools/spiking.py
def _estimate_rate_torch(
    spike_data: torch.Tensor,
    dt_ms: float,
    window_ms: float | None = None,
) -> torch.Tensor:
    """Torch implementation of rate estimation."""
    T = spike_data.shape[0]

    if window_ms is None:
        window_bins = max(10, T // 20)
    else:
        window_bins = max(1, int(window_ms / dt_ms))

    # Compute moving average using cumsum approach
    # Pad with edge values to handle boundaries
    pad = window_bins // 2

    # Handle different dimensions
    original_shape = spike_data.shape
    if spike_data.ndim == 1:
        spike_data = spike_data.unsqueeze(-1)  # [T, 1]

    # Pad using constant mode (replicate edge values)
    spike_padded = torch.nn.functional.pad(
        spike_data, (0, 0, pad, pad), mode="constant"
    )
    # Fill padding with edge values (replicate)
    if pad > 0:
        spike_padded[:pad] = spike_data[0]
        spike_padded[-pad:] = spike_data[-1]

    # Compute moving sum using cumsum
    cumsum = torch.cumsum(spike_padded, dim=0)
    moving_sum = cumsum[window_bins:] - cumsum[:-window_bins]

    # Extract center portion matching original size
    start_idx = window_bins // 2
    end_idx = start_idx + T
    moving_sum = moving_sum[start_idx:end_idx]

    # Ensure correct shape
    if moving_sum.shape[0] < T:
        # Pad if needed
        pad_right = T - moving_sum.shape[0]
        moving_sum = torch.nn.functional.pad(moving_sum, (0, 0, 0, pad_right))

    # Convert to Hz
    rate = moving_sum / (window_bins * dt_ms / 1000.0)

    # Restore original shape
    rate = rate.reshape(original_shape)

    return torch.clamp(rate, min=1e-6)

_flexible_overdispersion_moments(stimulus_drive, noise_std=0.5, nonlinearity='relu')

Compute mean and variance for flexible overdispersion model.

The flexible overdispersion model (Charles et al., 2018): λ_eff = f(g(x) + ε) where f is a nonlinearity.

Different nonlinearities produce different mean-FF relationships: - Rectified-linear: FF decreases with increasing rate - Rectified-squaring: FF ≈ constant - Exponential: FF increases with rate

Parameters:

Name Type Description Default
stimulus_drive ndarray | Tensor

Stimulus-dependent drive g(x)

required
noise_std float

Standard deviation of additive Gaussian noise ε

0.5
nonlinearity Literal['relu', 'square', 'exp', 'softplus']

Type of nonlinearity to apply

'relu'

Returns:

Name Type Description
mean ndarray | Tensor

Expected spike count

variance ndarray | Tensor

Spike count variance

Source code in btorch/analysis/dynamic_tools/spiking.py
def _flexible_overdispersion_moments(
    stimulus_drive: np.ndarray | torch.Tensor,
    noise_std: float = 0.5,
    nonlinearity: Literal["relu", "square", "exp", "softplus"] = "relu",
) -> tuple[np.ndarray | torch.Tensor, np.ndarray | torch.Tensor]:
    """Compute mean and variance for flexible overdispersion model.

    The flexible overdispersion model (Charles et al., 2018):
    λ_eff = f(g(x) + ε) where f is a nonlinearity.

    Different nonlinearities produce different mean-FF relationships:
    - Rectified-linear: FF decreases with increasing rate
    - Rectified-squaring: FF ≈ constant
    - Exponential: FF increases with rate

    Args:
        stimulus_drive: Stimulus-dependent drive g(x)
        noise_std: Standard deviation of additive Gaussian noise ε
        nonlinearity: Type of nonlinearity to apply

    Returns:
        mean: Expected spike count
        variance: Spike count variance
    """
    is_torch = isinstance(stimulus_drive, torch.Tensor)

    if nonlinearity == "relu":
        # Rectified linear: f(z) = max(0, z)
        # Mean requires Gaussian integral
        if is_torch:
            # Approximation for rectified Gaussian
            phi = lambda x: 0.5 * (1 + torch.erf(x / np.sqrt(2)))
            mean = stimulus_drive * phi(stimulus_drive / noise_std)
            mean += (
                noise_std
                * torch.exp(-0.5 * (stimulus_drive / noise_std) ** 2)
                / np.sqrt(2 * np.pi)
            )
        else:
            from scipy.stats import norm

            mean = stimulus_drive * norm.cdf(stimulus_drive / noise_std)
            mean += noise_std * norm.pdf(stimulus_drive / noise_std)

    elif nonlinearity == "square":
        # Rectified squaring: f(z) = max(0, z)²
        # E[f(z)] where z ~ N(g, σ²)
        if is_torch:
            phi = lambda x: 0.5 * (1 + torch.erf(x / np.sqrt(2)))
            mean = (stimulus_drive**2 + noise_std**2) * phi(stimulus_drive / noise_std)
            mean += (
                stimulus_drive
                * noise_std
                * torch.exp(-0.5 * (stimulus_drive / noise_std) ** 2)
                / np.sqrt(2 * np.pi)
            )
        else:
            from scipy.stats import norm

            mean = (stimulus_drive**2 + noise_std**2) * norm.cdf(
                stimulus_drive / noise_std
            )
            mean += stimulus_drive * noise_std * norm.pdf(stimulus_drive / noise_std)

    elif nonlinearity == "exp":
        # Exponential: f(z) = exp(z)
        # E[exp(z)] = exp(g + σ²/2) for z ~ N(g, σ²)
        mean = (
            np.exp(stimulus_drive + 0.5 * noise_std**2)
            if not is_torch
            else torch.exp(stimulus_drive + 0.5 * noise_std**2)
        )

    elif nonlinearity == "softplus":
        # Softplus: f(z) = log(1 + exp(z))
        # Numerically stable computation
        if is_torch:
            mean = torch.nn.functional.softplus(stimulus_drive)
        else:
            mean = np.log(1 + np.exp(np.clip(stimulus_drive, -700, 700)))

    else:
        raise ValueError(f"Unknown nonlinearity: {nonlinearity}")

    # Variance approximation (quadratic for most nonlinearities)
    variance = mean + 0.5 * (mean**2)  # Simplified overdispersion

    return mean, variance

_modulated_poisson_moments(lambda_base, gain_mean=1.0, gain_var=0.5)

Compute mean and variance for modulated Poisson model.

The modulated Poisson model (Goris et al., 2014): r ~ Poisson(λ · g) where g is multiplicative gain noise.

Mean: E[r] = λ · E[g] Variance: Var[r] = E[g] · λ + Var[g] · λ²

This produces quadratic mean-variance relationship: Var[r] = E[r] + (Var[g]/E[g]²) · E[r]²

Parameters:

Name Type Description Default
lambda_base ndarray | Tensor

Base firing rate

required
gain_mean float

Mean of gain distribution (E[g])

1.0
gain_var float

Variance of gain distribution (Var[g])

0.5

Returns:

Name Type Description
mean ndarray | Tensor

Expected spike count

variance ndarray | Tensor

Spike count variance

Source code in btorch/analysis/dynamic_tools/spiking.py
def _modulated_poisson_moments(
    lambda_base: np.ndarray | torch.Tensor,
    gain_mean: float = 1.0,
    gain_var: float = 0.5,
) -> tuple[np.ndarray | torch.Tensor, np.ndarray | torch.Tensor]:
    """Compute mean and variance for modulated Poisson model.

    The modulated Poisson model (Goris et al., 2014):
    r ~ Poisson(λ · g) where g is multiplicative gain noise.

    Mean: E[r] = λ · E[g]
    Variance: Var[r] = E[g] · λ + Var[g] · λ²

    This produces quadratic mean-variance relationship:
    Var[r] = E[r] + (Var[g]/E[g]²) · E[r]²

    Args:
        lambda_base: Base firing rate
        gain_mean: Mean of gain distribution (E[g])
        gain_var: Variance of gain distribution (Var[g])

    Returns:
        mean: Expected spike count
        variance: Spike count variance
    """
    mean = lambda_base * gain_mean
    # Law of total variance: Var[r] = E[Var[r|g]] + Var[E[r|g]]
    # Var[r|g] = λ·g (Poisson), E[r|g] = λ·g
    # E[Var[r|g]] = λ·E[g]
    # Var[E[r|g]] = λ²·Var[g]
    variance = lambda_base * gain_mean + (lambda_base**2) * gain_var
    return mean, variance

compare_fano_methods(spike_data, dt_ms=1.0, **kwargs)

Compare different Fano factor compensation methods.

Computes Fano factor using multiple methods for comparison.

Parameters:

Name Type Description Default
spike_data ndarray | Tensor

Spike train of shape [T, ...].

required
dt_ms float

Time step in milliseconds.

1.0
**kwargs

Additional arguments passed to methods.

{}

Returns:

Type Description
dict

Dictionary with results from each method.

Source code in btorch/analysis/dynamic_tools/spiking.py
def compare_fano_methods(
    spike_data: np.ndarray | torch.Tensor,
    dt_ms: float = 1.0,
    **kwargs,
) -> dict:
    """Compare different Fano factor compensation methods.

    Computes Fano factor using multiple methods for comparison.

    Args:
        spike_data: Spike train of shape [T, ...].
        dt_ms: Time step in milliseconds.
        **kwargs: Additional arguments passed to methods.

    Returns:
        Dictionary with results from each method.
    """
    results = {}

    # Standard Fano factor (no compensation)
    from btorch.analysis.spiking import fano as standard_fano

    results["standard"], _ = standard_fano(
        spike_data,
        **{k: v for k, v in kwargs.items() if k in ["window", "overlap", "batch_axis"]},
    )

    # Operational time
    try:
        results["operational_time"], _ = fano_operational_time(
            spike_data,
            dt_ms=dt_ms,
            **{
                k: v
                for k, v in kwargs.items()
                if k not in ["window", "overlap", "batch_axis"]
            },
        )
    except Exception as e:
        results["operational_time"] = f"Error: {e}"

    # Mean matching
    try:
        results["mean_matching"], _ = fano_mean_matching(
            spike_data, **{k: v for k, v in kwargs.items() if k not in ["dt_ms"]}
        )
    except Exception as e:
        results["mean_matching"] = f"Error: {e}"

    # Model-based
    try:
        results["modulated_poisson"], _ = fano_model_based(
            spike_data,
            model="modulated_poisson",
            **{k: v for k, v in kwargs.items() if k not in ["dt_ms"]},
        )
    except Exception as e:
        results["modulated_poisson"] = f"Error: {e}"

    return results

fano_compensated(spike_data, method='operational_time', **kwargs)

Unified interface for compensated Fano factor computation.

This function provides a unified interface to all rate-compensation methods for the Fano factor. The choice of method depends on the experimental design and data characteristics:

  • operational_time: Best for comparing variability across different firing rates. Transforms to rate-independent reference frame.
  • mean_matching: Best for condition comparisons with overlapping rate distributions. Matches rate histograms before computing FF.
  • modulated_poisson: Model-based approach assuming multiplicative gain noise. Good for explaining overdispersion.
  • flexible_overdispersion: Model-based approach with flexible nonlinearities. Good for testing different rate-FF relationships.

Parameters:

Name Type Description Default
spike_data ndarray | Tensor

Spike train of shape [T, ...]. First dimension is time.

required
method Literal['operational_time', 'mean_matching', 'modulated_poisson', 'flexible_overdispersion']

Compensation method to use.

'operational_time'
**kwargs

Method-specific arguments passed to underlying functions.

{}

Returns:

Name Type Description
fano ndarray | Tensor

Compensated Fano factor values.

info dict

Dictionary with method info and computed statistics.

Example

ff_op, _ = fano_compensated(spikes, method="operational_time")

Mean matching (for condition comparisons)

ff_mm, _ = fano_compensated(spikes, method="mean_matching", n_bins=10)

Model-based approach

ff_mod, _ = fano_compensated( ... spikes, method="modulated_poisson", ... model_params={"gain_var": 0.5} ... )

Source code in btorch/analysis/dynamic_tools/spiking.py
@use_percentiles(value_key="fano")
@use_stats(value_key="fano")
def fano_compensated(
    spike_data: np.ndarray | torch.Tensor,
    method: Literal[
        "operational_time",
        "mean_matching",
        "modulated_poisson",
        "flexible_overdispersion",
    ] = "operational_time",
    **kwargs,
) -> tuple[np.ndarray | torch.Tensor, dict]:
    """Unified interface for compensated Fano factor computation.

    This function provides a unified interface to all rate-compensation
    methods for the Fano factor. The choice of method depends on the
    experimental design and data characteristics:

    - operational_time: Best for comparing variability across different
      firing rates. Transforms to rate-independent reference frame.
    - mean_matching: Best for condition comparisons with overlapping
      rate distributions. Matches rate histograms before computing FF.
    - modulated_poisson: Model-based approach assuming multiplicative
      gain noise. Good for explaining overdispersion.
    - flexible_overdispersion: Model-based approach with flexible
      nonlinearities. Good for testing different rate-FF relationships.

    Args:
        spike_data: Spike train of shape [T, ...]. First dimension is time.
        method: Compensation method to use.
        **kwargs: Method-specific arguments passed to underlying functions.

    Returns:
        fano: Compensated Fano factor values.
        info: Dictionary with method info and computed statistics.

    Example:
        >>> # Operational time method (recommended for rate comparisons)
        >>> ff_op, _ = fano_compensated(spikes, method="operational_time")
        >>>
        >>> # Mean matching (for condition comparisons)
        >>> ff_mm, _ = fano_compensated(spikes, method="mean_matching", n_bins=10)
        >>>
        >>> # Model-based approach
        >>> ff_mod, _ = fano_compensated(
        ...     spikes, method="modulated_poisson",
        ...     model_params={"gain_var": 0.5}
        ... )
    """
    if method == "operational_time":
        # Filter out model_params if accidentally passed
        kwargs.pop("model_params", None)
        return fano_operational_time(spike_data, **kwargs)
    elif method == "mean_matching":
        kwargs.pop("model_params", None)
        return fano_mean_matching(spike_data, **kwargs)
    elif method in ("modulated_poisson", "flexible_overdispersion"):
        return fano_model_based(spike_data, model=method, **kwargs)
    else:
        raise ValueError(f"Unknown method: {method}")

fano_mean_matching(spike_data, window=None, overlap=0, condition_axis=1, n_bins=10, n_resamples=50, batch_axis=None)

Compute mean-matched Fano factor controlling for rate effects.

The mean matching method (Churchland et al., 2010) ensures that the

distribution of mean spike counts is matched across conditions or time points before computing the Fano factor. This removes artifacts caused by rate changes.

Reference: Churchland et al. (2010) "Stimulus onset quenches neural
variability: a widespread cortical phenomenon", Nature Neurosci.

Args:
    spike_data: Spike train of shape [T, n_conditions, ...] or
        [T, n_trials, ...]. First dimension is time.
    window: Window size for spike counting. If None, uses T//10.
    overlap: Overlap between consecutive windows.
    condition_axis: Axis representing conditions/trials (default 1).
    n_bins: Number of bins for mean count histogram matching.
    n_resamples: Number of resampling iterations for stability.
    batch_axis: Additional axes to aggregate across.

Returns:
    fano_mm: Mean-matched Fano factor values.
    info: Dictionary with matching info and computed statistics.

Example:
    >>> # spike_data shape: [T, n_conditions, n_neurons]
    >>> ff_mm, info = fano_mean_matching(
    ...     spike_data, condition_axis=1, n_bins=10
    ... )
Source code in btorch/analysis/dynamic_tools/spiking.py
@use_percentiles(value_key="fano_mm")
@use_stats(value_key="fano_mm")
def fano_mean_matching(
    spike_data: np.ndarray | torch.Tensor,
    window: int | None = None,
    overlap: int = 0,
    condition_axis: int = 1,
    n_bins: int = 10,
    n_resamples: int = 50,
    batch_axis: tuple[int, ...] | None = None,
) -> tuple[np.ndarray | torch.Tensor, dict]:
    """Compute mean-matched Fano factor controlling for rate effects.

        The mean matching method (Churchland et al., 2010) ensures that the
    distribution of mean spike counts is matched across conditions or time
        points before computing the Fano factor. This removes artifacts caused
        by rate changes.

        Reference: Churchland et al. (2010) "Stimulus onset quenches neural
        variability: a widespread cortical phenomenon", Nature Neurosci.

        Args:
            spike_data: Spike train of shape [T, n_conditions, ...] or
                [T, n_trials, ...]. First dimension is time.
            window: Window size for spike counting. If None, uses T//10.
            overlap: Overlap between consecutive windows.
            condition_axis: Axis representing conditions/trials (default 1).
            n_bins: Number of bins for mean count histogram matching.
            n_resamples: Number of resampling iterations for stability.
            batch_axis: Additional axes to aggregate across.

        Returns:
            fano_mm: Mean-matched Fano factor values.
            info: Dictionary with matching info and computed statistics.

        Example:
            >>> # spike_data shape: [T, n_conditions, n_neurons]
            >>> ff_mm, info = fano_mean_matching(
            ...     spike_data, condition_axis=1, n_bins=10
            ... )
    """
    is_torch = isinstance(spike_data, torch.Tensor)
    T = spike_data.shape[0]

    if window is None:
        window = max(1, T // 10)

    step = window - overlap
    assert step > 0, "window must be greater than overlap"

    # Move condition axis to position 1 for processing
    if condition_axis != 1:
        perm = list(range(spike_data.ndim))
        perm[1], perm[condition_axis] = perm[condition_axis], perm[1]
        spike_data = (
            spike_data.transpose(*perm) if not is_torch else spike_data.permute(*perm)
        )

    n_conditions = spike_data.shape[1]

    # Flatten non-time, non-condition dimensions
    rest_shape = spike_data.shape[2:]
    spike_flat = spike_data.reshape(T, n_conditions, -1)
    n_neurons_flat = spike_flat.shape[2]

    # Compute counts per window
    n_windows = (T - window) // step + 1
    if n_windows < 2:
        if is_torch:
            result = torch.full(
                (n_neurons_flat,), float("nan"), device=spike_data.device
            )
        else:
            result = np.full((n_neurons_flat,), np.nan)
        return result.reshape(rest_shape) if rest_shape else result, {
            "error": "insufficient windows"
        }

    if is_torch:
        counts = torch.zeros(
            (n_windows, n_conditions, n_neurons_flat), device=spike_data.device
        )
    else:
        counts = np.zeros((n_windows, n_conditions, n_neurons_flat))

    for w in range(n_windows):
        start = w * step
        end = start + window
        window_spikes = spike_flat[start:end]  # [window, n_conditions, n_neurons]
        counts[w] = window_spikes.sum(dim=0) if is_torch else window_spikes.sum(axis=0)

    # Compute mean per condition (across windows)
    if is_torch:
        means = counts.mean(dim=0)  # [n_conditions, n_neurons]
    else:
        means = counts.mean(axis=0)  # [n_conditions, n_neurons]

    # Compute mean matching weights
    if is_torch:
        match_info = _compute_mean_matching_weights_torch(means, n_bins)
        weights = match_info["weights"]

        # Apply weighted Fano computation
        fano_mm = _compute_weighted_fano_torch(
            counts.reshape(-1, n_neurons_flat), weights.repeat(n_windows, 1)
        )
    else:
        match_info = _compute_mean_matching_weights_numpy(means, n_bins)
        weights = match_info["weights"]

        # Compute weighted Fano factor
        fano_mm = _compute_weighted_fano_numpy(
            counts.reshape(-1, n_neurons_flat),
            weights.repeat(n_windows, axis=0),
        )

    # Reshape to original non-time, non-condition dimensions
    fano_mm = fano_mm.reshape(rest_shape)

    # Apply batch axis aggregation if requested
    if batch_axis is not None:
        if is_torch:
            fano_mm = torch.mean(fano_mm, dim=tuple(batch_axis))
        else:
            fano_mm = np.mean(fano_mm, axis=tuple(batch_axis))

    info = {
        **match_info,
        "method": "mean_matching",
        "n_bins": n_bins,
        "n_resamples": n_resamples,
    }

    return fano_mm, info

fano_model_based(spike_data, window=None, overlap=0, model='modulated_poisson', model_params=None, stimulus_drive=None, batch_axis=None)

Compute model-based Fano factor with rate compensation.

Model-based approaches fit a generative model to the spike data and extract the underlying variability independent of rate effects.

Models: 1. Modulated Poisson (Goris et al., 2014): r ~ Poisson(λ · g) where g is multiplicative gain noise. Produces quadratic mean-variance relationship.

  1. Flexible Overdispersion (Charles et al., 2018): λ_eff = f(g(x) + ε) with different nonlinearities.
  2. Rectified-linear: FF decreases with rate
  3. Rectified-squaring: FF ≈ constant
  4. Exponential: FF increases with rate

References: - Goris et al. (2014) "Partitioning neuronal variability" - Charles et al. (2018) "Dethroning the Fano factor"

Parameters:

Name Type Description Default
spike_data ndarray | Tensor

Spike train of shape [T, ...]. First dimension is time.

required
window int | None

Window size for spike counting. If None, uses T//10.

None
overlap int

Overlap between consecutive windows.

0
model Literal['modulated_poisson', 'flexible_overdispersion']

Model type to use.

'modulated_poisson'
model_params dict | None

Model-specific parameters.

None
stimulus_drive ndarray | Tensor | None

Stimulus-dependent drive (required for flexible_overdispersion).

None
batch_axis tuple[int, ...] | None

Axes to average across for FF computation.

None

Returns:

Name Type Description
fano_model ndarray | Tensor

Model-based Fano factor values.

info dict

Dictionary with model fit info and computed statistics.

Example

ff_mod, info = fano_model_based( ... spike_data, ... model="modulated_poisson", ... model_params={"gain_mean": 1.0, "gain_var": 0.5} ... )

Source code in btorch/analysis/dynamic_tools/spiking.py
@use_percentiles(value_key="fano_model")
@use_stats(value_key="fano_model")
def fano_model_based(
    spike_data: np.ndarray | torch.Tensor,
    window: int | None = None,
    overlap: int = 0,
    model: Literal[
        "modulated_poisson", "flexible_overdispersion"
    ] = "modulated_poisson",
    model_params: dict | None = None,
    stimulus_drive: np.ndarray | torch.Tensor | None = None,
    batch_axis: tuple[int, ...] | None = None,
) -> tuple[np.ndarray | torch.Tensor, dict]:
    """Compute model-based Fano factor with rate compensation.

    Model-based approaches fit a generative model to the spike data and
    extract the underlying variability independent of rate effects.

    Models:
    1. Modulated Poisson (Goris et al., 2014):
       r ~ Poisson(λ · g) where g is multiplicative gain noise.
       Produces quadratic mean-variance relationship.

    2. Flexible Overdispersion (Charles et al., 2018):
       λ_eff = f(g(x) + ε) with different nonlinearities.
       - Rectified-linear: FF decreases with rate
       - Rectified-squaring: FF ≈ constant
       - Exponential: FF increases with rate

    References:
    - Goris et al. (2014) "Partitioning neuronal variability"
    - Charles et al. (2018) "Dethroning the Fano factor"

    Args:
        spike_data: Spike train of shape [T, ...]. First dimension is time.
        window: Window size for spike counting. If None, uses T//10.
        overlap: Overlap between consecutive windows.
        model: Model type to use.
        model_params: Model-specific parameters.
        stimulus_drive: Stimulus-dependent drive (required for flexible_overdispersion).
        batch_axis: Axes to average across for FF computation.

    Returns:
        fano_model: Model-based Fano factor values.
        info: Dictionary with model fit info and computed statistics.

    Example:
        >>> ff_mod, info = fano_model_based(
        ...     spike_data,
        ...     model="modulated_poisson",
        ...     model_params={"gain_mean": 1.0, "gain_var": 0.5}
        ... )
    """
    is_torch = isinstance(spike_data, torch.Tensor)
    T = spike_data.shape[0]

    if window is None:
        window = max(1, T // 10)

    model_params = model_params or {}

    # Compute empirical mean and variance
    flat_spike = spike_data.reshape(T, -1)
    n_flat = flat_spike.shape[1]

    step = window - overlap
    n_windows = (T - window) // step + 1

    if n_windows < 2:
        if is_torch:
            result = torch.full((n_flat,), float("nan"), device=spike_data.device)
        else:
            result = np.full((n_flat,), np.nan)
        rest_shape = spike_data.shape[1:]
        return result.reshape(rest_shape), {"error": "insufficient windows"}

    if is_torch:
        counts = torch.zeros((n_windows, n_flat), device=spike_data.device)
    else:
        counts = np.zeros((n_windows, n_flat))

    for w in range(n_windows):
        start = w * step
        end = start + window
        if is_torch:
            counts[w] = flat_spike[start:end].sum(dim=0)
        else:
            counts[w] = flat_spike[start:end].sum(axis=0)

    # Compute empirical statistics
    if is_torch:
        empirical_mean = counts.mean(dim=0)
        empirical_var = counts.var(dim=0, unbiased=True)
    else:
        empirical_mean = counts.mean(axis=0)
        empirical_var = counts.var(axis=0, ddof=1)

    # Compute model-based prediction
    if model == "modulated_poisson":
        gain_mean = model_params.get("gain_mean", 1.0)
        gain_var = model_params.get("gain_var", 0.5)

        model_mean, model_var = _modulated_poisson_moments(
            empirical_mean, gain_mean, gain_var
        )

        # Model-based FF: ratio of model variance to model mean
        # normalized by expected Poisson variance
        fano_model = model_var / (model_mean + 1e-12)

        info = {
            "model": "modulated_poisson",
            "gain_mean": gain_mean,
            "gain_var": gain_var,
            "empirical_mean": empirical_mean,
            "empirical_var": empirical_var,
            "model_mean": model_mean,
            "model_var": model_var,
        }

    elif model == "flexible_overdispersion":
        if stimulus_drive is None:
            # Use empirical mean as proxy for stimulus drive
            stimulus_drive = (
                np.log(empirical_mean + 1)
                if not is_torch
                else torch.log(empirical_mean + 1)
            )

        nonlinearity = model_params.get("nonlinearity", "relu")
        noise_std = model_params.get("noise_std", 0.5)

        model_mean, model_var = _flexible_overdispersion_moments(
            stimulus_drive, noise_std, nonlinearity
        )

        fano_model = model_var / (model_mean + 1e-12)

        info = {
            "model": "flexible_overdispersion",
            "nonlinearity": nonlinearity,
            "noise_std": noise_std,
            "empirical_mean": empirical_mean,
            "empirical_var": empirical_var,
            "model_mean": model_mean,
            "model_var": model_var,
        }

    else:
        raise ValueError(f"Unknown model: {model}")

    # Reshape to original non-time dimensions
    rest_shape = spike_data.shape[1:]
    fano_model = fano_model.reshape(rest_shape)

    # Apply batch axis aggregation if requested
    if batch_axis is not None:
        if is_torch:
            fano_model = torch.mean(fano_model, dim=tuple(batch_axis))
        else:
            fano_model = np.mean(fano_model, axis=tuple(batch_axis))

    return fano_model, info

fano_operational_time(spike_data, window=None, overlap=None, rate_hz=None, dt_ms=1.0, batch_axis=None)

Compute Fano factor in operational time (rate-independent).

The operational time Fano factor transforms the spike train such that

the firing rate equals 1 (normalized), making FF independent of the absolute firing rate. This is the recommended method for comparing variability across conditions with different rates.

For renewal processes, the operational time Fano factor equals the
squared coefficient of variation (CV²) of the ISI distribution.

Reference: Rajdl et al. (2020) "Fano Factor: A Potentially Useful
Information", Front. Comput. Neurosci.

Args:
    spike_data: Spike train of shape [T, ...]. First dimension is time.
        Values are binary (0/1) or spike counts.
    window: Window size in operational time units (default: 1.0).
        This is the expected count at rate=1 (i.e., 1 spike expected).
    overlap: Not used (kept for API compatibility).
    rate_hz: Firing rate in Hz. Can be:
        - Scalar: homogeneous rate
        - Array [T, ...]: time-varying rate
        - None: estimated from data using sliding window
    dt_ms: Time step in milliseconds for original time axis.
    batch_axis: Axes to average across for FF computation.

Returns:
    fano_op: Operational time Fano factor values.
    info: Dictionary with operational time info and computed statistics.

Example:
    >>> # Compare Fano factors at different rates
    >>> spikes_low_rate = generate_poisson_spikes(rate_hz=20, ...)
    >>> spikes_high_rate = generate_poisson_spikes(rate_hz=80, ...)
    >>> ff_low, _ = fano_operational_time(spikes_low_rate)
    >>> ff_high, _ = fano_operational_time(spikes_high_rate)
    >>> # Both should be ≈ 1 regardless of rate difference
Source code in btorch/analysis/dynamic_tools/spiking.py
@use_percentiles(value_key="fano_op")
@use_stats(value_key="fano_op")
def fano_operational_time(
    spike_data: np.ndarray | torch.Tensor,
    window: float | None = None,
    overlap: float | None = None,
    rate_hz: np.ndarray | torch.Tensor | float | None = None,
    dt_ms: float = 1.0,
    batch_axis: tuple[int, ...] | None = None,
) -> tuple[np.ndarray | torch.Tensor, dict]:
    """Compute Fano factor in operational time (rate-independent).

        The operational time Fano factor transforms the spike train such that
    the firing rate equals 1 (normalized), making FF independent of the
        absolute firing rate. This is the recommended method for comparing
        variability across conditions with different rates.

        For renewal processes, the operational time Fano factor equals the
        squared coefficient of variation (CV²) of the ISI distribution.

        Reference: Rajdl et al. (2020) "Fano Factor: A Potentially Useful
        Information", Front. Comput. Neurosci.

        Args:
            spike_data: Spike train of shape [T, ...]. First dimension is time.
                Values are binary (0/1) or spike counts.
            window: Window size in operational time units (default: 1.0).
                This is the expected count at rate=1 (i.e., 1 spike expected).
            overlap: Not used (kept for API compatibility).
            rate_hz: Firing rate in Hz. Can be:
                - Scalar: homogeneous rate
                - Array [T, ...]: time-varying rate
                - None: estimated from data using sliding window
            dt_ms: Time step in milliseconds for original time axis.
            batch_axis: Axes to average across for FF computation.

        Returns:
            fano_op: Operational time Fano factor values.
            info: Dictionary with operational time info and computed statistics.

        Example:
            >>> # Compare Fano factors at different rates
            >>> spikes_low_rate = generate_poisson_spikes(rate_hz=20, ...)
            >>> spikes_high_rate = generate_poisson_spikes(rate_hz=80, ...)
            >>> ff_low, _ = fano_operational_time(spikes_low_rate)
            >>> ff_high, _ = fano_operational_time(spikes_high_rate)
            >>> # Both should be ≈ 1 regardless of rate difference
    """
    if window is None:
        window = 1.0  # Unit operational time window

    is_torch = isinstance(spike_data, torch.Tensor)

    # Estimate rate if not provided
    if rate_hz is None:
        if is_torch:
            rate_hz = _estimate_rate_torch(spike_data, dt_ms)
        else:
            rate_hz = _estimate_rate_numpy(spike_data, dt_ms)

    # Compute operational time Fano factor
    if is_torch:
        return _compute_operational_fano_torch(
            spike_data, rate_hz, window, dt_ms, batch_axis
        )
    else:
        return _compute_operational_fano_numpy(
            spike_data, rate_hz, window, dt_ms, batch_axis
        )

use_percentiles(func=None, *, value_key='values', default_percentiles=None)

Decorator to add percentiles arg and optionally compute percentiles.

This decorator adds a percentiles parameter to a function that returns per-neuron values. Percentiles are only computed if percentiles is not None. Results are stored in info[f"{value_key}_percentile"].

Can also accept a dict mapping return positions to labels for functions returning multiple values (e.g., {1: "eci", 3: "lag"}).

The decorated function should return either: - A tuple of (values, info_dict) where values are per-neuron metrics - Just the per-neuron values (will be wrapped in a tuple with empty dict) - A tuple of multiple values with info as the last element

Parameters:

Name Type Description Default
func Callable | None

The function to decorate (or None if using with parentheses)

None
value_key str | dict[int, str]

Key to use in info dict for the percentile result

'values'

Returns:

Type Description
Callable

Decorated function with added percentiles parameter

Example
@use_percentiles
def compute_metric(data, *, percentiles=None):
    values = some_computation(data)  # per-neuron values
    return values, {"raw": values}

# Usage:
values, info = compute_metric(data)  # no percentiles computed
values, info = compute_metric(data, percentiles=0.5)  # compute median
values, info = compute_metric(
    data, percentiles=(0.25, 0.5, 0.75)
)  # compute quartiles
Source code in btorch/analysis/statistics.py
def use_percentiles(
    func: Callable | None = None,
    *,
    value_key: str | dict[int, str] = "values",
    default_percentiles: float | tuple[float, ...] | None = None,
) -> Callable:
    """Decorator to add percentiles arg and optionally compute percentiles.

    This decorator adds a `percentiles` parameter to a function that returns
    per-neuron values. Percentiles are only computed if percentiles is not None.
    Results are stored in info[f"{value_key}_percentile"].

    Can also accept a dict mapping return positions to labels for functions
    returning multiple values (e.g., {1: "eci", 3: "lag"}).

    The decorated function should return either:
    - A tuple of (values, info_dict) where values are per-neuron metrics
    - Just the per-neuron values (will be wrapped in a tuple with empty dict)
    - A tuple of multiple values with info as the last element

    Args:
        func: The function to decorate (or None if using with parentheses)
        value_key: Key to use in info dict for the percentile result

    Returns:
        Decorated function with added percentiles parameter

    Example:
        ```python
        @use_percentiles
        def compute_metric(data, *, percentiles=None):
            values = some_computation(data)  # per-neuron values
            return values, {"raw": values}

        # Usage:
        values, info = compute_metric(data)  # no percentiles computed
        values, info = compute_metric(data, percentiles=0.5)  # compute median
        values, info = compute_metric(
            data, percentiles=(0.25, 0.5, 0.75)
        )  # compute quartiles
        ```
    """

    def decorator(f: Callable) -> Callable:
        @wraps(f)
        def wrapper(
            *args,
            percentiles: float
            | tuple[float, ...]
            | dict[int, float | tuple[float, ...]]
            | None = default_percentiles,
            **kwargs,
        ) -> tuple[Any, ...]:
            # Call the original function
            result = f(*args, **kwargs)
            if percentiles is None:
                return result

            # Unpack result using shared helper
            values_tuple, info = _unpack_result(result, value_key)

            # Ensure info is a dict
            if info is None:
                info = {}

            updated_info = dict(info)

            # Helper to get value key name for a position
            def _get_value_key_name(pos: int) -> str:
                if isinstance(value_key, dict):
                    return value_key.get(pos, f"values{pos}")
                return f"{value_key}{pos}"

            # Helper to get values at a position
            def _get_values(pos: int) -> Any:
                if pos < 0 or pos >= len(values_tuple):
                    raise IndexError(
                        f"Position {pos} out of range for return tuple "
                        f"of length {len(values_tuple)}"
                    )
                return values_tuple[pos]

            # Compute percentiles only if requested
            if isinstance(percentiles, dict):
                # Dict format: {position: percentile_value(s)}
                # Allows different percentiles for different return values
                for pos, perc_value in percentiles.items():
                    values = _get_values(pos)
                    key_name = _get_value_key_name(pos)
                    perc_result = compute_percentiles(values, perc_value)
                    updated_info[f"{key_name}_percentiles"] = perc_result["percentiles"]
                    updated_info[f"{key_name}_levels"] = perc_result["levels"]
            elif isinstance(value_key, dict):
                # Dict value_key with single percentiles value:
                # Apply same percentiles to all positions in value_key
                for pos, label in value_key.items():
                    values = _get_values(pos)
                    perc_result = compute_percentiles(values, percentiles)
                    updated_info[f"{label}_percentiles"] = perc_result["percentiles"]
                    updated_info[f"{label}_levels"] = perc_result["levels"]
            else:
                # Single percentiles format - apply to position 0
                values = _get_values(0)
                perc_result = compute_percentiles(values, percentiles)
                updated_info[f"{value_key}_percentiles"] = perc_result["percentiles"]
                updated_info[f"{value_key}_levels"] = perc_result["levels"]

            if len(values_tuple) > 1:
                return values_tuple + (updated_info,)
            else:
                return values_tuple[0], updated_info

        return wrapper

    if func is None:
        return decorator
    return decorator(func)

use_stats(func=None, *, value_key='values', dim=None, default_stat=None, default_stat_info=None, default_nan_policy='skip', default_inf_policy='propagate')

Decorator to add stat and stat_info args for aggregation.

This decorator adds stat, stat_info, nan_policy, and inf_policy parameters to a function that returns per-neuron values.

  • stat: If not None, returns the aggregated value instead of per-neuron values. The aggregation is stored in info[f"{value_key}_stat"]. Can be a StatChoice, or a dict mapping return position to label (e.g., {1: "eci", 3: "lag"}) for functions returning multiple values.
  • stat_info: Additional stats to compute and store in info dict without affecting the return value. Can be a single StatChoice, Iterable of StatChoice, a dict mapping position to label(s), or None. If dict, format is {position: stat_or_stats} where stat_or_stats can be a single StatChoice or Iterable of StatChoice.
  • dim: Dimension(s) to aggregate over. Can be:
    • None: Flatten all dimensions (default)
    • int: Aggregate over this dimension for all outputs
    • tuple[int, ...]: Aggregate over these dimensions for all outputs
    • dict[int, int | tuple[int, ...] | None]: Different dim for each output position (e.g., {0: 1, 1: 2, 2: None, 3: (1, 3, 4)})
  • nan_policy: How to handle NaN values:
    • "skip": Ignore NaN values (default)
    • "warn": Warn if NaN values found but continue
    • "assert": Raise error if NaN values found
  • inf_policy: How to handle Inf values:
    • "propagate": Keep Inf values (default)
    • "skip": Ignore Inf values
    • "warn": Warn if Inf values found but continue
    • "assert": Raise error if Inf values found

The decorated function should return either: - A tuple of (values, info_dict) where values are per-neuron metrics - Just the per-neuron values (will be wrapped in a tuple with empty dict) - A tuple of multiple values with info as the last element

Parameters:

Name Type Description Default
func Callable | None

The function to decorate (or None if using with parentheses)

None
value_key str | dict[int, str]

Key prefix to use in info dict for stat results

'values'
dim int | tuple[int, ...] | dict[int, int | tuple[int, ...] | None] | None

Dimension(s) to aggregate over for each output

None
default_nan_policy Literal['skip', 'warn', 'assert']

Default nan_policy for this decorated function

'skip'
default_inf_policy Literal['propagate', 'skip', 'warn', 'assert']

Default inf_policy for this decorated function

'propagate'
default_stat StatChoice | dict[int, StatChoice] | None

Default stat for this decorated function

None

Returns:

Type Description
Callable

Decorated function with added stat, stat_info, nan_policy, and

Callable

inf_policy parameters

Example
@use_stat
def compute_metric(
    data,
    *,
    stat=None,
    stat_info=None,
    nan_policy="skip",
    inf_policy="propagate",
):
    values = some_computation(data)  # per-neuron values
    return values, {"raw": values}

# Usage:
values, info = compute_metric(data)  # returns per-neuron values
mean_val, info = compute_metric(data, stat="mean")  # returns aggregated
values, info = compute_metric(
    data, stat_info=["mean", "max"]
)  # extra stats in info

# Multi-value return with dict stat:
@use_stat
def compute_multiple(data, *, stat=None, stat_info=None):
    eci = compute_eci(data)  # per-neuron
    lag = compute_lag(data)  # per-neuron
    return eci, lag, {}  # multiple values

# Aggregate specific positions with dict stat:
eci_mean, lag_mean, info = compute_multiple(
    data, stat={0: "eci", 1: "lag"}
)
Source code in btorch/analysis/statistics.py
def use_stats(
    func: Callable | None = None,
    *,
    value_key: str | dict[int, str] = "values",
    dim: int | tuple[int, ...] | dict[int, int | tuple[int, ...] | None] | None = None,
    default_stat: StatChoice | dict[int, StatChoice] | None = None,
    default_stat_info: (
        StatChoice
        | Iterable[StatChoice]
        | dict[int, StatChoice | Iterable[StatChoice]]
        | None
    ) = None,
    default_nan_policy: Literal["skip", "warn", "assert"] = "skip",
    default_inf_policy: Literal["propagate", "skip", "warn", "assert"] = "propagate",
) -> Callable:
    """Decorator to add stat and stat_info args for aggregation.

    This decorator adds `stat`, `stat_info`, `nan_policy`, and `inf_policy`
    parameters to a function that returns per-neuron values.

    - `stat`: If not None, returns the aggregated value instead of per-neuron
      values. The aggregation is stored in info[f"{value_key}_stat"].
      Can be a StatChoice, or a dict mapping return position to label
      (e.g., {1: "eci", 3: "lag"}) for functions returning multiple values.
    - `stat_info`: Additional stats to compute and store in info dict without
      affecting the return value. Can be a single StatChoice, Iterable of
      StatChoice, a dict mapping position to label(s), or None.
      If dict, format is {position: stat_or_stats} where stat_or_stats can be
      a single StatChoice or Iterable of StatChoice.
    - `dim`: Dimension(s) to aggregate over. Can be:
        - None: Flatten all dimensions (default)
        - int: Aggregate over this dimension for all outputs
        - tuple[int, ...]: Aggregate over these dimensions for all outputs
        - dict[int, int | tuple[int, ...] | None]: Different dim for each
          output position (e.g., {0: 1, 1: 2, 2: None, 3: (1, 3, 4)})
    - `nan_policy`: How to handle NaN values:
        - "skip": Ignore NaN values (default)
        - "warn": Warn if NaN values found but continue
        - "assert": Raise error if NaN values found
    - `inf_policy`: How to handle Inf values:
        - "propagate": Keep Inf values (default)
        - "skip": Ignore Inf values
        - "warn": Warn if Inf values found but continue
        - "assert": Raise error if Inf values found

    The decorated function should return either:
    - A tuple of (values, info_dict) where values are per-neuron metrics
    - Just the per-neuron values (will be wrapped in a tuple with empty dict)
    - A tuple of multiple values with info as the last element

    Args:
        func: The function to decorate (or None if using with parentheses)
        value_key: Key prefix to use in info dict for stat results
        dim: Dimension(s) to aggregate over for each output
        default_nan_policy: Default nan_policy for this decorated function
        default_inf_policy: Default inf_policy for this decorated function
        default_stat: Default stat for this decorated function

    Returns:
        Decorated function with added stat, stat_info, nan_policy, and
        inf_policy parameters

    Example:
        ```python
        @use_stat
        def compute_metric(
            data,
            *,
            stat=None,
            stat_info=None,
            nan_policy="skip",
            inf_policy="propagate",
        ):
            values = some_computation(data)  # per-neuron values
            return values, {"raw": values}

        # Usage:
        values, info = compute_metric(data)  # returns per-neuron values
        mean_val, info = compute_metric(data, stat="mean")  # returns aggregated
        values, info = compute_metric(
            data, stat_info=["mean", "max"]
        )  # extra stats in info

        # Multi-value return with dict stat:
        @use_stat
        def compute_multiple(data, *, stat=None, stat_info=None):
            eci = compute_eci(data)  # per-neuron
            lag = compute_lag(data)  # per-neuron
            return eci, lag, {}  # multiple values

        # Aggregate specific positions with dict stat:
        eci_mean, lag_mean, info = compute_multiple(
            data, stat={0: "eci", 1: "lag"}
        )
        ```
    """

    def decorator(f: Callable) -> Callable:
        # Inspect the wrapped function to determine what arguments it accepts
        sig = inspect.signature(f)
        f_accepts_nan_policy = "nan_policy" in sig.parameters
        f_accepts_inf_policy = "inf_policy" in sig.parameters

        @wraps(f)
        def wrapper(
            *args,
            stat: StatChoice | dict[int, StatChoice] | None = default_stat,
            stat_info: StatChoice
            | Iterable[StatChoice]
            | dict[int, StatChoice | Iterable[StatChoice]]
            | None = default_stat_info,
            nan_policy: Literal["skip", "warn", "assert"] | None = None,
            inf_policy: Literal["propagate", "skip", "warn", "assert"] | None = None,
            **kwargs,
        ) -> tuple[Any, ...]:
            # Use effective policies (passed value > decorator default > "skip")
            effective_nan_policy = (
                nan_policy if nan_policy is not None else default_nan_policy
            )
            effective_inf_policy = (
                inf_policy if inf_policy is not None else default_inf_policy
            )

            # Pass policies to the wrapped function if it accepts them
            if f_accepts_nan_policy:
                kwargs["nan_policy"] = effective_nan_policy
            if f_accepts_inf_policy:
                kwargs["inf_policy"] = effective_inf_policy

            # Call the original function
            result = f(*args, **kwargs)

            # Unpack result using shared helper
            values_tuple, info = _unpack_result(result, value_key)

            # Ensure info is a dict
            if info is None:
                info = {}

            updated_info = dict(info)

            # Helper to get value key name for a position
            def _get_value_key_name(pos: int) -> str:
                if isinstance(value_key, dict):
                    return value_key.get(pos, f"values{pos}")
                if len(values_tuple) > 1:
                    return f"{value_key}{pos}"
                return value_key

            # Helper to get values at a position
            def _get_values(pos: int) -> Any:
                if pos < 0 or pos >= len(values_tuple):
                    raise IndexError(
                        f"Position {pos} out of range for return tuple "
                        f"of length {len(values_tuple)}"
                    )
                return values_tuple[pos]

            # Helper to get effective dim for a position
            def _get_dim_for_pos(pos: int) -> int | tuple[int, ...] | None:
                if dim is None:
                    return None
                if isinstance(dim, dict):
                    return dim.get(pos, None)
                return dim

            # Handle stat parameter
            if stat is not None:
                # Check if stat is a dict mapping positions to stats
                if isinstance(stat, dict):
                    # Multiple position aggregation with dict stat
                    results = []
                    for pos, stat_choice in stat.items():
                        values = _get_values(pos)
                        key_name = _get_value_key_name(pos)
                        effective_dim = _get_dim_for_pos(pos)
                        stat_value = _compute_stat(
                            values,
                            stat_choice,
                            effective_nan_policy,
                            effective_inf_policy,
                            effective_dim,  # type: ignore
                        )
                        results.append(stat_value)
                        updated_info[key_name] = values
                        updated_info[f"{key_name}_{stat_choice}"] = stat_value
                    return tuple(results) + (updated_info,)
                else:
                    # Single stat - apply to position 0
                    values = _get_values(0)
                    key_name = _get_value_key_name(0)
                    effective_dim = _get_dim_for_pos(0)
                    stat_value = _compute_stat(
                        values,
                        stat,
                        effective_nan_policy,
                        effective_inf_policy,
                        effective_dim,  # type: ignore
                    )
                    updated_info[key_name] = values
                    updated_info[f"{key_name}_{stat}"] = stat_value
                    return stat_value, updated_info

            # Handle stat_info parameter
            if stat_info is not None:
                # Check if stat_info is a dict mapping positions to stats
                if isinstance(stat_info, dict):
                    # Dict format: {position: stat_or_stats}
                    for pos, stats in stat_info.items():
                        # Normalize to iterable
                        if isinstance(stats, str):
                            stats_list = [stats]
                        else:
                            stats_list = list(stats)

                        # Use batch computation for efficiency
                        values = _get_values(pos)
                        key_name = _get_value_key_name(pos)
                        effective_dim = _get_dim_for_pos(pos)
                        if len(stats_list) > 1:
                            batch_results = _compute_stats_batch(
                                values,
                                [str(s) for s in stats_list],
                                effective_nan_policy,
                                effective_inf_policy,
                                effective_dim,
                            )
                            for s in stats_list:
                                updated_info[f"{key_name}_{s}"] = batch_results[str(s)]
                        else:
                            # Single stat - no need for batch optimization
                            stat_value = _compute_stat(
                                values,
                                stats_list[0],
                                effective_nan_policy,
                                effective_inf_policy,
                                effective_dim,  # type: ignore
                            )
                            updated_info[f"{key_name}_{stats_list[0]}"] = stat_value
                else:
                    # Original format: apply to position 0
                    # Normalize to iterable
                    if isinstance(stat_info, str):
                        stat_info_list = [stat_info]
                    else:
                        stat_info_list = list(stat_info)

                    # Use batch computation for efficiency (reuses mean/std for cv)
                    values = _get_values(0)
                    key_name = _get_value_key_name(0)
                    effective_dim = _get_dim_for_pos(0)
                    if len(stat_info_list) > 1:
                        batch_results = _compute_stats_batch(
                            values,
                            [str(s) for s in stat_info_list],
                            effective_nan_policy,
                            effective_inf_policy,
                            effective_dim,
                        )
                        for s in stat_info_list:
                            updated_info[f"{key_name}_{s}"] = batch_results[str(s)]
                    else:
                        # Single stat - no need for batch optimization
                        stat_value = _compute_stat(
                            values,
                            stat_info_list[0],
                            effective_nan_policy,
                            effective_inf_policy,
                            effective_dim,  # type: ignore
                        )
                        updated_info[f"{key_name}_{stat_info_list[0]}"] = stat_value

                # Return original values with updated info
                if len(values_tuple) > 1:
                    return values_tuple + (updated_info,)
                else:
                    return (values_tuple[0], updated_info)

            # No stat or stat_info - return original values with info
            if len(values_tuple) > 1:
                return values_tuple + (updated_info,)
            else:
                return (values_tuple[0], updated_info)

        return wrapper

    if func is None:
        return decorator
    return decorator(func)