Skip to content

Visualisation

btorch.visualisation

Visualization tools for neuromorphic data analysis.

This module provides plotting utilities for spike trains, network dynamics, connectome structure, and neuron state traces. The API is organized into five plotting families:

Aggregation plots (aggregation): - Grouped distributions: plot_group_distribution, plot_group_violin, plot_group_box, plot_group_ecdf - Neuropil timeseries: plot_neuropil_timeseries_overview, plot_neuropil_timeseries_panels

Dynamics plots (dynamics): - Multiscale analysis: plot_multiscale_fano, plot_dfa_analysis, plot_isi_cv - Criticality and attractors: plot_avalanche_analysis, plot_eigenvalue_spectrum, plot_lyapunov_spectrum - Micro-dynamics: plot_firing_rate_distribution, plot_micro_dynamics, plot_gain_stability

Timeseries plots (timeseries): - Spike visualization: plot_raster - Continuous traces: plot_traces, plot_neuron_traces - Spectral analysis: plot_spectrum, plot_log_hist

Network plots (network, hexmap): - Graph layout: plot_network - Hexagonal heatmaps: hex_heatmap

Tuning plots (tuning): - Response curves: [plot_fi_vi_curve][btorch.visualisation.tuning.plot_fi_vi_curve]

Attributes

__all__ = ['hex_heatmap', 'plot_network', 'plot_group_box', 'plot_group_distribution', 'plot_group_ecdf', 'plot_group_violin', 'plot_neuropil_timeseries_overview', 'plot_neuropil_timeseries_panels', 'plot_log_hist', 'plot_raster', 'plot_spectrum', 'plot_traces', 'plot_neuron_traces', 'SimulationStates', 'TracePlotFormat', 'plot_multiscale_fano', 'plot_dfa_analysis', 'plot_isi_cv', 'DynamicsData', 'DFAConfig', 'DynamicsPlotFormat', 'FanoFactorConfig', 'plot_avalanche_analysis', 'plot_eigenvalue_spectrum', 'plot_firing_rate_distribution', 'plot_gain_stability', 'plot_lyapunov_spectrum', 'plot_micro_dynamics'] module-attribute

Classes

DFAConfig dataclass

Configuration for DFA (Detrended Fluctuation Analysis).

Attributes:

Name Type Description
min_window int

Minimum window size for DFA in timesteps.

max_window int | None

Maximum window size. If None, auto-calculated.

bin_size int

Bin size for spike binning in timesteps.

Source code in btorch/visualisation/dynamics.py
@dataclass
class DFAConfig:
    """Configuration for DFA (Detrended Fluctuation Analysis).

    Attributes:
        min_window: Minimum window size for DFA in timesteps.
        max_window: Maximum window size. If None, auto-calculated.
        bin_size: Bin size for spike binning in timesteps.
    """

    min_window: int = 4
    max_window: int | None = None
    bin_size: int = 1

DynamicsData dataclass

Container for dynamics analysis data and configs.

Attributes:

Name Type Description
spikes ndarray | Tensor

Spike trains with shape (time, neurons).

dt float

Simulation timestep in milliseconds.

neurons_df DataFrame | None

DataFrame with neuron metadata for grouping.

connections_df DataFrame | None

DataFrame with connection metadata for neuropil aggregation.

Source code in btorch/visualisation/dynamics.py
@dataclass
class DynamicsData:
    """Container for dynamics analysis data and configs.

    Attributes:
        spikes: Spike trains with shape (time, neurons).
        dt: Simulation timestep in milliseconds.
        neurons_df: DataFrame with neuron metadata for grouping.
        connections_df: DataFrame with connection metadata for neuropil
            aggregation.
    """

    spikes: np.ndarray | torch.Tensor
    dt: float = 1.0
    neurons_df: pd.DataFrame | None = None
    connections_df: pd.DataFrame | None = None

DynamicsPlotFormat dataclass

Figure formatting for dynamics plots.

Attributes:

Name Type Description
mode Literal['individual', 'grouped', 'distribution']

Visualization mode - "individual" for specific neurons, "grouped" for aggregated groups, "distribution" for summary stats.

group_by Literal['neuropil', 'neuron_type', None]

Grouping method for aggregation ("neuropil" or "neuron_type").

neuron_type_column str

Column name in neurons_df for neuron classification.

neuron_indices list[int] | None

Specific neuron indices for individual mode.

colors dict | None

Color mapping dictionary.

figsize tuple[float, float] | None

Figure size as (width, height) in inches.

Source code in btorch/visualisation/dynamics.py
@dataclass
class DynamicsPlotFormat:
    """Figure formatting for dynamics plots.

    Attributes:
        mode: Visualization mode - "individual" for specific neurons,
            "grouped" for aggregated groups, "distribution" for summary stats.
        group_by: Grouping method for aggregation ("neuropil" or "neuron_type").
        neuron_type_column: Column name in neurons_df for neuron classification.
        neuron_indices: Specific neuron indices for individual mode.
        colors: Color mapping dictionary.
        figsize: Figure size as (width, height) in inches.
    """

    mode: Literal["individual", "grouped", "distribution"] = "individual"
    group_by: Literal["neuropil", "neuron_type", None] = None
    neuron_type_column: str = "cell_type"
    neuron_indices: list[int] | None = None
    colors: dict | None = None
    figsize: tuple[float, float] | None = None

FanoFactorConfig dataclass

Configuration for Fano factor analysis.

Attributes:

Name Type Description
windows list[int] | None

Time windows in timesteps for multiscale analysis. If None, logarithmically spaced windows are auto-generated.

overlap int

Overlap between consecutive windows in timesteps.

Source code in btorch/visualisation/dynamics.py
@dataclass
class FanoFactorConfig:
    """Configuration for Fano factor analysis.

    Attributes:
        windows: Time windows in timesteps for multiscale analysis.
            If None, logarithmically spaced windows are auto-generated.
        overlap: Overlap between consecutive windows in timesteps.
    """

    windows: list[int] | None = None
    overlap: int = 0

SimulationStates dataclass

Container for simulation state data and configs.

Attributes:

Name Type Description
voltage ndarray | Tensor

Membrane voltage traces (time, neurons) or (time, batch, neurons) if batch dimension present

dt float

Simulation timestep in ms

asc ndarray | Tensor | None

Afterspike current traces (time, neurons), (time, batch, neurons), or (time, batch, neurons, n_asc) for multiple ASC components

psc ndarray | Tensor | None

Total postsynaptic current (time, neurons), (time, batch, neurons), or (time, batch, neurons, n_psc) for multiple PSC components

epsc ndarray | Tensor | None

Excitatory PSC (time, neurons) or (time, batch, neurons)

ipsc ndarray | Tensor | None

Inhibitory PSC (time, neurons) or (time, batch, neurons)

input ndarray | Tensor | None

Input current (time, neurons) or (time, batch, neurons)

spikes ndarray | Tensor | None

Spike trains (time, neurons) or (time, batch, neurons)

v_threshold float | Sequence[float] | ndarray | Tensor | None

Spike threshold voltage(s), scalar or per-neuron

v_reset float | Sequence[float] | ndarray | Tensor | None

Reset voltage(s), scalar or per-neuron

Source code in btorch/visualisation/timeseries.py
@dataclass
class SimulationStates:
    """Container for simulation state data and configs.

    Attributes:
        voltage: Membrane voltage traces (time, neurons) or
            (time, batch, neurons) if batch dimension present
        dt: Simulation timestep in ms
        asc: Afterspike current traces (time, neurons), (time, batch, neurons),
            or (time, batch, neurons, n_asc) for multiple ASC components
        psc: Total postsynaptic current (time, neurons), (time, batch, neurons),
            or (time, batch, neurons, n_psc) for multiple PSC components
        epsc: Excitatory PSC (time, neurons) or (time, batch, neurons)
        ipsc: Inhibitory PSC (time, neurons) or (time, batch, neurons)
        input: Input current (time, neurons) or (time, batch, neurons)
        spikes: Spike trains (time, neurons) or (time, batch, neurons)
        v_threshold: Spike threshold voltage(s), scalar or per-neuron
        v_reset: Reset voltage(s), scalar or per-neuron
    """

    voltage: np.ndarray | torch.Tensor
    dt: float = 1.0
    asc: np.ndarray | torch.Tensor | None = None
    psc: np.ndarray | torch.Tensor | None = None
    epsc: np.ndarray | torch.Tensor | None = None
    ipsc: np.ndarray | torch.Tensor | None = None
    input: np.ndarray | torch.Tensor | None = None
    spikes: np.ndarray | torch.Tensor | None = None
    v_threshold: float | Sequence[float] | np.ndarray | torch.Tensor | None = None
    v_reset: float | Sequence[float] | np.ndarray | torch.Tensor | None = None

TracePlotFormat dataclass

Figure formatting configuration.

Attributes:

Name Type Description
neuron_indices list[int] | None

Specific neuron indices to plot

sample_size int | None

Number of neurons to randomly sample

seed int

Random seed for sampling

show_voltage bool

Whether to show voltage subplot

show_asc bool

Whether to show ASC subplot

show_psc bool

Whether to show PSC subplot

show_spikes_on_voltage bool

Mark spikes on voltage trace

separate_figures bool

Return dict of figures (one per trace type) if True

auto_width bool

Adjust figure width based on simulation duration

colors dict[str, str] | None

Color mapping for different traces

figsize_per_neuron tuple[float, float]

Figure size per neuron row (width, height)

neuron_labels Sequence[str] | Callable[[int], str] | None

Side labels as sequence or callable(neuron_idx) -> str. Default None disables side labels.

neuron_label_position Literal['side', 'top']

Position for neuron labels when enabled. "side" places labels at the right of each neuron slot; "top" places labels above each neuron slot.

neurons_per_row int | None

Number of neurons to place per row in combined mode

batch_idx int | None

Batch index to plot when data has shape (time, batch, neurons). If None and data is 3D, defaults to 0 (first batch sample).

Source code in btorch/visualisation/timeseries.py
@dataclass
class TracePlotFormat:
    """Figure formatting configuration.

    Attributes:
        neuron_indices: Specific neuron indices to plot
        sample_size: Number of neurons to randomly sample
        seed: Random seed for sampling
        show_voltage: Whether to show voltage subplot
        show_asc: Whether to show ASC subplot
        show_psc: Whether to show PSC subplot
        show_spikes_on_voltage: Mark spikes on voltage trace
        separate_figures: Return dict of figures (one per trace type) if True
        auto_width: Adjust figure width based on simulation duration
        colors: Color mapping for different traces
        figsize_per_neuron: Figure size per neuron row (width, height)
        neuron_labels: Side labels as sequence or callable(neuron_idx) -> str.
            Default None disables side labels.
        neuron_label_position: Position for neuron labels when enabled.
            "side" places labels at the right of each neuron slot; "top"
            places labels above each neuron slot.
        neurons_per_row: Number of neurons to place per row in combined mode
        batch_idx: Batch index to plot when data has shape (time, batch, neurons).
            If None and data is 3D, defaults to 0 (first batch sample).
    """

    neuron_indices: list[int] | None = None
    sample_size: int | None = None
    seed: int = 42
    show_voltage: bool = True
    show_asc: bool = True
    show_psc: bool = True
    show_spikes_on_voltage: bool = True
    separate_figures: bool = False
    auto_width: bool = True
    colors: dict[str, str] | None = None
    figsize_per_neuron: tuple[float, float] = (12, 2.5)
    neuron_labels: Sequence[str] | Callable[[int], str] | None = None
    neuron_label_position: Literal["side", "top"] = "side"
    neuron_specs: list[NeuronSpec | dict] | NeuronSpec | dict | None = None
    neurons_per_row: int | None = None
    batch_idx: int | None = None

Functions

hex_heatmap(df, dataset, style=None, sizing=None, dpi=72, custom_colorscale=None)

Generate an interactive hexagonal heatmap.

Visualizes data on a hexagonal grid layout. Single-column data produces a static heatmap; multi-column DataFrames produce an animated heatmap with a slider to navigate through timepoints or conditions.

Parameters:

Name Type Description Default
df Series | DataFrame

Data to visualize. Single Series for static plot, or DataFrame with multiple columns for animated plot (one frame per column). Must have 'p' and 'q' columns representing hex grid coordinates.

required
dataset DataFrame

Reference dataset defining the full hex grid background. Used to render empty hexagons for spatial context.

required
style dict | None

Styling options dict with keys: - "font_type": Font family (default: "arial") - "markerlinecolor": Marker line color - "linecolor": Axis/line color (default: "black") - "papercolor": Background color (default: "rgba(255,255,255,255)")

None
sizing dict | None

Size configuration dict with keys: - "fig_width", "fig_height": Figure dimensions in mm - "markersize": Hexagon marker size (default: 16) - "cbar_thickness", "cbar_len": Colorbar dimensions

None
dpi int

Dots per inch for pixel calculations (default: 72).

72
custom_colorscale list | None

Custom Plotly colorscale. Default is white-to-blue.

None

Returns:

Type Description
Figure

Plotly Figure with hexagonal heatmap. Static for Series input,

Figure

animated with slider for DataFrame input.

Raises:

Type Description
ValueError

If df is not a Series or DataFrame.

Example
Static heatmap

fig = hex_heatmap(data_series, background_dataset) fig.show()

Animated heatmap with timepoints

fig = hex_heatmap(timepoint_df, background_dataset) fig.write_html("animated_hexmap.html")

Source code in btorch/visualisation/hexmap.py
def hex_heatmap(
    df: pd.Series | pd.DataFrame,
    dataset: pd.DataFrame,
    style: dict | None = None,
    sizing: dict | None = None,
    dpi: int = 72,
    custom_colorscale: list | None = None,
) -> go.Figure:
    """Generate an interactive hexagonal heatmap.

    Visualizes data on a hexagonal grid layout. Single-column data produces
    a static heatmap; multi-column DataFrames produce an animated heatmap
    with a slider to navigate through timepoints or conditions.

    Args:
        df: Data to visualize. Single Series for static plot, or DataFrame
            with multiple columns for animated plot (one frame per column).
            Must have 'p' and 'q' columns representing hex grid coordinates.
        dataset: Reference dataset defining the full hex grid background.
            Used to render empty hexagons for spatial context.
        style: Styling options dict with keys:
            - "font_type": Font family (default: "arial")
            - "markerlinecolor": Marker line color
            - "linecolor": Axis/line color (default: "black")
            - "papercolor": Background color (default: "rgba(255,255,255,255)")
        sizing: Size configuration dict with keys:
            - "fig_width", "fig_height": Figure dimensions in mm
            - "markersize": Hexagon marker size (default: 16)
            - "cbar_thickness", "cbar_len": Colorbar dimensions
        dpi: Dots per inch for pixel calculations (default: 72).
        custom_colorscale: Custom Plotly colorscale. Default is white-to-blue.

    Returns:
        Plotly Figure with hexagonal heatmap. Static for Series input,
        animated with slider for DataFrame input.

    Raises:
        ValueError: If `df` is not a Series or DataFrame.

    Example:
        >>> # Static heatmap
        >>> fig = hex_heatmap(data_series, background_dataset)
        >>> fig.show()
        >>>
        >>> # Animated heatmap with timepoints
        >>> fig = hex_heatmap(timepoint_df, background_dataset)
        >>> fig.write_html("animated_hexmap.html")
    """

    def bg_hex():
        goscatter = go.Scatter(
            x=background_hex["x"],
            y=background_hex["y"],
            mode="markers",
            marker_symbol=symbol_number,
            marker={
                "size": sizing["markersize"],
                "color": "white",
                "line": {
                    "width": sizing["markerlinewidth"],
                    "color": "lightgrey",
                },
            },
            showlegend=False,
        )
        return goscatter

    def data_hex(aseries):
        goscatter = go.Scatter(
            x=x_vals,
            y=y_vals,
            mode="markers",
            marker_symbol=symbol_number,
            marker={
                "cmin": global_min,
                "cmax": global_max,
                "size": sizing["markersize"],
                "color": aseries.values,
                "line": {
                    "width": sizing["markerlinewidth"],
                    "color": "lightgrey",
                },
                "colorbar": {
                    "orientation": "v",
                    "outlinecolor": style["linecolor"],
                    "outlinewidth": sizing["axislinewidth"],
                    "thickness": sizing["cbar_thickness"],
                    "len": sizing["cbar_len"],
                    "tickmode": "array",
                    "ticklen": sizing["ticklen"],
                    "tickwidth": sizing["tickwidth"],
                    "tickcolor": style["linecolor"],
                    "tickfont": {
                        "size": fsize_ticks_px,
                        "family": style["font_type"],
                        "color": style["linecolor"],
                    },
                    "tickformat": ".5f",
                    "title": {
                        "font": {
                            "family": style["font_type"],
                            "size": fsize_title_px,
                            "color": style["linecolor"],
                        },
                        "side": "right",
                    },
                },
                "colorscale": custom_colorscale,
            },
            showlegend=False,
        )
        return goscatter

    default_style = {
        "font_type": "arial",
        "markerlinecolor": "rgba(0,0,0,0)",
        "linecolor": "black",
        "papercolor": "rgba(255,255,255,255)",
    }

    markersize = 16

    default_sizing = {
        "fig_width": 260,
        "fig_height": 220,
        "fig_margin": 0,
        "fsize_ticks_pt": 20,
        "fsize_title_pt": 20,
        "markersize": markersize,
        "ticklen": 15,
        "tickwidth": 5,
        "axislinewidth": 3,
        "markerlinewidth": 0.9,
        "cbar_thickness": 20,
        "cbar_len": 0.75,
    }

    if style is not None:
        default_style.update(style)
    style = default_style

    if sizing is not None:
        default_sizing.update(sizing)
    sizing = default_sizing

    POINTS_PER_INCH = 72
    MM_PER_INCH = 25.4

    pixelsperinch = dpi
    pixelspermm = pixelsperinch / MM_PER_INCH

    if custom_colorscale is None:
        custom_colorscale = [[0, "rgb(255, 255, 255)"], [1, "rgb(0, 20, 200)"]]

    area_width = (sizing["fig_width"] - sizing["fig_margin"]) * pixelspermm
    area_height = (sizing["fig_height"] - sizing["fig_margin"]) * pixelspermm

    fsize_ticks_px = sizing["fsize_ticks_pt"] * (1 / POINTS_PER_INCH) * pixelsperinch
    fsize_title_px = sizing["fsize_title_pt"] * (1 / POINTS_PER_INCH) * pixelsperinch

    global_min = min(0, df.values.min())
    global_max = df.values.max()

    symbol_number = 15

    background_hex = dataset
    background_hex = background_hex.drop_duplicates(subset=["p", "q"])[
        ["p", "q"]
    ].astype(float)
    x, y = hex_to_pixel(background_hex.p, background_hex.q, mode="flat")
    background_hex["x"], background_hex["y"] = x, y

    fig = go.Figure()
    fig.update_layout(
        autosize=False,
        height=area_height,
        width=area_width,
        margin={"l": 0, "r": 0, "b": 0, "t": 0, "pad": 0},
        paper_bgcolor=style["papercolor"],
        plot_bgcolor=style["papercolor"],
    )
    fig.update_xaxes(
        showgrid=False, showticklabels=False, showline=False, visible=False
    )
    fig.update_yaxes(
        showgrid=False,
        showticklabels=False,
        showline=False,
        visible=False,
        scaleanchor="x",
        scaleratio=1,
    )

    df["x"], df["y"] = hex_to_pixel(df.p, df.q, mode="flat")
    x_vals, y_vals = df.x, df.y
    df = df.drop(columns=["p", "q", "x", "y"])

    if len(df.columns) == 1:
        if isinstance(df, pd.DataFrame):
            df = df.iloc[:, 0]
        fig.add_trace(bg_hex())
        fig.add_trace(data_hex(df))

    elif isinstance(df, pd.DataFrame):
        slider_height = 100
        area_height += slider_height

        frames = []
        slider_steps = []

        fig.update_layout(
            autosize=False,
            height=area_height,
            width=area_width,
            margin={
                "l": 0,
                "r": 0,
                "b": slider_height,
                "t": 0,
                "pad": 0,
            },
            paper_bgcolor=style["papercolor"],
            plot_bgcolor=style["papercolor"],
            sliders=[
                {
                    "active": 0,
                    "currentvalue": {
                        "font": {"size": 16},
                        "visible": True,
                        "xanchor": "right",
                    },
                    "pad": {"b": 10, "t": 0},
                    "len": 0.9,
                    "x": 0.1,
                    "y": 0,
                    "steps": [],
                }
            ],
        )

        for i, col_name in enumerate(df.columns):
            series = df[col_name]
            frame_data = [
                bg_hex(),
                data_hex(series),
            ]

            frames.append(go.Frame(data=frame_data, name=str(i)))

            slider_steps.append(
                {
                    "args": [
                        [str(i)],
                        {"frame": {"duration": 0, "redraw": True}, "mode": "immediate"},
                    ],
                    "label": col_name,
                    "method": "animate",
                }
            )

            if i == 0:
                fig.add_traces(frame_data)

        fig.layout.sliders[0].steps = slider_steps
        fig.frames = frames

        fig.update_xaxes(
            showgrid=False, showticklabels=False, showline=False, visible=False
        )
        fig.update_yaxes(
            showgrid=False, showticklabels=False, showline=False, visible=False
        )

    else:
        raise ValueError("df must be a pd.Series or pd.DataFrame")

    return fig

