Skip to content

Analysis

btorch.analysis

Attributes

StatChoice = Literal['mean', 'median', 'max', 'min', 'std', 'var', 'argmax', 'argmin', 'cv'] module-attribute

__all__ = ['agg_by_neuropil', 'agg_by_neuron', 'agg_conn', 'build_group_frame', 'group_values', 'group_summary', 'group_ecdf', 'branching_ratio', 'HopDistanceModel', 'compute_ie_ratio', 'indices_to_mask', 'select_on_metric', 'isi_cv', 'fano', 'kurtosis', 'local_variation', 'isi_cv_population', 'fano_population', 'kurtosis_population', 'cv_temporal', 'fano_temporal', 'fano_sweep', 'firing_rate', 'compute_raster', 'compute_log_hist', 'compute_spectrum', 'describe_array', 'suggest_skip_timestep', 'voltage_overshoot', 'StatChoice', 'use_stats', 'use_percentiles'] module-attribute

Classes

HopDistanceModel

Fast computation of hop distances in networks using BFS.

Source code in btorch/analysis/connectivity.py
class HopDistanceModel:
    """Fast computation of hop distances in networks using BFS."""

    def __init__(
        self,
        edges: Optional[pd.DataFrame] = None,
        adjacency: Optional[sparse.sparray] = None,
        node_mapping: Optional[Dict] = None,
        source: str = "source",
        target: str = "target",
    ):
        if edges is None and adjacency is None:
            raise ValueError("Must provide either edges DataFrame or adjacency matrix")

        if adjacency is not None:
            self.adjacency = adjacency.tocsr()
            self.use_sparse = True
            if node_mapping is None:
                self.node_mapping = {i: i for i in range(adjacency.shape[0])}
                self.reverse_mapping = {i: i for i in range(adjacency.shape[0])}
            else:
                self.node_mapping = node_mapping
                self.reverse_mapping = {v: k for k, v in node_mapping.items()}
            self.n_nodes = adjacency.shape[0]
        else:
            self.edges = edges.copy()
            self.source = source
            self.target = target
            self.use_sparse = False

            assert source in edges.columns, f'edges must contain "{source}" column'
            assert target in edges.columns, f'edges must contain "{target}" column'

            all_nodes = np.unique(
                np.concatenate([edges[source].values, edges[target].values])
            )
            self.node_mapping = {node: idx for idx, node in enumerate(all_nodes)}
            self.reverse_mapping = {
                idx: node for node, idx in self.node_mapping.items()
            }
            self.n_nodes = len(all_nodes)

            self._build_adjacency_dict()

    def _build_adjacency_dict(self):
        """Build adjacency dictionary for fast neighbor lookup."""
        self.adj_dict = {}
        for _, row in self.edges.iterrows():
            src_idx = self.node_mapping[row[self.source]]
            tgt_idx = self.node_mapping[row[self.target]]

            if src_idx not in self.adj_dict:
                self.adj_dict[src_idx] = []
            self.adj_dict[src_idx].append(tgt_idx)

        for node in self.adj_dict:
            self.adj_dict[node] = np.array(self.adj_dict[node], dtype=np.int32)

    def compute_distances(
        self, seeds: List[Union[int, str]], max_hops: Optional[int] = None
    ) -> pd.DataFrame:
        """Compute hop distances from seeds to all reachable nodes using
        BFS."""
        seed_indices = [self.node_mapping.get(seed, None) for seed in seeds]
        missing = [s for s in seeds if s not in self.node_mapping]
        if len(missing) != 0:
            print(f"Seeds not found in network: {missing}")
            seed_indices = [s for s in seed_indices if s is not None]

        distances = np.full(self.n_nodes, -1, dtype=np.int32)
        predecessors = np.full(self.n_nodes, -1, dtype=np.int32)

        queue = deque()
        for seed_idx in seed_indices:
            distances[seed_idx] = 0
            predecessors[seed_idx] = seed_idx
            queue.append(seed_idx)

        while queue:
            current_idx = queue.popleft()
            current_dist = distances[current_idx]

            if max_hops is not None and current_dist >= max_hops:
                continue

            if self.use_sparse:
                neighbors = self.adjacency[[current_idx]].indices.flatten()
            else:
                neighbors = self.adj_dict.get(current_idx, np.array([], dtype=np.int32))

            for neighbor_idx in neighbors:
                if distances[neighbor_idx] == -1:
                    distances[neighbor_idx] = current_dist + 1
                    predecessors[neighbor_idx] = current_idx
                    queue.append(neighbor_idx)

        reachable_mask = distances >= 0
        result_nodes = [
            self.reverse_mapping[idx] for idx in np.where(reachable_mask)[0]
        ]
        result_distances = distances[reachable_mask]
        result_predecessors = [
            self.reverse_mapping[predecessors[idx]]
            for idx in np.where(reachable_mask)[0]
        ]

        return pd.DataFrame(
            {
                "node": result_nodes,
                "distance": result_distances,
                "predecessor": result_predecessors,
            }
        )

    def hop_statistics(
        self, seeds: List[Union[int, str]], max_hops: Optional[int] = None
    ) -> pd.DataFrame:
        """Compute statistics about network reachability by hop distance."""
        distances_df = self.compute_distances(seeds, max_hops)

        if distances_df.empty:
            return pd.DataFrame(
                columns=[
                    "hops",
                    "nodes_count",
                    "nodes_percentage",
                    "cumulative_count",
                    "cumulative_percentage",
                ]
            )

        hop_counts = distances_df["distance"].value_counts().sort_index()

        stats = []
        cumulative = 0
        for hops in range(hop_counts.index.min(), hop_counts.index.max() + 1):
            count = hop_counts.get(hops, 0)
            cumulative += count

            stats.append(
                {
                    "hops": hops,
                    "nodes_count": count,
                    "nodes_percentage": 100.0 * count / self.n_nodes,
                    "cumulative_count": cumulative,
                    "cumulative_percentage": 100.0 * cumulative / self.n_nodes,
                }
            )

        return pd.DataFrame(stats)

    def reconstruct_path(
        self,
        source_node: Union[int, str],
        target_node: Union[int, str],
        distances_df: Optional[pd.DataFrame] = None,
    ) -> List[Union[int, str]]:
        """Reconstruct shortest path from source to target using predecessor
        info."""
        if distances_df is None:
            distances_df = self.compute_distances([source_node])

        target_row = distances_df[distances_df["node"] == target_node]
        if target_row.empty:
            return []

        path = []
        current = target_node
        predecessors_dict = dict(zip(distances_df["node"], distances_df["predecessor"]))

        while current != source_node:
            path.append(current)
            if current not in predecessors_dict:
                return []
            current = predecessors_dict[current]

        path.append(source_node)
        return path[::-1]
Functions
compute_distances(seeds, max_hops=None)

Compute hop distances from seeds to all reachable nodes using BFS.

Source code in btorch/analysis/connectivity.py
def compute_distances(
    self, seeds: List[Union[int, str]], max_hops: Optional[int] = None
) -> pd.DataFrame:
    """Compute hop distances from seeds to all reachable nodes using
    BFS."""
    seed_indices = [self.node_mapping.get(seed, None) for seed in seeds]
    missing = [s for s in seeds if s not in self.node_mapping]
    if len(missing) != 0:
        print(f"Seeds not found in network: {missing}")
        seed_indices = [s for s in seed_indices if s is not None]

    distances = np.full(self.n_nodes, -1, dtype=np.int32)
    predecessors = np.full(self.n_nodes, -1, dtype=np.int32)

    queue = deque()
    for seed_idx in seed_indices:
        distances[seed_idx] = 0
        predecessors[seed_idx] = seed_idx
        queue.append(seed_idx)

    while queue:
        current_idx = queue.popleft()
        current_dist = distances[current_idx]

        if max_hops is not None and current_dist >= max_hops:
            continue

        if self.use_sparse:
            neighbors = self.adjacency[[current_idx]].indices.flatten()
        else:
            neighbors = self.adj_dict.get(current_idx, np.array([], dtype=np.int32))

        for neighbor_idx in neighbors:
            if distances[neighbor_idx] == -1:
                distances[neighbor_idx] = current_dist + 1
                predecessors[neighbor_idx] = current_idx
                queue.append(neighbor_idx)

    reachable_mask = distances >= 0
    result_nodes = [
        self.reverse_mapping[idx] for idx in np.where(reachable_mask)[0]
    ]
    result_distances = distances[reachable_mask]
    result_predecessors = [
        self.reverse_mapping[predecessors[idx]]
        for idx in np.where(reachable_mask)[0]
    ]

    return pd.DataFrame(
        {
            "node": result_nodes,
            "distance": result_distances,
            "predecessor": result_predecessors,
        }
    )
hop_statistics(seeds, max_hops=None)

Compute statistics about network reachability by hop distance.

Source code in btorch/analysis/connectivity.py
def hop_statistics(
    self, seeds: List[Union[int, str]], max_hops: Optional[int] = None
) -> pd.DataFrame:
    """Compute statistics about network reachability by hop distance."""
    distances_df = self.compute_distances(seeds, max_hops)

    if distances_df.empty:
        return pd.DataFrame(
            columns=[
                "hops",
                "nodes_count",
                "nodes_percentage",
                "cumulative_count",
                "cumulative_percentage",
            ]
        )

    hop_counts = distances_df["distance"].value_counts().sort_index()

    stats = []
    cumulative = 0
    for hops in range(hop_counts.index.min(), hop_counts.index.max() + 1):
        count = hop_counts.get(hops, 0)
        cumulative += count

        stats.append(
            {
                "hops": hops,
                "nodes_count": count,
                "nodes_percentage": 100.0 * count / self.n_nodes,
                "cumulative_count": cumulative,
                "cumulative_percentage": 100.0 * cumulative / self.n_nodes,
            }
        )

    return pd.DataFrame(stats)
reconstruct_path(source_node, target_node, distances_df=None)

Reconstruct shortest path from source to target using predecessor info.

Source code in btorch/analysis/connectivity.py
def reconstruct_path(
    self,
    source_node: Union[int, str],
    target_node: Union[int, str],
    distances_df: Optional[pd.DataFrame] = None,
) -> List[Union[int, str]]:
    """Reconstruct shortest path from source to target using predecessor
    info."""
    if distances_df is None:
        distances_df = self.compute_distances([source_node])

    target_row = distances_df[distances_df["node"] == target_node]
    if target_row.empty:
        return []

    path = []
    current = target_node
    predecessors_dict = dict(zip(distances_df["node"], distances_df["predecessor"]))

    while current != source_node:
        path.append(current)
        if current not in predecessors_dict:
            return []
        current = predecessors_dict[current]

    path.append(source_node)
    return path[::-1]

Functions

agg_by_neuron(y, neurons, agg='mean', neuron_type_column='cell_type', **kwargs)

Aggregate data by neuron type.

Source code in btorch/analysis/aggregation.py
def agg_by_neuron(
    y,
    neurons: pd.DataFrame,
    agg: Literal["mean", "sum", "std"] = "mean",
    neuron_type_column: str = "cell_type",
    **kwargs,
) -> dict:
    """Aggregate data by neuron type."""
    agg_func = getattr(np, agg) if isinstance(y, np.ndarray) else getattr(torch, agg)
    ret = {}
    for neuron_type, group in neurons.groupby(
        neuron_type_column, dropna=True, **kwargs
    ):
        ret[neuron_type] = agg_func(y[..., group.simple_id.to_numpy()], -1)
    return ret

agg_by_neuropil(y, neurons=None, connections=None, mode='all_innervated', agg='mean', use_polars=False)

Aggregate activations by neuropil under a validated aggregation mode.

Source code in btorch/analysis/aggregation.py
def agg_by_neuropil(
    y,
    neurons: pd.DataFrame | None = None,
    connections: pd.DataFrame | None = None,
    mode: Literal["top_innervated", "all_innervated"] = "all_innervated",
    agg: Literal["mean", "sum", "std"] = "mean",
    use_polars: bool = False,
):
    """Aggregate activations by neuropil under a validated aggregation mode."""
    agg_func = getattr(np, agg) if isinstance(y, np.ndarray) else getattr(torch, agg)
    if use_polars:
        try:
            import polars as pl
        except ImportError:
            use_polars = False

    if mode == "top_innervated":
        assert neurons is not None, "neurons must be provided for top_innervated mode"
        tmp = neurons[["group", "simple_id"]].copy()
        tmp = tmp[tmp["simple_id"] < y.shape[-1]]
        pre_ret: dict = {}
        post_ret: dict = {}

        if use_polars:
            tmp["pre"] = tmp["group"].str.split(".").str[0]
            tmp["post"] = tmp["group"].str.split(".").str[-1]
            ptbl = pl.from_pandas(tmp)

            pre_groups = ptbl.group_by("pre", maintain_order=True).agg(
                pl.col("simple_id")
            )
            for row in pre_groups.iter_rows(named=True):
                sid = np.asarray(row["simple_id"], dtype=int)
                pre_ret[row["pre"]] = agg_func(y[..., sid], -1)

            post_groups = ptbl.group_by("post", maintain_order=True).agg(
                pl.col("simple_id")
            )
            for row in post_groups.iter_rows(named=True):
                sid = np.asarray(row["simple_id"], dtype=int)
                post_ret[row["post"]] = agg_func(y[..., sid], -1)
        else:
            tmp["pre"] = tmp["group"].apply(lambda x: x.split(".")[0])
            tmp["post"] = tmp["group"].apply(lambda x: x.split(".")[-1])
            for pre, group in tmp.groupby("pre", dropna=True):
                pre_ret[pre] = agg_func(y[..., group.simple_id], -1)
            for post, group in tmp.groupby("post", dropna=True):
                post_ret[post] = agg_func(y[..., group.simple_id], -1)
        return pre_ret, post_ret
    if mode == "all_innervated":
        assert (
            connections is not None
        ), "connections must be provided for all_innervated mode"
        tmp = connections[["pre_simple_id", "post_simple_id", "neuropil"]]
        tmp = tmp[
            (tmp["pre_simple_id"] < y.shape[-1]) & (tmp["post_simple_id"] < y.shape[-1])
        ]
        pre_ret: dict = {}
        post_ret: dict = {}

        if use_polars:
            ptbl = pl.from_pandas(tmp)
            groups = ptbl.group_by("neuropil", maintain_order=True).agg(
                pl.col("pre_simple_id"), pl.col("post_simple_id")
            )
            for row in groups.iter_rows(named=True):
                neuropil = row["neuropil"]
                pre_ids = np.asarray(row["pre_simple_id"], dtype=int)
                post_ids = np.asarray(row["post_simple_id"], dtype=int)
                pre_ret[neuropil] = agg_func(y[..., pre_ids], -1)
                post_ret[neuropil] = agg_func(y[..., post_ids], -1)
        else:
            for neuropil, group in tmp.groupby("neuropil", dropna=True):
                pre_ret[neuropil] = agg_func(y[..., group.pre_simple_id.to_numpy()], -1)
                post_ret[neuropil] = agg_func(
                    y[..., group.post_simple_id.to_numpy()], -1
                )
        return pre_ret, post_ret

    raise ValueError(
        "Invalid `mode` for `agg_by_neuropil`. "
        "Expected one of: {'top_innervated', 'all_innervated'}."
    )

agg_conn(y, conn, conn_weight=None, neurons=None, mode='neuron', neuron_type_column='cell_type', agg='mean')

Aggregate connectivity weights by neuropil or neuron-type pairs.

Source code in btorch/analysis/aggregation.py
def agg_conn(
    y,
    conn: pd.DataFrame,
    conn_weight: scipy.sparse.sparray | None = None,
    neurons: pd.DataFrame | None = None,
    mode: Literal["neuropil", "neuron"] = "neuron",
    neuron_type_column: str = "cell_type",
    agg: Literal["mean", "sum", "std"] = "mean",
):
    """Aggregate connectivity weights by neuropil or neuron-type pairs."""
    if conn_weight is not None:
        conn_weight = conn_weight.tocoo()
        conn = conn.merge(
            pd.DataFrame(
                {
                    "pre_simple_id": conn_weight.row,
                    "post_simple_id": conn_weight.col,
                    "weight": conn_weight.data,
                }
            ),
            how="left",
            on=["pre_simple_id", "post_simple_id"],
        )
    if mode == "neuropil":
        return conn.groupby("neuropil")["weight"].agg(agg)
    if mode == "neuron":
        assert neurons is not None, "neurons must be provided for neuron mode"
        conn = conn.merge(
            neurons[["simple_id", neuron_type_column]].rename(
                columns={
                    "simple_id": "pre_simple_id",
                    neuron_type_column: f"pre_{neuron_type_column}",
                }
            ),
            how="left",
            on="pre_simple_id",
        )
        conn = conn.merge(
            neurons[["simple_id", neuron_type_column]].rename(
                columns={
                    "simple_id": "post_simple_id",
                    neuron_type_column: f"post_{neuron_type_column}",
                }
            ),
            how="left",
            on="post_simple_id",
        )
        return conn.groupby(
            [f"pre_{neuron_type_column}", f"post_{neuron_type_column}"]
        )["weight"].agg(agg)

    raise ValueError(
        "Invalid `mode` for `agg_conn`. Expected one of: {'neuropil', 'neuron'}."
    )

branching_ratio(all_counts, k_max=40, *, maxslopes=None, scatterpoints=False, eps=1e-12, ar1_fallback=True)

Estimate branching ratio via subsampling-invariant MR fitting.

Citation: Wilting, J., & Priesemann, V. (2018). Inferring collective dynamical states from widely unobserved systems. Nature Communications, 9(1), 2325. https://doi.org/10.1038/s41467-018-04725-4

Parameters:

Name Type Description Default
all_counts ndarray | list[ndarray] | list[list[float]] | str

One trial, stacked trials, list of trials, or path input.

required
k_max int

Maximum lag used for the MR fit.

40
maxslopes int | None

Legacy synonym for k_max.

None
scatterpoints bool

Keep centered lagged x and y samples in result.

False
eps float

Small stabilizer for divisions and weights.

1e-12
ar1_fallback bool

Use AR(1) fallback when MR fit is ill-posed.

True

Returns:

Type Description
dict

Dictionary with branching_ratio, naive_branching_ratio, k,

dict

r_k, stderr and fit metadata.

Raises:

Type Description
ValueError

If input data are invalid or too short.

RuntimeError

If MR and fallback estimation both fail.

Source code in btorch/analysis/branching.py
def branching_ratio(
    all_counts: np.ndarray | list[np.ndarray] | list[list[float]] | str,
    k_max: int = 40,
    *,
    maxslopes: int | None = None,
    scatterpoints: bool = False,
    eps: float = 1e-12,
    ar1_fallback: bool = True,
) -> dict:
    """Estimate branching ratio via subsampling-invariant MR fitting.

    Citation:
    Wilting, J., & Priesemann, V. (2018). Inferring collective dynamical states
    from widely unobserved systems. Nature Communications, 9(1), 2325.
    https://doi.org/10.1038/s41467-018-04725-4

    Args:
        all_counts: One trial, stacked trials, list of trials, or path input.
        k_max: Maximum lag used for the MR fit.
        maxslopes: Legacy synonym for ``k_max``.
        scatterpoints: Keep centered lagged ``x`` and ``y`` samples in result.
        eps: Small stabilizer for divisions and weights.
        ar1_fallback: Use AR(1) fallback when MR fit is ill-posed.

    Returns:
        Dictionary with ``branching_ratio``, ``naive_branching_ratio``, ``k``,
        ``r_k``, ``stderr`` and fit metadata.

    Raises:
        ValueError: If input data are invalid or too short.
        RuntimeError: If MR and fallback estimation both fail.
    """
    counts_list = input_handler(all_counts)

    if maxslopes is not None:
        k_max = maxslopes
    k_max = int(max(2, k_max))

    shortest = min(len(c) for c in counts_list)
    max_possible = min(k_max, shortest - 1)
    if max_possible < 1:
        raise ValueError("Time series too short for branching-ratio estimation")

    k, r_k_raw, stderr_raw, data_length, mean_activity, xs, ys = get_slopes(
        counts_list,
        max_possible + 1,
        scatterpoints=scatterpoints,
        eps=eps,
    )

    r_k = np.asarray(r_k_raw, dtype=np.float64)
    stderr = np.asarray(stderr_raw, dtype=np.float64)

    valid = np.isfinite(r_k) & np.isfinite(stderr) & (r_k > 0) & (stderr > 0)
    if np.sum(valid) < 2:
        if ar1_fallback:
            result = _ar1_fallback(counts_list, data_length, mean_activity, eps)
            if scatterpoints:
                result["xs"] = xs
                result["ys"] = ys
            return result
        raise RuntimeError("Insufficient valid r_k values for MR estimation")

    k_valid = k[valid].astype(np.float64)
    r_valid = np.clip(r_k[valid], 1e-12, 1e12)
    stderr_valid = stderr[valid]

    y = np.log(r_valid)
    w = 1.0 / (stderr_valid + eps)
    w = w / np.mean(w)

    slope, intercept = np.polyfit(k_valid, y, 1, w=w)
    m_hat = float(np.exp(slope))
    a_fit = float(np.exp(intercept))

    if not np.isfinite(m_hat) or m_hat <= 0:
        raise RuntimeError("Fitted branching ratio is non-positive or invalid")

    result = {
        "branching_ratio": m_hat,
        "a_fit": a_fit,
        "autocorrelationtime": float(-1.0 / np.log(m_hat))
        if (m_hat > 0 and m_hat != 1.0)
        else np.inf,
        "naive_branching_ratio": float(r_k[0]) if len(r_k) > 0 else np.nan,
        "k": k,
        "r_k": r_k,
        "stderr": stderr,
        "fit_slope": float(slope),
        "fit_intercept": float(intercept),
        "fit_points_used": int(np.sum(valid)),
        "data_length": data_length,
        "mean_activity": mean_activity,
    }
    if scatterpoints:
        result["xs"] = xs
        result["ys"] = ys
    return result

build_group_frame(values, neurons_df, group_by, *, simple_id_col='simple_id', value_name='value', dropna=True)

Convert neuron-aligned values to a tidy frame for grouped analyses.

Parameters:

Name Type Description Default
values TensorLike

Array/tensor shaped [N] or [..., N] where the last axis is neuron. All leading dimensions are flattened into independent samples (e.g., trials, conditions, or time points).

required
neurons_df DataFrame