plot_avalanche_analysis(spikes, bin_size=1, dt=1.0)

Plot avalanche size and duration distributions to analyze criticality.

Creates a 3-panel figure showing: 1. Avalanche size distribution P(S) with power-law fit 2. Avalanche duration distribution P(T) with power-law fit 3. Average size vs duration scaling relation <S>(T)

Criticality is indicated by power-law distributions and specific scaling exponents (tau, alpha, gamma).

Parameters:

Name Type Description Default
spikes ndarray | Tensor

Spike matrix with shape (time, neurons).

required
bin_size int

Bin size for avalanche detection in timesteps.

1
dt float

Timestep in ms (unused, kept for interface consistency).

1.0

Returns:

Type Description
Figure

Tuple of (figure, results) where results contains fitted exponents

dict

(tau, alpha, gamma), CCC (criticality consistency check), and

tuple[Figure, dict]

power-law fit objects.

Example

fig, results = plot_avalanche_analysis(spikes, bin_size=5) print(f"Tau: {results['tau']:.2f}, CCC: {results['CCC']:.2f}")

Source code in btorch/visualisation/dynamics.py
def plot_avalanche_analysis(
    spikes: np.ndarray | torch.Tensor,
    bin_size: int = 1,
    dt: float = 1.0,  # Added for consistency
) -> tuple[Figure, dict]:
    """Plot avalanche size and duration distributions to analyze criticality.

    Creates a 3-panel figure showing:
    1. Avalanche size distribution P(S) with power-law fit
    2. Avalanche duration distribution P(T) with power-law fit
    3. Average size vs duration scaling relation `<S>(T)`

    Criticality is indicated by power-law distributions and specific
    scaling exponents (tau, alpha, gamma).

    Args:
        spikes: Spike matrix with shape (time, neurons).
        bin_size: Bin size for avalanche detection in timesteps.
        dt: Timestep in ms (unused, kept for interface consistency).

    Returns:
        Tuple of (figure, results) where results contains fitted exponents
        (tau, alpha, gamma), CCC (criticality consistency check), and
        power-law fit objects.

    Example:
        >>> fig, results = plot_avalanche_analysis(spikes, bin_size=5)
        >>> print(f"Tau: {results['tau']:.2f}, CCC: {results['CCC']:.2f}")
    """
    from ..analysis.dynamic_tools.criticality import compute_avalanche_statistics

    spikes = _to_numpy(spikes)
    results = compute_avalanche_statistics(spikes, bin_size=bin_size)

    fig = plt.figure(figsize=(15, 4))

    # 1. Size Distribution P(S)
    ax1 = fig.add_subplot(1, 3, 1)
    if results["fit_S"]:
        results["fit_S"].plot_pdf(color="b", linewidth=2, ax=ax1, label="Data")
        results["fit_S"].power_law.plot_pdf(
            color="b", linestyle="--", ax=ax1, label=f"Fit (tau={results['tau']:.2f})"
        )
    ax1.set_xlabel("Avalanche Size (S)")
    ax1.set_ylabel("P(S)")
    ax1.set_title("Size Distribution")
    if ax1.get_legend_handles_labels()[0]:
        ax1.legend()

    # 2. Duration Distribution P(T)
    ax2 = fig.add_subplot(1, 3, 2)
    if results["fit_T"]:
        results["fit_T"].plot_pdf(color="r", linewidth=2, ax=ax2, label="Data")
        results["fit_T"].power_law.plot_pdf(
            color="r",
            linestyle="--",
            ax=ax2,
            label=f"Fit (alpha={results['alpha']:.2f})",
        )
    ax2.set_xlabel("Avalanche Duration (T)")
    ax2.set_ylabel("P(T)")
    ax2.set_title("Duration Distribution")
    if ax2.get_legend_handles_labels()[0]:
        ax2.legend()

    # 3. Average Size vs Duration <S>(T)
    ax3 = fig.add_subplot(1, 3, 3)
    if (
        "avg_size_by_duration" in results
        and results["avg_size_by_duration"] is not None
    ):
        durations, mean_sizes = results["avg_size_by_duration"]
        ax3.loglog(durations, mean_sizes, "ko", markersize=4, label="Data")

        # Plot fit
        if not np.isnan(results["gamma"]):
            if "gamma_stats" in results and "popt" in results["gamma_stats"]:
                popt = results["gamma_stats"]["popt"]
                a, gamma = popt
                x_fit = np.logspace(
                    np.log10(durations.min()), np.log10(durations.max()), 100
                )
                y_fit = a * np.power(x_fit, gamma)
                ax3.loglog(
                    x_fit,
                    y_fit,
                    "g--",
                    label=f"Fit (gamma={results['gamma']:.2f})",
                )

    # Annotate CCC
    if not np.isnan(results["CCC"]):
        txt = f"CCC = {results['CCC']:.2f}\nPred gamma = {results['gamma_pred']:.2f}"
        ax3.text(
            0.05,
            0.95,
            txt,
            transform=ax3.transAxes,
            verticalalignment="top",
            bbox=dict(boxstyle="round", facecolor="white", alpha=0.8),
        )

    ax3.set_xlabel("Duration (T)")
    ax3.set_ylabel("Average Size <S>")
    ax3.set_title("Scaling Relation")
    if ax3.get_legend_handles_labels()[0]:
        ax3.legend()

    plt.tight_layout()
    return fig, results

plot_dfa_analysis(data=None, config=None, format=None, spikes=None, dt=1.0, min_window=4, max_window=None, bin_size=1, mode='individual', neurons_df=None, **kwargs)

Plot DFA (Detrended Fluctuation Analysis) results.

DFA quantifies long-range temporal correlations in spike trains. The scaling exponent (alpha) indicates: - alpha ≈ 0.5: Uncorrelated (random) activity - alpha > 0.5: Long-range positive correlations - alpha < 0.5: Long-range anti-correlations

Supports both dataclass and plain argument interfaces.

Parameters:

Name Type Description Default
data DynamicsData | None

DynamicsData container with spikes.

None
config DFAConfig | None

DFAConfig with window and bin settings.

None
format DynamicsPlotFormat | None

DynamicsPlotFormat (mode affects plot style).

None
spikes ndarray | Tensor | None

Spike trains (time, neurons). Required if data not provided.

None
dt float

Timestep in milliseconds (for label consistency).

1.0
min_window int

Minimum window size for DFA in timesteps.

4
max_window int | None

Maximum window size. Auto-calculated if None.

None
bin_size int

Bin size for spike binning in timesteps.

1
mode Literal['individual', 'grouped', 'distribution']

Visualization mode (affects annotation style).

'individual'
neurons_df DataFrame | None

Neuron metadata for potential grouping.

None
**kwargs

Additional arguments.

{}

Returns:

Type Description
Figure

Figure with DFA summary and interpretation guide.

Raises:

Type Description
ValueError

If spikes are not provided.

Example

fig = plot_dfa_analysis(spikes, bin_size=10)

Source code in btorch/visualisation/dynamics.py
def plot_dfa_analysis(
    # Dataclass interface
    data: DynamicsData | None = None,
    config: DFAConfig | None = None,
    format: DynamicsPlotFormat | None = None,
    # Plain args interface
    spikes: np.ndarray | torch.Tensor | None = None,
    dt: float = 1.0,
    min_window: int = 4,
    max_window: int | None = None,
    bin_size: int = 1,
    mode: Literal["individual", "grouped", "distribution"] = "individual",
    neurons_df: pd.DataFrame | None = None,
    **kwargs,
) -> Figure:
    """Plot DFA (Detrended Fluctuation Analysis) results.

    DFA quantifies long-range temporal correlations in spike trains.
    The scaling exponent (alpha) indicates:
    - alpha ≈ 0.5: Uncorrelated (random) activity
    - alpha > 0.5: Long-range positive correlations
    - alpha < 0.5: Long-range anti-correlations

    Supports both dataclass and plain argument interfaces.

    Args:
        data: DynamicsData container with spikes.
        config: DFAConfig with window and bin settings.
        format: DynamicsPlotFormat (mode affects plot style).
        spikes: Spike trains (time, neurons). Required if `data` not provided.
        dt: Timestep in milliseconds (for label consistency).
        min_window: Minimum window size for DFA in timesteps.
        max_window: Maximum window size. Auto-calculated if None.
        bin_size: Bin size for spike binning in timesteps.
        mode: Visualization mode (affects annotation style).
        neurons_df: Neuron metadata for potential grouping.
        **kwargs: Additional arguments.

    Returns:
        Figure with DFA summary and interpretation guide.

    Raises:
        ValueError: If spikes are not provided.

    Example:
        >>> fig = plot_dfa_analysis(spikes, bin_size=10)
    """
    # Resolve dataclass vs plain args
    if data is not None:
        spikes = data.spikes if spikes is None else spikes
        dt = data.dt if dt == 1.0 else dt
        neurons_df = data.neurons_df if neurons_df is None else neurons_df

    if config is not None:
        max_window = config.max_window if max_window is None else max_window
        bin_size = config.bin_size

    # Validate
    if spikes is None:
        raise ValueError("spikes is required")

    spikes = _to_numpy(spikes)

    # Compute DFA
    from ..analysis.dynamic_tools.criticality import calculate_dfa

    alpha = calculate_dfa(spikes, bin_size=bin_size)

    # Create simple plot showing the result
    fig, ax = plt.subplots(1, 1, figsize=(8, 6))
    ax.text(
        0.5,
        0.5,
        f"DFA Exponent (α): {alpha:.3f}",
        ha="center",
        va="center",
        fontsize=16,
        transform=ax.transAxes,
    )
    ax.text(
        0.5,
        0.4,
        "α ≈ 0.5: uncorrelated\n"
        "α > 0.5: long-range correlations\n"
        "α < 0.5: anti-correlations",
        ha="center",
        va="center",
        fontsize=10,
        transform=ax.transAxes,
    )
    ax.set_title("Detrended Fluctuation Analysis")
    ax.axis("off")
    return fig

plot_eigenvalue_spectrum(weight_matrix, ax=None)

Plot the eigenvalue spectrum of a weight matrix.

Visualizes eigenvalues in the complex plane with the spectral radius indicated by a dashed circle. Outliers (eigenvalues outside the bulk) are highlighted in red.

Parameters:

Name Type Description Default
weight_matrix ndarray | Tensor

Square connectivity matrix (N, N).

required
ax Axes | None

Existing axes to plot on. Creates new figure if None.

None

Returns:

Type Description
Figure

Tuple of (figure, axes, results) where results contains:

Axes
  • "eigenvalues": Complex array of all eigenvalues
dict
  • "spectral_radius": Radius of spectral bulk
tuple[Figure, Axes, dict]
  • "outliers": Array of outlier eigenvalues
Example

fig, ax, results = plot_eigenvalue_spectrum(W) print(f"Spectral radius: {results['spectral_radius']:.2f}")

Source code in btorch/visualisation/dynamics.py
def plot_eigenvalue_spectrum(
    weight_matrix: np.ndarray | torch.Tensor, ax: Axes | None = None
) -> tuple[Figure, Axes, dict]:
    """Plot the eigenvalue spectrum of a weight matrix.

    Visualizes eigenvalues in the complex plane with the spectral radius
    indicated by a dashed circle. Outliers (eigenvalues outside the bulk)
    are highlighted in red.

    Args:
        weight_matrix: Square connectivity matrix (N, N).
        ax: Existing axes to plot on. Creates new figure if None.

    Returns:
        Tuple of (figure, axes, results) where results contains:
        - "eigenvalues": Complex array of all eigenvalues
        - "spectral_radius": Radius of spectral bulk
        - "outliers": Array of outlier eigenvalues

    Example:
        >>> fig, ax, results = plot_eigenvalue_spectrum(W)
        >>> print(f"Spectral radius: {results['spectral_radius']:.2f}")
    """
    from ..analysis.dynamic_tools.attractor_dynamics import (
        calculate_structural_eigenvalue_outliers,
    )

    W = _to_numpy(weight_matrix)
    results = calculate_structural_eigenvalue_outliers(W)

    if ax is None:
        fig, ax = plt.subplots(figsize=(6, 6))
    else:
        fig = ax.get_figure()

    evals = results["eigenvalues"]
    r_spec = results["spectral_radius"]

    # Draw unit circle / spectral radius
    circle = plt.Circle(
        (0, 0),
        r_spec,
        color="black",
        fill=False,
        linestyle="--",
        alpha=0.5,
        label=f"R={r_spec:.2f}",
    )
    ax.add_artist(circle)

    # Scatter eigenvalues
    ax.scatter(evals.real, evals.imag, s=10, alpha=0.6, c="gray", edgecolors="none")

    # Highlight outliers
    outliers = results["outliers"]
    if len(outliers) > 0:
        ax.scatter(
            outliers.real,
            outliers.imag,
            s=30,
            c="red",
            label=f"Outliers ({len(outliers)})",
        )

    ax.set_aspect("equal")
    ax.set_xlabel(r"Re($\lambda$)")
    ax.set_ylabel(r"Im($\lambda$)")
    ax.set_title("Eigenvalue Spectrum")
    ax.grid(True, linestyle=":", alpha=0.3)
    ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.tight_layout()

    return fig, ax, results

plot_firing_rate_distribution(spikes, dt=1.0, ax=None)

Plot the distribution of firing rates across neurons.

Computes per-neuron firing rates and displays as a histogram with mean indicator.

Parameters:

Name Type Description Default
spikes ndarray | Tensor

Spike matrix with shape (time, neurons).

required
dt float

Timestep in milliseconds for rate calculation.

1.0
ax Axes | None

Existing axes to plot on. Creates new figure if None.

None

Returns:

Type Description
Figure

Tuple of (figure, stats) where stats contains:

dict
  • "rates": Array of firing rates per neuron (Hz)
tuple[Figure, dict]
  • "mean", "std", "median": Summary statistics
Example

fig, stats = plot_firing_rate_distribution(spikes, dt=1.0) print(f"Mean rate: {stats['mean']:.1f} Hz")

Source code in btorch/visualisation/dynamics.py
def plot_firing_rate_distribution(
    spikes: np.ndarray | torch.Tensor,
    dt: float = 1.0,
    ax: Axes | None = None,
) -> tuple[Figure, dict]:
    """Plot the distribution of firing rates across neurons.

    Computes per-neuron firing rates and displays as a histogram with
    mean indicator.

    Args:
        spikes: Spike matrix with shape (time, neurons).
        dt: Timestep in milliseconds for rate calculation.
        ax: Existing axes to plot on. Creates new figure if None.

    Returns:
        Tuple of (figure, stats) where stats contains:
        - "rates": Array of firing rates per neuron (Hz)
        - "mean", "std", "median": Summary statistics

    Example:
        >>> fig, stats = plot_firing_rate_distribution(spikes, dt=1.0)
        >>> print(f"Mean rate: {stats['mean']:.1f} Hz")
    """
    from ..analysis.dynamic_tools.micro_scale import calculate_fr_distribution

    spikes = _to_numpy(spikes)
    stats = calculate_fr_distribution(spikes, dt=dt)

    if ax is None:
        fig, ax = plt.subplots(figsize=(6, 4))
    else:
        fig = ax.get_figure()

    rates = stats["rates"]
    ax.hist(rates, bins=30, color="skyblue", edgecolor="black", alpha=0.7)
    ax.axvline(
        stats["mean"],
        color="red",
        linestyle="--",
        label=f"Mean={stats['mean']:.1f} Hz",
    )
    ax.set_xlabel("Firing Rate (Hz)")
    ax.set_ylabel("Count")
    ax.set_title("Rate Distribution")
    ax.legend()
    ax.grid(alpha=0.3)

    if ax is None:
        plt.tight_layout()

    return fig, stats