DataFrame containing at least simple_id_col and group_by.

required
group_by str

Column in neurons_df used as grouping key.

required
simple_id_col str

Column mapping rows in neurons_df to neuron index in values.

'simple_id'
value_name str

Name for the output value column.

'value'
dropna bool

Drop missing values in group/value columns when True.

True
Source code in btorch/analysis/aggregation.py
def build_group_frame(
    values: TensorLike,
    neurons_df: pd.DataFrame,
    group_by: str,
    *,
    simple_id_col: str = "simple_id",
    value_name: str = "value",
    dropna: bool = True,
) -> pd.DataFrame:
    """Convert neuron-aligned values to a tidy frame for grouped analyses.

    Args:
        values: Array/tensor shaped `[N]` or `[..., N]` where the last axis is
            neuron. All leading dimensions are flattened into independent
            samples (e.g., trials, conditions, or time points).
        neurons_df: DataFrame containing at least `simple_id_col` and `group_by`.
        group_by: Column in `neurons_df` used as grouping key.
        simple_id_col: Column mapping rows in `neurons_df` to neuron index in
            `values`.
        value_name: Name for the output value column.
        dropna: Drop missing values in group/value columns when `True`.
    """
    y = _to_numpy(values)
    if y.ndim < 1:
        raise ValueError("`values` must have at least one dimension.")

    if simple_id_col not in neurons_df.columns:
        raise ValueError(f"Missing `{simple_id_col}` in `neurons_df`.")
    if group_by not in neurons_df.columns:
        raise ValueError(f"Missing `{group_by}` in `neurons_df`.")

    metadata = neurons_df.loc[:, [simple_id_col, group_by]].copy()
    if dropna:
        metadata = metadata.dropna(subset=[group_by])
    if metadata.empty:
        raise ValueError("No neuron metadata available after filtering.")

    if metadata[simple_id_col].duplicated().any():
        raise ValueError(f"`{simple_id_col}` must be unique in `neurons_df`.")

    try:
        simple_ids = pd.to_numeric(metadata[simple_id_col], errors="raise").to_numpy(
            dtype=np.int64
        )
    except Exception as exc:
        raise ValueError(f"`{simple_id_col}` must be numeric.") from exc

    n_neurons = y.shape[-1]
    out_of_range = (simple_ids < 0) | (simple_ids >= n_neurons)
    if np.any(out_of_range):
        bad_ids = simple_ids[out_of_range]
        raise ValueError(
            f"Found `{simple_id_col}` outside [0, {n_neurons - 1}]: {bad_ids.tolist()}"
        )

    selected = y[..., simple_ids]
    n_samples = int(np.prod(selected.shape[:-1], dtype=np.int64))
    n_samples = max(1, n_samples)

    flattened = selected.reshape(n_samples, len(simple_ids))
    group_labels = metadata[group_by].to_numpy()

    frame = pd.DataFrame(
        {
            group_by: np.repeat(group_labels, n_samples),
            value_name: flattened.T.reshape(-1),
        }
    )

    if dropna:
        frame = frame.dropna(subset=[value_name])
    if frame.empty:
        raise ValueError("No values available after filtering.")

    return frame

compute_ie_ratio(excitatory_mat, inhibitory_mat, excitatory_neuron_only=True, neurons=None, warn_strict=True)

Compute inhibitory/excitatory ratio per neuron and whole-brain mean.

Source code in btorch/analysis/connectivity.py
def compute_ie_ratio(
    excitatory_mat: sparray,
    inhibitory_mat: sparray,
    excitatory_neuron_only: bool = True,
    neurons: Optional[pd.DataFrame] = None,
    warn_strict: bool = True,
) -> tuple[float, np.ndarray]:
    """Compute inhibitory/excitatory ratio per neuron and whole-brain mean."""
    if excitatory_neuron_only:
        assert neurons is not None

    sum_inhibitory = inhibitory_mat.sum(axis=0).astype(float)
    sum_excitatory = excitatory_mat.sum(axis=0).astype(float)

    # neurons that have zero inputs from both e and i connections are likely input,
    # so don't warn on them
    input_indices = (sum_excitatory == 0) & (sum_inhibitory == 0)

    sum_excitatory[sum_excitatory == 0] = np.nan
    if excitatory_neuron_only:
        sum_excitatory[neurons[neurons.EI == "E"].simple_id.to_numpy()] = np.nan

    ie_ratios = sum_inhibitory / sum_excitatory

    if warn_strict:
        nan_indices = (np.isnan(ie_ratios) & ~input_indices).nonzero()[0]
        if nan_indices.size > 0:
            print(f"Warning: IE ratio contains NaN values at indices {nan_indices}")

        inf_indices = np.isinf(ie_ratios).nonzero()[0]
        if inf_indices.size > 0:
            print(f"Warning: IE ratio contains Inf values at indices {inf_indices}")

    ie_ratios = np.where(np.isinf(ie_ratios), np.nan, ie_ratios)
    ie_ratio_whole_brain: float = np.nanmean(ie_ratios[ie_ratios != 0])

    return ie_ratio_whole_brain, ie_ratios

compute_log_hist(data, bins=1000, edge_pos='mid')

Compute histogram with logarithmically-spaced bins.

Useful for visualizing heavy-tailed distributions like synaptic weights or degree distributions.

Parameters:

Name Type Description Default
data

Input data array (must be positive).

required
bins

Number of histogram bins.

1000
edge_pos Literal['mid', 'sep']

Position to return for bin edges. "mid": Return bin centers (midpoints). "sep": Return bin separators (edges).

'mid'

Returns:

Type Description

Tuple of (hist, bin_edges) where hist is the count array and

bin_edges are the positions (centers or edges based on edge_pos).

Raises:

Type Description
ValueError

If data contains non-positive values (required for log scale).

Source code in btorch/analysis/statistics.py
def compute_log_hist(data, bins=1000, edge_pos: Literal["mid", "sep"] = "mid"):
    """Compute histogram with logarithmically-spaced bins.

    Useful for visualizing heavy-tailed distributions like synaptic weights
    or degree distributions.

    Args:
        data: Input data array (must be positive).
        bins: Number of histogram bins.
        edge_pos: Position to return for bin edges.
            "mid": Return bin centers (midpoints).
            "sep": Return bin separators (edges).

    Returns:
        Tuple of (hist, bin_edges) where hist is the count array and
        bin_edges are the positions (centers or edges based on edge_pos).

    Raises:
        ValueError: If data contains non-positive values (required for log scale).
    """
    bin_edges = np.logspace(np.log10(np.min(data)), np.log10(np.max(data)), num=bins)
    hist, edges = np.histogram(data, bins=bin_edges)
    if edge_pos == "mid":
        bin_edges = 0.5 * (edges[:-1] + edges[1:])
    return hist, bin_edges

compute_raster(sp_matrix, times)

Get spike raster plot which displays the spiking activity of a group of neurons over time.

Source code in btorch/analysis/spiking.py
def compute_raster(sp_matrix: np.ndarray, times: np.ndarray):
    """Get spike raster plot which displays the spiking activity of a group of
    neurons over time."""
    times = np.asarray(times)
    elements = np.where(sp_matrix > 0.0)
    index = elements[1]
    time = times[elements[0]]
    return index, time

compute_spectrum(y, dt, nperseg=None)

Source code in btorch/analysis/spiking.py
def compute_spectrum(y, dt, nperseg=None):
    from scipy.signal import welch

    freqs, Y_mag = welch(y, fs=1 / dt, nperseg=nperseg, axis=0)
    return freqs, Y_mag

cv_temporal(spike_data, dt_ms=1.0, window=100, step=1, batch_axis=None, dtype=None)

Compute CV in sliding temporal windows.

Calculates the coefficient of variation of ISIs within sliding windows over time, giving a time-resolved measure of spike train irregularity.

Parameters:

Name Type Description Default
spike_data ndarray | Tensor

Spike train array of shape [T, ...]. First dimension is time.

required
dt_ms float

Time step in milliseconds.

1.0
window int

Size of the sliding window in time steps.

100
step int

Step size between consecutive windows.

1
batch_axis tuple[int, ...] | int | None

Axes to aggregate across (e.g., (1, 2) for trials).

None
dtype dtype | dtype | None

Data type for output arrays. If None, uses float64 for NumPy and float32 for Torch.

None

Returns:

Name Type Description
cv_temporal

CV values for each window. Shape: [n_windows, ...] where n_windows = (T - window) // step + 1.

info

Dictionary with window boundaries.

Source code in btorch/analysis/spiking.py
@use_stats(value_key="cv_temporal", dim=1)
def cv_temporal(
    spike_data: np.ndarray | torch.Tensor,
    dt_ms: float = 1.0,
    window: int = 100,
    step: int = 1,
    batch_axis: tuple[int, ...] | int | None = None,
    dtype: np.dtype | torch.dtype | None = None,
):
    """Compute CV in sliding temporal windows.

    Calculates the coefficient of variation of ISIs within sliding windows
    over time, giving a time-resolved measure of spike train irregularity.

    Args:
        spike_data: Spike train array of shape [T, ...]. First dimension is time.
        dt_ms: Time step in milliseconds.
        window: Size of the sliding window in time steps.
        step: Step size between consecutive windows.
        batch_axis: Axes to aggregate across (e.g., (1, 2) for trials).
        dtype: Data type for output arrays. If None, uses float64 for NumPy
            and float32 for Torch.

    Returns:
        cv_temporal: CV values for each window. Shape: [n_windows, ...]
            where n_windows = (T - window) // step + 1.
        info: Dictionary with window boundaries.
    """
    if isinstance(batch_axis, int):
        batch_axis = (batch_axis,)

    T = spike_data.shape[0]
    n_windows = (T - window) // step + 1

    # Determine output dtype
    if isinstance(spike_data, torch.Tensor):
        out_dtype = dtype if dtype is not None else torch.float32
        cv_values = torch.full(
            (n_windows,) + spike_data.shape[1:],
            float("nan"),
            dtype=out_dtype,
            device=spike_data.device,
        )
    else:
        out_dtype = dtype if dtype is not None else float
        cv_values = np.full(
            (n_windows,) + spike_data.shape[1:], np.nan, dtype=out_dtype
        )

    for i in range(n_windows):
        start = i * step
        end = start + window
        if end > T:
            break

        window_data = spike_data[start:end]

        cv_window, _ = _cv(window_data, dt_ms, batch_axis, dtype)
        cv_values[i] = cv_window

    window_starts = np.arange(n_windows) * step * dt_ms
    window_ends = window_starts + window * dt_ms

    info = {
        "window": window,
        "step": step,
        "window_starts_ms": window_starts,
        "window_ends_ms": window_ends,
    }
    return cv_values, info

describe_array(array)

Print descriptive statistics for a 1D array.

Displays mean, median, std, min, max, and quartiles.

Parameters:

Name Type Description Default
array ndarray

1D NumPy array to describe.

required
Example

describe_array(np.random.randn(100)) Mean: 0.05 Median: 0.12 ...