plot_gain_stability(data)

Plot gain stability analysis results.

Visualizes the relationship between network gain (g) and stability metrics (e.g., maximum Lyapunov exponent or spectral abscissa). A linear fit indicates consistent scaling behavior.

Parameters:

Name Type Description Default
data tuple

Tuple of (slope, intercept, g_values, lambda_values) where: - slope, intercept: Linear fit parameters - g_values: Array of gain values tested - lambda_values: Corresponding stability metrics

required

Returns:

Type Description
tuple[Figure, Axes]

Tuple of (figure, axes) with scatter plot and fit line.

Example

data = (slope, intercept, g_vals, lyap_vals) fig, ax = plot_gain_stability(data)

Source code in btorch/visualisation/dynamics.py
def plot_gain_stability(data: tuple) -> tuple[Figure, Axes]:
    """Plot gain stability analysis results.

    Visualizes the relationship between network gain (g) and stability
    metrics (e.g., maximum Lyapunov exponent or spectral abscissa).
    A linear fit indicates consistent scaling behavior.

    Args:
        data: Tuple of (slope, intercept, g_values, lambda_values) where:
            - slope, intercept: Linear fit parameters
            - g_values: Array of gain values tested
            - lambda_values: Corresponding stability metrics

    Returns:
        Tuple of (figure, axes) with scatter plot and fit line.

    Example:
        >>> data = (slope, intercept, g_vals, lyap_vals)
        >>> fig, ax = plot_gain_stability(data)
    """
    slope, intercept, g_values, lambda_values = data

    fig, ax = plt.subplots(figsize=(6, 4))

    # Plot scatter of metrics
    ax.scatter(g_values, lambda_values, label="Data", color="blue", alpha=0.6)

    # Plot fit line
    x_range = np.linspace(min(g_values), max(g_values), 100)
    y_fit = slope * x_range + intercept
    ax.plot(x_range, y_fit, "r--", label=f"Fit: y={slope:.2f}x+{intercept:.2f}")

    ax.set_xlabel("Gain (g)")
    ax.set_ylabel("Lyapunov / Eigenvalue metric")
    ax.set_title("Gain Stability Analysis")
    ax.legend()
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    return fig, ax

plot_group_box(values, neurons_df, group_by, **kwargs)

Convenience wrapper for plot_group_distribution(..., kind='box').

Source code in btorch/visualisation/aggregation.py
def plot_group_box(
    values: TensorLike,
    neurons_df: pd.DataFrame,
    group_by: str,
    **kwargs,
) -> tuple[Figure, Axes]:
    """Convenience wrapper for `plot_group_distribution(..., kind='box')`."""
    return _plot_group_distribution_kind(
        values=values,
        neurons_df=neurons_df,
        group_by=group_by,
        kind="box",
        **kwargs,
    )

plot_group_distribution(values, neurons_df, group_by, *, kind='violin', simple_id_col='simple_id', value_name='value', group_order=None, dropna=True, ax=None, figsize=(9.0, 5.0), title=None, showfliers=False, linewidth=1.5, alpha=0.8)

Plot grouped value distributions as violin, box, or ECDF.

Groups per-neuron values by a categorical column in neurons_df and visualizes the distribution using the specified plot type.

Parameters:

Name Type Description Default
values TensorLike

Per-neuron values with shape (neurons,) or (time, neurons). If 2D, values are flattened across time for each neuron.

required
neurons_df DataFrame

DataFrame with neuron metadata. Must contain simple_id_col and group_by columns.

required
group_by str

Column name in neurons_df to group by (e.g., "cell_type").

required
kind GroupPlotKind

Plot type - "violin", "box", or "ecdf".

'violin'
simple_id_col str

Column name for neuron identifiers in neurons_df.

'simple_id'
value_name str

Label for the y-axis / value dimension.

'value'
group_order Sequence | None

Explicit order for groups. If None, uses natural sort.

None
dropna bool

Whether to drop NaN values before plotting.

True
ax Axes | None

Existing axes to plot on. If None, creates new figure.

None
figsize tuple[float, float]

Figure size (width, height) in inches.

(9.0, 5.0)
title str | None

Plot title. If None, auto-generated from kind and group_by.

None
showfliers bool

For box plots, whether to show outlier points.

False
linewidth float

Line width for ECDF curves.

1.5
alpha float

Opacity for violin/box fill.

0.8

Returns:

Type Description
tuple[Figure, Axes]

Tuple of (figure, axes) containing the plot.

Raises:

Type Description
ValueError

If kind is not one of "violin", "box", "ecdf".

Example

fig, ax = plot_group_distribution( ... firing_rates, neurons_df, group_by="cell_type", kind="violin" ... )

Source code in btorch/visualisation/aggregation.py
def plot_group_distribution(
    values: TensorLike,
    neurons_df: pd.DataFrame,
    group_by: str,
    *,
    kind: GroupPlotKind = "violin",
    simple_id_col: str = "simple_id",
    value_name: str = "value",
    group_order: Sequence | None = None,
    dropna: bool = True,
    ax: Axes | None = None,
    figsize: tuple[float, float] = (9.0, 5.0),
    title: str | None = None,
    showfliers: bool = False,
    linewidth: float = 1.5,
    alpha: float = 0.8,
) -> tuple[Figure, Axes]:
    """Plot grouped value distributions as violin, box, or ECDF.

    Groups per-neuron values by a categorical column in `neurons_df` and
    visualizes the distribution using the specified plot type.

    Args:
        values: Per-neuron values with shape (neurons,) or (time, neurons).
            If 2D, values are flattened across time for each neuron.
        neurons_df: DataFrame with neuron metadata. Must contain `simple_id_col`
            and `group_by` columns.
        group_by: Column name in `neurons_df` to group by (e.g., "cell_type").
        kind: Plot type - "violin", "box", or "ecdf".
        simple_id_col: Column name for neuron identifiers in `neurons_df`.
        value_name: Label for the y-axis / value dimension.
        group_order: Explicit order for groups. If None, uses natural sort.
        dropna: Whether to drop NaN values before plotting.
        ax: Existing axes to plot on. If None, creates new figure.
        figsize: Figure size (width, height) in inches.
        title: Plot title. If None, auto-generated from `kind` and `group_by`.
        showfliers: For box plots, whether to show outlier points.
        linewidth: Line width for ECDF curves.
        alpha: Opacity for violin/box fill.

    Returns:
        Tuple of (figure, axes) containing the plot.

    Raises:
        ValueError: If `kind` is not one of "violin", "box", "ecdf".

    Example:
        >>> fig, ax = plot_group_distribution(
        ...     firing_rates, neurons_df, group_by="cell_type", kind="violin"
        ... )
    """
    fig, ax = _resolve_figure_ax(ax=ax, figsize=figsize)

    if kind in {"violin", "box"}:
        grouped = group_values(
            values,
            neurons_df,
            group_by,
            simple_id_col=simple_id_col,
            value_name=value_name,
            group_order=group_order,
            dropna=dropna,
        )
        order = list(grouped.keys())
        grouped_arrays = [grouped[group] for group in order]

        if kind == "violin":
            _plot_violin(ax, grouped_arrays, order, alpha=alpha)
        else:
            _plot_box(
                ax,
                grouped_arrays,
                order,
                showfliers=showfliers,
                alpha=alpha,
            )

        ax.set_xlabel(group_by)
        ax.set_ylabel(value_name)
    elif kind == "ecdf":
        ecdf_by_group = group_ecdf(
            values,
            neurons_df,
            group_by,
            simple_id_col=simple_id_col,
            value_name=value_name,
            group_order=group_order,
            dropna=dropna,
        )
        _plot_ecdf(
            ax,
            ecdf_by_group,
            group_by=group_by,
            value_name=value_name,
            linewidth=linewidth,
        )
    else:
        raise ValueError(f"Unsupported kind `{kind}`.")

    if title is None:
        title = f"{kind.upper()} grouped by {group_by}"
    ax.set_title(title)
    ax.grid(alpha=0.2)

    return fig, ax

plot_group_ecdf(values, neurons_df, group_by, **kwargs)

Convenience wrapper for plot_group_distribution(..., kind='ecdf').

Source code in btorch/visualisation/aggregation.py
def plot_group_ecdf(
    values: TensorLike,
    neurons_df: pd.DataFrame,
    group_by: str,
    **kwargs,
) -> tuple[Figure, Axes]:
    """Convenience wrapper for `plot_group_distribution(..., kind='ecdf')`."""
    return _plot_group_distribution_kind(
        values=values,
        neurons_df=neurons_df,
        group_by=group_by,
        kind="ecdf",
        **kwargs,
    )

plot_group_violin(values, neurons_df, group_by, **kwargs)

Convenience wrapper for plot_group_distribution(..., kind='violin').

Source code in btorch/visualisation/aggregation.py
def plot_group_violin(
    values: TensorLike,
    neurons_df: pd.DataFrame,
    group_by: str,
    **kwargs,
) -> tuple[Figure, Axes]:
    """Convenience wrapper for `plot_group_distribution(...,
    kind='violin')`."""
    return _plot_group_distribution_kind(
        values=values,
        neurons_df=neurons_df,
        group_by=group_by,
        kind="violin",
        **kwargs,
    )

plot_isi_cv(data=None, format=None, spikes=None, dt=1.0, mode='individual', neurons_df=None, group_by=None, neuron_type_column='cell_type', **kwargs)

Plot ISI CV (Coefficient of Variation) distribution.

ISI CV measures spike train irregularity: - CV = 1: Poisson-like (irregular) firing - CV < 1: Regular, periodic firing - CV > 1: Bursty, irregular firing

Supports histogram view for distributions and bar plots for grouped comparisons.

Parameters:

Name Type Description Default
data DynamicsData | None

DynamicsData container with spikes and metadata.

None
format DynamicsPlotFormat | None

DynamicsPlotFormat with visualization settings.

None
spikes ndarray | Tensor | None

Spike trains (time, neurons). Required if data not provided.

None
dt float

Timestep in milliseconds for ISI calculation.

1.0
mode Literal['individual', 'grouped', 'distribution']

Visualization mode - "distribution", "individual", or "grouped".

'individual'
neurons_df DataFrame | None

DataFrame with neuron metadata for grouping.

None
group_by Literal['neuropil', 'neuron_type', None]

Grouping method - "neuropil" or "neuron_type".

None
neuron_type_column str

Column name for neuron types in neurons_df.

'cell_type'
**kwargs

Additional arguments.

{}

Returns:

Type Description
Figure

Figure with ISI CV histogram or grouped bar plot.

Raises:

Type Description
ValueError

If spikes are not provided, or if grouped mode is requested without required metadata.

Example

fig = plot_isi_cv(spikes, dt=1.0, mode="distribution")

Grouped by cell type

fig = plot_isi_cv(spikes, neurons_df=df, ... mode="grouped", group_by="neuron_type")

Source code in btorch/visualisation/dynamics.py
def plot_isi_cv(
    # Dataclass interface
    data: DynamicsData | None = None,
    format: DynamicsPlotFormat | None = None,
    # Plain args interface
    spikes: np.ndarray | torch.Tensor | None = None,
    dt: float = 1.0,
    mode: Literal["individual", "grouped", "distribution"] = "individual",
    neurons_df: pd.DataFrame | None = None,
    group_by: Literal["neuropil", "neuron_type", None] = None,
    neuron_type_column: str = "cell_type",
    **kwargs,
) -> Figure:
    """Plot ISI CV (Coefficient of Variation) distribution.

    ISI CV measures spike train irregularity:
    - CV = 1: Poisson-like (irregular) firing
    - CV < 1: Regular, periodic firing
    - CV > 1: Bursty, irregular firing

    Supports histogram view for distributions and bar plots for grouped
    comparisons.

    Args:
        data: DynamicsData container with spikes and metadata.
        format: DynamicsPlotFormat with visualization settings.
        spikes: Spike trains (time, neurons). Required if `data` not provided.
        dt: Timestep in milliseconds for ISI calculation.
        mode: Visualization mode - "distribution", "individual", or "grouped".
        neurons_df: DataFrame with neuron metadata for grouping.
        group_by: Grouping method - "neuropil" or "neuron_type".
        neuron_type_column: Column name for neuron types in neurons_df.
        **kwargs: Additional arguments.

    Returns:
        Figure with ISI CV histogram or grouped bar plot.

    Raises:
        ValueError: If spikes are not provided, or if grouped mode is
            requested without required metadata.

    Example:
        >>> fig = plot_isi_cv(spikes, dt=1.0, mode="distribution")
        >>>
        >>> # Grouped by cell type
        >>> fig = plot_isi_cv(spikes, neurons_df=df,
        ...                   mode="grouped", group_by="neuron_type")
    """
    # Resolve dataclass vs plain args
    if data is not None:
        spikes = data.spikes if spikes is None else spikes
        dt = data.dt if dt == 1.0 else dt
        neurons_df = data.neurons_df if neurons_df is None else neurons_df

    if format is not None:
        mode = format.mode
        group_by = format.group_by if group_by is None else group_by
        neuron_type_column = format.neuron_type_column

    # Validate
    if spikes is None:
        raise ValueError("spikes is required")

    spikes = _to_numpy(spikes)

    # Compute ISI CV
    cv_results = calculate_cv_isi(spikes, dt=dt)
    cv_values = cv_results["cv_isi"]

    # Create figure based on mode
    fig, ax = plt.subplots(1, 1, figsize=(10, 6))

    if mode == "distribution" or mode == "individual":
        # Histogram
        valid_cv = cv_values[~np.isnan(cv_values)]
        ax.hist(valid_cv, bins=30, color="skyblue", edgecolor="black", alpha=0.7)
        ax.axvline(
            cv_results["mean"],
            color="red",
            linestyle="--",
            label=f"Mean={cv_results['mean']:.2f}",
        )
        ax.set_xlabel("ISI CV")
        ax.set_ylabel("Count")
        ax.set_title("ISI Coefficient of Variation Distribution")
        ax.legend()
        ax.grid(alpha=0.3)

    elif mode == "grouped":
        if group_by is None or neurons_df is None:
            raise ValueError("group_by and neurons_df required for grouped mode")

        # Group by neuron type
        grouped = agg_by_neuron(
            cv_values, neurons_df, agg="mean", neuron_type_column=neuron_type_column
        )

        # Bar plot
        names = list(grouped.keys())
        values = list(grouped.values())
        ax.bar(names, values, color="teal", alpha=0.7, edgecolor="black")
        ax.set_xlabel("Neuron Type")
        ax.set_ylabel("Mean ISI CV")
        ax.set_title(f"ISI CV by {neuron_type_column}")
        ax.grid(axis="y", alpha=0.3)
        plt.xticks(rotation=45, ha="right")

    plt.tight_layout()
    return fig

plot_log_hist(values, ax=None, title='Distribution', xlabel='Value', **kwargs)

Plot log-log histogram with logarithmic binning.

Creates a scatter plot of histogram counts using logarithmically spaced bins. Useful for visualizing heavy-tailed distributions (e.g., power laws).

Parameters:

Name Type Description Default
values Union[ndarray, Tensor]

Input values to histogram. Flattened if multidimensional.

required
ax Axes | None

Existing axes to plot on. Creates new figure if None.

None
title str

Plot title.

'Distribution'
xlabel str

X-axis label.

'Value'
**kwargs

Additional arguments passed to ax.scatter().

{}

Returns:

Type Description
Axes

Axes containing the log-log histogram.

Example

ax = plot_log_hist(synapse_weights, title="Weight Distribution")

Source code in btorch/visualisation/timeseries.py
def plot_log_hist(
    values: Union[np.ndarray, torch.Tensor],
    ax: Axes | None = None,
    title: str = "Distribution",
    xlabel: str = "Value",
    **kwargs,
) -> Axes:
    """Plot log-log histogram with logarithmic binning.

    Creates a scatter plot of histogram counts using logarithmically
    spaced bins. Useful for visualizing heavy-tailed distributions
    (e.g., power laws).

    Args:
        values: Input values to histogram. Flattened if multidimensional.
        ax: Existing axes to plot on. Creates new figure if None.
        title: Plot title.
        xlabel: X-axis label.
        **kwargs: Additional arguments passed to ax.scatter().

    Returns:
        Axes containing the log-log histogram.

    Example:
        >>> ax = plot_log_hist(synapse_weights, title="Weight Distribution")
    """
    vals = _to_numpy(values)
    hist, bin_centers = compute_log_hist(vals)

    if ax is None:
        fig, ax = plt.subplots(figsize=(6, 4))

    ax.scatter(bin_centers, hist, **kwargs)
    ax.set_xscale("log")
    ax.set_yscale("log")
    ax.set_xlabel(xlabel)
    ax.set_ylabel("Count")
    ax.set_title(title)

    return ax

plot_lyapunov_spectrum(spectrum, ax=None)

Plot the Lyapunov exponents spectrum.

Displays Lyapunov exponents sorted by magnitude. Positive exponents indicate chaos; the number of non-negative exponents relates to the Kaplan-Yorke dimension.

Parameters:

Name Type Description Default
spectrum list[float] | ndarray

List or array of Lyapunov exponents.

required
ax Axes | None

Existing axes to plot on. Creates new figure if None.

None

Returns:

Type Description
tuple[Figure, Axes]

Tuple of (figure, axes) with the spectrum plot.

Example

fig, ax = plot_lyapunov_spectrum(lyap_spectrum)

Positive exponents indicate chaotic dynamics
Source code in btorch/visualisation/dynamics.py
def plot_lyapunov_spectrum(
    spectrum: list[float] | np.ndarray, ax: Axes | None = None
) -> tuple[Figure, Axes]:
    """Plot the Lyapunov exponents spectrum.

    Displays Lyapunov exponents sorted by magnitude. Positive exponents
    indicate chaos; the number of non-negative exponents relates to the
    Kaplan-Yorke dimension.

    Args:
        spectrum: List or array of Lyapunov exponents.
        ax: Existing axes to plot on. Creates new figure if None.

    Returns:
        Tuple of (figure, axes) with the spectrum plot.

    Example:
        >>> fig, ax = plot_lyapunov_spectrum(lyap_spectrum)
        >>> # Positive exponents indicate chaotic dynamics
    """
    from ..analysis.dynamic_tools.attractor_dynamics import (
        calculate_kaplan_yorke_dimension,
    )

    spec = _to_numpy(spectrum)
    # Sort descending just in case, though standard is descending
    spec = np.sort(spec)[::-1]

    if ax is None:
        fig, ax = plt.subplots(figsize=(6, 4))
    else:
        fig = ax.get_figure()

    x = np.arange(1, len(spec) + 1)
    ax.plot(x, spec, "o-", markersize=4, linewidth=1, color="black")
    ax.axhline(0, color="k", linestyle="--", linewidth=0.8)

    # Calculate Kaplan-Yorke Dim
    ky_dim = calculate_kaplan_yorke_dimension(spec)

    title = f"Lyapunov Spectrum (D_KY = {ky_dim:.2f})"
    ax.set_title(title)
    ax.set_xlabel("Index")
    ax.set_ylabel("Lyapunov Exponent")
    ax.grid(True, alpha=0.3)
    plt.tight_layout()

    return fig, ax

plot_micro_dynamics(spikes, dt=1.0, ax=None)

Plot firing rate and ISI CV distributions side-by-side.

Creates a 2-panel figure summarizing micro-scale dynamics: - Left: Firing rate distribution histogram - Right: ISI CV distribution histogram

Parameters:

Name Type Description Default
spikes ndarray | Tensor

Spike matrix with shape (time, neurons).

required
dt float

Timestep in milliseconds.

1.0
ax Axes | None

Unused parameter (kept for API compatibility).

None

Returns:

Type Description
Figure

Tuple of (figure, stats) where stats is a dict with keys:

dict
  • "fr": Firing rate statistics
tuple[Figure, dict]
  • "cv": ISI CV statistics
Example

fig, stats = plot_micro_dynamics(spikes, dt=1.0) print(f"Rate: {stats['fr']['mean']:.1f} Hz, CV: {stats['cv']['mean']:.2f}")

Source code in btorch/visualisation/dynamics.py
def plot_micro_dynamics(
    spikes: np.ndarray | torch.Tensor,
    dt: float = 1.0,
    ax: Axes | None = None,
) -> tuple[Figure, dict]:
    """Plot firing rate and ISI CV distributions side-by-side.

    Creates a 2-panel figure summarizing micro-scale dynamics:
    - Left: Firing rate distribution histogram
    - Right: ISI CV distribution histogram

    Args:
        spikes: Spike matrix with shape (time, neurons).
        dt: Timestep in milliseconds.
        ax: Unused parameter (kept for API compatibility).

    Returns:
        Tuple of (figure, stats) where stats is a dict with keys:
        - "fr": Firing rate statistics
        - "cv": ISI CV statistics

    Example:
        >>> fig, stats = plot_micro_dynamics(spikes, dt=1.0)
        >>> print(f"Rate: {stats['fr']['mean']:.1f} Hz, CV: {stats['cv']['mean']:.2f}")
    """
    from ..analysis.dynamic_tools.micro_scale import calculate_cv_isi

    # Plot FR
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

    _, fr_stats = plot_firing_rate_distribution(spikes, dt=dt, ax=ax1)

    # Plot CV (Re-implementing simplified version or using plot_isi_cv logic)
    # reusing logic from plot_isi_cv for consistency but without full overhead
    spikes_np = _to_numpy(spikes)
    cv_results = calculate_cv_isi(spikes_np, dt=dt)
    cv_values = cv_results["cv_isi"]
    valid_cv = cv_values[~np.isnan(cv_values)]

    if len(valid_cv) > 0:
        ax2.hist(valid_cv, bins=30, color="orange", edgecolor="black", alpha=0.7)
        ax2.axvline(
            cv_results["mean"],
            color="red",
            linestyle="--",
            label=f"Mean={cv_results['mean']:.2f}",
        )
        ax2.set_xlabel("CV (ISI)")
        ax2.set_ylabel("Count")
        ax2.set_title("CV Distribution")
        ax2.legend()
    else:
        ax2.text(0.5, 0.5, "No valid CVs", ha="center")
    ax2.grid(alpha=0.3)

    plt.tight_layout()
    return fig, {"fr": fr_stats, "cv": cv_results}

plot_multiscale_fano(data=None, config=None, format=None, spikes=None, dt=1.0, windows=None, overlap=0, mode='individual', neurons_df=None, connections_df=None, group_by=None, neuron_type_column='cell_type', neuron_indices=None, **kwargs)

Plot multiscale Fano factor analysis.

Computes and visualizes Fano factor (spike count variance/mean) across multiple time windows. Supports three visualization modes: - "individual": Line plots for selected neurons - "grouped": Aggregated by neuron type or neuropil - "distribution": Violin plots showing population distribution

Supports both dataclass and plain argument interfaces. Dataclass arguments take precedence when both are provided.

Parameters:

Name Type Description Default
data DynamicsData | None

DynamicsData container with spikes and metadata.

None
config FanoFactorConfig | None

FanoFactorConfig with window settings.

None
format DynamicsPlotFormat | None

DynamicsPlotFormat with visualization options.

None
spikes ndarray | Tensor | None

Spike trains with shape (time, neurons). Required if data is not provided.

None
dt float

Timestep in milliseconds. Default 1.0.

1.0
windows list[int] | None

List of window sizes in timesteps. Auto-generated if None.

None
overlap int

Window overlap in timesteps. Default 0.

0
mode Literal['individual', 'grouped', 'distribution']

Visualization mode - "individual", "grouped", "distribution".

'individual'
neurons_df DataFrame | None

DataFrame with neuron metadata for grouping.

None
connections_df DataFrame | None

DataFrame with connection metadata for neuropil grouping.

None
group_by Literal['neuropil', 'neuron_type', None]

Grouping method - "neuropil" or "neuron_type".

None
neuron_type_column str

Column name for neuron types in neurons_df.

'cell_type'
neuron_indices list[int] | None

Specific neuron indices for "individual" mode. If None, first 10 neurons are plotted.

None
**kwargs

Additional arguments passed to plotting functions.

{}

Returns:

Type Description
Figure

Figure with Fano factor plots.

Raises:

Type Description
ValueError

If spikes are not provided through either data or spikes argument.

Example
Plain args interface

fig = plot_multiscale_fano(spikes, dt=1.0, mode="distribution")

Dataclass interface

data = DynamicsData(spikes=spikes, dt=1.0, neurons_df=df) config = FanoFactorConfig(windows=[10, 50, 100]) fig = plot_multiscale_fano(data=data, config=config)