Source code in btorch/analysis/statistics.py
def describe_array(array: np.ndarray):
    """Print descriptive statistics for a 1D array.

    Displays mean, median, std, min, max, and quartiles.

    Args:
        array: 1D NumPy array to describe.

    Example:
        >>> describe_array(np.random.randn(100))
        Mean: 0.05
        Median: 0.12
        ...
    """
    mean = np.mean(array)
    median = np.median(array)
    std_dev = np.std(array)
    min_val = np.min(array)
    max_val = np.max(array)
    q25 = np.percentile(array, 25)
    q50 = np.percentile(array, 50)  # This is the same as the median
    q75 = np.percentile(array, 75)

    print(f"Mean: {mean}")
    print(f"Median: {median}")
    print(f"Standard Deviation: {std_dev}")
    print(f"Min: {min_val}")
    print(f"Max: {max_val}")
    print(f"25th Percentile (Q1): {q25}")
    print(f"50th Percentile (Q2/Median): {q50}")
    print(f"75th Percentile (Q3): {q75}")

fano(spike, window=None, overlap=0, batch_axis=None, dtype=None)

Compute Fano factor for spike trains using optimized cumulative sums.

Supports both NumPy and PyTorch inputs. GPU-friendly operation.

This function is decorated with @use_stats and @use_percentiles. See use_stats() and use_percentiles() for detailed usage.

Parameters:

Name Type Description Default
spike ndarray | Tensor

Spike train of shape [T, ...]. First dimension is time.

required
window int | None

Window size for spike counting. If None, uses full duration T.

None
overlap int

Overlap between consecutive windows.

0
batch_axis tuple[int, ...] | int | None

Axes to average across for FF computation (e.g., trials).

None
dtype dtype | dtype | None

Data type for accumulation. If None, uses float64 for NumPy and input dtype (or float32 for float16/bfloat16) for Torch.

None
stat

Aggregation statistic to return instead of per-neuron values. Options: "mean", "median", "max", "min", "std", "var", "argmax", "argmin", "cv". See use_stats().

required
stat_info

Additional statistics to compute and store in info dict. See use_stats().

required
nan_policy

How to handle NaN values ("skip", "warn", "assert"). See use_stats().

required
inf_policy

How to handle Inf values ("propagate", "skip", "warn", "assert"). See use_stats().

required
percentiles

Percentile level(s) in [0, 100] to compute. See use_percentiles().

required

Returns:

Name Type Description
fano

Fano factor values with shape [...] (input shape without time dimension). If stat is provided, returns the aggregated statistic instead.

info

Dictionary with optional computed statistics and percentile data.