Source code in btorch/visualisation/dynamics.py
def plot_multiscale_fano(
    # Dataclass interface
    data: DynamicsData | None = None,
    config: FanoFactorConfig | None = None,
    format: DynamicsPlotFormat | None = None,
    # Plain args interface
    spikes: np.ndarray | torch.Tensor | None = None,
    dt: float = 1.0,
    windows: list[int] | None = None,
    overlap: int = 0,
    mode: Literal["individual", "grouped", "distribution"] = "individual",
    neurons_df: pd.DataFrame | None = None,
    connections_df: pd.DataFrame | None = None,
    group_by: Literal["neuropil", "neuron_type", None] = None,
    neuron_type_column: str = "cell_type",
    neuron_indices: list[int] | None = None,
    **kwargs,
) -> Figure:
    """Plot multiscale Fano factor analysis.

    Computes and visualizes Fano factor (spike count variance/mean) across
    multiple time windows. Supports three visualization modes:
    - "individual": Line plots for selected neurons
    - "grouped": Aggregated by neuron type or neuropil
    - "distribution": Violin plots showing population distribution

    Supports both dataclass and plain argument interfaces. Dataclass
    arguments take precedence when both are provided.

    Args:
        data: DynamicsData container with spikes and metadata.
        config: FanoFactorConfig with window settings.
        format: DynamicsPlotFormat with visualization options.
        spikes: Spike trains with shape (time, neurons). Required if
            `data` is not provided.
        dt: Timestep in milliseconds. Default 1.0.
        windows: List of window sizes in timesteps. Auto-generated if None.
        overlap: Window overlap in timesteps. Default 0.
        mode: Visualization mode - "individual", "grouped", "distribution".
        neurons_df: DataFrame with neuron metadata for grouping.
        connections_df: DataFrame with connection metadata for neuropil
            grouping.
        group_by: Grouping method - "neuropil" or "neuron_type".
        neuron_type_column: Column name for neuron types in neurons_df.
        neuron_indices: Specific neuron indices for "individual" mode.
            If None, first 10 neurons are plotted.
        **kwargs: Additional arguments passed to plotting functions.

    Returns:
        Figure with Fano factor plots.

    Raises:
        ValueError: If spikes are not provided through either `data` or
            `spikes` argument.

    Example:
        >>> # Plain args interface
        >>> fig = plot_multiscale_fano(spikes, dt=1.0, mode="distribution")
        >>>
        >>> # Dataclass interface
        >>> data = DynamicsData(spikes=spikes, dt=1.0, neurons_df=df)
        >>> config = FanoFactorConfig(windows=[10, 50, 100])
        >>> fig = plot_multiscale_fano(data=data, config=config)
    """
    # Resolve dataclass vs plain args
    if data is not None:
        spikes = data.spikes if spikes is None else spikes
        dt = data.dt if dt == 1.0 else dt
        neurons_df = data.neurons_df if neurons_df is None else neurons_df
        connections_df = (
            data.connections_df if connections_df is None else connections_df
        )

    if config is not None:
        windows = config.windows if windows is None else windows
        overlap = config.overlap if overlap == 0 else overlap

    if format is not None:
        mode = format.mode
        group_by = format.group_by if group_by is None else group_by
        neuron_type_column = format.neuron_type_column
        neuron_indices = (
            format.neuron_indices if neuron_indices is None else neuron_indices
        )

    # Validate
    if spikes is None:
        raise ValueError("spikes is required")

    spikes = _to_numpy(spikes)
    n_time, n_neurons = spikes.shape

    # Default windows: logarithmically spaced
    if windows is None:
        windows = [int(w) for w in np.logspace(1, np.log10(n_time // 4), 10)]

    # Compute Fano factor for each window
    fano_results = {}
    for w in windows:
        fano_values, info = fano(spikes, window=w, overlap=overlap)
        fano_results[w] = fano_values

    # Create figure based on mode
    if mode == "individual":
        return _plot_fano_individual(
            fano_results, windows, dt, neuron_indices, n_neurons
        )
    elif mode == "grouped":
        if group_by is None:
            raise ValueError("group_by must be specified for grouped mode")
        return _plot_fano_grouped(
            fano_results,
            windows,
            dt,
            spikes,
            neurons_df,
            connections_df,
            group_by,
            neuron_type_column,
        )
    elif mode == "distribution":
        return _plot_fano_distribution(fano_results, windows, dt)
    else:
        raise ValueError(f"Unknown mode: {mode}")

plot_network(sparse_mat, ax=None)

Plot a network graph from a sparse connectivity matrix.

Uses NetworkX spring layout to visualize the graph structure. Nodes are colored skyblue, edges are gray.

Parameters:

Name Type Description Default
sparse_mat

Sparse matrix (scipy.sparse) representing connections. Non-zero entries indicate edges.

required
ax Axes | None

Existing axes to plot on. If None, creates new figure.

None

Returns:

Type Description
Figure

Figure containing the network plot.

Raises:

Type Description
ImportError

If networkx is not installed.

Example

from scipy.sparse import random mat = random(50, 50, density=0.1, format="csr") fig = plot_network(mat)

Source code in btorch/visualisation/network.py
def plot_network(sparse_mat, ax: Axes | None = None) -> Figure:
    """Plot a network graph from a sparse connectivity matrix.

    Uses NetworkX spring layout to visualize the graph structure.
    Nodes are colored skyblue, edges are gray.

    Args:
        sparse_mat: Sparse matrix (scipy.sparse) representing connections.
            Non-zero entries indicate edges.
        ax: Existing axes to plot on. If None, creates new figure.

    Returns:
        Figure containing the network plot.

    Raises:
        ImportError: If networkx is not installed.

    Example:
        >>> from scipy.sparse import random
        >>> mat = random(50, 50, density=0.1, format="csr")
        >>> fig = plot_network(mat)
    """
    import matplotlib.pyplot as plt
    import networkx as nx

    G = nx.from_scipy_sparse_array(sparse_mat)
    pos = nx.spring_layout(G)

    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 8))
    else:
        fig = ax.figure

    nx.draw(
        G,
        pos,
        with_labels=True,
        node_color="skyblue",
        edge_color="gray",
        ax=ax,
    )
    ax.set_title("Network Graph")
    return fig

plot_neuron_traces(states=None, format=None, voltage=None, dt=1.0, asc=None, psc=None, epsc=None, ipsc=None, input=None, psc_labels=None, spikes=None, v_threshold=None, v_reset=None, neuron_indices=None, sample_size=None, seed=42, show_voltage=True, show_asc=True, show_psc=True, neuron_labels=None, neuron_label_position='side', neuron_specs=None, neurons_df=None, separate_figures=False, auto_width=True, neurons_per_row=None, batch_idx=None)

Plot neuron state traces with flexible interface.

Supports both dataclass and plain argument interfaces. Each neuron gets a row of subplots showing voltage, ASC, and PSC traces.

Parameters:

Name Type Description Default
states SimulationStates | DataFrame | None

SimulationStates dataclass with all state data

None
format TracePlotFormat | None

TracePlotFormat dataclass with formatting options

None
voltage ndarray | Tensor | None

Voltage traces (time, neurons) or (time, batch, neurons)

None
dt float

Timestep in ms

1.0
asc ndarray | Tensor | None

Afterspike current traces (time, neurons), (time, batch, neurons), or (time, batch, neurons, n_asc) for multiple ASC components

None
psc ndarray | Tensor | None

Postsynaptic current traces (time, neurons), (time, batch, neurons), or (time, batch, neurons, n_psc) for multiple PSC components. If psc has additional dims (n_psc > 1), epsc, ipsc, and input should be None.

None
epsc ndarray | Tensor | None

Excitatory PSC traces (time, neurons) or (time, batch, neurons)

None
ipsc ndarray | Tensor | None

Inhibitory PSC traces (time, neurons) or (time, batch, neurons)

None
input ndarray | Tensor | None

Input current traces (time, neurons) or (time, batch, neurons)

None
psc_labels Sequence[str] | None

Labels for PSC components when psc has shape (time, neurons, n_psc) or (time, batch, neurons, n_psc). If None, defaults to ["PSC_0", "PSC_1", ...].

None
spikes ndarray | Tensor | None

Spike trains (time, neurons) or (time, batch, neurons)

None
v_threshold float | Sequence[float] | ndarray | Tensor | None

Spike threshold(s), scalar or per-neuron values

None
v_reset float | Sequence[float] | ndarray | Tensor | None

Reset voltage reference line(s), scalar or per-neuron values

None
neuron_indices list[int] | None

Specific neurons to plot

None
sample_size int | None

Number of neurons to randomly sample

None
seed int

Random seed for sampling

42
show_voltage bool

Show voltage subplot

True
show_asc bool

Show ASC subplot

True
show_psc bool

Show PSC subplot

True
neuron_labels Sequence[str] | Callable[[int], str] | None

Side labels as sequence or callable(neuron_idx) -> str. Default None disables side labels.

None
neuron_label_position Literal['side', 'top']

Position for neuron labels when enabled. "side" or "top".

'side'
neuron_specs list[NeuronSpec | dict] | NeuronSpec | dict | None

Specifications for per-neuron styling (scalar or list)

None
neurons_df DataFrame | None

DataFrame with neuron metadata for labels

None
separate_figures bool

Return dict of figures (one per trace type)

False
auto_width bool

Adjust width based on duration

True
neurons_per_row int | None

Number of neurons per row in combined figure

None
batch_idx int | None

Batch index to plot when data has shape (time, batch, neurons). If None and data is 3D, defaults to 0.

None

Returns:

Type Description
Figure | dict[str, Figure]

Figure with neuron trace subplots OR dict of Figures

Source code in btorch/visualisation/timeseries.py
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
def plot_neuron_traces(
    # Dataclass interface
    states: SimulationStates | pd.DataFrame | None = None,
    format: TracePlotFormat | None = None,
    # Plain args interface
    voltage: np.ndarray | torch.Tensor | None = None,
    dt: float = 1.0,
    asc: np.ndarray | torch.Tensor | None = None,
    psc: np.ndarray | torch.Tensor | None = None,
    epsc: np.ndarray | torch.Tensor | None = None,
    ipsc: np.ndarray | torch.Tensor | None = None,
    input: np.ndarray | torch.Tensor | None = None,
    psc_labels: Sequence[str] | None = None,
    spikes: np.ndarray | torch.Tensor | None = None,
    v_threshold: float | Sequence[float] | np.ndarray | torch.Tensor | None = None,
    v_reset: float | Sequence[float] | np.ndarray | torch.Tensor | None = None,
    neuron_indices: list[int] | None = None,
    sample_size: int | None = None,
    seed: int = 42,
    show_voltage: bool = True,
    show_asc: bool = True,
    show_psc: bool = True,
    neuron_labels: Sequence[str] | Callable[[int], str] | None = None,
    neuron_label_position: Literal["side", "top"] = "side",
    neuron_specs: list[NeuronSpec | dict] | NeuronSpec | dict | None = None,
    neurons_df: pd.DataFrame | None = None,
    separate_figures: bool = False,
    auto_width: bool = True,
    neurons_per_row: int | None = None,
    batch_idx: int | None = None,
) -> Figure | dict[str, Figure]:
    """Plot neuron state traces with flexible interface.

    Supports both dataclass and plain argument interfaces. Each neuron gets
    a row of subplots showing voltage, ASC, and PSC traces.

    Args:
        states: SimulationStates dataclass with all state data
        format: TracePlotFormat dataclass with formatting options
        voltage: Voltage traces (time, neurons) or (time, batch, neurons)
        dt: Timestep in ms
        asc: Afterspike current traces (time, neurons), (time, batch, neurons),
            or (time, batch, neurons, n_asc) for multiple ASC components
        psc: Postsynaptic current traces (time, neurons), (time, batch, neurons),
            or (time, batch, neurons, n_psc) for multiple PSC components.
            If psc has additional dims (n_psc > 1), epsc, ipsc, and input
            should be None.
        epsc: Excitatory PSC traces (time, neurons) or (time, batch, neurons)
        ipsc: Inhibitory PSC traces (time, neurons) or (time, batch, neurons)
        input: Input current traces (time, neurons) or (time, batch, neurons)
        psc_labels: Labels for PSC components when psc has shape
            (time, neurons, n_psc) or (time, batch, neurons, n_psc).
            If None, defaults to ["PSC_0", "PSC_1", ...].
        spikes: Spike trains (time, neurons) or (time, batch, neurons)
        v_threshold: Spike threshold(s), scalar or per-neuron values
        v_reset: Reset voltage reference line(s), scalar or per-neuron values
        neuron_indices: Specific neurons to plot
        sample_size: Number of neurons to randomly sample
        seed: Random seed for sampling
        show_voltage: Show voltage subplot
        show_asc: Show ASC subplot
        show_psc: Show PSC subplot
        neuron_labels: Side labels as sequence or callable(neuron_idx) -> str.
            Default None disables side labels.
        neuron_label_position: Position for neuron labels when enabled.
            "side" or "top".
        neuron_specs: Specifications for per-neuron styling (scalar or list)
        neurons_df: DataFrame with neuron metadata for labels
        separate_figures: Return dict of figures (one per trace type)
        auto_width: Adjust width based on duration
        neurons_per_row: Number of neurons per row in combined figure
        batch_idx: Batch index to plot when data has shape (time, batch, neurons).
            If None and data is 3D, defaults to 0.

    Returns:
        Figure with neuron trace subplots OR dict of Figures
    """
    # Resolve dataclass vs plain args
    if states is not None:
        voltage = states.voltage if voltage is None else voltage
        dt = states.dt if dt == 1.0 else dt
        asc = states.asc if asc is None else asc
        psc = states.psc if psc is None else psc
        epsc = states.epsc if epsc is None else epsc
        ipsc = states.ipsc if ipsc is None else ipsc
        input = states.input if input is None else input
        spikes = states.spikes if spikes is None else spikes
        v_threshold = states.v_threshold if v_threshold is None else v_threshold
        v_reset = states.v_reset if v_reset is None else v_reset

    if format is not None:
        neuron_indices = (
            format.neuron_indices if neuron_indices is None else neuron_indices
        )
        sample_size = format.sample_size if sample_size is None else sample_size
        seed = format.seed if seed == 42 else seed
        show_voltage = format.show_voltage
        show_asc = format.show_asc
        show_psc = format.show_psc
        neuron_labels = format.neuron_labels if neuron_labels is None else neuron_labels
        neuron_label_position = format.neuron_label_position
        neuron_specs = format.neuron_specs if neuron_specs is None else neuron_specs
        separate_figures = format.separate_figures
        auto_width = format.auto_width
        neurons_per_row = (
            format.neurons_per_row if neurons_per_row is None else neurons_per_row
        )
        batch_idx = format.batch_idx if batch_idx is None else batch_idx

    # Validate required data
    if voltage is None:
        raise ValueError("voltage is required (provide via states or direct arg)")

    # Default batch_idx to 0 if data is 3D and no index specified
    if batch_idx is None:
        batch_idx = 0

    # Check if PSC has additional dimensions (n_psc > 1) BEFORE batch extraction
    # This must be done first because 3D PSC (time, neurons, n_psc) would be
    # incorrectly treated as (time, batch, neurons) by _extract_batch_dim
    _psc_temp = _to_numpy(psc) if psc is not None else None
    psc_has_extra_dim = False
    n_psc_components = 1
    if _psc_temp is not None and _psc_temp.ndim == 3:
        # Check if shape matches (time, neurons, n_psc) vs (time, batch, neurons)
        # We need voltage shape to determine n_neurons
        _voltage_temp = _to_numpy(voltage)
        n_neurons_from_v = _voltage_temp.shape[1]
        if _psc_temp.shape[1] == n_neurons_from_v:
            # Shape is (time, neurons, n_psc) - has extra dimension
            psc_has_extra_dim = True
            n_psc_components = _psc_temp.shape[2]
            # Validate that epsc, ipsc, input are None when PSC has extra dims
            if epsc is not None:
                raise ValueError(
                    "epsc must be None when psc has additional dimension (n_psc > 1)"
                )
            if ipsc is not None:
                raise ValueError(
                    "ipsc must be None when psc has additional dimension (n_psc > 1)"
                )
            if input is not None:
                raise ValueError(
                    "input must be None when psc has additional dimension (n_psc > 1)"
                )
            # Default labels if not provided
            if psc_labels is None:
                psc_labels = [f"PSC_{i}" for i in range(n_psc_components)]
    # Clean up temp variables

    # Extract batch dimension from all data arrays
    voltage = _extract_batch_dim(voltage, batch_idx)
    spikes = _extract_batch_dim(spikes, batch_idx)
    asc = _extract_batch_dim(asc, batch_idx)
    # PSC with extra dims is handled separately (no batch dim expected)
    if not psc_has_extra_dim:
        psc = _extract_batch_dim(psc, batch_idx)
    else:
        # For multi-dim PSC, we still need to validate batch_idx
        psc_arr_check = _to_numpy(psc)
        if psc_arr_check.ndim == 4:  # (time, batch, neurons, n_psc)
            if batch_idx >= psc_arr_check.shape[1]:
                raise ValueError(
                    f"batch_idx {batch_idx} out of bounds for batch dim "
                    f"{psc_arr_check.shape[1]}"
                )
            psc = psc_arr_check[:, batch_idx, :, :]
        elif psc_arr_check.ndim == 3:  # (time, neurons, n_psc)
            # No batch dim, use as-is
            psc = psc_arr_check
    epsc = _extract_batch_dim(epsc, batch_idx)
    ipsc = _extract_batch_dim(ipsc, batch_idx)
    input = _extract_batch_dim(input, batch_idx)

    # Convert to numpy (preserve None for optional arrays)
    voltage = _to_numpy(voltage)
    spikes = _to_numpy(spikes) if spikes is not None else None
    psc_arr = _to_numpy(psc) if psc is not None else None
    n_time, n_neurons = voltage.shape
    times = np.arange(n_time) * dt
    duration_ms = n_time * dt

    # Select neurons to plot
    if neuron_indices is None and sample_size is None:
        # Default: plot first 5 neurons
        neuron_indices = list(range(min(5, n_neurons)))
    elif neuron_indices is None:
        # Random sample
        np.random.seed(seed)
        neuron_indices = sorted(
            np.random.choice(n_neurons, min(sample_size, n_neurons), replace=False)
        )

    n_plot = len(neuron_indices)
    if neurons_per_row is None:
        neurons_per_row = 1
    if neurons_per_row < 1:
        raise ValueError("neurons_per_row must be >= 1")

    v_threshold_per_neuron = _resolve_per_neuron_values(
        v_threshold, neuron_indices, n_neurons, "v_threshold"
    )
    v_reset_per_neuron = _resolve_per_neuron_values(
        v_reset, neuron_indices, n_neurons, "v_reset"
    )

    # Determine figure dimensions
    base_width = 12.0
    if auto_width:
        # Scale: ~1 inch per 40ms, bounded [10, 30]
        base_width = max(10.0, min(duration_ms * 0.025, 30.0))
    elif format:
        base_width = format.figsize_per_neuron[0]

    height_per_row = format.figsize_per_neuron[1] if format else 2.5
    total_height = height_per_row * n_plot

    # Default colors
    default_colors = {
        "voltage": "#2E86AB",
        "asc": "#A23B72",
        "psc": "#F18F01",
        "epsc": "#06A77D",
        "ipsc": "#D62246",
        "input": "#9467bd",  # purple
        "spike": "#000000",
    }
    colors = format.colors if format and format.colors else default_colors

    label_values: Sequence[str] | None = None
    label_fn: Callable[[int], str] | None = None
    if callable(neuron_labels):
        label_fn = neuron_labels
    elif neuron_labels is not None:
        label_values = neuron_labels

    def _resolve_side_label(
        plot_idx: int, neuron_idx: int, spec: NeuronSpec | None = None
    ) -> str | None:
        if spec is not None and spec.label is not None:
            return spec.label
        if label_fn is not None:
            return str(label_fn(neuron_idx))
        if label_values is not None and plot_idx < len(label_values):
            return str(label_values[plot_idx])
        return None

    resolved_labels = [
        _resolve_side_label(plot_idx, neuron_idx)
        for plot_idx, neuron_idx in enumerate(neuron_indices)
    ]
    final_labels = list(resolved_labels)
    max_label_len = max((len(label) for label in resolved_labels if label), default=0)

    # Determine subplot layout based on data availability
    # Only show columns if requested AND data is present
    _show_v = show_voltage and (voltage is not None)
    _show_asc = show_asc and (asc is not None)
    _show_psc = show_psc and (psc is not None)

    if separate_figures:
        figures = {}
        trace_types = []
        if _show_v:
            trace_types.append("voltage")
        if _show_asc:
            trace_types.append("asc")
        if _show_psc:
            trace_types.append("psc")

        for t_type in trace_types:
            fig, axes = plt.subplots(
                n_plot, 1, figsize=(base_width, total_height), squeeze=False
            )

            for i, neuron_idx in enumerate(neuron_indices):
                ax = axes[i, 0]
                label = _resolve_side_label(i, neuron_idx)

                if t_type == "voltage":
                    _plot_voltage_on_ax(
                        ax,
                        times,
                        voltage[:, neuron_idx],
                        spikes[:, neuron_idx] if spikes is not None else None,
                        colors,
                        format,
                        v_threshold_per_neuron[i],
                        v_reset_per_neuron[i],
                    )
                    ax.set_ylabel("V (mV)")
                    if i == 0:
                        ax.set_title("Voltage Traces")
                        if (
                            v_threshold_per_neuron[i] is not None
                            or v_reset_per_neuron[i] is not None
                        ):
                            ax.legend(loc="upper right", fontsize=8)

                elif t_type == "asc":
                    asc_arr = _to_numpy(asc)
                    _plot_simple_trace_on_ax(
                        ax, times, asc_arr[:, neuron_idx], colors["asc"], "ASC (pA)"
                    )
                    if i == 0:
                        ax.set_title("Afterspike Current")

                elif t_type == "psc":
                    if psc_has_extra_dim:
                        # PSC has additional dimension: plot each component
                        psc_traces = psc_arr[:, neuron_idx, :]  # Shape: (time, n_psc)
                        _plot_multi_psc_on_ax(ax, times, psc_traces, psc_labels, colors)
                        if i == 0:
                            ax.set_title("Postsynaptic Current")
                            ax.legend(loc="upper right", fontsize=8)
                    else:
                        psc_arr_2d = _to_numpy(psc)
                        epsc_arr = (
                            _to_numpy(epsc[:, neuron_idx]) if epsc is not None else None
                        )
                        ipsc_arr = (
                            _to_numpy(ipsc[:, neuron_idx]) if ipsc is not None else None
                        )
                        input_arr_plot = (
                            _to_numpy(input[:, neuron_idx])
                            if input is not None
                            else None
                        )
                        _plot_psc_on_ax(
                            ax,
                            times,
                            psc_arr_2d[:, neuron_idx],
                            epsc_arr,
                            ipsc_arr,
                            input_arr_plot,
                            colors,
                        )
                        if i == 0:
                            ax.set_title("Postsynaptic Current")
                            if (
                                epsc is not None
                                or ipsc is not None
                                or input is not None
                            ):
                                ax.legend(loc="upper right", fontsize=8)

                if i == n_plot - 1:
                    ax.set_xlabel("Time (ms)")
                ax.grid(alpha=0.3, linewidth=0.5)
                if label is not None:
                    if neuron_label_position == "top":
                        ax.text(
                            0.5,
                            1.12,
                            label,
                            transform=ax.transAxes,
                            fontsize=10,
                            fontweight="bold",
                            va="bottom",
                            ha="center",
                        )
                    else:
                        ax.text(
                            1.02,
                            0.5,
                            label,
                            transform=ax.transAxes,
                            fontsize=10,
                            fontweight="bold",
                            va="center",
                            ha="left",
                        )

            plt.tight_layout()
            figures[t_type] = fig

        return figures

    # Combined figure
    n_cols = sum([_show_v, _show_asc, _show_psc])
    if n_cols == 0:
        # Default fallback: if nothing strictly requested by data presence,
        # but voltage is required arg, show voltage
        if voltage is not None:
            _show_v = True
            n_cols = 1
        else:
            raise ValueError(
                "No data available to plot (voltage, asc, or psc required)"
            )

    n_rows = int(ceil(n_plot / neurons_per_row))
    total_cols = n_cols * neurons_per_row
    use_top_label_rows = neuron_label_position == "top"
    label_height_ratio = 0.22
    max_label_chars_per_line = 0
    if use_top_label_rows and max_label_len > 0:
        # Rough estimate for wrapping purposes only (not for sizing)
        max_label_chars_per_line = max(36, int(base_width * 8))
        label_line_count = max(
            1,
            int(ceil(max_label_len / max(max_label_chars_per_line, 1))),
        )
        label_height_ratio = 0.22 + 0.12 * (label_line_count - 1)
    total_height_grid = (
        height_per_row
        * n_rows
        * (1.0 + label_height_ratio if use_top_label_rows else 1.0)
    )
    # Keep enough width per trace column to avoid label crowding.
    base_width = max(base_width, 4.0 * n_cols)
    fig_width = base_width * neurons_per_row
    n_grid_rows = n_rows * 2 if use_top_label_rows else n_rows
    gridspec_kw = (
        {"height_ratios": [v for _ in range(n_rows) for v in (label_height_ratio, 1.0)]}
        if use_top_label_rows
        else None
    )
    fig = plt.figure(figsize=(fig_width, total_height_grid))
    grid_spec = fig.add_gridspec(
        n_grid_rows,
        total_cols,
        **(gridspec_kw or {}),
    )
    axes: dict[tuple[int, int], Axes] = {}
    label_axes: dict[tuple[int, int], Axes] = {}

    for row_idx in range(n_rows):
        if use_top_label_rows:
            label_row = row_idx * 2
            for slot_idx in range(neurons_per_row):
                col_base = slot_idx * n_cols
                label_ax = fig.add_subplot(
                    grid_spec[label_row, col_base : col_base + n_cols]
                )
                label_ax.set_axis_off()
                label_axes[(row_idx, slot_idx)] = label_ax

        plot_row = row_idx * 2 + 1 if use_top_label_rows else row_idx
        for c in range(total_cols):
            axes[(plot_row, c)] = fig.add_subplot(grid_spec[plot_row, c])

    asc_arr = _to_numpy(asc) if _show_asc else None
    # psc_arr is already converted to numpy earlier (with extra dim handling)
    # input_arr for plotting (only if input is 2D, not when psc has extra dim)
    input_arr = (
        _to_numpy(input) if input is not None and not psc_has_extra_dim else None
    )
    used_axes: set[tuple[int, int, int]] = set()

    for plot_idx, neuron_idx in enumerate(neuron_indices):
        row_idx = plot_idx // neurons_per_row
        slot_idx = plot_idx % neurons_per_row
        # Calculate the actual grid row (accounting for label rows)
        plot_row = row_idx * 2 + 1 if use_top_label_rows else row_idx

        # Resolve spec
        spec = NeuronSpec()
        if neuron_specs is not None:
            if isinstance(neuron_specs, list):
                if plot_idx < len(neuron_specs):
                    s = neuron_specs[plot_idx]
                    spec = NeuronSpec(**s) if isinstance(s, dict) else s
            elif isinstance(neuron_specs, dict):
                spec = NeuronSpec(**neuron_specs)
            elif isinstance(neuron_specs, NeuronSpec):
                spec = neuron_specs

        label = spec.label if spec.label is not None else resolved_labels[plot_idx]
        final_labels[plot_idx] = label

        # Color resolution
        local_colors = colors.copy()
        if spec.color is not None:
            if isinstance(spec.color, dict):
                local_colors.update(spec.color)
            else:
                for k in local_colors:
                    if k != "spike":
                        local_colors[k] = spec.color

        col_base = slot_idx * n_cols
        col_idx = 0

        # Voltage subplot
        if _show_v:
            ax = axes[(plot_row, col_base + col_idx)]
            _plot_voltage_on_ax(
                ax,
                times,
                voltage[:, neuron_idx],
                spikes[:, neuron_idx] if spikes is not None else None,
                local_colors,
                format,
                v_threshold_per_neuron[plot_idx],
                v_reset_per_neuron[plot_idx],
                linestyle=spec.linestyle,
                linewidth=spec.linewidth,
                alpha=spec.alpha,
            )
            ax.set_ylabel("V (mV)")
            if row_idx == 0:
                ax.set_title("Voltage")
                if (
                    v_threshold_per_neuron[plot_idx] is not None
                    or v_reset_per_neuron[plot_idx] is not None
                ):
                    ax.legend(loc="upper right", fontsize=8)
            if row_idx == n_rows - 1:
                ax.set_xlabel("Time (ms)")
            ax.grid(alpha=0.3, linewidth=0.5)
            used_axes.add((plot_row, col_base + col_idx))
            col_idx += 1

        # ASC subplot
        if _show_asc and asc_arr is not None:
            ax = axes[(plot_row, col_base + col_idx)]
            _plot_simple_trace_on_ax(
                ax,
                times,
                asc_arr[:, neuron_idx],
                local_colors["asc"],
                "ASC (pA)",
                linestyle=spec.linestyle,
                linewidth=spec.linewidth,
                alpha=spec.alpha,
            )
            if row_idx == 0:
                ax.set_title("Afterspike Current")
            if row_idx == n_rows - 1:
                ax.set_xlabel("Time (ms)")
            ax.grid(alpha=0.3, linewidth=0.5)
            used_axes.add((plot_row, col_base + col_idx))
            col_idx += 1

        # PSC subplot
        if _show_psc and psc_arr is not None:
            ax = axes[(plot_row, col_base + col_idx)]
            if psc_has_extra_dim:
                # PSC has additional dimension: plot each component
                psc_traces = psc_arr[:, neuron_idx, :]  # Shape: (time, n_psc)
                _plot_multi_psc_on_ax(
                    ax,
                    times,
                    psc_traces,
                    psc_labels,
                    local_colors,
                    linestyle=spec.linestyle,
                    linewidth=spec.linewidth,
                    alpha=spec.alpha,
                )
                if row_idx == 0:
                    ax.set_title("Postsynaptic Current")
                    ax.legend(loc="upper right", fontsize=8)
            else:
                # Standard PSC: plot total, epsc, ipsc, input
                epsc_arr = _to_numpy(epsc[:, neuron_idx]) if epsc is not None else None
                ipsc_arr = _to_numpy(ipsc[:, neuron_idx]) if ipsc is not None else None
                input_arr_single = (
                    input_arr[:, neuron_idx] if input_arr is not None else None
                )
                _plot_psc_on_ax(
                    ax,
                    times,
                    psc_arr[:, neuron_idx],
                    epsc_arr,
                    ipsc_arr,
                    input_arr_single,
                    local_colors,
                    linestyle=spec.linestyle,
                    linewidth=spec.linewidth,
                    alpha=spec.alpha,
                )
                if row_idx == 0:
                    ax.set_title("Postsynaptic Current")
                    if epsc is not None or ipsc is not None or input is not None:
                        ax.legend(loc="upper right", fontsize=8)
            if row_idx == n_rows - 1:
                ax.set_xlabel("Time (ms)")
            ax.grid(alpha=0.3, linewidth=0.5)
            used_axes.add((plot_row, col_base + col_idx))
            col_idx += 1

        if label is not None:
            if use_top_label_rows:
                label_ax = label_axes[(row_idx, slot_idx)]
                wrapped_label = _format_top_neuron_label(
                    label, max_label_chars_per_line
                )
                label_ax.text(
                    0.5,
                    0.5,
                    wrapped_label,
                    transform=label_ax.transAxes,
                    fontsize=10,
                    fontweight="bold",
                    va="center",
                    ha="center",
                )
            else:
                # Add label to the rightmost subplot in this neuron slot.
                last_ax = axes[(plot_row, col_base + n_cols - 1)]
                last_ax.text(
                    1.02,
                    0.5,
                    label,
                    transform=last_ax.transAxes,
                    fontsize=10,
                    fontweight="bold",
                    va="center",
                    ha="left",
                )

    # Hide unused plot axes for empty neuron slots in the final row.
    for r in range(n_rows):
        plot_row = r * 2 + 1 if use_top_label_rows else r
        for c in range(total_cols):
            if (plot_row, c) not in used_axes:
                axes[(plot_row, c)].set_visible(False)

        if use_top_label_rows:
            for slot_idx in range(neurons_per_row):
                plot_idx = r * neurons_per_row + slot_idx
                if plot_idx >= n_plot:
                    label_axes[(r, slot_idx)].set_visible(False)

    # Adjust figure width based on actual label widths if using top labels
    if use_top_label_rows and label_axes:
        fig.canvas.draw()  # Ensure text is rendered
        max_label_width_inches = 0.0
        for (row_idx, slot_idx), label_ax in label_axes.items():
            for text in label_ax.texts:
                bbox = text.get_window_extent(renderer=fig.canvas.get_renderer())
                width_inches = bbox.width / fig.dpi
                max_label_width_inches = max(max_label_width_inches, width_inches)
        if max_label_width_inches > 0:
            # Add padding and account for column count per slot
            required_slot_width = max_label_width_inches * 1.2 + 1.0  # 20% pad + margin
            min_fig_width = required_slot_width * neurons_per_row
            current_width = fig.get_figwidth()
            if min_fig_width > current_width:
                fig.set_figwidth(min_fig_width)

    right_margin = 0.96 if neuron_label_position == "side" else 1.0
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        fig.tight_layout(rect=(0.0, 0.0, right_margin, 1.0), w_pad=0.8, h_pad=0.8)
    return fig

plot_neuropil_timeseries_overview(data, *, dt, mode='all_innervated', agg='mean', connections=None, neurons=None, kind='wave', figsize=(12, 8), cmap='viridis', top_n=50, use_polars=False, show=False)

Plot aggregated neuropil traces as a single overview figure.

Visualizes neural activity aggregated by brain regions (neuropils). Can display as stacked waveforms or a heatmap.

Parameters:

Name Type Description Default
data TensorLike | Mapping[str, TensorLike]

Spike matrix (time, neurons) or pre-computed dict of {region_name: activity_array}.

required
dt float

Time step in seconds for x-axis scaling.

required
mode Literal['top_innervated', 'all_innervated']

Aggregation mode - "top_innervated" uses primary neuropil per neuron, "all_innervated" includes all neuropil connections.

'all_innervated'
agg Literal['mean', 'sum', 'std']

Aggregation function applied per neuropil ("mean", "sum", "std").

'mean'
connections DataFrame | None

DataFrame with connection metadata (required if data is a matrix, not a dict).

None
neurons DataFrame | None

DataFrame with neuron metadata (required if data is a matrix, not a dict).

None
kind Literal['wave', 'heatmap']

Visualization style - "wave" for stacked traces, "heatmap" for 2D intensity map.

'wave'
figsize tuple[float, float]

Figure size (width, height) in inches.

(12, 8)
cmap str

Colormap for heatmap style.

'viridis'
top_n int

Number of top regions to show in wave style (ranked by maximum absolute activity).

50
use_polars bool

Whether to use Polars for aggregation (faster for large datasets).

False
show bool

Whether to call plt.show() before returning.

False

Returns:

Type Description
tuple[Figure, Axes]

Tuple of (figure, axes) containing the plot.

Raises:

Type Description
ValueError

If no neuropil traces can be computed from inputs.

Example

fig, ax = plot_neuropil_timeseries_overview( ... spikes, dt=0.001, connections=conn_df, neurons=neurons_df ... )

Source code in btorch/visualisation/aggregation.py
def plot_neuropil_timeseries_overview(
    data: TensorLike | Mapping[str, TensorLike],
    *,
    dt: float,
    mode: Literal["top_innervated", "all_innervated"] = "all_innervated",
    agg: Literal["mean", "sum", "std"] = "mean",
    connections: pd.DataFrame | None = None,
    neurons: pd.DataFrame | None = None,
    kind: Literal["wave", "heatmap"] = "wave",
    figsize: tuple[float, float] = (12, 8),
    cmap: str = "viridis",
    top_n: int = 50,
    use_polars: bool = False,
    show: bool = False,
) -> tuple[Figure, Axes]:
    """Plot aggregated neuropil traces as a single overview figure.

    Visualizes neural activity aggregated by brain regions (neuropils).
    Can display as stacked waveforms or a heatmap.

    Args:
        data: Spike matrix (time, neurons) or pre-computed dict of
            {region_name: activity_array}.
        dt: Time step in seconds for x-axis scaling.
        mode: Aggregation mode - "top_innervated" uses primary neuropil per
            neuron, "all_innervated" includes all neuropil connections.
        agg: Aggregation function applied per neuropil ("mean", "sum", "std").
        connections: DataFrame with connection metadata (required if `data`
            is a matrix, not a dict).
        neurons: DataFrame with neuron metadata (required if `data` is a
            matrix, not a dict).
        kind: Visualization style - "wave" for stacked traces, "heatmap"
            for 2D intensity map.
        figsize: Figure size (width, height) in inches.
        cmap: Colormap for heatmap style.
        top_n: Number of top regions to show in wave style (ranked by
            maximum absolute activity).
        use_polars: Whether to use Polars for aggregation (faster for large
            datasets).
        show: Whether to call `plt.show()` before returning.

    Returns:
        Tuple of (figure, axes) containing the plot.

    Raises:
        ValueError: If no neuropil traces can be computed from inputs.

    Example:
        >>> fig, ax = plot_neuropil_timeseries_overview(
        ...     spikes, dt=0.001, connections=conn_df, neurons=neurons_df
        ... )
    """
    traces = _resolve_neuropil_traces(
        data,
        mode=mode,
        agg=agg,
        connections=connections,
        neurons=neurons,
        use_polars=use_polars,
    )
    if not traces:
        raise ValueError("No neuropil traces available for plotting.")

    n_time = len(next(iter(traces.values())))
    time_points = np.arange(n_time, dtype=float) * dt

    fig, ax = plt.subplots(figsize=figsize)
    if kind == "wave":
        _plot_wave_style(ax, traces, time_points, top_n=top_n)
    else:
        _plot_heatmap_style(ax, traces, time_points, cmap=cmap)

    ax.set_ylabel("Neuropil Activity (z-scored)", fontsize=12)
    ax.set_title("Neuropil Activity Traces", fontsize=14)
    ax.set_xlabel("Time (s)", fontsize=12)

    fig.tight_layout()
    if show:
        plt.show()
    return fig, ax

plot_neuropil_timeseries_panels(data, *, dt, mode='all_innervated', agg='mean', connections=None, neurons=None, regions=None, figsize=(15, 10), cols=3, use_polars=False, show=False)

Plot selected neuropil traces as a subplot grid.

Creates a grid of subplots showing individual neuropil activity traces, with optional statistics annotations.

Parameters:

Name Type Description Default
data TensorLike | Mapping[str, TensorLike]

Spike matrix (time, neurons) or pre-computed dict of {region_name: activity_array}.

required
dt float

Time step in seconds for x-axis scaling.

required
mode Literal['top_innervated', 'all_innervated']

Aggregation mode - "top_innervated" or "all_innervated".

'all_innervated'
agg Literal['mean', 'sum', 'std']

Aggregation function applied per neuropil ("mean", "sum", "std").

'mean'
connections DataFrame | None

DataFrame with connection metadata.

None
neurons DataFrame | None

DataFrame with neuron metadata.

None
regions Sequence[str] | None

List of region names to plot. If None, top 9 regions by maximum activity are selected automatically.

None
figsize tuple[float, float]

Figure size (width, height) in inches.

(15, 10)
cols int

Number of columns in the subplot grid.

3
use_polars bool

Whether to use Polars for aggregation.

False
show bool

Whether to call plt.show() before returning.

False

Returns:

Type Description
tuple[Figure, ndarray]

Tuple of (figure, axes_array) containing the subplot grid.

Raises:

Type Description
ValueError

If no neuropil traces available or regions is empty.

Source code in btorch/visualisation/aggregation.py
def plot_neuropil_timeseries_panels(
    data: TensorLike | Mapping[str, TensorLike],
    *,
    dt: float,
    mode: Literal["top_innervated", "all_innervated"] = "all_innervated",
    agg: Literal["mean", "sum", "std"] = "mean",
    connections: pd.DataFrame | None = None,
    neurons: pd.DataFrame | None = None,
    regions: Sequence[str] | None = None,
    figsize: tuple[float, float] = (15, 10),
    cols: int = 3,
    use_polars: bool = False,
    show: bool = False,
) -> tuple[Figure, np.ndarray]:
    """Plot selected neuropil traces as a subplot grid.

    Creates a grid of subplots showing individual neuropil activity traces,
    with optional statistics annotations.

    Args:
        data: Spike matrix (time, neurons) or pre-computed dict of
            {region_name: activity_array}.
        dt: Time step in seconds for x-axis scaling.
        mode: Aggregation mode - "top_innervated" or "all_innervated".
        agg: Aggregation function applied per neuropil ("mean", "sum", "std").
        connections: DataFrame with connection metadata.
        neurons: DataFrame with neuron metadata.
        regions: List of region names to plot. If None, top 9 regions by
            maximum activity are selected automatically.
        figsize: Figure size (width, height) in inches.
        cols: Number of columns in the subplot grid.
        use_polars: Whether to use Polars for aggregation.
        show: Whether to call `plt.show()` before returning.

    Returns:
        Tuple of (figure, axes_array) containing the subplot grid.

    Raises:
        ValueError: If no neuropil traces available or regions is empty.
    """
    traces = _resolve_neuropil_traces(
        data,
        mode=mode,
        agg=agg,
        connections=connections,
        neurons=neurons,
        use_polars=use_polars,
    )
    if not traces:
        raise ValueError("No neuropil traces available for plotting.")

    if regions is None:
        ranked = sorted(
            traces.items(),
            key=lambda x: float(np.max(np.abs(np.asarray(x[1])))),
            reverse=True,
        )[:9]
        regions = [region for region, _ in ranked]

    n_regions = len(regions)
    if n_regions == 0:
        raise ValueError("`regions` must contain at least one region.")

    rows = int(np.ceil(n_regions / cols))
    fig, axes = plt.subplots(rows, cols, figsize=figsize)
    if rows == 1:
        axes = np.asarray(axes).reshape(1, -1)
    if cols == 1:
        axes = np.asarray(axes).reshape(-1, 1)

    n_time = len(next(iter(traces.values())))
    time_points = np.arange(n_time, dtype=float) * dt

    for i, region in enumerate(regions):
        row = i // cols
        col = i % cols
        ax = axes[row, col]

        if region in traces:
            activity = np.asarray(traces[region])
            ax.plot(time_points, activity, linewidth=1.5, color="darkblue")
            ax.set_title(str(region), fontsize=10)
            ax.set_xlabel("Time (s)", fontsize=8)
            ax.set_ylabel("Activity", fontsize=8)
            ax.grid(True, alpha=0.3)

            mean_act = float(np.mean(activity))
            std_act = float(np.std(activity))
            ax.text(
                0.02,
                0.98,
                f"μ={mean_act:.2f}\nσ={std_act:.2f}",
                transform=ax.transAxes,
                verticalalignment="top",
                bbox=dict(boxstyle="round", facecolor="white", alpha=0.8),
                fontsize=8,
            )

    for i in range(n_regions, rows * cols):
        row = i // cols
        col = i % cols
        fig.delaxes(axes[row, col])

    fig.suptitle("Neuropil Activity Comparison", fontsize=16)
    fig.tight_layout()
    if show:
        plt.show()
    return fig, axes

plot_raster(spikes, dt=None, times=None, ax=None, neurons_df=None, group_key=None, group_sort=None, spike_color='black', marker='.', marker_size=5.0, neuron_specs=None, show_group_separators=True, separator_style=None, title=None, xlabel='Time (ms)', ylabel='Neuron Index', rate=False, group_rate=False, rate_window_ms=10.0, show_group_strip=False, group_color_key=None, strip_cmap='tab10', group_strip_kwargs=None, group_strip_legend=True, group_label_mode='top_sub', group_strip_side='right', sort_neurons=True, events=None, regions=None, show_tracks=False, event_kwargs=None, region_kwargs=None)

Plot spike raster with optional grouping and styling.

Parameters

spikes : np.ndarray or torch.Tensor Spike matrix of shape (time, neurons). dt : float, optional Time step in ms. Default is 1.0 if times is not provided. times : array-like, optional Explicit time array. ax : matplotlib.axes.Axes, optional Axis to plot on. If None, a new figure is created. neurons_df : pd.DataFrame, optional Dataframe containing neuron metadata, required for grouping. group_key : str, optional Column name in neurons_df to group neurons by. group_sort : list[str], optional Specific order for the groups. spike_color : str or dict or sequence, optional Default color for spikes. Can be a dict mapping group names or neuron indices to colors, or a per-neuron color sequence. marker : str Marker type. marker_size : float Size of the markers. neuron_specs : dict, list, or NeuronSpec, optional Specific styling per neuron. show_group_separators : bool Whether to draw lines separating groups. separator_style : dict, optional Arguments for separator lines (color, linewidth, etc.). title : str, optional Plot title. xlabel : str Label for x-axis. ylabel : str Label for y-axis. rate : bool or array-like, optional If True, compute and plot the population firing rate. If array-like, use it directly with length matching the time axis. group_rate : bool or dict or array-like, optional If True, compute and plot per-group firing rates when grouping is available. If dict, map group names to per-group rate arrays. If array-like, interpret as (T, G) in the order of resolved groups. rate_window_ms : float Window size for firing rate smoothing in ms. show_group_strip : bool If True, draw a colorbar-like group strip on the side. group_color_key : str, optional Column name in neurons_df to color the group strip. Defaults to group_key. strip_cmap : str Matplotlib colormap name used to derive both top-group and subgroup colors. group_strip_kwargs : dict, optional Additional options for colorbar layout and labels. group_strip_legend : bool If True, add a legend for group colors. group_label_mode : {"top", "sub", "top_sub"} Label mode for the colorbar when using subgroups. group_strip_side : {"left", "right"} Side on which to draw the group strip and labels. sort_neurons : bool If True (default), neurons are reordered by group and subgroup so bands are continuous. If False, original order is preserved.

Returns

ax or (ax_raster, ax_rate) The axis object(s).

Source code in btorch/visualisation/timeseries.py
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
def plot_raster(
    spikes: Union[np.ndarray, torch.Tensor],
    dt: float | None = None,
    times: Sequence[float] | None = None,
    ax: Axes | None = None,
    # Grouping and Metadata
    neurons_df: pd.DataFrame | None = None,
    group_key: str | None = None,
    group_sort: list[str] | None = None,
    # Styling
    spike_color: str | dict | Sequence[Any] | None = "black",
    marker: str = ".",
    marker_size: float = 5.0,
    neuron_specs: dict | list | NeuronSpec | None = None,
    show_group_separators: bool = True,
    separator_style: dict | None = None,
    # Standard Plot Args
    title: str | None = None,
    xlabel: str = "Time (ms)",
    ylabel: str = "Neuron Index",
    rate: bool | np.ndarray | torch.Tensor | None = False,
    group_rate: bool | dict[str, np.ndarray | torch.Tensor] | np.ndarray | None = False,
    rate_window_ms: float = 10.0,
    show_group_strip: bool = False,
    group_color_key: str | None = None,
    strip_cmap: str = "tab10",
    group_strip_kwargs: dict | None = None,
    group_strip_legend: bool = True,
    group_label_mode: Literal["top", "sub", "top_sub"] = "top_sub",
    group_strip_side: Literal["left", "right"] = "right",
    sort_neurons: bool = True,
    events: Sequence[float] | dict[str, Sequence[float]] | None = None,
    regions: Sequence[tuple[float, float]]
    | dict[str, Sequence[tuple[float, float]]]
    | None = None,
    show_tracks: bool = False,
    event_kwargs: dict | None = None,
    region_kwargs: dict | None = None,
) -> Union[Axes, tuple[Axes, Axes]]:
    """Plot spike raster with optional grouping and styling.

    Parameters
    ----------
    spikes : np.ndarray or torch.Tensor
        Spike matrix of shape (time, neurons).
    dt : float, optional
        Time step in ms. Default is 1.0 if times is not provided.
    times : array-like, optional
        Explicit time array.
    ax : matplotlib.axes.Axes, optional
        Axis to plot on. If None, a new figure is created.
    neurons_df : pd.DataFrame, optional
        Dataframe containing neuron metadata, required for grouping.
    group_key : str, optional
        Column name in neurons_df to group neurons by.
    group_sort : list[str], optional
        Specific order for the groups.
    spike_color : str or dict or sequence, optional
        Default color for spikes. Can be a dict mapping group names or neuron
        indices to colors, or a per-neuron color sequence.
    marker : str
        Marker type.
    marker_size : float
        Size of the markers.
    neuron_specs : dict, list, or NeuronSpec, optional
        Specific styling per neuron.
    show_group_separators : bool
        Whether to draw lines separating groups.
    separator_style : dict, optional
        Arguments for separator lines (color, linewidth, etc.).
    title : str, optional
        Plot title.
    xlabel : str
        Label for x-axis.
    ylabel : str
        Label for y-axis.
    rate : bool or array-like, optional
        If True, compute and plot the population firing rate. If array-like,
        use it directly with length matching the time axis.
    group_rate : bool or dict or array-like, optional
        If True, compute and plot per-group firing rates when grouping is
        available. If dict, map group names to per-group rate arrays. If
        array-like, interpret as (T, G) in the order of resolved groups.
    rate_window_ms : float
        Window size for firing rate smoothing in ms.
    show_group_strip : bool
        If True, draw a colorbar-like group strip on the side.
    group_color_key : str, optional
        Column name in neurons_df to color the group strip. Defaults to group_key.
    strip_cmap : str
        Matplotlib colormap name used to derive both top-group and subgroup colors.
    group_strip_kwargs : dict, optional
        Additional options for colorbar layout and labels.
    group_strip_legend : bool
        If True, add a legend for group colors.
    group_label_mode : {"top", "sub", "top_sub"}
        Label mode for the colorbar when using subgroups.
    group_strip_side : {"left", "right"}
        Side on which to draw the group strip and labels.
    sort_neurons : bool
        If True (default), neurons are reordered by group and subgroup so bands are
        continuous. If False, original order is preserved.

    Returns
    -------
    ax or (ax_raster, ax_rate)
        The axis object(s).
    """
    spikes_np = _to_numpy(spikes)
    if spikes_np.ndim != 2:
        raise ValueError("spikes must be 2D (time, neurons)")

    n_time, n_neurons = spikes_np.shape
    t = _get_time_axis(n_time, dt, times)

    # Evaluate isinstance first to avoid calling bool() on arrays/tensors
    # (which raises ValueError for >1-element numpy arrays).
    show_rate = isinstance(rate, (np.ndarray, torch.Tensor)) or bool(rate)
    show_group_rate = isinstance(group_rate, (dict, np.ndarray, torch.Tensor)) or bool(
        group_rate
    )

    raster_height = _auto_raster_height(n_neurons)
    raster_width = 8.0
    rate_height = 2.6  # keep rate panel at a stable height

    if show_rate or show_group_rate:
        if ax is not None:
            warnings.warn(
                "ax argument is ignored when rate/group_rate is enabled. "
                "Creating new figure."
            )
        fig, (ax_raster, ax_rate) = plt.subplots(
            2,
            1,
            figsize=(raster_width, raster_height + rate_height),
            gridspec_kw={
                "height_ratios": [raster_height, rate_height],
                "hspace": 0.06,
            },
        )
    else:
        if ax is None:
            fig, ax = plt.subplots(figsize=(raster_width, raster_height))
        ax_raster = ax
        ax_rate = None

    # Handle Grouping
    sorted_indices = np.arange(n_neurons)
    group_boundaries = []  # List of (y_coord, label)

    group_labels = np.full(n_neurons, "Unknown", dtype=object)
    subgroup_labels = None

    if group_key is not None or group_color_key is not None:
        if neurons_df is None:
            raise ValueError("neurons_df must be provided when grouping is used.")
        if group_key is not None and group_key not in neurons_df.columns:
            raise ValueError(f"Column '{group_key}' not found in neurons_df.")
        if group_color_key is not None and group_color_key not in neurons_df.columns:
            raise ValueError(f"Column '{group_color_key}' not found in neurons_df.")

    if group_key is not None:
        group_values = neurons_df[group_key].to_numpy()
        n_copy = min(n_neurons, len(group_values))
        group_labels[:n_copy] = group_values[:n_copy]
        group_labels[pd.isna(group_labels)] = "Unknown"

    if group_color_key is not None:
        subgroup_labels = np.full(n_neurons, "Unknown", dtype=object)
        sub_values = neurons_df[group_color_key].to_numpy()
        n_copy = min(n_neurons, len(sub_values))
        subgroup_labels[:n_copy] = sub_values[:n_copy]
        subgroup_labels[pd.isna(subgroup_labels)] = "Unknown"
    else:
        subgroup_labels = group_labels

    if group_key is not None:
        present_groups = set(group_labels.tolist())
        if group_sort:
            groups = [g for g in group_sort if g in present_groups]
            remaining = sorted(present_groups - set(groups))
            groups.extend(remaining)
        else:
            groups = sorted(present_groups)

        if sort_neurons:
            new_order = []
            current_y = 0
            for g in groups:
                g_indices = np.flatnonzero(group_labels == g)
                if g_indices.size == 0:
                    continue
                if group_color_key is not None:
                    subgroup_vals = subgroup_labels[g_indices]
                    subgroup_order = list(dict.fromkeys(subgroup_vals.tolist()))
                    order_map = {k: i for i, k in enumerate(subgroup_order)}
                    subgroup_rank = np.array(
                        [order_map[v] for v in subgroup_vals], dtype=int
                    )
                    g_indices = g_indices[np.argsort(subgroup_rank, kind="stable")]

                new_order.append(g_indices)
                current_y += g_indices.size
                group_boundaries.append((current_y - 0.5, g))

            if new_order:
                sorted_indices = np.concatenate(new_order)
            if len(sorted_indices) < n_neurons:
                warnings.warn(
                    "Not all neurons were assigned to a group. Appending defaults."
                )
                missing = np.setdiff1d(np.arange(n_neurons), sorted_indices)
                sorted_indices = np.concatenate([sorted_indices, missing])
        else:
            sorted_indices = np.arange(n_neurons)
            prev_g = None
            for i, idx in enumerate(sorted_indices):
                g = group_labels[idx]
                if i == 0:
                    prev_g = g
                else:
                    if g != prev_g:
                        group_boundaries.append((i - 0.5, prev_g))
                        prev_g = g

    # Mapping from original index to plot y-index
    # y-axis: 0 at bottom, N-1 at top.
    # If we want group 0 at top, we should reverse? Standard raster usually 0 at bottom.
    # Let's stick to 0 at bottom.
    # sorted_indices[0] is plotted at y=0.

    # We need a map: original_idx -> y_coord
    idx_map = np.empty(n_neurons)
    idx_map[sorted_indices] = np.arange(len(sorted_indices))

    # Compute raster coordinates
    # spike indices are row indices in spikes_np (time)??
    # No, usually spikes is (time, neurons).
    # compute_raster returns (neuron_indices, spike_times) where indices are 0..N-1
    orig_neuron_indices, spike_times = compute_raster(spikes_np, t)

    # Map neuron indices to sorted plot positions
    plot_neuron_indices = idx_map[orig_neuron_indices]

    # Handle Colors
    c_array = spike_color
    skip_main_scatter = False
    draw_spikes_later = show_group_strip
    ms_array = marker_size  # default fallback if no specs
    marker_list = None
    size_list = None

    color_by_neuron = None

    if isinstance(spike_color, dict):
        has_int_keys = any(isinstance(k, (int, np.integer)) for k in spike_color)
        if has_int_keys:
            color_by_neuron = np.array(
                [spike_color.get(i, "black") for i in range(n_neurons)],
                dtype=object,
            )
        elif group_key is not None:
            color_by_neuron = np.array(
                [spike_color.get(g, "black") for g in group_labels],
                dtype=object,
            )
        else:
            warnings.warn(
                "spike_color dict provided but group_key not set. Using black."
            )
            c_array = "black"
    elif isinstance(spike_color, (list, tuple, np.ndarray)):
        if len(spike_color) != n_neurons:
            raise ValueError(
                "spike_color sequence length must match number of neurons."
            )
        color_by_neuron = np.array(spike_color, dtype=object)

    if color_by_neuron is not None:
        c_array = color_by_neuron[orig_neuron_indices]
    elif neuron_specs is not None:
        c_list = []
        m_list = []
        ms_list = []

        # Helper to get spec for an index
        def get_spec_attrs(idx):
            s = None
            if isinstance(neuron_specs, list):
                if idx < len(neuron_specs):
                    s = neuron_specs[idx]
            elif isinstance(neuron_specs, dict):
                if idx in neuron_specs:
                    s = neuron_specs[idx]

            c = "black"
            m = marker
            ms = marker_size

            if s is not None:
                if isinstance(s, NeuronSpec):
                    c = s.color if s.color is not None else c
                    m = s.marker if s.marker is not None else m
                    ms = s.markersize if s.markersize is not None else ms
                elif isinstance(s, dict):
                    c = s.get("color", c)
                    m = s.get("marker", m)
                    ms = s.get("markersize", ms)
            return c, m, ms

        for orig_idx in orig_neuron_indices:
            c, m, ms = get_spec_attrs(orig_idx)
            c_list.append(c)
            m_list.append(m)
            ms_list.append(ms)

        c_array = c_list
        marker_list = np.array(m_list)
        size_list = np.array(ms_list)
        # If markers vary, we might need multiple scatter calls or loop.
        # Matplotlib scatter accepts list of colors/sizes
        # but SINGLE marker style usually.
        # Actually scatter does NOT accept list of markers.
        # We must group by marker type if markers vary.

        # Check if multiple markers used
        unique_markers = set(m_list)
        if len(unique_markers) > 1:
            if not show_group_strip:
                # We need to loop
                for um in unique_markers:
                    mask = np.array(m_list) == um
                    # Line-based markers (x, +, |, _) need linewidths > 0
                    lw = 0.5 if um in ("x", "+", "|", "_", "1", "2", "3", "4") else 0
                    ax_raster.scatter(
                        spike_times[mask],
                        plot_neuron_indices[mask],
                        s=np.array(ms_list)[mask],
                        c=np.array(c_list, dtype=object)[mask],
                        marker=um,
                        linewidths=lw,
                    )
            # Skip the main scatter call
            skip_main_scatter = True
        else:
            marker = m_list[0] if m_list else marker
            ms_array = ms_list
            skip_main_scatter = False

    if not skip_main_scatter:
        # If sizes vary? scatter accepts array of sizes 's'
        if neuron_specs is not None:
            # attributes were collected above
            s_arg = ms_array
        else:
            s_arg = marker_size

        # Line-based markers need linewidths > 0
        lw = 0.5 if marker in ("x", "+", "|", "_", "1", "2", "3", "4") else 0

        if not draw_spikes_later:
            ax_raster.scatter(
                spike_times,
                plot_neuron_indices,
                s=s_arg,
                c=c_array,
                marker=marker,
                linewidths=lw,
            )
    ax_raster.set_xlim(t[0], t[-1])
    ax_raster.set_ylim(-0.5, n_neurons - 0.5)
    ax_raster.set_ylabel(ylabel)
    ax_raster.yaxis.set_major_locator(MaxNLocator(integer=True))

    # Advanced Annotations
    # 1. Tracks (Horizontal lines for each neuron)
    if show_tracks:
        # For large N, this might be heavy. Use LineCollection?
        # Or just simple axhlines if N is not too huge.
        # For very large N, maybe skip or use alpha.
        track_alpha = 0.1 if n_neurons > 100 else 0.2
        track_lw = 0.5
        # Draw lines at 0, 1, ... N-1
        # range(n_neurons) maps to y positions.
        # But we actually want lines at integer positions.
        ax_raster.hlines(
            y=np.arange(n_neurons),
            xmin=t[0],
            xmax=t[-1],
            colors="gray",
            alpha=track_alpha,
            linewidth=track_lw,
            zorder=0,
        )

    # 2. Events (Vertical lines)
    if events is not None:
        def_evt_kwargs = {
            "color": "red",
            "linestyle": "--",
            "alpha": 0.8,
            "linewidth": 1.0,
        }
        if event_kwargs:
            def_evt_kwargs.update(event_kwargs)

        if isinstance(events, dict):
            # Cycle colors if not specified? Or just use default.
            # Ideally one color per key if user wants?
            # For now use default kwargs for all
            for label, times in events.items():
                for et in times:
                    ax_raster.axvline(x=et, **def_evt_kwargs)
        else:
            # Sequence
            for et in events:
                ax_raster.axvline(x=et, **def_evt_kwargs)

    # 3. Regions (Shaded intervals)
    if regions is not None:
        def_reg_kwargs = {"color": "yellow", "alpha": 0.2}
        if region_kwargs:
            def_reg_kwargs.update(region_kwargs)

        if isinstance(regions, dict):
            for label, intervals in regions.items():
                for start, end in intervals:
                    ax_raster.axvspan(start, end, **def_reg_kwargs)
        else:
            for start, end in regions:
                ax_raster.axvspan(start, end, **def_reg_kwargs)

    spike_count = len(spike_times)
    fired_neurons = len(np.unique(orig_neuron_indices)) if spike_count > 0 else 0
    stats_title = f"Fired {fired_neurons}/{n_neurons}, Spikes {spike_count}"

    if title:
        ax_raster.set_title(title)
    else:
        ax_raster.set_title(f"Spike raster {stats_title}")

    # Add separators and group labels
    if group_key and show_group_separators:
        sep_args = (
            separator_style
            if separator_style
            else {"color": "gray", "linestyle": "--", "alpha": 0.5, "linewidth": 0.8}
        )

        # We have boundaries at the TOP of groups.
        # We also need to label them. Ideally label is centered in the group band.

        prev_y = -0.5
        for y_limit, label in group_boundaries:
            if y_limit < n_neurons - 0.5:  # Don't draw line at very top if fully filled
                ax_raster.axhline(y_limit, **sep_args)

            # Add text label only when no strip is shown (strip draws labels itself)
            if not show_group_strip:
                mid_y = (prev_y + y_limit) / 2
                label_x = -0.02 if group_strip_side == "left" else 1.01
                label_ha = "right" if group_strip_side == "left" else "left"
                ax_raster.text(
                    label_x,
                    mid_y,
                    str(label),
                    transform=ax_raster.get_yaxis_transform(),
                    va="center",
                    ha=label_ha,
                    fontsize=8,
                    color=sep_args.get("color", "black"),
                )

            prev_y = y_limit

    # Optional group strip
    if show_group_strip:
        if neurons_df is None:
            raise ValueError("neurons_df must be provided for group strip.")

        group_col = group_color_key or group_key
        if group_col is None:
            raise ValueError(
                "group_color_key or group_key must be set for group strip."
            )
        if group_col not in neurons_df.columns:
            raise ValueError(f"Column '{group_col}' not found in neurons_df.")

        cb_args = {
            "width": 0.06,
            "pad": 0.005,
            "alpha": 0.9,
            "label_fontsize": 7,
            "label_weight": "bold",
            "legend_fontsize": 6,
            "legend_ncol_threshold": 15,
            "min_label_distance": 0.02,
            "min_span_frac": 0.01,
            "span_line_frac": 0.005,
            "strip_x0": 0.3,
            "strip_width": 0.4,
            "label_x": None,
            "label_gap": 0.05,
            "bracket_x0": 0.78,
            "bracket_x1": 0.95,
            "label_sep": " / ",
            "group_sep_color": "black",
            "group_sep_lw": 1.4,
            "sub_hue_span": 0.06,
            "sub_val_span": 0.18,
            "left_extra_pad": 0.04,
        }
        if group_strip_kwargs:
            cb_args.update(group_strip_kwargs)

        fig = ax_raster.figure
        pos = ax_raster.get_position()
        if group_strip_side == "right":
            cax_x0 = pos.x1 + cb_args["pad"]
        else:
            cax_x0 = (
                pos.x0 - cb_args["pad"] - cb_args["width"] - cb_args["left_extra_pad"]
            )
        cax = fig.add_axes([cax_x0, pos.y0, cb_args["width"], pos.height])

        if group_strip_side == "left":
            ylabel_x = (cax_x0 - cb_args["pad"] - pos.x0) / pos.width
            ax_raster.yaxis.set_label_coords(ylabel_x, 0.5)

        # Resolve subgroup and top-group labels per neuron in sorted order
        sub_labels_raw = subgroup_labels[sorted_indices]
        top_group_labels = group_labels[sorted_indices]
        group_labels_list = sub_labels_raw.tolist()

        use_subgroups = group_key is not None and group_col != group_key
        if use_subgroups:
            if group_label_mode == "top":
                group_labels_list = [str(top) for top in top_group_labels]
            elif group_label_mode == "sub":
                group_labels_list = [str(sub) for sub in group_labels_list]
            else:
                group_labels_list = [
                    f"{top}{cb_args['label_sep']}{sub}"
                    for top, sub in zip(top_group_labels, group_labels_list)
                ]

        base_colors, subgroup_colors, subgroups_by_top, top_groups_order = (
            _build_group_color_maps(
                top_group_labels,
                sub_labels_raw,
                use_subgroups,
                strip_cmap,
                strip_cmap,
                cb_args["sub_hue_span"],
                cb_args["sub_val_span"],
            )
        )

        # Draw patches using group/subgroup colors
        for i, _ in enumerate(group_labels_list):
            tg = top_group_labels[i]
            sub = sub_labels_raw[i]
            if use_subgroups:
                color = subgroup_colors.get((tg, sub), base_colors.get(tg, "#cccccc"))
            else:
                color = base_colors.get(sub, "#cccccc")
            cax.add_patch(
                Rectangle(
                    (cb_args["strip_x0"], i - 0.5),
                    cb_args["strip_width"],
                    1.0,
                    facecolor=color,
                    edgecolor="none",
                    alpha=cb_args["alpha"],
                )
            )

        # Compute ranges for labels
        type_ranges: dict[str, dict[str, int]] = {}
        for i, label in enumerate(group_labels_list):
            if label not in type_ranges:
                type_ranges[label] = {"start": i, "end": i}
            else:
                type_ranges[label]["end"] = i

        unique_types = list(dict.fromkeys(group_labels_list))
        sorted_types = sorted(unique_types, key=lambda x: type_ranges[x]["start"])
        label_positions: list[float] = []
        for label in sorted_types:
            start_idx = type_ranges[label]["start"]
            end_idx = type_ranges[label]["end"]
            mid_y = (start_idx + end_idx) / 2

            min_distance = n_neurons * cb_args["min_label_distance"]
            too_close = any(abs(mid_y - pos) < min_distance for pos in label_positions)

            if (not too_close) or (
                (end_idx - start_idx) > (n_neurons * cb_args["min_span_frac"])
            ):
                label_x = cb_args["label_x"]
                if label_x is None:
                    if group_strip_side == "right":
                        label_x = (
                            cb_args["strip_x0"]
                            + cb_args["strip_width"]
                            + cb_args["label_gap"]
                        )
                        label_ha = "left"
                    else:
                        label_x = cb_args["strip_x0"] - cb_args["label_gap"]
                        label_ha = "right"
                else:
                    label_ha = "left"
                cax.text(
                    label_x,
                    mid_y,
                    str(label),
                    ha=label_ha,
                    va="center",
                    fontsize=cb_args["label_fontsize"],
                    transform=cax.transData,
                    weight=cb_args["label_weight"],
                )
                label_positions.append(mid_y)

        cax.set_xlim(0, 1)
        cax.set_ylim(ax_raster.get_ylim())
        cax.set_xticks([])
        cax.set_yticks([])
        cax.set_frame_on(False)
        for spine in cax.spines.values():
            spine.set_visible(False)

        if group_strip_legend:
            if group_label_mode == "top":
                legend_elements = [
                    mpatches.Patch(color=base_colors[tg], label=str(tg))
                    for tg in top_groups_order
                ]
            elif group_label_mode == "sub":
                legend_elements = [
                    mpatches.Patch(
                        color=subgroup_colors.get((tg, sub), base_colors.get(tg)),
                        label=str(sub),
                    )
                    for tg in top_groups_order
                    for sub in subgroups_by_top[tg]
                ]
            else:  # top_sub
                legend_elements = [
                    mpatches.Patch(
                        color=subgroup_colors.get((tg, sub), base_colors.get(tg)),
                        label=f"{tg}{cb_args['label_sep']}{sub}",
                    )
                    for tg in top_groups_order
                    for sub in subgroups_by_top[tg]
                ]
            ncol = 2 if len(legend_elements) > cb_args["legend_ncol_threshold"] else 1
            cax.legend(
                handles=legend_elements,
                loc="upper right",
                bbox_to_anchor=(1, 1),
                fontsize=cb_args["legend_fontsize"],
                ncol=ncol,
                frameon=True,
                shadow=True,
            )

        # If we postponed spike drawing earlier, now draw spikes with
        # matching group/subgroup colors derived above.
        if draw_spikes_later:
            # Build color list per spike (orig_neuron_indices order)
            top_vals = group_labels[orig_neuron_indices]
            sub_vals = subgroup_labels[orig_neuron_indices]
            spike_colors = []
            for top, sub in zip(top_vals, sub_vals):
                if use_subgroups:
                    spike_colors.append(
                        subgroup_colors.get((top, sub), base_colors.get(top, "black"))
                    )
                else:
                    spike_colors.append(base_colors.get(sub, "black"))

            spike_colors = np.array(spike_colors, dtype=object)
            if size_list is not None:
                s_arg = size_list
            else:
                s_arg = marker_size

            if marker_list is not None and len(set(marker_list)) > 1:
                for um in sorted(set(marker_list)):
                    mask = marker_list == um
                    lw = 0.5 if um in ("x", "+", "|", "_", "1", "2", "3", "4") else 0
                    ax_raster.scatter(
                        spike_times[mask],
                        plot_neuron_indices[mask],
                        s=s_arg[mask],
                        c=spike_colors[mask],
                        marker=um,
                        linewidths=lw,
                    )
            else:
                marker_use = marker_list[0] if marker_list is not None else marker
                lw = (
                    0.5 if marker_use in ("x", "+", "|", "_", "1", "2", "3", "4") else 0
                )

                ax_raster.scatter(
                    spike_times,
                    plot_neuron_indices,
                    s=s_arg,
                    c=spike_colors,
                    marker=marker_use,
                    linewidths=lw,
                )

    ax_raster.text(
        0.01,
        0.99,  # Move to top left to avoid conflict with right-side group labels
        f"N={spike_count}",
        transform=ax_raster.transAxes,
        ha="left",
        va="top",
        bbox=dict(facecolor="white", alpha=0.7, edgecolor="none"),
        fontsize=8,
    )

    if show_rate or show_group_rate:
        assert ax_rate is not None
        fr = None
        if isinstance(rate, (np.ndarray, torch.Tensor)):
            fr = _to_numpy(rate)
            if fr.ndim == 2 and fr.shape[1] == 1:
                fr = fr[:, 0]
            if fr.ndim != 1:
                raise ValueError("rate must be 1D or shape (T, 1).")
            if fr.shape[0] != n_time:
                raise ValueError("rate length must match time axis length.")
        elif rate is True:
            eff_dt = dt if dt is not None else (t[1] - t[0] if len(t) > 1 else 1.0)
            fr = firing_rate(
                spikes_np, width=rate_window_ms / eff_dt, dt=eff_dt * 1e-3, axis=-1
            )

        group_alpha = 0.45
        group_lw = 0.9
        group_zorder = 1
        total_lw = 1.8
        total_zorder = 2

        if show_group_rate and group_key is not None:
            if group_key not in (neurons_df.columns if neurons_df is not None else []):
                raise ValueError(
                    "neurons_df with group_key is required for group_rate."
                )
            group_color_map: dict[str, Any] = {}
            if isinstance(spike_color, dict):
                if not any(isinstance(k, (int, np.integer)) for k in spike_color):
                    group_color_map = dict(spike_color)

            if not group_color_map:
                group_palette = _sample_cmap_colors(strip_cmap, len(groups))
                group_color_map = dict(zip(groups, group_palette))

            if isinstance(group_rate, dict):
                group_rates = {k: _to_numpy(v) for k, v in group_rate.items()}
                for g in groups:
                    if g not in group_rates:
                        continue
                    g_rate = group_rates[g]
                    if g_rate.ndim == 2 and g_rate.shape[1] == 1:
                        g_rate = g_rate[:, 0]
                    if g_rate.ndim != 1 or g_rate.shape[0] != n_time:
                        raise ValueError(
                            "group_rate values must be 1D and match time axis."
                        )
                    ax_rate.plot(
                        t,
                        g_rate,
                        color=group_color_map.get(g, "black"),
                        alpha=group_alpha,
                        lw=group_lw,
                        zorder=group_zorder,
                        label=str(g),
                    )
            elif isinstance(group_rate, (np.ndarray, torch.Tensor)):
                group_rate_arr = _to_numpy(group_rate)
                if group_rate_arr.ndim != 2 or group_rate_arr.shape[0] != n_time:
                    raise ValueError("group_rate array must have shape (T, G).")
                if group_rate_arr.shape[1] != len(groups):
                    raise ValueError("group_rate array must match number of groups.")
                for idx, g in enumerate(groups):
                    ax_rate.plot(
                        t,
                        group_rate_arr[:, idx],
                        color=group_color_map.get(g, "black"),
                        alpha=group_alpha,
                        lw=group_lw,
                        zorder=group_zorder,
                        label=str(g),
                    )
            elif group_rate is True:
                eff_dt = dt if dt is not None else (t[1] - t[0] if len(t) > 1 else 1.0)
                for g in groups:
                    g_indices = np.flatnonzero(group_labels == g)
                    if g_indices.size == 0:
                        continue
                    g_rate = firing_rate(
                        spikes_np[:, g_indices],
                        width=rate_window_ms / eff_dt,
                        dt=eff_dt * 1e-3,
                        axis=-1,
                    )
                    ax_rate.plot(
                        t,
                        g_rate,
                        color=group_color_map.get(g, "black"),
                        alpha=group_alpha,
                        lw=group_lw,
                        zorder=group_zorder,
                        label=str(g),
                    )

        if fr is not None:
            ax_rate.plot(
                t,
                fr,
                color="black",
                lw=total_lw,
                alpha=0.9,
                zorder=total_zorder,
            )
        ax_rate.set_xlim(t[0], t[-1])
        ax_rate.set_ylabel("Rate (Hz)")
        ax_rate.set_xlabel(xlabel)
        # Hide x-labels of raster
        ax_raster.set_xticklabels([])
        ax_raster.set_xlabel("")

        return ax_raster, ax_rate
    else:
        ax_raster.set_xlabel(xlabel)
        return ax_raster

plot_spectrum(data, dt=None, nperseg=None, ax=None, mode='loglog', show_mean=True, title='Frequency Spectrum', color=None, label='Mean', alpha=0.2, mean_linewidth=1.5)

Plot frequency spectrum of timeseries data.

Computes power spectral density using Welch's method and visualizes the frequency content. For 2D input (time, neurons), plots individual traces with optional mean overlay.

Parameters:

Name Type Description Default
data Union[ndarray, Tensor]

Input timeseries with shape (time,) or (time, neurons).

required
dt float | None

Sampling interval in ms. Default 1.0.

None
nperseg int | None

Length of FFT segments. Default is min(256, time//4).

None
ax Axes | None

Existing axes to plot on. Creates new figure if None.

None
mode str

Plot scale - "loglog" (default) or "semilogx".

'loglog'
show_mean bool

Whether to overlay the mean spectrum (for 2D data).

True
title str

Plot title.

'Frequency Spectrum'
color str | None

Color for traces. Uses default if None.

None
label str | None

Legend label for mean trace.

'Mean'
alpha float

Opacity for individual traces.

0.2
mean_linewidth float

Line width for mean trace.

1.5

Returns:

Type Description
tuple[ndarray, ndarray, Axes]

Tuple of (frequencies, power_spectrum, axes).

Example

freqs, power, ax = plot_spectrum(spikes, dt=1.0, mode="loglog")

Source code in btorch/visualisation/timeseries.py
def plot_spectrum(
    data: Union[np.ndarray, torch.Tensor],
    dt: float | None = None,
    nperseg: int | None = None,
    ax: Axes | None = None,
    mode: str = "loglog",
    show_mean: bool = True,
    title: str = "Frequency Spectrum",
    color: str | None = None,
    label: str | None = "Mean",
    alpha: float = 0.2,
    mean_linewidth: float = 1.5,
) -> tuple[np.ndarray, np.ndarray, Axes]:
    """Plot frequency spectrum of timeseries data.

    Computes power spectral density using Welch's method and visualizes
    the frequency content. For 2D input (time, neurons), plots individual
    traces with optional mean overlay.

    Args:
        data: Input timeseries with shape (time,) or (time, neurons).
        dt: Sampling interval in ms. Default 1.0.
        nperseg: Length of FFT segments. Default is min(256, time//4).
        ax: Existing axes to plot on. Creates new figure if None.
        mode: Plot scale - "loglog" (default) or "semilogx".
        show_mean: Whether to overlay the mean spectrum (for 2D data).
        title: Plot title.
        color: Color for traces. Uses default if None.
        label: Legend label for mean trace.
        alpha: Opacity for individual traces.
        mean_linewidth: Line width for mean trace.

    Returns:
        Tuple of (frequencies, power_spectrum, axes).

    Example:
        >>> freqs, power, ax = plot_spectrum(spikes, dt=1.0, mode="loglog")
    """
    data_np = _to_numpy(data)
    if dt is None:
        dt = 1.0

    freqs, power = compute_spectrum(data_np, dt=dt, nperseg=nperseg)

    if ax is None:
        fig, ax = plt.subplots(figsize=(6, 4))

    power_db = 10 * np.log10(power)

    y_data = power if "log" in mode else power_db

    # Defaults
    trace_color = color if color else "blue"
    mean_color = color if color else "black"

    if show_mean and data_np.ndim > 1:
        # Plot individual traces
        if alpha > 0:
            ax.plot(freqs, y_data, color=trace_color, alpha=alpha, lw=0.5)
        # Plot mean
        mean_power = y_data.mean(axis=1) if y_data.ndim > 1 else y_data
        ax.plot(freqs, mean_power, color=mean_color, lw=mean_linewidth, label=label)
    else:
        ax.plot(freqs, y_data, color=mean_color, label=label)

    if mode == "loglog":
        ax.set_xscale("log")
        ax.set_yscale("log")
        ax.set_ylabel("Power")
    elif mode == "semilogx":
        ax.set_xscale("log")
        ax.set_ylabel("Power (dB)")

    ax.set_xlabel("Frequency (Hz)")
    ax.set_title(title)

    return freqs, power, ax

plot_traces(data, dt=None, times=None, ax=None, neurons=None, labels=None, colors=None, title=None, xlabel='Time (ms)', ylabel=None, legend=True, alpha=0.8)

Plot continuous timeseries traces.

Parameters

data : array-like Shape (Time, Neurons) or (Time, Neurons, Features). dt : float, optional Time step. times : array-like, optional Explicit time array. ax : Axes, optional Axis to plot on. neurons : list of int or int, optional Indices of neurons to plot. If None, plots all (careful with large N). If int, samples that many neurons randomly. labels : list of str, optional Labels for the legend. colors : list of colors, optional Colors for traces. title : str, optional Plot title.

Returns

Axes

Source code in btorch/visualisation/timeseries.py
def plot_traces(
    data: Union[np.ndarray, torch.Tensor],
    dt: float | None = None,
    times: Sequence[float] | None = None,
    ax: Axes | None = None,
    neurons: Sequence[int] | int | None = None,
    labels: Sequence[str] | str | None = None,
    colors: Sequence[Any] | None = None,
    title: str | None = None,
    xlabel: str = "Time (ms)",
    ylabel: str | None = None,
    legend: bool = True,
    alpha: float = 0.8,
) -> Axes:
    """Plot continuous timeseries traces.

    Parameters
    ----------
    data : array-like
        Shape (Time, Neurons) or (Time, Neurons, Features).
    dt : float, optional
        Time step.
    times : array-like, optional
        Explicit time array.
    ax : Axes, optional
        Axis to plot on.
    neurons : list of int or int, optional
        Indices of neurons to plot. If None, plots all (careful with large N).
        If int, samples that many neurons randomly.
    labels : list of str, optional
        Labels for the legend.
    colors : list of colors, optional
        Colors for traces.
    title : str, optional
        Plot title.

    Returns
    -------
    Axes
    """
    data_np = _to_numpy(data)
    t = _get_time_axis(data_np.shape[0], dt, times)

    if data_np.ndim == 2:
        # (Time, Neurons)
        data_np = data_np[:, :, np.newaxis]  # make it (Time, Neurons, 1)
    elif data_np.ndim != 3:
        raise ValueError("Data must be 2D (T, N) or 3D (T, N, F)")

    # Select neurons
    n_neurons = data_np.shape[1]
    if neurons is None:
        neuron_indices = np.arange(n_neurons)
    elif isinstance(neurons, int):
        if neurons >= n_neurons:
            neuron_indices = np.arange(n_neurons)
        else:
            neuron_indices = np.sort(
                np.random.choice(n_neurons, neurons, replace=False)
            )
    else:
        neuron_indices = np.array(neurons)

    n_features = data_np.shape[2]

    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 4))

    if colors is None:
        # Generate distinct colors for each neuron
        cmap = plt.get_cmap("turbo", len(neuron_indices))
        colors = [cmap(i) for i in range(len(neuron_indices))]

    for i, idx in enumerate(neuron_indices):
        c = (
            colors[i]
            if isinstance(colors, (list, np.ndarray))
            and len(colors) == len(neuron_indices)
            else None
        )

        for feat in range(n_features):
            trace = data_np[:, idx, feat]

            # Construct label
            lbl = None
            if labels is not None:
                if isinstance(labels, str):
                    lbl = f"{labels} {idx}"
                elif len(labels) == len(neuron_indices):
                    lbl = labels[i]
                else:
                    lbl = f"Neuron {idx}"
            else:
                lbl = f"Neuron {idx}"

            if n_features > 1:
                lbl += f" (f{feat})"

            ax.plot(t, trace, label=lbl, color=c, alpha=alpha)

    ax.set_xlabel(xlabel)
    if ylabel:
        ax.set_ylabel(ylabel)
    ax.set_xlim(t[0], t[-1])

    if title:
        ax.set_title(title)

    if legend and len(neuron_indices) <= 20:  # Limit legend clutter
        ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left")

    return ax