Source code in btorch/analysis/spiking.py
@use_percentiles(value_key="fano")
@use_stats(value_key="fano")
def fano(
    spike: np.ndarray | torch.Tensor,
    window: int | None = None,
    overlap: int = 0,
    batch_axis: tuple[int, ...] | int | None = None,
    dtype: np.dtype | torch.dtype | None = None,
):
    """Compute Fano factor for spike trains using optimized cumulative sums.

    Supports both NumPy and PyTorch inputs. GPU-friendly operation.

    This function is decorated with `@use_stats` and `@use_percentiles`.
    See [`use_stats()`](btorch/analysis/statistics.py:483) and
    [`use_percentiles()`](btorch/analysis/statistics.py:777) for detailed usage.

    Args:
        spike: Spike train of shape [T, ...]. First dimension is time.
        window: Window size for spike counting. If None, uses full duration T.
        overlap: Overlap between consecutive windows.
        batch_axis: Axes to average across for FF computation (e.g., trials).
        dtype: Data type for accumulation. If None, uses float64 for NumPy
            and input dtype (or float32 for float16/bfloat16) for Torch.
        stat: Aggregation statistic to return instead of per-neuron values.
            Options: "mean", "median", "max", "min", "std", "var", "argmax",
            "argmin", "cv". See [`use_stats()`](btorch/analysis/statistics.py:483).
        stat_info: Additional statistics to compute and store in info dict.
            See [`use_stats()`](btorch/analysis/statistics.py:483).
        nan_policy: How to handle NaN values ("skip", "warn", "assert").
            See [`use_stats()`](btorch/analysis/statistics.py:483).
        inf_policy: How to handle Inf values ("propagate", "skip", "warn",
            "assert"). See [`use_stats()`](btorch/analysis/statistics.py:483).
        percentiles: Percentile level(s) in [0, 100] to compute.
            See [`use_percentiles()`](btorch/analysis/statistics.py:777).

    Returns:
        fano: Fano factor values with shape [...]
            (input shape without time dimension).
            If `stat` is provided, returns the aggregated statistic instead.
        info: Dictionary with optional computed statistics and percentile data.
    """
    if isinstance(batch_axis, int):
        batch_axis = (batch_axis,)
    if window is None:
        # Default window size to get ~10 bins for variance computation
        # Need at least 2 bins for valid variance with unbiased=True
        window = max(1, spike.shape[0] // 10)
    if isinstance(spike, torch.Tensor):
        return _fano_torch(spike, window, overlap, batch_axis, dtype)
    else:
        return _fano_numpy(spike, window, overlap, batch_axis, dtype)

fano_population(spike, window=None, overlap=0)

Compute Fano factor for the pooled population activity.

This computes Fano factor from the summed population spike count, giving a single population-level metric.

Parameters:

Name Type Description Default
spike ndarray | Tensor

Spike train of shape [T, ...]. First dimension is time.

required
window int | None

Window size for spike counting. If None, uses full duration T.

None
overlap int

Overlap between consecutive windows.

0

Returns:

Name Type Description
fano_pop

Single scalar Fano factor for the population.

info

Dictionary with 'window' and 'n_windows'.

Source code in btorch/analysis/spiking.py
def fano_population(
    spike: np.ndarray | torch.Tensor,
    window: int | None = None,
    overlap: int = 0,
):
    """Compute Fano factor for the pooled population activity.

    This computes Fano factor from the summed population spike count,
    giving a single population-level metric.

    Args:
        spike: Spike train of shape [T, ...]. First dimension is time.
        window: Window size for spike counting. If None, uses full duration T.
        overlap: Overlap between consecutive windows.

    Returns:
        fano_pop: Single scalar Fano factor for the population.
        info: Dictionary with 'window' and 'n_windows'.
    """
    if isinstance(spike, torch.Tensor):
        return _fano_population_torch(spike, window, overlap)
    else:
        return _fano_population_numpy(spike, window, overlap)

fano_sweep(spike, window=None, overlap=0, batch_axis=None, dtype=None)

Compute Fano factor sweeping over window sizes.

This sweeps through window sizes and computes the Fano factor for each, useful for analyzing how variability scales with counting window size.

The window parameter follows numpy.arange semantics: - window=10: windows from 1 to 10 (step=1) - window=(5, 20): windows from 5 to 20 (step=1) - window=(5, 20, 2): windows 5, 7, 9, ..., 19 (step=2) - window=None: defaults to range(1, T//20 + 1, 1)

Parameters:

Name Type Description Default
spike ndarray | Tensor

Spike train of shape [T, ...]. First dimension is time.

required
window int | tuple[int, ...] | None

Window size specification following arange convention: - int: stop value (start=1, step=1) - tuple (start, stop): range with step=1 - tuple (start, stop, step): full range specification - None: auto-determine as (1, T//20, 1)

None
overlap int

Overlap between consecutive windows.

0
batch_axis tuple[int, ...] | int | None

Axes to average across (e.g., trials).

None
dtype dtype | dtype | None

Data type for output arrays. If None, uses float64 for NumPy and float32 for Torch.

None

Returns:

Name Type Description
fano_sweep

Fano factor values for each window size. Shape: [n_windows, ...] where n_windows depends on range.

info

Dictionary with 'window_sizes' array and 'window' spec.

Examples:

>>> # Sweep window sizes 1 to 50
>>> fano_sweep(spike, window=50)
>>> # Sweep window sizes 10, 20, 30, ..., 100
>>> fano_sweep(spike, window=(10, 101, 10))
>>> # Sweep window sizes 20, 30, 40, 50
>>> fano_sweep(spike, window=(20, 51, 1))
Source code in btorch/analysis/spiking.py
def fano_sweep(
    spike: np.ndarray | torch.Tensor,
    window: int | tuple[int, ...] | None = None,
    overlap: int = 0,
    batch_axis: tuple[int, ...] | int | None = None,
    dtype: np.dtype | torch.dtype | None = None,
):
    """Compute Fano factor sweeping over window sizes.

    This sweeps through window sizes and computes the Fano factor for each,
    useful for analyzing how variability scales with counting window size.

    The window parameter follows numpy.arange semantics:
        - window=10: windows from 1 to 10 (step=1)
        - window=(5, 20): windows from 5 to 20 (step=1)
        - window=(5, 20, 2): windows 5, 7, 9, ..., 19 (step=2)
        - window=None: defaults to range(1, T//20 + 1, 1)

    Args:
        spike: Spike train of shape [T, ...]. First dimension is time.
        window: Window size specification following arange convention:
            - int: stop value (start=1, step=1)
            - tuple (start, stop): range with step=1
            - tuple (start, stop, step): full range specification
            - None: auto-determine as (1, T//20, 1)
        overlap: Overlap between consecutive windows.
        batch_axis: Axes to average across (e.g., trials).
        dtype: Data type for output arrays. If None, uses float64 for NumPy
            and float32 for Torch.

    Returns:
        fano_sweep: Fano factor values for each window size.
            Shape: [n_windows, ...] where n_windows depends on range.
        info: Dictionary with 'window_sizes' array and 'window' spec.

    Examples:
        >>> # Sweep window sizes 1 to 50
        >>> fano_sweep(spike, window=50)
        >>> # Sweep window sizes 10, 20, 30, ..., 100
        >>> fano_sweep(spike, window=(10, 101, 10))
        >>> # Sweep window sizes 20, 30, 40, 50
        >>> fano_sweep(spike, window=(20, 51, 1))
    """
    if isinstance(batch_axis, int):
        batch_axis = (batch_axis,)
    T = spike.shape[0]

    # Parse window specification following arange semantics
    if window is None:
        start, stop, step = 1, T // 20 + 1, 1
    elif isinstance(window, int):
        start, stop, step = 1, window + 1, 1
    elif len(window) == 2:
        start, stop, step = window[0], window[1], 1
    elif len(window) == 3:
        start, stop, step = window[0], window[1], window[2]
    else:
        raise ValueError("window must be int, (start, stop), or (start, stop, step)")

    if step <= 0:
        raise ValueError("step must be positive")
    if start < 1:
        raise ValueError("start must be >= 1")
    if stop > T + 1:
        raise ValueError(f"stop must be <= T+1 ({T+1})")

    window_sizes = np.arange(start, stop, step)
    n_windows = len(window_sizes)

    if n_windows == 0:
        raise ValueError("window range produces no valid window sizes")

    if isinstance(spike, torch.Tensor):
        device = spike.device
        out = torch.zeros(
            (n_windows,) + spike.shape[1:], device=device, dtype=torch.float64
        )
        for i, w in enumerate(window_sizes):
            fano_val, _ = _fano_torch(spike, int(w), overlap, batch_axis)
            out[i] = fano_val
    else:
        out = np.zeros((n_windows,) + spike.shape[1:])
        for i, w in enumerate(window_sizes):
            fano_val, _ = _fano_numpy(spike, int(w), overlap, batch_axis)
            out[i] = fano_val

    info = {
        "window": (start, stop, step),
        "window_sizes": window_sizes,
        "n_windows": n_windows,
    }
    return out, info

fano_temporal(spike, window=100, step=1, batch_axis=None)

Compute Fano factor in sliding temporal windows.

Calculates the Fano factor within sliding windows over time, giving a time-resolved measure of spike count variability.

Parameters:

Name Type Description Default
spike ndarray | Tensor

Spike train of shape [T, ...]. First dimension is time.

required
window int

Size of the sliding window in time steps for Fano computation.

100
step int

Step size between consecutive windows.

1
batch_axis tuple[int, ...] | None

Axes to average across (e.g., trials).

None

Returns:

Name Type Description
fano_temporal

Fano factor values for each window. Shape: [n_windows, ...]

info

Dictionary with window boundaries.

Source code in btorch/analysis/spiking.py
@use_stats(value_key="fano_temporal", dim=1)
def fano_temporal(
    spike: np.ndarray | torch.Tensor,
    window: int = 100,
    step: int = 1,
    batch_axis: tuple[int, ...] | None = None,
):
    """Compute Fano factor in sliding temporal windows.

    Calculates the Fano factor within sliding windows over time,
    giving a time-resolved measure of spike count variability.

    Args:
        spike: Spike train of shape [T, ...]. First dimension is time.
        window: Size of the sliding window in time steps for Fano computation.
        step: Step size between consecutive windows.
        batch_axis: Axes to average across (e.g., trials).

    Returns:
        fano_temporal: Fano factor values for each window. Shape: [n_windows, ...]
        info: Dictionary with window boundaries.
    """
    T = spike.shape[0]
    n_windows = (T - window) // step + 1

    if isinstance(spike, torch.Tensor):
        fano_values = torch.full(
            (n_windows,) + spike.shape[1:],
            float("nan"),
            dtype=torch.float64,
            device=spike.device,
        )
    else:
        fano_values = np.full((n_windows,) + spike.shape[1:], np.nan, dtype=float)

    for i in range(n_windows):
        start = i * step
        end = start + window
        if end > T:
            break

        window_data = spike[start:end]
        bin = max(1, window_data.shape[0] // 10)

        if isinstance(window_data, torch.Tensor):
            fano_window, _ = _fano_torch(window_data, bin, 0, batch_axis)
            fano_values[i] = fano_window
        else:
            fano_window, _ = _fano_numpy(window_data, bin, 0, batch_axis)
            fano_values[i] = fano_window

    window_starts = np.arange(n_windows) * step
    window_ends = window_starts + window

    info = {
        "window": window,
        "step": step,
        "window_starts": window_starts,
        "window_ends": window_ends,
    }
    return fano_values, info

firing_rate(spikes, width=4, dt=None, axis=None)

Smooth spikes into firing rates.

Supports input shapes like [T, ...]. If axis is not None, averages over the specified dimensions before smoothing.

Parameters:

Name Type Description Default
spikes ndarray | Tensor

Spike train array of shape [T, ...].

required
width int | float | None

Smoothing window width. If None or 0, no smoothing is applied.

4
dt int | float | None

Time step in milliseconds. If None, defaults to 1.0.

None
axis int | Sequence[int] | None

Axes to average over before smoothing. Can be int or tuple of ints.

None

Returns:

Name Type Description
firing_rates

Smoothed firing rates with same shape as input (minus averaged axes if axis is specified).

Source code in btorch/analysis/spiking.py
def firing_rate(
    spikes: np.ndarray | torch.Tensor,
    width: int | float | None = 4,
    dt: int | float | None = None,
    axis: int | Sequence[int] | None = None,
):
    """Smooth spikes into firing rates.

    Supports input shapes like [T, ...].
    If axis is not None, averages over the specified dimensions before smoothing.

    Args:
        spikes: Spike train array of shape [T, ...].
        width: Smoothing window width. If None or 0, no smoothing is applied.
        dt: Time step in milliseconds. If None, defaults to 1.0.
        axis: Axes to average over before smoothing. Can be int or tuple of ints.

    Returns:
        firing_rates: Smoothed firing rates with same shape as input (minus
            averaged axes if axis is specified).
    """
    if dt is None:
        dt = 1.0

    if axis is not None:
        # Normalize axis to tuple for consistent handling
        if isinstance(axis, int):
            axis = (axis,)
        if isinstance(spikes, np.ndarray):
            spikes = spikes.mean(axis=axis)
        else:
            spikes = spikes.mean(dim=axis)

    if width is None or width == 0:
        return spikes / dt

    width1 = int(width // 2) * 2 + 1

    if isinstance(spikes, np.ndarray):
        if spikes.dtype == np.float16:
            spikes = spikes.astype(np.float32)
        window = np.ones(width1, dtype=float) / width1
        # Convolve along time axis (0) for all other dimensions
        out = convolve1d(spikes, window, axis=0, mode="constant", cval=0.0)
        return out / dt

    else:
        # torch implementation for arbitrary dimensions [T, *others]
        orig_shape = spikes.shape
        T = orig_shape[0]

        # Flatten others to treat as batches for conv1d: [T, B] -> [B, 1, T]
        x = spikes.reshape(T, -1).T.unsqueeze(1)

        window = torch.ones(width1, device=spikes.device, dtype=spikes.dtype) / width1
        weight = window.view(1, 1, -1)

        y = torch.conv1d(x, weight, padding="same")

        # [B, 1, T] -> [B, T] -> [T, B] -> [T, *others]
        return y.squeeze(1).T.reshape(orig_shape) / dt

group_ecdf(values, neurons_df, group_by, *, simple_id_col='simple_id', value_name='value', group_order=None, dropna=True)

Compute grouped ECDF points ready for plotting or analysis.

Source code in btorch/analysis/aggregation.py
def group_ecdf(
    values: TensorLike,
    neurons_df: pd.DataFrame,
    group_by: str,
    *,
    simple_id_col: str = "simple_id",
    value_name: str = "value",
    group_order: Sequence | None = None,
    dropna: bool = True,
) -> dict[object, pd.DataFrame]:
    """Compute grouped ECDF points ready for plotting or analysis."""
    grouped = group_values(
        values,
        neurons_df,
        group_by,
        simple_id_col=simple_id_col,
        value_name=value_name,
        group_order=group_order,
        dropna=dropna,
    )

    ret: dict[object, pd.DataFrame] = {}
    for group, vals in grouped.items():
        x = np.sort(vals)
        y = np.arange(1, len(x) + 1, dtype=float) / len(x)
        ret[group] = pd.DataFrame({value_name: x, "ecdf": y})
    return ret

group_summary(values, neurons_df, group_by, *, simple_id_col='simple_id', value_name='value', group_order=None, dropna=True)

Compute per-group summary statistics from neuron-aligned values.

Source code in btorch/analysis/aggregation.py
def group_summary(
    values: TensorLike,
    neurons_df: pd.DataFrame,
    group_by: str,
    *,
    simple_id_col: str = "simple_id",
    value_name: str = "value",
    group_order: Sequence | None = None,
    dropna: bool = True,
) -> pd.DataFrame:
    """Compute per-group summary statistics from neuron-aligned values."""
    grouped = group_values(
        values,
        neurons_df,
        group_by,
        simple_id_col=simple_id_col,
        value_name=value_name,
        group_order=group_order,
        dropna=dropna,
    )

    rows = []
    for group, vals in grouped.items():
        rows.append(
            {
                group_by: group,
                "n": int(vals.size),
                "mean": float(np.mean(vals)),
                "std": float(np.std(vals)),
                "min": float(np.min(vals)),
                "q25": float(np.quantile(vals, 0.25)),
                "median": float(np.median(vals)),
                "q75": float(np.quantile(vals, 0.75)),
                "max": float(np.max(vals)),
            }
        )

    return pd.DataFrame(rows)

group_values(values, neurons_df, group_by, *, simple_id_col='simple_id', value_name='value', group_order=None, dropna=True)

Return grouped value arrays, keyed by group label in plotting order.

Source code in btorch/analysis/aggregation.py
def group_values(
    values: TensorLike,
    neurons_df: pd.DataFrame,
    group_by: str,
    *,
    simple_id_col: str = "simple_id",
    value_name: str = "value",
    group_order: Sequence | None = None,
    dropna: bool = True,
) -> dict[object, np.ndarray]:
    """Return grouped value arrays, keyed by group label in plotting order."""
    frame = build_group_frame(
        values,
        neurons_df,
        group_by,
        simple_id_col=simple_id_col,
        value_name=value_name,
        dropna=dropna,
    )
    order = _resolve_group_order(frame, group_by, group_order)
    return {
        group: frame.loc[frame[group_by] == group, value_name].to_numpy(dtype=float)
        for group in order
    }

indices_to_mask(indices, shape=None, array=None)

Convert an array of indices to a boolean mask.

For multi-dimensional masks, a 1D index array is treated as flattened indices. Provide a tuple of index arrays for per-axis indexing.

Source code in btorch/analysis/metrics.py
def indices_to_mask(indices: np.ndarray, shape=None, array=None) -> np.ndarray:
    """Convert an array of indices to a boolean mask.

    For multi-dimensional masks, a 1D index array is treated as
    flattened indices. Provide a tuple of index arrays for per-axis
    indexing.
    """
    assert not (shape is None and array is None)
    mask = (
        np.zeros(shape, dtype=bool)
        if shape is not None
        else np.zeros_like(array, dtype=bool)
    )
    indices_arr = np.asarray(indices)
    if mask.ndim > 1 and indices_arr.ndim == 1 and not isinstance(indices, tuple):
        mask.flat[indices_arr] = True
    else:
        mask[indices] = True
    return mask

isi_cv(spike_data, dt_ms=1.0, batch_axis=None, dtype=None)

Calculate coefficient of variation of ISIs per neuron.

Supports both NumPy and PyTorch inputs. For GPU tensors, uses a hybrid approach: aggregates on GPU, transfers to CPU for ISI extraction, then returns to GPU.

This function is decorated with @use_stats and @use_percentiles. See use_stats() and use_percentiles() for detailed usage.

Parameters:

Name Type Description Default
spike_data ndarray | Tensor

Spike train array of shape [T, ...]. First dimension is time. Values are binary (0/1) or spike counts.

required
dt_ms float

Time step in milliseconds for converting ISI to ms.

1.0
batch_axis tuple[int, ...] | int | None

Axes to aggregate ISIs across (e.g., (1, 2) for trials). If None, computes CV per element in the non-time dimensions.

None
dtype dtype | dtype | None

Data type for accumulation.

None
stat

Aggregation statistic to return instead of per-neuron values. Options: "mean", "median", "max", "min", "std", "var", "argmax", "argmin", "cv". See use_stats().

required
stat_info

Additional statistics to compute and store in info dict. See use_stats().

required
nan_policy

How to handle NaN values ("skip", "warn", "assert"). See use_stats().

required
inf_policy

How to handle Inf values ("propagate", "skip", "warn", "assert"). See use_stats().

required
percentiles

Percentile level(s) in [0, 100] to compute. See use_percentiles().

required

Returns:

Name Type Description
cv_values

CV values reshaped to match input without time dimension. Shape: [...] (original shape without T, aggregated over batch_axis). If stat is provided, returns the aggregated statistic instead.

info

Dictionary with 'isi_total' (aggregate ISI statistics), 'isi_stats' (per-neuron statistics), and optional percentile data.

Source code in btorch/analysis/spiking.py
@use_percentiles(value_key="cv")
@use_stats(value_key="cv")
def isi_cv(
    spike_data: np.ndarray | torch.Tensor,
    dt_ms: float = 1.0,
    batch_axis: tuple[int, ...] | int | None = None,
    dtype: np.dtype | torch.dtype | None = None,
):
    """Calculate coefficient of variation of ISIs per neuron.

    Supports both NumPy and PyTorch inputs. For GPU tensors, uses a hybrid
    approach: aggregates on GPU, transfers to CPU for ISI extraction, then
    returns to GPU.

    This function is decorated with `@use_stats` and `@use_percentiles`.
    See [`use_stats()`](btorch/analysis/statistics.py:483) and
    [`use_percentiles()`](btorch/analysis/statistics.py:777) for detailed usage.

    Args:
        spike_data: Spike train array of shape [T, ...]. First dimension is time.
            Values are binary (0/1) or spike counts.
        dt_ms: Time step in milliseconds for converting ISI to ms.
        batch_axis: Axes to aggregate ISIs across (e.g., (1, 2) for trials).
            If None, computes CV per element in the non-time dimensions.
        dtype: Data type for accumulation.
        stat: Aggregation statistic to return instead of per-neuron values.
            Options: "mean", "median", "max", "min", "std", "var", "argmax",
            "argmin", "cv". See [`use_stats()`](btorch/analysis/statistics.py:483).
        stat_info: Additional statistics to compute and store in info dict.
            See [`use_stats()`](btorch/analysis/statistics.py:483).
        nan_policy: How to handle NaN values ("skip", "warn", "assert").
            See [`use_stats()`](btorch/analysis/statistics.py:483).
        inf_policy: How to handle Inf values ("propagate", "skip", "warn",
            "assert"). See [`use_stats()`](btorch/analysis/statistics.py:483).
        percentiles: Percentile level(s) in [0, 100] to compute.
            See [`use_percentiles()`](btorch/analysis/statistics.py:777).

    Returns:
        cv_values: CV values reshaped to match input without time dimension.
            Shape: [...] (original shape without T, aggregated over batch_axis).
            If `stat` is provided, returns the aggregated statistic instead.
        info: Dictionary with 'isi_total' (aggregate ISI statistics),
            'isi_stats' (per-neuron statistics), and optional percentile data.
    """
    if isinstance(batch_axis, int):
        batch_axis = (batch_axis,)
    return _cv(spike_data, dt_ms, batch_axis, dtype)

isi_cv_population(spike_data, dt_ms=1.0)

Calculate coefficient of variation of ISIs pooled across all neurons.

This computes CV from the pooled ISI distribution across the entire population, giving a single population-level metric.

This function is decorated with @use_stats. See use_stats() for detailed usage.

Parameters:

Name Type Description Default
spike_data ndarray | Tensor

Spike train array of shape [T, ...]. First dimension is time.

required
dt_ms float

Time step in milliseconds for converting ISI to ms.

1.0
stat

Aggregation statistic to return. Default is "cv". Options: "mean", "median", "max", "min", "std", "var", "argmax", "argmin", "cv". See use_stats().

required
stat_info

Additional statistics to compute and store in info dict. See use_stats().

required
nan_policy

How to handle NaN values ("skip", "warn", "assert"). See use_stats().

required
inf_policy

How to handle Inf values ("propagate", "skip", "warn", "assert"). See use_stats().

required

Returns:

Name Type Description
cv_pop

Single scalar CV value for the population, or aggregated statistic if stat is provided.

info

Dictionary with computed statistics.

Source code in btorch/analysis/spiking.py
@use_stats(value_key="isi_population", default_stat="cv")
def isi_cv_population(
    spike_data: np.ndarray | torch.Tensor,
    dt_ms: float = 1.0,
):
    """Calculate coefficient of variation of ISIs pooled across all neurons.

    This computes CV from the pooled ISI distribution across the entire
    population, giving a single population-level metric.

    This function is decorated with `@use_stats`.
    See [`use_stats()`](btorch/analysis/statistics.py:483) for detailed usage.

    Args:
        spike_data: Spike train array of shape [T, ...]. First dimension is time.
        dt_ms: Time step in milliseconds for converting ISI to ms.
        stat: Aggregation statistic to return. Default is "cv".
            Options: "mean", "median", "max", "min", "std", "var", "argmax",
            "argmin", "cv". See [`use_stats()`](btorch/analysis/statistics.py:483).
        stat_info: Additional statistics to compute and store in info dict.
            See [`use_stats()`](btorch/analysis/statistics.py:483).
        nan_policy: How to handle NaN values ("skip", "warn", "assert").
            See [`use_stats()`](btorch/analysis/statistics.py:483).
        inf_policy: How to handle Inf values ("propagate", "skip", "warn",
            "assert"). See [`use_stats()`](btorch/analysis/statistics.py:483).

    Returns:
        cv_pop: Single scalar CV value for the population, or aggregated
            statistic if `stat` is provided.
        info: Dictionary with computed statistics.
    """
    if isinstance(spike_data, torch.Tensor):
        return _isis_population_torch(spike_data, dt_ms)
    else:
        return _isis_population_numpy(spike_data, dt_ms)

kurtosis(spike, window=None, overlap=0, fisher=True, batch_axis=None)

Compute kurtosis of spike counts using optimized cumulative sums.

Supports both NumPy and PyTorch inputs. GPU-friendly operation.

This function is decorated with @use_stats and @use_percentiles. See use_stats() and use_percentiles() for detailed usage.

Parameters:

Name Type Description Default
spike ndarray | Tensor

Spike train of shape [T, ...]. First dimension is time.

required
window int | None

Window size for spike counting. If None, uses full duration T.

None
overlap int

Overlap between consecutive windows.

0
fisher bool

If True, return excess kurtosis (subtract 3).

True
batch_axis tuple[int, ...] | int | None

Axes to average across for kurtosis computation.

None
stat

Aggregation statistic to return instead of per-neuron values. Options: "mean", "median", "max", "min", "std", "var", "argmax", "argmin", "cv". See use_stats().

required
stat_info

Additional statistics to compute and store in info dict. See use_stats().

required
nan_policy

How to handle NaN values ("skip", "warn", "assert"). See use_stats().

required
inf_policy

How to handle Inf values ("propagate", "skip", "warn", "assert"). See use_stats().

required
percentiles

Percentile level(s) in [0, 100] to compute. See use_percentiles().

required

Returns:

Name Type Description
kurt

Kurtosis values with shape [...] (input shape without time dimension). If stat is provided, returns the aggregated statistic instead.

info

Dictionary with optional computed statistics and percentile data.

Source code in btorch/analysis/spiking.py
@use_percentiles(value_key="kurtosis")
@use_stats(value_key="kurtosis")
def kurtosis(
    spike: np.ndarray | torch.Tensor,
    window: int | None = None,
    overlap: int = 0,
    fisher: bool = True,
    batch_axis: tuple[int, ...] | int | None = None,
):
    """Compute kurtosis of spike counts using optimized cumulative sums.

    Supports both NumPy and PyTorch inputs. GPU-friendly operation.

    This function is decorated with `@use_stats` and `@use_percentiles`.
    See [`use_stats()`](btorch/analysis/statistics.py:483) and
    [`use_percentiles()`](btorch/analysis/statistics.py:777) for detailed usage.

    Args:
        spike: Spike train of shape [T, ...]. First dimension is time.
        window: Window size for spike counting. If None, uses full duration T.
        overlap: Overlap between consecutive windows.
        fisher: If True, return excess kurtosis (subtract 3).
        batch_axis: Axes to average across for kurtosis computation.
        stat: Aggregation statistic to return instead of per-neuron values.
            Options: "mean", "median", "max", "min", "std", "var", "argmax",
            "argmin", "cv". See [`use_stats()`](btorch/analysis/statistics.py:483).
        stat_info: Additional statistics to compute and store in info dict.
            See [`use_stats()`](btorch/analysis/statistics.py:483).
        nan_policy: How to handle NaN values ("skip", "warn", "assert").
            See [`use_stats()`](btorch/analysis/statistics.py:483).
        inf_policy: How to handle Inf values ("propagate", "skip", "warn",
            "assert"). See [`use_stats()`](btorch/analysis/statistics.py:483).
        percentiles: Percentile level(s) in [0, 100] to compute.
            See [`use_percentiles()`](btorch/analysis/statistics.py:777).

    Returns:
        kurt: Kurtosis values with shape [...]
            (input shape without time dimension).
            If `stat` is provided, returns the aggregated statistic instead.
        info: Dictionary with optional computed statistics and percentile data.
    """
    if isinstance(batch_axis, int):
        batch_axis = (batch_axis,)
    if window is None:
        # Default window size to get ~10 bins for variance computation
        # Need at least 2 bins for valid variance with unbiased=True
        window = max(1, spike.shape[0] // 10)
    if isinstance(spike, torch.Tensor):
        return _kurtosis_torch(spike, window, overlap, fisher, batch_axis)
    else:
        return _kurtosis_numpy(spike, window, overlap, fisher, batch_axis)

kurtosis_population(spike, window=None, overlap=0, fisher=True)

Compute kurtosis for the pooled population activity.

This computes kurtosis from the summed population spike count, giving a single population-level metric.

Parameters:

Name Type Description Default
spike ndarray | Tensor

Spike train of shape [T, ...]. First dimension is time.

required
window int | None

Window size for spike counting. If None, uses full duration T.

None
overlap int

Overlap between consecutive windows.

0
fisher bool

If True, return excess kurtosis (subtract 3).

True

Returns:

Name Type Description
kurt_pop

Single scalar kurtosis for the population.

info

Dictionary with 'window' and 'n_windows'.

Source code in btorch/analysis/spiking.py
def kurtosis_population(
    spike: np.ndarray | torch.Tensor,
    window: int | None = None,
    overlap: int = 0,
    fisher: bool = True,
):
    """Compute kurtosis for the pooled population activity.

    This computes kurtosis from the summed population spike count,
    giving a single population-level metric.

    Args:
        spike: Spike train of shape [T, ...]. First dimension is time.
        window: Window size for spike counting. If None, uses full duration T.
        overlap: Overlap between consecutive windows.
        fisher: If True, return excess kurtosis (subtract 3).

    Returns:
        kurt_pop: Single scalar kurtosis for the population.
        info: Dictionary with 'window' and 'n_windows'.
    """
    if isinstance(spike, torch.Tensor):
        return _kurtosis_population_torch(spike, window, overlap, fisher)
    else:
        return _kurtosis_population_numpy(spike, window, overlap, fisher)

local_variation(spike_data, dt_ms=1.0, batch_axis=None)

Calculate Local Variation (LV) of ISIs per neuron.

LV is a measure of spike train irregularity that is less sensitive to slow rate fluctuations than CV. For a Poisson process, LV = 1.

LV = (1/(N-1)) * sum(3*(ISI_i - ISI_{i+1})^2 / (ISI_i + ISI_{i+1})^2)

Supports both NumPy and PyTorch inputs.

Parameters:

Name Type Description Default
spike_data ndarray | Tensor

Spike train array of shape [T, ...]. First dimension is time.

required
dt_ms float

Time step in milliseconds.

1.0
batch_axis tuple[int, ...] | None

Axes to aggregate ISIs across (e.g., (1, 2) for trials).

None

Returns:

Name Type Description
lv_values

LV values reshaped to match input without time dimension.

lv_stats

Dictionary with per-neuron LV statistics.

Source code in btorch/analysis/spiking.py
@use_percentiles(value_key="lv")
@use_stats(value_key="lv")
def local_variation(
    spike_data: np.ndarray | torch.Tensor,
    dt_ms: float = 1.0,
    batch_axis: tuple[int, ...] | None = None,
):
    """Calculate Local Variation (LV) of ISIs per neuron.

    LV is a measure of spike train irregularity that is less sensitive to
    slow rate fluctuations than CV. For a Poisson process, LV = 1.

    LV = (1/(N-1)) * sum(3*(ISI_i - ISI_{i+1})^2 / (ISI_i + ISI_{i+1})^2)

    Supports both NumPy and PyTorch inputs.

    Args:
        spike_data: Spike train array of shape [T, ...]. First dimension is time.
        dt_ms: Time step in milliseconds.
        batch_axis: Axes to aggregate ISIs across (e.g., (1, 2) for trials).

    Returns:
        lv_values: LV values reshaped to match input without time dimension.
        lv_stats: Dictionary with per-neuron LV statistics.
    """
    if isinstance(spike_data, torch.Tensor):
        return _lv_torch(spike_data, dt_ms, batch_axis)
    else:
        return _lv_numpy(spike_data, dt_ms, batch_axis)

select_on_metric(metrics, num=None, mode='topk', ret_indices=False)

Select neurons based on a metric array.

Source code in btorch/analysis/metrics.py
def select_on_metric(
    metrics: np.ndarray,
    num: int | None = None,
    mode: Literal["topk", "any"] = "topk",
    ret_indices: bool = False,
):
    """Select neurons based on a metric array."""
    if mode == "topk":
        assert num is not None
        ret = np.argpartition(metrics, -num)[-num:]
    elif mode == "any":
        ret = metrics.nonzero()[0]
        if num is not None and len(ret) > num:
            ret = np.random.choice(ret, num, replace=False)
    else:
        raise ValueError(f"Unsupported mode {mode}")

    if ret_indices:
        return ret, indices_to_mask(ret, array=metrics)
    else:
        return ret

suggest_skip_timestep(data)

Suggest a burn-in period based on trace length.

Source code in btorch/analysis/voltage.py
def suggest_skip_timestep(data: np.ndarray) -> float:
    """Suggest a burn-in period based on trace length."""
    skip_timestep = data.shape[0] // 8
    if skip_timestep < 100:
        skip_timestep = 0
    if skip_timestep > 1000:
        skip_timestep = 1000
    return skip_timestep

use_percentiles(func=None, *, value_key='values', default_percentiles=None)

Decorator to add percentiles arg and optionally compute percentiles.

This decorator adds a percentiles parameter to a function that returns per-neuron values. Percentiles are only computed if percentiles is not None. Results are stored in info[f"{value_key}_percentile"].

Can also accept a dict mapping return positions to labels for functions returning multiple values (e.g., {1: "eci", 3: "lag"}).

The decorated function should return either: - A tuple of (values, info_dict) where values are per-neuron metrics - Just the per-neuron values (will be wrapped in a tuple with empty dict) - A tuple of multiple values with info as the last element

Parameters:

Name Type Description Default
func Callable | None

The function to decorate (or None if using with parentheses)

None
value_key str | dict[int, str]

Key to use in info dict for the percentile result

'values'

Returns:

Type Description
Callable

Decorated function with added percentiles parameter

Example
@use_percentiles
def compute_metric(data, *, percentiles=None):
    values = some_computation(data)  # per-neuron values
    return values, {"raw": values}

# Usage:
values, info = compute_metric(data)  # no percentiles computed
values, info = compute_metric(data, percentiles=0.5)  # compute median
values, info = compute_metric(
    data, percentiles=(0.25, 0.5, 0.75)
)  # compute quartiles
Source code in btorch/analysis/statistics.py
def use_percentiles(
    func: Callable | None = None,
    *,
    value_key: str | dict[int, str] = "values",
    default_percentiles: float | tuple[float, ...] | None = None,
) -> Callable:
    """Decorator to add percentiles arg and optionally compute percentiles.

    This decorator adds a `percentiles` parameter to a function that returns
    per-neuron values. Percentiles are only computed if percentiles is not None.
    Results are stored in info[f"{value_key}_percentile"].

    Can also accept a dict mapping return positions to labels for functions
    returning multiple values (e.g., {1: "eci", 3: "lag"}).

    The decorated function should return either:
    - A tuple of (values, info_dict) where values are per-neuron metrics
    - Just the per-neuron values (will be wrapped in a tuple with empty dict)
    - A tuple of multiple values with info as the last element

    Args:
        func: The function to decorate (or None if using with parentheses)
        value_key: Key to use in info dict for the percentile result

    Returns:
        Decorated function with added percentiles parameter

    Example:
        ```python
        @use_percentiles
        def compute_metric(data, *, percentiles=None):
            values = some_computation(data)  # per-neuron values
            return values, {"raw": values}

        # Usage:
        values, info = compute_metric(data)  # no percentiles computed
        values, info = compute_metric(data, percentiles=0.5)  # compute median
        values, info = compute_metric(
            data, percentiles=(0.25, 0.5, 0.75)
        )  # compute quartiles
        ```
    """

    def decorator(f: Callable) -> Callable:
        @wraps(f)
        def wrapper(
            *args,
            percentiles: float
            | tuple[float, ...]
            | dict[int, float | tuple[float, ...]]
            | None = default_percentiles,
            **kwargs,
        ) -> tuple[Any, ...]:
            # Call the original function
            result = f(*args, **kwargs)
            if percentiles is None:
                return result

            # Unpack result using shared helper
            values_tuple, info = _unpack_result(result, value_key)

            # Ensure info is a dict
            if info is None:
                info = {}

            updated_info = dict(info)

            # Helper to get value key name for a position
            def _get_value_key_name(pos: int) -> str:
                if isinstance(value_key, dict):
                    return value_key.get(pos, f"values{pos}")
                return f"{value_key}{pos}"

            # Helper to get values at a position
            def _get_values(pos: int) -> Any:
                if pos < 0 or pos >= len(values_tuple):
                    raise IndexError(
                        f"Position {pos} out of range for return tuple "
                        f"of length {len(values_tuple)}"
                    )
                return values_tuple[pos]

            # Compute percentiles only if requested
            if isinstance(percentiles, dict):
                # Dict format: {position: percentile_value(s)}
                # Allows different percentiles for different return values
                for pos, perc_value in percentiles.items():
                    values = _get_values(pos)
                    key_name = _get_value_key_name(pos)
                    perc_result = compute_percentiles(values, perc_value)
                    updated_info[f"{key_name}_percentiles"] = perc_result["percentiles"]
                    updated_info[f"{key_name}_levels"] = perc_result["levels"]
            elif isinstance(value_key, dict):
                # Dict value_key with single percentiles value:
                # Apply same percentiles to all positions in value_key
                for pos, label in value_key.items():
                    values = _get_values(pos)
                    perc_result = compute_percentiles(values, percentiles)
                    updated_info[f"{label}_percentiles"] = perc_result["percentiles"]
                    updated_info[f"{label}_levels"] = perc_result["levels"]
            else:
                # Single percentiles format - apply to position 0
                values = _get_values(0)
                perc_result = compute_percentiles(values, percentiles)
                updated_info[f"{value_key}_percentiles"] = perc_result["percentiles"]
                updated_info[f"{value_key}_levels"] = perc_result["levels"]

            if len(values_tuple) > 1:
                return values_tuple + (updated_info,)
            else:
                return values_tuple[0], updated_info

        return wrapper

    if func is None:
        return decorator
    return decorator(func)

use_stats(func=None, *, value_key='values', dim=None, default_stat=None, default_stat_info=None, default_nan_policy='skip', default_inf_policy='propagate')

Decorator to add stat and stat_info args for aggregation.

This decorator adds stat, stat_info, nan_policy, and inf_policy parameters to a function that returns per-neuron values.

  • stat: If not None, returns the aggregated value instead of per-neuron values. The aggregation is stored in info[f"{value_key}_stat"]. Can be a StatChoice, or a dict mapping return position to label (e.g., {1: "eci", 3: "lag"}) for functions returning multiple values.
  • stat_info: Additional stats to compute and store in info dict without affecting the return value. Can be a single StatChoice, Iterable of StatChoice, a dict mapping position to label(s), or None. If dict, format is {position: stat_or_stats} where stat_or_stats can be a single StatChoice or Iterable of StatChoice.
  • dim: Dimension(s) to aggregate over. Can be:
    • None: Flatten all dimensions (default)
    • int: Aggregate over this dimension for all outputs
    • tuple[int, ...]: Aggregate over these dimensions for all outputs
    • dict[int, int | tuple[int, ...] | None]: Different dim for each output position (e.g., {0: 1, 1: 2, 2: None, 3: (1, 3, 4)})
  • nan_policy: How to handle NaN values:
    • "skip": Ignore NaN values (default)
    • "warn": Warn if NaN values found but continue
    • "assert": Raise error if NaN values found
  • inf_policy: How to handle Inf values:
    • "propagate": Keep Inf values (default)
    • "skip": Ignore Inf values
    • "warn": Warn if Inf values found but continue
    • "assert": Raise error if Inf values found

The decorated function should return either: - A tuple of (values, info_dict) where values are per-neuron metrics - Just the per-neuron values (will be wrapped in a tuple with empty dict) - A tuple of multiple values with info as the last element

Parameters:

Name Type Description Default
func Callable | None

The function to decorate (or None if using with parentheses)

None
value_key str | dict[int, str]

Key prefix to use in info dict for stat results

'values'
dim int | tuple[int, ...] | dict[int, int | tuple[int, ...] | None] | None

Dimension(s) to aggregate over for each output

None
default_nan_policy Literal['skip', 'warn', 'assert']

Default nan_policy for this decorated function

'skip'
default_inf_policy Literal['propagate', 'skip', 'warn', 'assert']

Default inf_policy for this decorated function

'propagate'
default_stat StatChoice | dict[int, StatChoice] | None

Default stat for this decorated function

None

Returns:

Type Description
Callable

Decorated function with added stat, stat_info, nan_policy, and

Callable

inf_policy parameters

Example
@use_stat
def compute_metric(
    data,
    *,
    stat=None,
    stat_info=None,
    nan_policy="skip",
    inf_policy="propagate",
):
    values = some_computation(data)  # per-neuron values
    return values, {"raw": values}

# Usage:
values, info = compute_metric(data)  # returns per-neuron values
mean_val, info = compute_metric(data, stat="mean")  # returns aggregated
values, info = compute_metric(
    data, stat_info=["mean", "max"]
)  # extra stats in info

# Multi-value return with dict stat:
@use_stat
def compute_multiple(data, *, stat=None, stat_info=None):
    eci = compute_eci(data)  # per-neuron
    lag = compute_lag(data)  # per-neuron
    return eci, lag, {}  # multiple values

# Aggregate specific positions with dict stat:
eci_mean, lag_mean, info = compute_multiple(
    data, stat={0: "eci", 1: "lag"}
)
Source code in btorch/analysis/statistics.py
def use_stats(
    func: Callable | None = None,
    *,
    value_key: str | dict[int, str] = "values",
    dim: int | tuple[int, ...] | dict[int, int | tuple[int, ...] | None] | None = None,
    default_stat: StatChoice | dict[int, StatChoice] | None = None,
    default_stat_info: (
        StatChoice
        | Iterable[StatChoice]
        | dict[int, StatChoice | Iterable[StatChoice]]
        | None
    ) = None,
    default_nan_policy: Literal["skip", "warn", "assert"] = "skip",
    default_inf_policy: Literal["propagate", "skip", "warn", "assert"] = "propagate",
) -> Callable:
    """Decorator to add stat and stat_info args for aggregation.

    This decorator adds `stat`, `stat_info`, `nan_policy`, and `inf_policy`
    parameters to a function that returns per-neuron values.

    - `stat`: If not None, returns the aggregated value instead of per-neuron
      values. The aggregation is stored in info[f"{value_key}_stat"].
      Can be a StatChoice, or a dict mapping return position to label
      (e.g., {1: "eci", 3: "lag"}) for functions returning multiple values.
    - `stat_info`: Additional stats to compute and store in info dict without
      affecting the return value. Can be a single StatChoice, Iterable of
      StatChoice, a dict mapping position to label(s), or None.
      If dict, format is {position: stat_or_stats} where stat_or_stats can be
      a single StatChoice or Iterable of StatChoice.
    - `dim`: Dimension(s) to aggregate over. Can be:
        - None: Flatten all dimensions (default)
        - int: Aggregate over this dimension for all outputs
        - tuple[int, ...]: Aggregate over these dimensions for all outputs
        - dict[int, int | tuple[int, ...] | None]: Different dim for each
          output position (e.g., {0: 1, 1: 2, 2: None, 3: (1, 3, 4)})
    - `nan_policy`: How to handle NaN values:
        - "skip": Ignore NaN values (default)
        - "warn": Warn if NaN values found but continue
        - "assert": Raise error if NaN values found
    - `inf_policy`: How to handle Inf values:
        - "propagate": Keep Inf values (default)
        - "skip": Ignore Inf values
        - "warn": Warn if Inf values found but continue
        - "assert": Raise error if Inf values found

    The decorated function should return either:
    - A tuple of (values, info_dict) where values are per-neuron metrics
    - Just the per-neuron values (will be wrapped in a tuple with empty dict)
    - A tuple of multiple values with info as the last element

    Args:
        func: The function to decorate (or None if using with parentheses)
        value_key: Key prefix to use in info dict for stat results
        dim: Dimension(s) to aggregate over for each output
        default_nan_policy: Default nan_policy for this decorated function
        default_inf_policy: Default inf_policy for this decorated function
        default_stat: Default stat for this decorated function

    Returns:
        Decorated function with added stat, stat_info, nan_policy, and
        inf_policy parameters

    Example:
        ```python
        @use_stat
        def compute_metric(
            data,
            *,
            stat=None,
            stat_info=None,
            nan_policy="skip",
            inf_policy="propagate",
        ):
            values = some_computation(data)  # per-neuron values
            return values, {"raw": values}

        # Usage:
        values, info = compute_metric(data)  # returns per-neuron values
        mean_val, info = compute_metric(data, stat="mean")  # returns aggregated
        values, info = compute_metric(
            data, stat_info=["mean", "max"]
        )  # extra stats in info

        # Multi-value return with dict stat:
        @use_stat
        def compute_multiple(data, *, stat=None, stat_info=None):
            eci = compute_eci(data)  # per-neuron
            lag = compute_lag(data)  # per-neuron
            return eci, lag, {}  # multiple values

        # Aggregate specific positions with dict stat:
        eci_mean, lag_mean, info = compute_multiple(
            data, stat={0: "eci", 1: "lag"}
        )
        ```
    """

    def decorator(f: Callable) -> Callable:
        # Inspect the wrapped function to determine what arguments it accepts
        sig = inspect.signature(f)
        f_accepts_nan_policy = "nan_policy" in sig.parameters
        f_accepts_inf_policy = "inf_policy" in sig.parameters

        @wraps(f)
        def wrapper(
            *args,
            stat: StatChoice | dict[int, StatChoice] | None = default_stat,
            stat_info: StatChoice
            | Iterable[StatChoice]
            | dict[int, StatChoice | Iterable[StatChoice]]
            | None = default_stat_info,
            nan_policy: Literal["skip", "warn", "assert"] | None = None,
            inf_policy: Literal["propagate", "skip", "warn", "assert"] | None = None,
            **kwargs,
        ) -> tuple[Any, ...]:
            # Use effective policies (passed value > decorator default > "skip")
            effective_nan_policy = (
                nan_policy if nan_policy is not None else default_nan_policy
            )
            effective_inf_policy = (
                inf_policy if inf_policy is not None else default_inf_policy
            )

            # Pass policies to the wrapped function if it accepts them
            if f_accepts_nan_policy:
                kwargs["nan_policy"] = effective_nan_policy
            if f_accepts_inf_policy:
                kwargs["inf_policy"] = effective_inf_policy

            # Call the original function
            result = f(*args, **kwargs)

            # Unpack result using shared helper
            values_tuple, info = _unpack_result(result, value_key)

            # Ensure info is a dict
            if info is None:
                info = {}

            updated_info = dict(info)

            # Helper to get value key name for a position
            def _get_value_key_name(pos: int) -> str:
                if isinstance(value_key, dict):
                    return value_key.get(pos, f"values{pos}")
                if len(values_tuple) > 1:
                    return f"{value_key}{pos}"
                return value_key

            # Helper to get values at a position
            def _get_values(pos: int) -> Any:
                if pos < 0 or pos >= len(values_tuple):
                    raise IndexError(
                        f"Position {pos} out of range for return tuple "
                        f"of length {len(values_tuple)}"
                    )
                return values_tuple[pos]

            # Helper to get effective dim for a position
            def _get_dim_for_pos(pos: int) -> int | tuple[int, ...] | None:
                if dim is None:
                    return None
                if isinstance(dim, dict):
                    return dim.get(pos, None)
                return dim

            # Handle stat parameter
            if stat is not None:
                # Check if stat is a dict mapping positions to stats
                if isinstance(stat, dict):
                    # Multiple position aggregation with dict stat
                    results = []
                    for pos, stat_choice in stat.items():
                        values = _get_values(pos)
                        key_name = _get_value_key_name(pos)
                        effective_dim = _get_dim_for_pos(pos)
                        stat_value = _compute_stat(
                            values,
                            stat_choice,
                            effective_nan_policy,
                            effective_inf_policy,
                            effective_dim,  # type: ignore
                        )
                        results.append(stat_value)
                        updated_info[key_name] = values
                        updated_info[f"{key_name}_{stat_choice}"] = stat_value
                    return tuple(results) + (updated_info,)
                else:
                    # Single stat - apply to position 0
                    values = _get_values(0)
                    key_name = _get_value_key_name(0)
                    effective_dim = _get_dim_for_pos(0)
                    stat_value = _compute_stat(
                        values,
                        stat,
                        effective_nan_policy,
                        effective_inf_policy,
                        effective_dim,  # type: ignore
                    )
                    updated_info[key_name] = values
                    updated_info[f"{key_name}_{stat}"] = stat_value
                    return stat_value, updated_info

            # Handle stat_info parameter
            if stat_info is not None:
                # Check if stat_info is a dict mapping positions to stats
                if isinstance(stat_info, dict):
                    # Dict format: {position: stat_or_stats}
                    for pos, stats in stat_info.items():
                        # Normalize to iterable
                        if isinstance(stats, str):
                            stats_list = [stats]
                        else:
                            stats_list = list(stats)

                        # Use batch computation for efficiency
                        values = _get_values(pos)
                        key_name = _get_value_key_name(pos)
                        effective_dim = _get_dim_for_pos(pos)
                        if len(stats_list) > 1:
                            batch_results = _compute_stats_batch(
                                values,
                                [str(s) for s in stats_list],
                                effective_nan_policy,
                                effective_inf_policy,
                                effective_dim,
                            )
                            for s in stats_list:
                                updated_info[f"{key_name}_{s}"] = batch_results[str(s)]
                        else:
                            # Single stat - no need for batch optimization
                            stat_value = _compute_stat(
                                values,
                                stats_list[0],
                                effective_nan_policy,
                                effective_inf_policy,
                                effective_dim,  # type: ignore
                            )
                            updated_info[f"{key_name}_{stats_list[0]}"] = stat_value
                else:
                    # Original format: apply to position 0
                    # Normalize to iterable
                    if isinstance(stat_info, str):
                        stat_info_list = [stat_info]
                    else:
                        stat_info_list = list(stat_info)

                    # Use batch computation for efficiency (reuses mean/std for cv)
                    values = _get_values(0)
                    key_name = _get_value_key_name(0)
                    effective_dim = _get_dim_for_pos(0)
                    if len(stat_info_list) > 1:
                        batch_results = _compute_stats_batch(
                            values,
                            [str(s) for s in stat_info_list],
                            effective_nan_policy,
                            effective_inf_policy,
                            effective_dim,
                        )
                        for s in stat_info_list:
                            updated_info[f"{key_name}_{s}"] = batch_results[str(s)]
                    else:
                        # Single stat - no need for batch optimization
                        stat_value = _compute_stat(
                            values,
                            stat_info_list[0],
                            effective_nan_policy,
                            effective_inf_policy,
                            effective_dim,  # type: ignore
                        )
                        updated_info[f"{key_name}_{stat_info_list[0]}"] = stat_value

                # Return original values with updated info
                if len(values_tuple) > 1:
                    return values_tuple + (updated_info,)
                else:
                    return (values_tuple[0], updated_info)

            # No stat or stat_info - return original values with info
            if len(values_tuple) > 1:
                return values_tuple + (updated_info,)
            else:
                return (values_tuple[0], updated_info)

        return wrapper

    if func is None:
        return decorator
    return decorator(func)

voltage_overshoot(V, mode='threshold_resting', skip_timestep=None, **params)

Quantify voltage stability/overshoot in different ways.

Source code in btorch/analysis/voltage.py
def voltage_overshoot(
    V,
    mode: Literal["std", "mse_threshold", "threshold_resting"] = "threshold_resting",
    skip_timestep: Optional[int] = None,
    **params,
):
    """Quantify voltage stability/overshoot in different ways."""
    is_numpy = isinstance(V, np.ndarray)
    V = V.astype(np.float32) if is_numpy else V.to(torch.float32)

    if skip_timestep is None:
        skip_timestep = suggest_skip_timestep(V)

    V_slice = V[skip_timestep:]
    if mode == "std":
        return V_slice.std(0)

    elif mode == "mse_threshold":
        V_th = params["V_th"]
        return (
            (V_slice - V_th).pow(2).mean(0)
            if not is_numpy
            else np.mean((V_slice - V_th) ** 2, axis=0)
        )

    elif mode == "threshold_resting":
        V_th = params["V_th"]
        V_reset = params["V_reset"]
        n_scale = params.get("n_scale", 3)

        scale = V_th - V_reset

        upper = V_th + n_scale * scale
        lower = V_reset - n_scale * scale

        mask = (V_slice > upper) | (V_slice < lower)
        return mask.mean(0)

    else:
        raise ValueError(f"Unsupported mode: {mode}")