Skip to content

Utils

btorch.utils.bench

Benchmarking utilities.

Performance measurement tools for PyTorch code, supporting both CPU wall-clock and GPU event-based timing with warmup and statistical summarization.

Classes

PerfTimer

Context manager for measuring execution time.

Example

with PerfTimer() as timer: ... result = some_function() print(f"Took {timer.elapsed_ms():.2f} ms")

Source code in btorch/utils/bench.py
class PerfTimer:
    """Context manager for measuring execution time.

    Example:
        >>> with PerfTimer() as timer:
        ...     result = some_function()
        >>> print(f"Took {timer.elapsed_ms():.2f} ms")
    """

    def __init__(self):
        self.start_time = None
        self.end_time = None

    def __enter__(self):
        self.start_time = time.perf_counter()
        return self

    def __exit__(self, *args):
        self.end_time = time.perf_counter()

    def elapsed_ms(self) -> float:
        """Return elapsed time in milliseconds.

        Returns:
            Elapsed time from ``__enter__`` to ``__exit__`` (or now
            if ``__exit__`` hasn't been called).

        Raises:
            RuntimeError: If timer was never started.
        """
        if self.start_time is None:
            raise RuntimeError("Timer never started")
        end = self.end_time if self.end_time is not None else time.perf_counter()
        return (end - self.start_time) * 1000
Functions
elapsed_ms()

Return elapsed time in milliseconds.

Returns:

Type Description
float

Elapsed time from __enter__ to __exit__ (or now

float

if __exit__ hasn't been called).

Raises:

Type Description
RuntimeError

If timer was never started.

Source code in btorch/utils/bench.py
def elapsed_ms(self) -> float:
    """Return elapsed time in milliseconds.

    Returns:
        Elapsed time from ``__enter__`` to ``__exit__`` (or now
        if ``__exit__`` hasn't been called).

    Raises:
        RuntimeError: If timer was never started.
    """
    if self.start_time is None:
        raise RuntimeError("Timer never started")
    end = self.end_time if self.end_time is not None else time.perf_counter()
    return (end - self.start_time) * 1000

Functions

do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode='mean', timing_method='cpu', sync_cuda=True)

Benchmark function runtime with warmup and statistics.

Supports both CPU wall-clock timing and GPU CUDA event timing. Warmup and repetition can be specified as iteration counts (int) or durations in milliseconds (float).

Parameters:

Name Type Description Default
fn Callable

Function to benchmark (callable with no arguments).

required
warmup int | float

Warmup iterations (int) or duration in ms (float).

25
rep int | float

Measurement iterations (int) or duration in ms (float).

100
grad_to_none Optional[Tensor]

Optional tensor whose gradient is reset to None between repetitions.

None
quantiles Optional[List[float]]

Optional quantiles to compute (e.g., [0.05, 0.95]).

None
return_mode Literal['min', 'max', 'mean', 'median', 'all']

Central statistic to return: "min", "max", "mean", "median", or "all" for all stats.

'mean'
timing_method Literal['gpu', 'cpu']

"gpu" for CUDA events (if available) or "cpu" for wall-clock timing.

'cpu'
sync_cuda bool

Whether to synchronize CUDA before/after timing (only applies to CPU timing).

True

Returns:

Type Description
Union[float, Dict[str, float]]

Timing result. Float for single statistics, dict for "all"

Union[float, Dict[str, float]]

or when quantiles are specified.

Example

def bench_fn(): ... return torch.mm(a, b) do_bench(bench_fn, warmup=10, rep=100, return_mode="median") 0.523

Source code in btorch/utils/bench.py
def do_bench(
    fn: Callable,
    warmup: int | float = 25,
    rep: int | float = 100,
    grad_to_none: Optional[torch.Tensor] = None,
    quantiles: Optional[List[float]] = None,
    return_mode: Literal["min", "max", "mean", "median", "all"] = "mean",
    timing_method: Literal["gpu", "cpu"] = "cpu",
    sync_cuda: bool = True,
) -> Union[float, Dict[str, float]]:
    """Benchmark function runtime with warmup and statistics.

    Supports both CPU wall-clock timing and GPU CUDA event timing.
    Warmup and repetition can be specified as iteration counts (int)
    or durations in milliseconds (float).

    Args:
        fn: Function to benchmark (callable with no arguments).
        warmup: Warmup iterations (int) or duration in ms (float).
        rep: Measurement iterations (int) or duration in ms (float).
        grad_to_none: Optional tensor whose gradient is reset to None
            between repetitions.
        quantiles: Optional quantiles to compute (e.g., [0.05, 0.95]).
        return_mode: Central statistic to return:
            "min", "max", "mean", "median", or "all" for all stats.
        timing_method: "gpu" for CUDA events (if available) or "cpu"
            for wall-clock timing.
        sync_cuda: Whether to synchronize CUDA before/after timing
            (only applies to CPU timing).

    Returns:
        Timing result. Float for single statistics, dict for "all"
        or when quantiles are specified.

    Example:
        >>> def bench_fn():
        ...     return torch.mm(a, b)
        >>> do_bench(bench_fn, warmup=10, rep=100, return_mode="median")
        0.523
    """
    if not callable(fn):
        raise TypeError("The 'fn' parameter must be callable")

    if timing_method == "total":
        timing_method = "cpu"
    if timing_method not in ["gpu", "cpu"]:
        raise ValueError("timing_method must be either 'gpu' or 'cpu'")

    if timing_method == "gpu" and not torch.cuda.is_available():
        print(
            "Warning: GPU timing requested but CUDA is not available. "
            "Falling back to cpu timing."
        )
        timing_method = "cpu"

    if timing_method == "gpu":
        use_reps = isinstance(warmup, int) and isinstance(rep, int)
        if use_reps:
            if warmup < 0 or rep < 1:
                raise ValueError("warmup must be >= 0 and rep must be >= 1")
        else:
            warmup = float(warmup)
            rep = float(rep)
            if warmup < 0.0 or rep <= 0.0:
                raise ValueError("warmup and rep must be positive durations in ms")

        from triton.testing import _summarize_statistics, runtime

        di = runtime.driver.active.get_device_interface()

        fn()
        di.synchronize()

        cache = runtime.driver.active.get_empty_cache_for_benchmark()

        if use_reps:
            n_warmup = warmup
            n_repeat = rep
        else:
            start_event = di.Event(enable_timing=True)
            end_event = di.Event(enable_timing=True)
            start_event.record()
            for _ in range(5):
                runtime.driver.active.clear_cache(cache)
                fn()
            end_event.record()
            di.synchronize()
            estimate_ms = start_event.elapsed_time(end_event) / 5
            n_warmup = max(1, int(warmup / estimate_ms))
            n_repeat = max(1, int(rep / estimate_ms))

        start_event = [di.Event(enable_timing=True) for _ in range(n_repeat)]
        end_event = [di.Event(enable_timing=True) for _ in range(n_repeat)]
        for _ in range(n_warmup):
            fn()
        for i in range(n_repeat):
            if grad_to_none is not None:
                for x in grad_to_none:
                    x.grad = None
            runtime.driver.active.clear_cache(cache)
            start_event[i].record()
            fn()
            end_event[i].record()
        di.synchronize()
        times = [s.elapsed_time(e) for s, e in zip(start_event, end_event)]
        return _summarize_statistics(times, quantiles, return_mode)

    use_reps = isinstance(warmup, int) and isinstance(rep, int)
    if use_reps:
        if warmup < 0 or rep < 1:
            raise ValueError("warmup must be >= 0 and rep must be >= 1")
    else:
        warmup = float(warmup)
        rep = float(rep)
        if warmup < 0.0 or rep <= 0.0:
            raise ValueError("warmup and rep must be positive durations in ms")

    if use_reps:
        for _ in range(warmup):
            fn()
    else:
        warmup_start = time.perf_counter()
        while (time.perf_counter() - warmup_start) * 1000 < warmup:
            fn()

    times = []
    if use_reps:
        rep_start = None
    else:
        rep_start = time.perf_counter()
    while True:
        if use_reps:
            if len(times) >= rep:
                break
        else:
            if (time.perf_counter() - rep_start) * 1000 >= rep:
                break
        if grad_to_none is not None:
            for x in grad_to_none:
                x.grad = None
        if sync_cuda and torch.cuda.is_available():
            torch.cuda.synchronize()
        with PerfTimer() as timer:
            fn()
            if sync_cuda and torch.cuda.is_available():
                torch.cuda.synchronize()
        times.append(timer.elapsed_ms())

    times = torch.tensor(times)

    if quantiles is not None:
        ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist()
        if len(ret) == 1:
            ret = ret[0]
        return ret
    return getattr(torch, return_mode)(times).item()

btorch.utils.conf

OmegaConf configuration utilities.

Helpers for loading, manipulating, and comparing structured configs. Provides CLI-style dotlist conversion and diff operations for configuration management workflows.

Attributes

ConfigT = TypeVar('ConfigT') module-attribute

Functions

diff_conf(conf_a, conf_b, mode=None)

Compare conf_b to conf_a and return a structured OmegaConf diff.

The returned config contains only the selected changed keys. For removed keys (when moded by mode), values are set to None so callers can render these entries via :func:to_dotlist as key=null.

Source code in btorch/utils/conf.py
def diff_conf(
    conf_a: DictConfig | ListConfig | Any,
    conf_b: DictConfig | ListConfig,
    mode: (Iterable[Literal["changed", "added", "removed"]] | None) = None,
) -> DictConfig | ListConfig:
    """Compare ``conf_b`` to ``conf_a`` and return a structured OmegaConf diff.

    The returned config contains only the selected changed keys. For removed
    keys (when moded by ``mode``), values are set to ``None`` so callers can
    render these entries via :func:`to_dotlist` as ``key=null``.
    """

    if not isinstance(conf_a, (DictConfig, ListConfig)):
        conf_a = OmegaConf.structured(conf_a)
    if not isinstance(conf_b, (DictConfig, ListConfig)):
        conf_b = OmegaConf.structured(conf_b)

    records = diff_conf_records(conf_a, conf_b, mode=mode)

    def _is_index(token: str) -> bool:
        return token.isdigit()

    def _empty_container(next_token: str):
        return [] if _is_index(next_token) else {}

    def _set_path(root, path: str, value):
        if path == "<root>":
            return value

        tokens = path.split(".")
        if root is None:
            root = _empty_container(tokens[0])

        cur = root
        for i, token in enumerate(tokens):
            is_last = i == len(tokens) - 1

            if isinstance(cur, dict):
                if is_last:
                    cur[token] = value
                    break
                if token not in cur:
                    cur[token] = _empty_container(tokens[i + 1])
                cur = cur[token]
                continue

            if isinstance(cur, list):
                idx = int(token)
                while len(cur) <= idx:
                    if is_last:
                        cur.append(None)
                    else:
                        cur.append(_empty_container(tokens[i + 1]))
                if is_last:
                    cur[idx] = value
                    break
                if cur[idx] is None:
                    cur[idx] = _empty_container(tokens[i + 1])
                cur = cur[idx]
                continue

            raise TypeError(f"Cannot set nested path '{path}' through scalar node.")

        return root

    diff_tree = None
    for path, entry in sorted(records.items()):
        if entry["status"] == "removed":
            value = None
        else:
            value = entry["new"]
        diff_tree = _set_path(diff_tree, path, value)

    if diff_tree is None:
        # Keep return type stable and flattenable.
        if isinstance(conf_b, ListConfig):
            return OmegaConf.create([])
        return OmegaConf.create({})

    if not isinstance(diff_tree, (dict, list, DictConfig, ListConfig)):
        raise ValueError(
            "diff_conf produced a scalar root diff. "
            "Expected a DictConfig/ListConfig root."
        )

    return OmegaConf.create(diff_tree)

diff_conf_dotlist(conf_a, conf_b, mode=None, removed_prefix='~')

Build CLI-style overrides that transform conf_a into conf_b.

For added and changed entries, this emits "path=value". For removed entries (when moded by mode), this emits "{removed_prefix}path".

Source code in btorch/utils/conf.py
def diff_conf_dotlist(
    conf_a: DictConfig | ListConfig,
    conf_b: DictConfig | ListConfig,
    mode: (Iterable[Literal["changed", "added", "removed"]] | None) = None,
    removed_prefix: str = "~",
) -> list[str]:
    """Build CLI-style overrides that transform ``conf_a`` into ``conf_b``.

    For ``added`` and ``changed`` entries, this emits ``"path=value"``.
    For ``removed`` entries (when moded by ``mode``), this emits
    ``"{removed_prefix}path"``.
    """

    records = diff_conf_records(conf_a, conf_b, mode=mode)

    dotlist: list[str] = []
    for key in sorted(records.keys()):
        status = records[key]["status"]
        new_value = records[key]["new"]

        if status == "removed":
            dotlist.append(f"{removed_prefix}{key}")
            continue

        if isinstance(new_value, (DictConfig, ListConfig, dict, list)):
            raise ValueError(
                "diff_conf_dotlist cannot serialize container-valued changes at "
                f"'{key}'. Use diff_conf_records for this case."
            )

        value = "null" if new_value is None else new_value
        dotlist.append(f"{key}={value}")

    return dotlist

diff_conf_records(conf_a, conf_b, mode=None)

Compare conf_b to conf_a and return per-key value-level records.

Each record has the shape {"status": str, "old": object, "new": object}.

  • status='changed': key exists in both, value differs.
  • status='added': key exists only in conf_b.
  • status='removed': key exists only in conf_a.

This representation is suitable when a caller needs both key names and values, for example to build child-process overrides from a baseline config.

Source code in btorch/utils/conf.py
def diff_conf_records(
    conf_a: DictConfig | ListConfig,
    conf_b: DictConfig | ListConfig,
    mode: (Iterable[Literal["changed", "added", "removed"]] | None) = None,
) -> dict[str, dict[str, object]]:
    """Compare ``conf_b`` to ``conf_a`` and return per-key value-level records.

    Each record has the shape ``{"status": str, "old": object, "new": object}``.

    - ``status='changed'``: key exists in both, value differs.
    - ``status='added'``: key exists only in ``conf_b``.
    - ``status='removed'``: key exists only in ``conf_a``.

    This representation is suitable when a caller needs both key names and values,
    for example to build child-process overrides from a baseline config.
    """

    if not isinstance(conf_a, (DictConfig, ListConfig)):
        raise TypeError(
            "diff_conf_records expects conf_a to be DictConfig or ListConfig."
        )
    if not isinstance(conf_b, (DictConfig, ListConfig)):
        raise TypeError(
            "diff_conf_records expects conf_b to be DictConfig or ListConfig."
        )

    if mode is None:
        mode_set = {"changed", "added", "removed"}
    else:
        mode_set = set(mode)

    valid_status = {"changed", "added", "removed"}
    if not mode_set.issubset(valid_status):
        raise ValueError("mode must only contain: 'changed', 'added', 'removed'.")

    changed: set[str] = set()
    added: set[str] = set()
    removed: set[str] = set()
    changed_values: dict[str, tuple[object, object]] = {}
    added_values: dict[str, object] = {}
    removed_values: dict[str, object] = {}

    # Use plain containers so structured union explicit type metadata (`_type_`)
    # emitted by OmegaConf is visible to the recursive diff.
    plain_a = OmegaConf.to_container(conf_a, resolve=False)
    plain_b = OmegaConf.to_container(conf_b, resolve=False)

    def _is_container(node) -> bool:
        return isinstance(node, (dict, list))

    def _select_plain(node, path: str):
        if path == "<root>":
            return node

        cur = node
        for token in path.split("."):
            if isinstance(cur, dict):
                cur = cur[token]
                continue
            if isinstance(cur, list):
                cur = cur[int(token)]
                continue
            raise KeyError(f"Cannot descend through scalar at token '{token}'.")
        return cur

    def _collect_leaf_paths(node, path: str) -> set[str]:
        if isinstance(node, dict):
            out: set[str] = set()
            for key, value in node.items():
                new_path = f"{path}.{key}" if path else str(key)
                out |= _collect_leaf_paths(value, new_path)
            return out

        if isinstance(node, list):
            out = set()
            for idx, value in enumerate(node):
                new_path = f"{path}.{idx}" if path else str(idx)
                out |= _collect_leaf_paths(value, new_path)
            return out

        return {path} if path else {"<root>"}

    def _walk(a_node, b_node, path: str = ""):
        if isinstance(a_node, dict) and isinstance(b_node, dict):
            # Structured union switch: treat as full subtree replacement so old
            # type keys are discarded in one step.
            if (
                "_type_" in a_node
                and "_type_" in b_node
                and a_node["_type_"] != b_node["_type_"]
            ):
                key = path if path else "<root>"
                changed.add(key)
                changed_values[key] = (a_node, b_node)
                return

            keys_a = set(a_node.keys())
            keys_b = set(b_node.keys())

            for key in keys_a - keys_b:
                key_path = f"{path}.{key}" if path else str(key)
                leaf_paths = _collect_leaf_paths(a_node[key], key_path)
                removed.update(leaf_paths)
                for leaf_path in leaf_paths:
                    removed_values[leaf_path] = _select_plain(plain_a, leaf_path)

            for key in keys_b - keys_a:
                key_path = f"{path}.{key}" if path else str(key)
                leaf_paths = _collect_leaf_paths(b_node[key], key_path)
                added.update(leaf_paths)
                for leaf_path in leaf_paths:
                    added_values[leaf_path] = _select_plain(plain_b, leaf_path)

            for key in keys_a & keys_b:
                key_path = f"{path}.{key}" if path else str(key)
                _walk(a_node[key], b_node[key], key_path)
            return

        if isinstance(a_node, list) and isinstance(b_node, list):
            len_a = len(a_node)
            len_b = len(b_node)
            common = min(len_a, len_b)

            for idx in range(common):
                key_path = f"{path}.{idx}" if path else str(idx)
                _walk(a_node[idx], b_node[idx], key_path)

            for idx in range(common, len_a):
                key_path = f"{path}.{idx}" if path else str(idx)
                leaf_paths = _collect_leaf_paths(a_node[idx], key_path)
                removed.update(leaf_paths)
                for leaf_path in leaf_paths:
                    removed_values[leaf_path] = _select_plain(plain_a, leaf_path)

            for idx in range(common, len_b):
                key_path = f"{path}.{idx}" if path else str(idx)
                leaf_paths = _collect_leaf_paths(b_node[idx], key_path)
                added.update(leaf_paths)
                for leaf_path in leaf_paths:
                    added_values[leaf_path] = _select_plain(plain_b, leaf_path)
            return

        if _is_container(a_node) != _is_container(b_node):
            key = path if path else "<root>"
            changed.add(key)
            changed_values[key] = (a_node, b_node)
            return

        if a_node != b_node:
            key = path if path else "<root>"
            changed.add(key)
            changed_values[key] = (a_node, b_node)

    _walk(plain_a, plain_b)

    def _is_under(key: str, parent: str) -> bool:
        if parent == "<root>":
            return True
        return key == parent or key.startswith(f"{parent}.")

    # If explicit union type changed ("..._type_"), collapse that subtree into a
    # single changed record at the parent path. This avoids emitting stale
    # per-leaf removals for the previous union member.
    union_switch_parents: set[str] = set()
    for key in changed:
        if key == "_type_":
            union_switch_parents.add("<root>")
            continue
        if key.endswith("._type_"):
            union_switch_parents.add(key.rsplit(".", 1)[0])

    for parent in sorted(union_switch_parents, key=lambda p: (p != "<root>", p)):
        for key in list(changed):
            if _is_under(key, parent):
                changed.remove(key)
                changed_values.pop(key, None)
        for key in list(added):
            if _is_under(key, parent):
                added.remove(key)
                added_values.pop(key, None)
        for key in list(removed):
            if _is_under(key, parent):
                removed.remove(key)
                removed_values.pop(key, None)

        changed.add(parent)
        if parent == "<root>":
            changed_values[parent] = (plain_a, plain_b)
        else:
            changed_values[parent] = (
                _select_plain(plain_a, parent),
                _select_plain(plain_b, parent),
            )

    out: set[str] = set()
    if "changed" in mode_set:
        out |= changed
    if "added" in mode_set:
        out |= added
    if "removed" in mode_set:
        out |= removed

    records: dict[str, dict[str, object]] = {}
    for key in sorted(out):
        if key in changed:
            old_value, new_value = changed_values[key]
            records[key] = {"status": "changed", "old": old_value, "new": new_value}
            continue
        if key in added:
            records[key] = {"status": "added", "old": None, "new": added_values[key]}
            continue
        records[key] = {"status": "removed", "old": removed_values[key], "new": None}

    return records

get_dotkey(obj, key, default=None)

Get nested attribute by dot-separated key.

Parameters:

Name Type Description Default
obj Any

Object to access (supports DictConfig/ListConfig or regular objects).

required
key str

Dot-separated path (e.g., "a.b.c").

required
default

Value to return if key not found.

None

Returns:

Type Description

Value at the nested path, or default if not found.

Source code in btorch/utils/conf.py
def get_dotkey(obj: Any, key: str, default=None):
    """Get nested attribute by dot-separated key.

    Args:
        obj: Object to access (supports DictConfig/ListConfig or regular objects).
        key: Dot-separated path (e.g., "a.b.c").
        default: Value to return if key not found.

    Returns:
        Value at the nested path, or ``default`` if not found.
    """
    if isinstance(obj, (DictConfig, ListConfig)):
        return OmegaConf.select(obj, key, default=default)
    try:
        for part in key.split("."):
            obj = getattr(obj, part)
        return obj
    except AttributeError:
        return default

load_config(Param, use_config_file=True, search_path=Path('.'), argv_arglist=None, return_cli=False, make_concrete=True)

load_config(Param: type[ConfigT] | ConfigT, use_config_file: bool = True, search_path: Path = Path('.'), argv_arglist: list[str] | None = None, return_cli: Literal[False] = False, make_concrete: Literal[True] = True) -> ConfigT
load_config(Param: type[ConfigT] | ConfigT, use_config_file: bool = True, search_path: Path = Path('.'), argv_arglist: list[str] | None = None, return_cli: Literal[False] = False, make_concrete: Literal[False] = False) -> DictConfig | ListConfig
load_config(Param: type[ConfigT] | ConfigT, use_config_file: bool = True, search_path: Path = Path('.'), argv_arglist: list[str] | None = None, return_cli: Literal[True] = True, make_concrete: Literal[True] = True) -> tuple[ConfigT, DictConfig | ListConfig]
load_config(Param: type[ConfigT] | ConfigT, use_config_file: bool = True, search_path: Path = Path('.'), argv_arglist: list[str] | None = None, return_cli: Literal[True] = True, make_concrete: Literal[False] = False) -> tuple[DictConfig | ListConfig, DictConfig | ListConfig]

Load structured config from defaults, file, and CLI arguments.

Merges configuration in order: dataclass defaults -> config file -> CLI arguments. Config file path is read from config_path in CLI arguments.

Parameters:

Name Type Description Default
Param type[ConfigT] | ConfigT

Dataclass type or instance defining the configuration schema.

required
use_config_file bool

Whether to load from file specified by config_path CLI argument.

True
search_path Path

Directory to search for relative config paths.

Path('.')
argv_arglist list[str] | None

Optional CLI arguments list (defaults to sys.argv).

None
return_cli bool

If True, also return raw CLI config.

False
make_concrete bool

If True, convert to Python objects. If False, return OmegaConf containers.

True

Returns:

Type Description
Any

Loaded configuration. Tuple of (config, cli_config) if

Any

return_cli=True.

Note

Does not support help text or Literal types in the schema.

Source code in btorch/utils/conf.py
def load_config(
    Param: type[ConfigT] | ConfigT,
    use_config_file: bool = True,
    search_path: Path = Path("."),
    argv_arglist: list[str] | None = None,
    return_cli: bool = False,
    make_concrete: bool = True,
) -> Any:
    """Load structured config from defaults, file, and CLI arguments.

    Merges configuration in order: dataclass defaults -> config file ->
    CLI arguments. Config file path is read from ``config_path`` in CLI
    arguments.

    Args:
        Param: Dataclass type or instance defining the configuration schema.
        use_config_file: Whether to load from file specified by
            ``config_path`` CLI argument.
        search_path: Directory to search for relative config paths.
        argv_arglist: Optional CLI arguments list (defaults to sys.argv).
        return_cli: If True, also return raw CLI config.
        make_concrete: If True, convert to Python objects. If False,
            return OmegaConf containers.

    Returns:
        Loaded configuration. Tuple of (config, cli_config) if
        ``return_cli=True``.

    Note:
        Does not support help text or Literal types in the schema.
    """
    defaults = OmegaConf.structured(Param)
    if argv_arglist is None:
        cli_cfg_ = cli_cfg = OmegaConf.from_cli()
    else:
        cli_cfg_ = cli_cfg = OmegaConf.from_cli(argv_arglist)
    if use_config_file and "config_path" in cli_cfg:
        assert "config_path" not in Param.__dataclass_fields__
        config_path = Path(cli_cfg.config_path)
        if not config_path.is_file():
            config_path = search_path / config_path
            assert config_path.is_file()
        cfg_cli_file = OmegaConf.load(cli_cfg.config_path)
        if return_cli:
            cli_cfg_ = cli_cfg.copy()
        cli_cfg.pop("config_path")
    else:
        cfg_cli_file = OmegaConf.create()
    cfg = OmegaConf.unsafe_merge(defaults, cfg_cli_file, cli_cfg)
    # workaround for from_cli doesn't treat integer index as dict key in some cases.
    # cli_dotlist = to_dotlist(cli_cfg, use_equal=True)
    # cfg.merge_with_dotlist(cli_dotlist)
    if make_concrete:
        cfg = OmegaConf.to_object(cfg)

    if return_cli:
        return cfg, cli_cfg_
    return cfg

set_dotkey(obj, key, value)

Set nested attribute by dot-separated key.

Parameters:

Name Type Description Default
obj Any

Object to modify (supports DictConfig/ListConfig or regular objects).

required
key str

Dot-separated path (e.g., "a.b.c").

required
value

Value to set.

required

Returns:

Type Description

None

Source code in btorch/utils/conf.py
def set_dotkey(obj: Any, key: str, value):
    """Set nested attribute by dot-separated key.

    Args:
        obj: Object to modify (supports DictConfig/ListConfig or regular objects).
        key: Dot-separated path (e.g., "a.b.c").
        value: Value to set.

    Returns:
        None
    """
    if isinstance(obj, (DictConfig, ListConfig)):
        OmegaConf.update(obj, key, value)
        return
    parts = key.split(".")
    for part in parts[:-1]:
        obj = getattr(obj, part)
    setattr(obj, parts[-1], value)

to_dotlist(conf, use_equal=True, include=None, exclude=None, subfield=None, missing_subfield_policy='raise')

Flatten DictConfig/ListConfig to CLI-style dotlist.

Parameters

conf: Root OmegaConf container. Must be DictConfig or ListConfig. use_equal: If True, emit ["a.b=1"] form. If False, emit ["a.b", "1"] pairs. include, exclude: Optional exact-path filters applied to leaf paths. Paths are evaluated relative to subfield (if provided), otherwise relative to the root conf. subfield: Optional dotted path used as the flattening start point. Supports list indices (e.g. "a.b.1"). missing_subfield_policy: Behavior when subfield cannot be resolved. "raise" (default) raises KeyError. "empty" returns [].

Examples

{"a": {"b": 1}} -> ["a.b=1"] subfield="a" -> ["b=1"]

Source code in btorch/utils/conf.py
def to_dotlist(
    conf,
    use_equal: bool = True,
    include: set | None = None,
    exclude: set | None = None,
    subfield: str | None = None,
    missing_subfield_policy: Literal["raise", "empty"] = "raise",
):
    """Flatten DictConfig/ListConfig to CLI-style dotlist.

    Parameters
    ----------
    conf:
        Root OmegaConf container. Must be ``DictConfig`` or ``ListConfig``.
    use_equal:
        If True, emit ``["a.b=1"]`` form. If False, emit ``["a.b", "1"]`` pairs.
    include, exclude:
        Optional exact-path filters applied to leaf paths.
        Paths are evaluated relative to ``subfield`` (if provided), otherwise
        relative to the root ``conf``.
    subfield:
        Optional dotted path used as the flattening start point.
        Supports list indices (e.g. ``"a.b.1"``).
    missing_subfield_policy:
        Behavior when ``subfield`` cannot be resolved.
        ``"raise"`` (default) raises ``KeyError``.
        ``"empty"`` returns ``[]``.

    Examples
    --------
    ``{"a": {"b": 1}}`` -> ``["a.b=1"]``
    ``subfield="a"`` -> ``["b=1"]``
    """

    if not isinstance(conf, (DictConfig, ListConfig)):
        raise TypeError("to_dotlist expects DictConfig or ListConfig.")

    ret = []

    def _select_subfield(cfg, path: str):
        # Traverse the dotted path against OmegaConf containers only.
        cur = cfg
        for token in path.split("."):
            if token == "":
                raise ValueError("subfield contains an empty token.")
            if isinstance(cur, DictConfig):
                if token not in cur:
                    raise KeyError(f"subfield '{path}' not found at token '{token}'.")
                cur = cur[token]
                continue
            if isinstance(cur, ListConfig):
                try:
                    idx = int(token)
                except ValueError as exc:
                    raise KeyError(
                        f"subfield '{path}' expects a list index at token '{token}'."
                    ) from exc
                if idx < 0 or idx >= len(cur):
                    raise KeyError(f"subfield '{path}' list index out of range: {idx}.")
                cur = cur[idx]
                continue
            raise KeyError(
                f"subfield '{path}' cannot descend through non-container "
                f"at token '{token}'."
            )
        return cur

    def flatten_conf(cfg, path=""):
        nonlocal ret
        # Recurse through OmegaConf containers and emit scalar leaves.
        if isinstance(cfg, DictConfig):
            items = cfg.items()
        elif isinstance(cfg, ListConfig):
            # For lists, indices become path tokens ("a.0.b").
            items = enumerate(cfg)
        else:
            # Base case: leaf scalar value.
            if path:
                # Keep "null" spelling to match OmegaConf textual conventions.
                value = "null" if cfg is None else cfg
                if include is not None:
                    if path not in include:
                        return
                if exclude is not None:
                    if path in exclude:
                        return
                if use_equal:
                    ret.append(f"{path}={value}")
                else:
                    ret += [path, str(value)]
            return

        for key, value in items:
            # For DictConfig, key is a string. For ListConfig, key is an int index.
            new_path = f"{path}.{key}" if path else str(key)

            # Recursively flatten nested configs
            if isinstance(value, (DictConfig, ListConfig)):
                flatten_conf(value, new_path)
            else:
                # Handle the final value
                flatten_conf(value, new_path)

    if subfield:
        try:
            start_cfg = _select_subfield(conf, subfield)
        except KeyError:
            if missing_subfield_policy == "empty":
                return []
            raise
    else:
        start_cfg = conf
    flatten_conf(start_cfg)
    return ret

btorch.utils.dict_utils

Dictionary manipulation utilities.

Helpers for transforming, flattening, and mapping nested dictionaries commonly used in configuration and data preprocessing pipelines.

Functions

flatten_dict(d, dot=False)

Flatten nested dictionary into single-level dictionary.

Parameters:

Name Type Description Default
d

Nested dictionary to flatten.

required
dot

If True, use dot-notation keys ("a.b"). If False, use tuple keys (("a", "b")).

False

Returns:

Type Description

Flattened dictionary.

Example

flatten_dict({"a": {"b": 1}, "c": 2}) {("a", "b"): 1, ("c",): 2} flatten_dict({"a": {"b": 1}}, dot=True)

Source code in btorch/utils/dict_utils.py
def flatten_dict(d, dot=False):
    """Flatten nested dictionary into single-level dictionary.

    Args:
        d: Nested dictionary to flatten.
        dot: If True, use dot-notation keys ("a.b"). If False,
            use tuple keys (("a", "b")).

    Returns:
        Flattened dictionary.

    Example:
        >>> flatten_dict({"a": {"b": 1}, "c": 2})
        {("a", "b"): 1, ("c",): 2}
        >>> flatten_dict({"a": {"b": 1}}, dot=True)
        {"a.b": 1}
    """

    def _flatten_dict(d, parent_key):
        items = []
        for k, v in d.items():
            new_key = parent_key + "." + k if dot else parent_key + (k,)
            if isinstance(v, dict):
                items.extend(_flatten_dict(v, new_key))
            else:
                items.append((new_key, v))
        return items

    items = _flatten_dict(d, "" if dot else ())
    if dot:
        # remove the leading '.'
        items = [(k.lstrip("."), v) for k, v in items]
    return dict(items)

recurse_dict(d, mapper, include_sequence=False)

Recursively apply function to dictionary leaf values.

Parameters:

Name Type Description Default
d dict

Input dictionary (potentially nested).

required
mapper Callable

Function called with (key, value) for each leaf.

required
include_sequence bool

If True, also recurse into tuples and lists.

False

Returns:

Type Description
dict

New dictionary with transformed leaf values.

Source code in btorch/utils/dict_utils.py
def recurse_dict(d: dict, mapper: Callable, include_sequence: bool = False) -> dict:
    """Recursively apply function to dictionary leaf values.

    Args:
        d: Input dictionary (potentially nested).
        mapper: Function called with (key, value) for each leaf.
        include_sequence: If True, also recurse into tuples and lists.

    Returns:
        New dictionary with transformed leaf values.
    """

    def _f(d, k):
        if isinstance(d, dict):
            return {k: _f(v, k) for k, v in d.items()}
        if include_sequence:
            if isinstance(d, tuple):
                return tuple(_f(ve, None) for ve in d)
            elif isinstance(d, list):
                return list(_f(ve, None) for ve in d)
        return mapper(k, d)

    return _f(d, None)

reverse_map(map)

Reverse a mapping, handling sequence values.

Flattens sequence values so each item maps to the original key. Non-sequence values map directly.

Parameters:

Name Type Description Default
map dict[Any, Any | Sequence[Any]]

Dictionary with scalar or sequence values.

required

Returns:

Type Description
dict[Any, Any]

Reversed mapping where each original value (or sequence item)

dict[Any, Any]

maps to its original key.

Example

reverse_map({"a": [1, 2], "b": 3})

Source code in btorch/utils/dict_utils.py
def reverse_map(map: dict[Any, Any | Sequence[Any]]) -> dict[Any, Any]:
    """Reverse a mapping, handling sequence values.

    Flattens sequence values so each item maps to the original key.
    Non-sequence values map directly.

    Args:
        map: Dictionary with scalar or sequence values.

    Returns:
        Reversed mapping where each original value (or sequence item)
        maps to its original key.

    Example:
        >>> reverse_map({"a": [1, 2], "b": 3})
        {1: "a", 2: "a", 3: "b"}
    """
    ret = {}
    for key, items in map.items():
        if isinstance(items, Sequence) and not isinstance(items, str):
            for item in items:
                ret[item] = key
        else:
            ret[items] = key
    return ret

unflatten_dict(flattened_dict, dot=False)

Unflatten dictionary with compound keys into nested structure.

Parameters:

Name Type Description Default
flattened_dict

Dictionary with tuple or dot-notation keys.

required
dot

If True, split keys on dots. If False, keys are tuples.

False

Returns:

Type Description

Nested dictionary.

Example

unflatten_dict({("a",): 1, ("b", "c"): 2}) {"a": 1, "b": {"c": 2}} unflatten_dict({"a.b": 1}, dot=True) {"a": {"b": 1}}

Source code in btorch/utils/dict_utils.py
def unflatten_dict(flattened_dict, dot=False):
    """Unflatten dictionary with compound keys into nested structure.

    Args:
        flattened_dict: Dictionary with tuple or dot-notation keys.
        dot: If True, split keys on dots. If False, keys are tuples.

    Returns:
        Nested dictionary.

    Example:
        >>> unflatten_dict({("a",): 1, ("b", "c"): 2})
        {"a": 1, "b": {"c": 2}}
        >>> unflatten_dict({"a.b": 1}, dot=True)
        {"a": {"b": 1}}
    """
    result = {}
    for key_tuple, value in flattened_dict.items():
        if dot:
            key_tuple = key_tuple.split(".")
        current_level = result
        for i, key_part in enumerate(key_tuple):
            if i == len(key_tuple) - 1:
                # Assign the value at the last key part
                current_level[key_part] = value
            else:
                # Ensure the key part exists and is a dict, then move down
                if key_part not in current_level:
                    current_level[key_part] = {}
                current_level = current_level[key_part]
    return result

btorch.utils.file

File path utilities.

Helpers for resolving figure output paths based on caller location within the repository structure.

Classes

FigPathConfig dataclass

Configuration for figure output directory structure.

Attributes:

Name Type Description
root_dir str

Root directory for all figures.

benchmark_dir str

Subdirectory for benchmark script outputs.

tests_dir str

Subdirectory for test script outputs.

other_dir str

Subdirectory for other script outputs.

Source code in btorch/utils/file.py
@dataclass(frozen=True)
class FigPathConfig:
    """Configuration for figure output directory structure.

    Attributes:
        root_dir: Root directory for all figures.
        benchmark_dir: Subdirectory for benchmark script outputs.
        tests_dir: Subdirectory for test script outputs.
        other_dir: Subdirectory for other script outputs.
    """

    root_dir: str = "fig"
    benchmark_dir: str = "benchmark"
    tests_dir: str = "tests"
    other_dir: str = "misc"

Functions

_is_relative_to(path, base)

Check if path is within base directory.

Source code in btorch/utils/file.py
def _is_relative_to(path: Path, base: Path) -> bool:
    """Check if path is within base directory."""
    try:
        path.relative_to(base)
        return True
    except ValueError:
        return False

_repo_root()

Return repository root directory.

Source code in btorch/utils/file.py
def _repo_root() -> Path:
    """Return repository root directory."""
    return Path(__file__).resolve().parents[2]

_resolve_cfg(cfg)

Merge user config with defaults.

Source code in btorch/utils/file.py
def _resolve_cfg(cfg: FigPathConfig | dict | conf.DictConfig | None):
    """Merge user config with defaults."""
    defaults = conf.OmegaConf.structured(FigPathConfig)
    if cfg is None:
        return defaults
    if isinstance(cfg, FigPathConfig):
        cfg = conf.OmegaConf.structured(cfg)
    elif isinstance(cfg, dict):
        cfg = conf.OmegaConf.create(cfg)
    return conf.OmegaConf.merge(defaults, cfg)

caller_file(stack_level=2)

Source code in btorch/utils/file.py
def caller_file(stack_level: int = 2) -> str:
    # 0 == this frame, 1 == caller, 2 == caller of caller
    try:
        from IPython.core.getipython import get_ipython

        shell = get_ipython()
        if shell is not None and hasattr(shell, "kernel"):
            # Running in a notebook - try to get the notebook path
            ns = shell.user_ns
            # VS Code sets this
            if "__vsc_ipynb_file__" in ns:
                return ns["__vsc_ipynb_file__"]
            # Some Jupyter setups set __file__
            if "__file__" in ns:
                return ns["__file__"]
            # Last resort: use cwd as a fake path for relative resolution
            import os

            return str(Path(os.getcwd()) / "__notebook__.ipynb")
    except Exception:
        pass
    return sys._getframe(stack_level).f_code.co_filename

fig_path(file=None, cfg=None)

Resolve figure output directory based on caller location.

Places outputs in fig/benchmark/, fig/tests/, or fig/misc/ depending on whether the caller is in the benchmark, tests, or other directory.

Parameters:

Name Type Description Default
file str | Path | None

File path to use for path resolution. If None, uses caller file.

None
cfg FigPathConfig | dict | None

Configuration for directory naming.

None

Returns:

Type Description

Path object for the figure directory (created if needed).

Source code in btorch/utils/file.py
def fig_path(file: str | Path | None = None, cfg: FigPathConfig | dict | None = None):
    """Resolve figure output directory based on caller location.

    Places outputs in ``fig/benchmark/``, ``fig/tests/``, or ``fig/misc/``
    depending on whether the caller is in the benchmark, tests, or other
    directory.

    Args:
        file: File path to use for path resolution. If None, uses caller file.
        cfg: Configuration for directory naming.

    Returns:
        Path object for the figure directory (created if needed).
    """
    file_path = Path(file) if file is not None else caller_file()
    file_path = file_path.resolve()
    root = _repo_root()
    cfg = _resolve_cfg(cfg)

    benchmark_roots = [root / "benchmark", root / "tests" / "benchmark"]
    for bench_root in benchmark_roots:
        if _is_relative_to(file_path, bench_root):
            rel = file_path.relative_to(bench_root)
            path = root / cfg.root_dir / cfg.benchmark_dir / rel.with_suffix("")
            path.mkdir(parents=True, exist_ok=True)
            return path

    tests_root = root / "tests"
    if _is_relative_to(file_path, tests_root):
        rel = file_path.relative_to(tests_root)
        path = root / cfg.root_dir / cfg.tests_dir / rel.with_suffix("")
        path.mkdir(parents=True, exist_ok=True)
        return path

    if _is_relative_to(file_path, root):
        rel = file_path.relative_to(root)
    else:
        rel = Path(file_path.name)
    path = root / cfg.root_dir / cfg.other_dir / rel.with_suffix("")
    path.mkdir(parents=True, exist_ok=True)
    return path

save_fig(fig, name=None, path=None, *, file=None, cfg=None, suffix='pdf', transparent=False)

Save matplotlib figure to appropriate directory.

Parameters:

Name Type Description Default
fig

Matplotlib figure object.

required
name str | None

Output filename (without extension). If None, uses caller stem.

None
path Path | None

Output directory. If None, uses fig_path().

None
file str | Path | None

File path for context resolution. If None, uses caller file.

None
cfg FigPathConfig | dict | None

Configuration for directory naming.

None
suffix str

File extension (default: "pdf").

'pdf'
transparent bool

Save with transparent background.

False

Returns:

Type Description
Path

Path to the saved figure file.

Source code in btorch/utils/file.py
def save_fig(
    fig,
    name: str | None = None,
    path: Path | None = None,
    *,
    file: str | Path | None = None,
    cfg: FigPathConfig | dict | None = None,
    suffix: str = "pdf",
    transparent: bool = False,
) -> Path:
    """Save matplotlib figure to appropriate directory.

    Args:
        fig: Matplotlib figure object.
        name: Output filename (without extension). If None, uses caller stem.
        path: Output directory. If None, uses ``fig_path()``.
        file: File path for context resolution. If None, uses caller file.
        cfg: Configuration for directory naming.
        suffix: File extension (default: "pdf").
        transparent: Save with transparent background.

    Returns:
        Path to the saved figure file.
    """
    file_path = Path(file) if file is not None else caller_file()
    if path is None:
        path = fig_path(file_path, cfg=cfg)
    if name is None:
        name = file_path.stem
    path.mkdir(parents=True, exist_ok=True)
    output_path = path / f"{name}.{suffix}"
    fig.savefig(output_path.as_posix(), transparent=transparent)
    return output_path

btorch.utils.grad_checkpoint

Functions

checkpoint_wrapper(module, checkpoint_fn=None, **checkpoint_fn_kwargs)

Source code in btorch/utils/grad_checkpoint/checkpoint.py
def checkpoint_wrapper(
    module: torch.nn.Module,
    checkpoint_fn=None,
    **checkpoint_fn_kwargs,
) -> torch.nn.Module:
    return CheckpointWrapper(
        module,
        checkpoint_fn,
        **checkpoint_fn_kwargs,
    )

btorch.utils.hdf5_utils

HDF5 serialization utilities.

Helpers for saving and loading nested dictionaries containing arrays to HDF5 files with optional Blosc2 compression for large arrays.

Functions

load_dict_from_hdf5(folder_or_filename, filename=None)

Load nested dictionary from HDF5 file.

Parameters:

Name Type Description Default
folder_or_filename

Directory path if filename is provided, otherwise full file path.

required
filename Optional[str]

Optional filename when folder_or_filename is a directory.

None

Returns:

Type Description

Nested dictionary with restored array values.

Source code in btorch/utils/hdf5_utils.py
def load_dict_from_hdf5(folder_or_filename, filename: Optional[str] = None):
    """Load nested dictionary from HDF5 file.

    Args:
        folder_or_filename: Directory path if ``filename`` is provided,
            otherwise full file path.
        filename: Optional filename when ``folder_or_filename`` is a directory.

    Returns:
        Nested dictionary with restored array values.
    """
    file = (
        folder_or_filename
        if filename is None
        else os.path.join(folder_or_filename, filename)
    )

    def load_group(h5file, path):
        data = {}
        for k, v in h5file[path].items():
            if isinstance(v, h5py.Group):
                data[k] = load_group(h5file, f"{path}/{k}")
            else:
                if v.shape == ():
                    data[k] = v[()]  # Load scalar value
                else:
                    data[k] = v[:]  # Load array value
        return data

    with h5py.File(file, "r") as f:
        return load_group(f, "/")

save_dict_to_hdf5(folder_or_filename, data, compression=hdf5plugin.Blosc2(), filename=None, compression_threshold=1024 * 1024)

Save nested dictionary with array values to HDF5 file.

Recursively traverses data and saves arrays as datasets. Datasets larger than compression_threshold are compressed with the specified compression filter.

Parameters:

Name Type Description Default
folder_or_filename

Directory path if filename is provided, otherwise full file path.

required
data

Nested dictionary with array-like values to serialize.

required
compression

Compression filter (default: Blosc2).

Blosc2()
filename Optional[str]

Optional filename when folder_or_filename is a directory.

None
compression_threshold

Minimum array size in bytes to trigger compression (default: 1 MiB).

1024 * 1024

Returns:

Type Description

None

Source code in btorch/utils/hdf5_utils.py
def save_dict_to_hdf5(
    folder_or_filename,
    data,
    compression=hdf5plugin.Blosc2(),
    filename: Optional[str] = None,
    compression_threshold=1024 * 1024,  # 1MiB
):
    """Save nested dictionary with array values to HDF5 file.

    Recursively traverses ``data`` and saves arrays as datasets.
    Datasets larger than ``compression_threshold`` are compressed
    with the specified compression filter.

    Args:
        folder_or_filename: Directory path if ``filename`` is provided,
            otherwise full file path.
        data: Nested dictionary with array-like values to serialize.
        compression: Compression filter (default: Blosc2).
        filename: Optional filename when ``folder_or_filename`` is a directory.
        compression_threshold: Minimum array size in bytes to trigger
            compression (default: 1 MiB).

    Returns:
        None
    """

    def save_array(h5file, path_k, v):
        if v.nbytes > compression_threshold:
            h5file.create_dataset(path_k, data=v, compression=compression)
        else:
            h5file.create_dataset(path_k, data=v)

    def save_group(h5file, path, data):
        for k, v in data.items():
            if v is None:
                continue
            elif isinstance(v, dict):
                h5file.create_group(f"{path}/{k}")
                save_group(h5file, f"{path}/{k}", v)
            elif hasattr(v, "shape") and hasattr(v, "dtype"):
                save_array(h5file, f"{path}/{k}", v)
            else:
                h5file.create_dataset(f"{path}/{k}", data=v)

    file = (
        folder_or_filename
        if filename is None
        else os.path.join(folder_or_filename, filename)
    )
    with h5py.File(file, "w") as f:
        save_group(f, "", data)

btorch.utils.hex_utils

Hexagonal lattice coordinate utilities.

Functions for working with hexagonal grid coordinates (axial/cube), conversions between hex and pixel coordinates, and geometric operations on hexagonal lattices.

See Also

https://www.redblobgames.com/grids/hexagons/

Classes

HexArray

Bases: ndarray

Flat array holding Hexal's as elements.

Can be constructed with

HexArray(hexals: Iterable, values: Optional[np.nan]) HexArray(u: Iterable, v: Iterable, values: Optional[np.nan])

Source code in btorch/utils/hex_utils.py
class HexArray(np.ndarray):
    """Flat array holding Hexal's as elements.

    Can be constructed with:
        HexArray(hexals: Iterable, values: Optional[np.nan])
        HexArray(u: Iterable, v: Iterable, values: Optional[np.nan])
    """

    def __new__(cls, hexals=None, u=None, v=None, values=0):
        if isinstance(hexals, Iterable):
            u = np.array([h.u for h in hexals])
            v = np.array([h.v for h in hexals])
            values = np.array([h.value for h in hexals])
        if not isinstance(values, Iterable):
            values = np.ones_like(u) * values
        u, v = HexArray.sort(u, v)
        hexals = np.array(
            [Hexal(_u, _v, _val) for _u, _v, _val in zip(u, v, values)],
            dtype=Hexal,
        ).view(cls)
        return hexals

    def __array_finalize__(self, obj):
        if obj is None:
            return

    def __eq__(self, other):
        if isinstance(other, Hexal):
            return other == self
        else:
            return super().__eq__(other)

    def __getitem__(self, key):
        if isinstance(key, HexArray):
            mask = self.where_hexarray(key)
            return self[mask]
        else:
            return super().__getitem__(key)

    def __setitem__(self, key, value):
        if isinstance(key, slice) and key == slice(None):
            self.values = value
        elif isinstance(key, HexArray):
            mask = self.where_hexarray(key)
            super().__setitem__(mask, value)
        else:
            super().__setitem__(key, value)

    def where_hexarray(self, hexarray):
        return matrix_mask_by_sub(
            np.stack((hexarray.u, hexarray.v), axis=0).T,
            np.stack((self.u, self.v), axis=0).T,
        )

    @staticmethod
    def sort(u, v):
        sort_index = np.lexsort((v, u))
        u = u[sort_index]
        v = v[sort_index]
        return u, v

    @staticmethod
    def get_extent(hexals=None, u=None, v=None, center=Hexal(0, 0, 0)):
        """Returns the columnar extent."""
        from numbers import Number

        if isinstance(u, Number) and isinstance(v, Number):
            h = Hexal(u, v, 0)
            return h.distance(center)
        else:
            ha = HexArray(hexals, u, v)
            distance = max([h.distance(center) for h in ha])
            return distance

    @property
    def u(self):
        return np.array([h.u for h in self])

    @property
    def v(self):
        return np.array([h.v for h in self])

    @property
    def values(self):
        return np.array([h.value for h in self])

    @values.setter
    def values(self, values):
        for h, val in zip(self, values):
            h.value = val

    @property
    def extent(self):
        return super().get_extent(self)

    def with_stride(self, u_stride=None, v_stride=None):
        """Returns a sliced instance obeying strides in u- and v-direction."""
        new = []
        for u, v, _ in zip(self.u, self.v, self.values):
            new.append(u % u_stride == 0 and v % v_stride == 0)
        return self[np.array(new)]

    def where(self, value):
        """Returns a mask of where values are equal to the given one.

        Note: value can be np.nan.
        """
        return np.isclose(self.values, value, rtol=0, atol=0, equal_nan=True)

    def fill(self, value):
        """Fills the values with the given one."""
        for h in self:
            h.value = value

    def to_pixel(self, scale=1, mode="default"):
        """Converts to pixel coordinates."""
        return hex_to_pixel(self.u, self.v, scale, mode=mode)

    def plot(self, figsize=[3, 3], fill=True):
        """Plots values in regular hexagonal lattice.

        Meant for debugging.
        """
        u = np.array([h.u for h in self])
        v = np.array([h.v for h in self])
        color = np.array([h.value for h in self])
        return flyvis.plots.hex_scatter(
            u,
            v,
            color,
            fill=fill,
            cmap=cm.get_cmap("binary"),
            edgecolor="black",
            figsize=figsize,
        )
Functions
fill(value)

Fills the values with the given one.

Source code in btorch/utils/hex_utils.py
def fill(self, value):
    """Fills the values with the given one."""
    for h in self:
        h.value = value
get_extent(hexals=None, u=None, v=None, center=Hexal(0, 0, 0)) staticmethod

Returns the columnar extent.

Source code in btorch/utils/hex_utils.py
@staticmethod
def get_extent(hexals=None, u=None, v=None, center=Hexal(0, 0, 0)):
    """Returns the columnar extent."""
    from numbers import Number

    if isinstance(u, Number) and isinstance(v, Number):
        h = Hexal(u, v, 0)
        return h.distance(center)
    else:
        ha = HexArray(hexals, u, v)
        distance = max([h.distance(center) for h in ha])
        return distance
plot(figsize=[3, 3], fill=True)

Plots values in regular hexagonal lattice.

Meant for debugging.

Source code in btorch/utils/hex_utils.py
def plot(self, figsize=[3, 3], fill=True):
    """Plots values in regular hexagonal lattice.

    Meant for debugging.
    """
    u = np.array([h.u for h in self])
    v = np.array([h.v for h in self])
    color = np.array([h.value for h in self])
    return flyvis.plots.hex_scatter(
        u,
        v,
        color,
        fill=fill,
        cmap=cm.get_cmap("binary"),
        edgecolor="black",
        figsize=figsize,
    )
to_pixel(scale=1, mode='default')

Converts to pixel coordinates.

Source code in btorch/utils/hex_utils.py
def to_pixel(self, scale=1, mode="default"):
    """Converts to pixel coordinates."""
    return hex_to_pixel(self.u, self.v, scale, mode=mode)
where(value)

Returns a mask of where values are equal to the given one.

Note: value can be np.nan.

Source code in btorch/utils/hex_utils.py
def where(self, value):
    """Returns a mask of where values are equal to the given one.

    Note: value can be np.nan.
    """
    return np.isclose(self.values, value, rtol=0, atol=0, equal_nan=True)
with_stride(u_stride=None, v_stride=None)

Returns a sliced instance obeying strides in u- and v-direction.

Source code in btorch/utils/hex_utils.py
def with_stride(self, u_stride=None, v_stride=None):
    """Returns a sliced instance obeying strides in u- and v-direction."""
    new = []
    for u, v, _ in zip(self.u, self.v, self.values):
        new.append(u % u_stride == 0 and v % v_stride == 0)
    return self[np.array(new)]

HexLattice

Bases: HexArray

Flat array of Hexals.

Parameters:

Name Type Description Default
extent

Extent of the regular hexagon grid.

required
hexals

Existing hexals to initialize with.

required
center

Center hexal of the lattice.

required
u_stride

Stride in u-direction.

required
v_stride

Stride in v-direction.

required
Source code in btorch/utils/hex_utils.py
class HexLattice(HexArray):
    """Flat array of Hexals.

    Args:
        extent: Extent of the regular hexagon grid.
        hexals: Existing hexals to initialize with.
        center: Center hexal of the lattice.
        u_stride: Stride in u-direction.
        v_stride: Stride in v-direction.
    """

    def __new__(
        cls,
        extent=15,
        hexals=None,
        center=Hexal(0, 0, 0),
        u_stride=1,
        v_stride=1,
    ):
        if isinstance(hexals, Iterable):
            hexals = HexArray(hexals=hexals)
            u = np.array([h.u for h in hexals])
            v = np.array([h.v for h in hexals])
            extent = extent or super().get_extent(hexals, center=center)
            lattice = HexLattice(
                extent=extent,
                center=center,
                u_stride=u_stride,
                v_stride=v_stride,
            )
            for h in lattice:
                if h in hexals:
                    h.value = hexals[h == hexals][0].value
        else:
            u, v = get_hex_coords(extent)
            u += center.u
            v += center.v
            values = [np.nan for _ in range(len(u))]  # np.ones_like(u) * np.nan
            lattice = []
            for _u, _v, _val in zip(u, v, values):
                if _u % u_stride == 0 and _v % v_stride == 0:
                    lattice.append(Hexal(_u, _v, _val, u_stride, v_stride))
            lattice = np.array(lattice, dtype=Hexal).view(cls)
        return lattice

    @property
    def center(self):
        return self[len(self) // 2]

    @property
    def extent(self):
        return super().get_extent(self, center=self.center)

    # ----- Geometry

    def circle(self, radius=None, center=Hexal(0, 0, 0), as_lattice=False):
        """Draws a circle in hex coordinates.

        Args:
            radius: Radius in columns of the circle.
            center: Center of the circle.
            as_lattice: Returns the circle on a constrained regular lattice.
        """
        lattice = HexLattice(extent=max(radius or 0, self.extent), center=center)
        radius = radius or self.extent
        circle = []
        for _, h in enumerate(lattice):
            distance = center.distance(h)
            if distance == radius:
                h.value = 1
                circle.append(h)
        if as_lattice:
            return HexLattice(hexals=circle)
        return HexArray(hexals=circle)

    @staticmethod
    def filled_circle(radius=None, center=Hexal(0, 0, 0), as_lattice=False):
        """Draws a circle in hex coordinates.

        Args:
            radius: Radius in columns of the circle.
            center: Center of the circle.
            as_lattice: Returns the circle on a constrained regular lattice.
        """
        lattice = HexLattice(extent=radius or 0, center=center)
        radius = radius
        circle = []
        for _, h in enumerate(lattice):
            distance = center.distance(h)
            if distance <= radius:
                h.value = 1
                circle.append(h)
        if as_lattice:
            return HexLattice(hexals=circle)
        return HexArray(hexals=circle)

    def hull(self):
        """Returns the hull of the regular lattice."""
        return self.circle(radius=self.extent, center=self.center)

    def _line_span(self, angle):
        """Returns two points spanning a line with given angle wrt. origin.

        Args:
            angle: In [0, np.pi]

        Returns:
            HexArray
        """
        # To offset the line by simple addition of the offset,
        # radius=2 * self.extent spans the line in ways that each valid offset
        # can be added.
        distant_hull = self.ring(radius=2 * self.extent)
        angles = np.array([h.angle(signed=True) for h in distant_hull])
        distance = (angles - angle) % np.pi
        index = np.argsort(distance)
        span = distant_hull[index[0:2]]
        for h in span:
            h.value = 1
        return HexArray(hexals=span)

    def line(self, angle, center=Hexal(0, 0, 1), as_lattice=False):
        """Returns a line on a HexLattice or HexArray.

        Args:
            angle: In [0, np.pi]
            center: Midpoint of the line
            as_lattice: Returns the ring on a constrained regular lattice.

        Returns:
            HexArray or constrained HexLattice
        """
        line_span = self._line_span(angle)
        distance = line_span[0].distance(line_span[1])
        line = []
        for i in range(distance + 1):
            _next = line_span[0].interp(line_span[1], 1 / distance * i)
            line.append(_next)
        for h in line:
            h.value = 1
        if as_lattice:
            return HexLattice(extent=self.extent, hexals=center + line)
        return HexArray(hexals=center + line)

    def _get_neighbour_indices(self, index):
        _neighbours = self[index].neighbours()
        neighbours = ()
        for n in _neighbours:
            valid = self == n
            if valid.any():
                neighbours += (np.where(valid)[0][0],)
        return neighbours

    def valid_neighbours(self):
        neighbours = ()
        for i in range(len(self)):
            neighbours += (self._get_neighbour_indices(i),)
        return neighbours
Functions
circle(radius=None, center=Hexal(0, 0, 0), as_lattice=False)

Draws a circle in hex coordinates.

Parameters:

Name Type Description Default
radius

Radius in columns of the circle.

None
center

Center of the circle.

Hexal(0, 0, 0)
as_lattice

Returns the circle on a constrained regular lattice.

False
Source code in btorch/utils/hex_utils.py
def circle(self, radius=None, center=Hexal(0, 0, 0), as_lattice=False):
    """Draws a circle in hex coordinates.

    Args:
        radius: Radius in columns of the circle.
        center: Center of the circle.
        as_lattice: Returns the circle on a constrained regular lattice.
    """
    lattice = HexLattice(extent=max(radius or 0, self.extent), center=center)
    radius = radius or self.extent
    circle = []
    for _, h in enumerate(lattice):
        distance = center.distance(h)
        if distance == radius:
            h.value = 1
            circle.append(h)
    if as_lattice:
        return HexLattice(hexals=circle)
    return HexArray(hexals=circle)
filled_circle(radius=None, center=Hexal(0, 0, 0), as_lattice=False) staticmethod

Draws a circle in hex coordinates.

Parameters:

Name Type Description Default
radius

Radius in columns of the circle.

None
center

Center of the circle.

Hexal(0, 0, 0)
as_lattice

Returns the circle on a constrained regular lattice.

False
Source code in btorch/utils/hex_utils.py
@staticmethod
def filled_circle(radius=None, center=Hexal(0, 0, 0), as_lattice=False):
    """Draws a circle in hex coordinates.

    Args:
        radius: Radius in columns of the circle.
        center: Center of the circle.
        as_lattice: Returns the circle on a constrained regular lattice.
    """
    lattice = HexLattice(extent=radius or 0, center=center)
    radius = radius
    circle = []
    for _, h in enumerate(lattice):
        distance = center.distance(h)
        if distance <= radius:
            h.value = 1
            circle.append(h)
    if as_lattice:
        return HexLattice(hexals=circle)
    return HexArray(hexals=circle)
hull()

Returns the hull of the regular lattice.

Source code in btorch/utils/hex_utils.py
def hull(self):
    """Returns the hull of the regular lattice."""
    return self.circle(radius=self.extent, center=self.center)
line(angle, center=Hexal(0, 0, 1), as_lattice=False)

Returns a line on a HexLattice or HexArray.

Parameters:

Name Type Description Default
angle

In [0, np.pi]

required
center

Midpoint of the line

Hexal(0, 0, 1)
as_lattice

Returns the ring on a constrained regular lattice.

False

Returns:

Type Description

HexArray or constrained HexLattice

Source code in btorch/utils/hex_utils.py
def line(self, angle, center=Hexal(0, 0, 1), as_lattice=False):
    """Returns a line on a HexLattice or HexArray.

    Args:
        angle: In [0, np.pi]
        center: Midpoint of the line
        as_lattice: Returns the ring on a constrained regular lattice.

    Returns:
        HexArray or constrained HexLattice
    """
    line_span = self._line_span(angle)
    distance = line_span[0].distance(line_span[1])
    line = []
    for i in range(distance + 1):
        _next = line_span[0].interp(line_span[1], 1 / distance * i)
        line.append(_next)
    for h in line:
        h.value = 1
    if as_lattice:
        return HexLattice(extent=self.extent, hexals=center + line)
    return HexArray(hexals=center + line)

Hexal

Hexal representation containing u, v, z coordinates and value.

Attributes:

Name Type Description
u

Coordinate in u principal direction (0 degree axis).

v

Coordinate in v principal direction (60 degree axis).

z

Coordinate in z principal direction (-60 degree axis).

value

'Hexal' value.

u_stride

Stride in u-direction.

v_stride

Stride in v-direction.

Source code in btorch/utils/hex_utils.py
class Hexal:
    """Hexal representation containing u, v, z coordinates and value.

    Attributes:
        u: Coordinate in u principal direction (0 degree axis).
        v: Coordinate in v principal direction (60 degree axis).
        z: Coordinate in z principal direction (-60 degree axis).
        value: 'Hexal' value.
        u_stride: Stride in u-direction.
        v_stride: Stride in v-direction.
    """

    def __init__(
        self,
        u: int,
        v: int,
        value: float = np.nan,
        u_stride: int = 1,
        v_stride: int = 1,
    ):
        self.u = u
        self.v = v
        self.z = -(u + v)
        self.value = value
        self.u_stride = u_stride
        self.v_stride = v_stride

    def __repr__(self):
        return "Hexal(u={}, v={}, value={}, u_stride={}, v_stride={})".format(
            self.u, self.v, self.value, self.u_stride, self.v_stride
        )

    def __eq__(self, other):
        """Compares coordinates (not values)."""
        if isinstance(other, Hexal):
            return all((self.u == other.u, self.v == other.v))
        elif isinstance(other, Iterable):
            return np.array([self == h for h in other])

    def __add__(self, other):
        """Adds u and v coordinates, while keeping the value of the left
        hexal."""
        if isinstance(other, Hexal):
            return Hexal(self.u + other.u, self.v + other.v, self.value)
        elif isinstance(other, Iterable):
            return np.array([self + h for h in other])

    def __mul__(self, other):
        """Multiplies values, while preserving coordinates."""
        if isinstance(other, Hexal):
            return Hexal(self.u, self.v, self.value * other.value)
        elif isinstance(other, Iterable):
            return np.array([self * h for h in other])
        else:
            return Hexal(self.u, self.v, self.value * other)

    def eq_val(self, other):
        """Compares the values, not the coordinates."""
        if isinstance(other, Hexal):
            return self.value == other.value
        elif isinstance(other, Iterable):
            return np.array([self.eq_val(h) for h in other])

    # ----- Neighbour identification

    @property
    def east(self):
        return Hexal(self.u + self.u_stride, self.v, 0)

    @property
    def north_east(self):
        return Hexal(self.u, self.v + self.v_stride, 0)

    @property
    def north_west(self):
        return Hexal(self.u - self.u_stride, self.v + self.v_stride, 0)

    @property
    def west(self):
        return Hexal(self.u - self.u_stride, self.v, 0)

    @property
    def south_west(self):
        return Hexal(self.u, self.v - self.v_stride, 0)

    @property
    def south_east(self):
        return Hexal(self.u + self.u_stride, self.v - self.v_stride, 0)

    def neighbours(self):
        """Returns 6 neighbours sorted CCW, starting from east."""
        return (
            self.east,
            self.north_east,
            self.north_west,
            self.west,
            self.south_west,
            self.south_east,
        )

    def is_neighbour(self, other):
        """Evaluates if other is a neighbour."""
        neighbours = self.neighbours()
        if isinstance(other, Hexal):
            return other in neighbours
        elif isinstance(other, Iterable):
            return np.array([self.neighbour(h) for h in other])

    @staticmethod
    def unit_directions():
        """Returns the six unit directions."""
        return HexArray(Hexal(0, 0, 0).neighbours())

    def neighbour(self, angle):
        neighbours = np.array(self.neighbours())
        angles = np.array([h.angle(signed=True) for h in neighbours])
        distance = (angles - angle) % np.pi
        index = np.argsort(distance)
        return HexArray(neighbours[index[:2]])

    def direction(self, angle):
        neighbours = HexArray(self.neighbour(angle))
        angles = np.array([h.angle(signed=True) for h in neighbours])
        distance = (angles - angle) % np.pi
        index = np.argsort(distance)
        return HexArray(self.unit_directions()[index[:2]])

    # ----- Geometric methods

    def interp(self, other, t):
        """Interpolates towards other.

        Args:
            other (Hexal)
            t (float): interpolation step, 0<t<1.

        Returns:
            Hexal
        """

        def hex_round(u, v):
            z = -(u + v)
            ru = round(u)
            rv = round(v)
            rz = round(z)
            u_diff = abs(ru - u)
            v_diff = abs(rv - v)
            z_diff = abs(rz - z)
            if u_diff > v_diff and u_diff > z_diff:
                ru = -rv - rz
            elif v_diff > z_diff:
                rv = -ru - rz
            return ru, rv

        uprime, vprime = (
            self.u + (other.u - self.u) * t,
            self.v + (other.v - self.v) * t,
        )
        uprime, vprime = hex_round(uprime, vprime)
        return Hexal(uprime, vprime, 0)

    def angle(self, other=None, non_negative=False):
        """Returns the angle to other or the origin.

        Args:
            other (Hexal)
            non_negative (bool): add 2pi if angle is negative.
                Default: False.

        Returns:
            float: angle in radians.
        """

        def _angle(p1, p2):
            """Counter clockwise angle from p1 to p2.

            Returns:
                float: angle in [0, np.pi]
            """
            dot = p1[0] * p2[0] + p1[1] * p2[1]
            det = p1[0] * p2[1] - p1[1] * p2[0]
            angle = np.arctan2(det, dot)
            return angle

        x, y = self._to_pixel(self.u, self.v)
        theta = np.arctan2(y, x)
        if other is not None:
            xother, yother = self._to_pixel(other.u, other.v)
            theta = _angle([x, y], [xother, yother])
        if non_negative:
            theta += 2 * np.pi if theta < 0 else 0
        return theta

    def distance(self, other=None):
        """Returns the columnar distance between to hexals."""
        if other is not None:
            return int(
                (
                    abs(self.u - other.u)
                    + abs(self.u + self.v - other.u - other.v)
                    + abs(self.v - other.v)
                )
                / 2
            )
        return int((abs(self.u) + abs(self.u + self.v) + abs(self.v)) / 2)

    @staticmethod
    def _to_pixel(u, v, scale=1):
        """Converts to pixel coordinates."""
        return hex_to_pixel(u, v, scale)
Functions
__add__(other)

Adds u and v coordinates, while keeping the value of the left hexal.

Source code in btorch/utils/hex_utils.py
def __add__(self, other):
    """Adds u and v coordinates, while keeping the value of the left
    hexal."""
    if isinstance(other, Hexal):
        return Hexal(self.u + other.u, self.v + other.v, self.value)
    elif isinstance(other, Iterable):
        return np.array([self + h for h in other])
__eq__(other)

Compares coordinates (not values).

Source code in btorch/utils/hex_utils.py
def __eq__(self, other):
    """Compares coordinates (not values)."""
    if isinstance(other, Hexal):
        return all((self.u == other.u, self.v == other.v))
    elif isinstance(other, Iterable):
        return np.array([self == h for h in other])
__mul__(other)

Multiplies values, while preserving coordinates.

Source code in btorch/utils/hex_utils.py
def __mul__(self, other):
    """Multiplies values, while preserving coordinates."""
    if isinstance(other, Hexal):
        return Hexal(self.u, self.v, self.value * other.value)
    elif isinstance(other, Iterable):
        return np.array([self * h for h in other])
    else:
        return Hexal(self.u, self.v, self.value * other)
angle(other=None, non_negative=False)

Returns the angle to other or the origin.

Parameters:

Name Type Description Default
non_negative bool

add 2pi if angle is negative. Default: False.

False

Returns:

Name Type Description
float

angle in radians.

Source code in btorch/utils/hex_utils.py
def angle(self, other=None, non_negative=False):
    """Returns the angle to other or the origin.

    Args:
        other (Hexal)
        non_negative (bool): add 2pi if angle is negative.
            Default: False.

    Returns:
        float: angle in radians.
    """

    def _angle(p1, p2):
        """Counter clockwise angle from p1 to p2.

        Returns:
            float: angle in [0, np.pi]
        """
        dot = p1[0] * p2[0] + p1[1] * p2[1]
        det = p1[0] * p2[1] - p1[1] * p2[0]
        angle = np.arctan2(det, dot)
        return angle

    x, y = self._to_pixel(self.u, self.v)
    theta = np.arctan2(y, x)
    if other is not None:
        xother, yother = self._to_pixel(other.u, other.v)
        theta = _angle([x, y], [xother, yother])
    if non_negative:
        theta += 2 * np.pi if theta < 0 else 0
    return theta
distance(other=None)

Returns the columnar distance between to hexals.

Source code in btorch/utils/hex_utils.py
def distance(self, other=None):
    """Returns the columnar distance between to hexals."""
    if other is not None:
        return int(
            (
                abs(self.u - other.u)
                + abs(self.u + self.v - other.u - other.v)
                + abs(self.v - other.v)
            )
            / 2
        )
    return int((abs(self.u) + abs(self.u + self.v) + abs(self.v)) / 2)
eq_val(other)

Compares the values, not the coordinates.

Source code in btorch/utils/hex_utils.py
def eq_val(self, other):
    """Compares the values, not the coordinates."""
    if isinstance(other, Hexal):
        return self.value == other.value
    elif isinstance(other, Iterable):
        return np.array([self.eq_val(h) for h in other])
interp(other, t)

Interpolates towards other.

Parameters:

Name Type Description Default
t float

interpolation step, 0<t<1.

required

Returns:

Type Description

Hexal

Source code in btorch/utils/hex_utils.py
def interp(self, other, t):
    """Interpolates towards other.

    Args:
        other (Hexal)
        t (float): interpolation step, 0<t<1.

    Returns:
        Hexal
    """

    def hex_round(u, v):
        z = -(u + v)
        ru = round(u)
        rv = round(v)
        rz = round(z)
        u_diff = abs(ru - u)
        v_diff = abs(rv - v)
        z_diff = abs(rz - z)
        if u_diff > v_diff and u_diff > z_diff:
            ru = -rv - rz
        elif v_diff > z_diff:
            rv = -ru - rz
        return ru, rv

    uprime, vprime = (
        self.u + (other.u - self.u) * t,
        self.v + (other.v - self.v) * t,
    )
    uprime, vprime = hex_round(uprime, vprime)
    return Hexal(uprime, vprime, 0)
is_neighbour(other)

Evaluates if other is a neighbour.

Source code in btorch/utils/hex_utils.py
def is_neighbour(self, other):
    """Evaluates if other is a neighbour."""
    neighbours = self.neighbours()
    if isinstance(other, Hexal):
        return other in neighbours
    elif isinstance(other, Iterable):
        return np.array([self.neighbour(h) for h in other])
neighbours()

Returns 6 neighbours sorted CCW, starting from east.

Source code in btorch/utils/hex_utils.py
def neighbours(self):
    """Returns 6 neighbours sorted CCW, starting from east."""
    return (
        self.east,
        self.north_east,
        self.north_west,
        self.west,
        self.south_west,
        self.south_east,
    )
unit_directions() staticmethod

Returns the six unit directions.

Source code in btorch/utils/hex_utils.py
@staticmethod
def unit_directions():
    """Returns the six unit directions."""
    return HexArray(Hexal(0, 0, 0).neighbours())

LatticeMask

Boolean masks for lattice dimension.

Parameters:

Name Type Description Default
extent int

Extent of the hexagonal lattice.

15
u_stride int

Stride in u-direction.

1
v_stride int

Stride in v-direction.

1
Source code in btorch/utils/hex_utils.py
class LatticeMask:
    """Boolean masks for lattice dimension.

    Args:
        extent: Extent of the hexagonal lattice.
        u_stride: Stride in u-direction.
        v_stride: Stride in v-direction.
    """

    def __init__(self, extent: int = 15, u_stride: int = 1, v_stride: int = 1):
        self._lattice = HexLattice(extent=extent, u_stride=u_stride, v_stride=v_stride)

    @property
    def center(self):
        return self._lattice.center == self._lattice

    @property
    def center_east(self):
        return self._lattice.center.east == self._lattice

    @property
    def center_north_east(self):
        return self._lattice.center.north_east == self._lattice

    @property
    def center_north_west(self):
        return self._lattice.center.north_west == self._lattice

    @property
    def center_west(self):
        return self._lattice.center.west == self._lattice

    @property
    def center_south_west(self):
        return self._lattice.center.south_west == self._lattice

    @property
    def center_south_east(self):
        return self._lattice.center.south_east == self._lattice

Functions

crop_to_extent(u, v, color, max_extent)

Crop hexagonal grid data to a specified maximum extent.

Parameters:

Name Type Description Default
u ndarray

Array of hex coordinates in u direction.

required
v ndarray

Array of hex coordinates in v direction.

required
color ndarray

Array of values associated with each (u, v) coordinate.

required
max_extent int

Maximum extent to crop the hexagonal grid to.

required

Returns:

Type Description
Tuple[ndarray, ndarray, ndarray]

Tuple of cropped u, v, and color arrays.

Source code in btorch/utils/hex_utils.py
def crop_to_extent(
    u: np.ndarray, v: np.ndarray, color: np.ndarray, max_extent: int
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Crop hexagonal grid data to a specified maximum extent.

    Args:
        u: Array of hex coordinates in u direction.
        v: Array of hex coordinates in v direction.
        color: Array of values associated with each (u, v) coordinate.
        max_extent: Maximum extent to crop the hexagonal grid to.

    Returns:
        Tuple of cropped u, v, and color arrays.
    """
    extent_condition = (
        (-max_extent <= u)
        & (u <= max_extent)
        & (-max_extent <= v)
        & (v <= max_extent)
        & (-max_extent <= u + v)
        & (u + v <= max_extent)
    )
    return u[extent_condition], v[extent_condition], color[extent_condition]

get_extent(u, v, astype=int)

Returns extent (integer distance to origin) of arbitrary u, v coordinates.

Parameters:

Name Type Description Default
u ndarray

U-coordinate of hexal.

required
v ndarray

V-coordinate of hexal.

required
astype type

Type to cast to.

int

Returns:

Type Description
int

Extent of hex-lattice.

Note

If u and v are arrays, returns the maximum extent.

See Also

https://www.redblobgames.com/grids/hexagons/#distances

Source code in btorch/utils/hex_utils.py
def get_extent(u: np.ndarray, v: np.ndarray, astype: type = int) -> int:
    """Returns extent (integer distance to origin) of arbitrary u, v
    coordinates.

    Args:
        u: U-coordinate of hexal.
        v: V-coordinate of hexal.
        astype: Type to cast to.

    Returns:
        Extent of hex-lattice.

    Note:
        If u and v are arrays, returns the maximum extent.

    See Also:
        https://www.redblobgames.com/grids/hexagons/#distances
    """
    if isinstance(u, Number) and isinstance(v, Number):
        u, v = np.array((u,)), np.array((v,))
    uv = np.stack((u, v), 1)
    extent = (
        abs(0 - uv[:, 0]) + abs(0 + 0 - uv[:, 0] - uv[:, 1]) + abs(0 - uv[:, 1])
    ) / 2
    return np.max(extent).astype(astype)

get_hex_coords(extent, astensor=False)

Construct hexagonal coordinates for a regular hex-lattice with extent.

Parameters:

Name Type Description Default
extent int

Integer radius of hexagonal lattice. 0 returns the single center coordinate.

required
astensor bool

If True, returns torch.Tensor, else np.array.

False

Returns:

Type Description
Tuple[ndarray, ndarray]

A tuple containing: u: Hex-coordinates in u-direction. v: Hex-coordinates in v-direction.

Note

Will return get_num_hexals(extent) coordinates.

See Also

https://www.redblobgames.com/grids/hexagons/#range-coordinate

Source code in btorch/utils/hex_utils.py
def get_hex_coords(
    extent: int, astensor: bool = False
) -> Tuple[np.ndarray, np.ndarray]:
    """Construct hexagonal coordinates for a regular hex-lattice with extent.

    Args:
        extent: Integer radius of hexagonal lattice. 0 returns the single
            center coordinate.
        astensor: If True, returns torch.Tensor, else np.array.

    Returns:
        A tuple containing:
            u: Hex-coordinates in u-direction.
            v: Hex-coordinates in v-direction.

    Note:
        Will return `get_num_hexals(extent)` coordinates.

    See Also:
        https://www.redblobgames.com/grids/hexagons/#range-coordinate
    """
    u = []
    v = []
    for q in range(-extent, extent + 1):
        for r in range(max(-extent, -extent - q), min(extent, extent - q) + 1):
            u.append(q)
            v.append(r)
    if astensor:
        return torch.tensor(u, dtype=torch.long), torch.tensor(v, dtype=torch.long)
    return np.array(u), np.array(v)

get_hextent(num_hexals)

Computes the hex-lattice extent from the number of hexals.

Parameters:

Name Type Description Default
num_hexals int

Number of hexals.

required

Returns:

Type Description
int

Extent of hex-lattice.

Note

Inverse of get_num_hexals.

Source code in btorch/utils/hex_utils.py
def get_hextent(num_hexals: int) -> int:
    """Computes the hex-lattice extent from the number of hexals.

    Args:
        num_hexals: Number of hexals.

    Returns:
        Extent of hex-lattice.

    Note:
        Inverse of get_num_hexals.
    """

    return np.floor(np.sqrt(num_hexals / 3)).astype("int")

get_num_hexals(extent)

Returns the absolute number of hexals in a hexagonal grid with extent.

Parameters:

Name Type Description Default
extent int

Extent of hex-lattice.

required

Returns:

Type Description
int

Number of hexals.

Note

Inverse of get_hextent.

Source code in btorch/utils/hex_utils.py
def get_num_hexals(extent: int) -> int:
    """Returns the absolute number of hexals in a hexagonal grid with extent.

    Args:
        extent: Extent of hex-lattice.

    Returns:
        Number of hexals.

    Note:
        Inverse of get_hextent.
    """
    return 1 + 3 * extent * (extent + 1)

hex_rows(n_rows, n_columns, eps=0.1, mode='pointy')

Return a hex grid in pixel coordinates.

Parameters:

Name Type Description Default
n_rows int

Number of rows.

required
n_columns int

Number of columns.

required
eps float

Small offset to avoid overlapping hexagons.

0.1
mode Literal['pointy', 'flat']

Orientation of hexagons.

'pointy'

Returns:

Type Description
Tuple[ndarray, ndarray]

A tuple containing: x: X-coordinates of hexagon centers. y: Y-coordinates of hexagon centers.

Source code in btorch/utils/hex_utils.py
def hex_rows(
    n_rows: int,
    n_columns: int,
    eps: float = 0.1,
    mode: Literal["pointy", "flat"] = "pointy",
) -> Tuple[np.ndarray, np.ndarray]:
    """Return a hex grid in pixel coordinates.

    Args:
        n_rows: Number of rows.
        n_columns: Number of columns.
        eps: Small offset to avoid overlapping hexagons.
        mode: Orientation of hexagons.

    Returns:
        A tuple containing:
            x: X-coordinates of hexagon centers.
            y: Y-coordinates of hexagon centers.
    """
    u = []
    v = []
    for r in range(n_rows):
        for c in range(n_columns):
            u.append(c)
            v.append(r)
    u = np.array(u)
    v = np.array(v)
    x, y = hex_to_pixel(u, v, mode=mode)
    x += eps
    y += eps
    return x, y

hex_to_pixel(u, v, size=1, mode='default')

Returns pixel coordinates from hex coordinates.

Parameters:

Name Type Description Default
u ndarray

Hex-coordinates in u-direction.

required
v ndarray

Hex-coordinates in v-direction.

required
size float

Size of hexagon.

1
mode Literal['default', 'flat', 'pointy']

Coordinate system convention.

'default'

Returns:

Type Description
Tuple[ndarray, ndarray]

A tuple containing: x: Pixel-coordinates in x-direction. y: Pixel-coordinates in y-direction.

See Also

https://www.redblobgames.com/grids/hexagons/#hex-to-pixel

Source code in btorch/utils/hex_utils.py
def hex_to_pixel(
    u: np.ndarray,
    v: np.ndarray,
    size: float = 1,
    mode: Literal["default", "flat", "pointy"] = "default",
) -> Tuple[np.ndarray, np.ndarray]:
    """Returns pixel coordinates from hex coordinates.

    Args:
        u: Hex-coordinates in u-direction.
        v: Hex-coordinates in v-direction.
        size: Size of hexagon.
        mode: Coordinate system convention.

    Returns:
        A tuple containing:
            x: Pixel-coordinates in x-direction.
            y: Pixel-coordinates in y-direction.

    See Also:
        https://www.redblobgames.com/grids/hexagons/#hex-to-pixel
    """
    if isinstance(u, list) and isinstance(v, list):
        u = np.array(u)
        v = np.array(v)
    if mode == "default":
        return 3 / 2.0 * v, -np.sqrt(3) * (u + v / 2)
    elif mode == "flat":
        return (3 / 2.0 * u) * size, (np.sqrt(3) / 2 * u + np.sqrt(3) * v) * size
    elif mode == "pointy":
        return (np.sqrt(3) * u + np.sqrt(3) / 2 * v) * size, (3 / 2.0 * v) * size
    else:
        raise ValueError(f"{mode} not recognized.")

max_extent_index(u, v, max_extent)

Returns a mask to constrain u and v axial-hex-coordinates by max_extent.

Parameters:

Name Type Description Default
u ndarray

Hex-coordinates in u-direction.

required
v ndarray

Hex-coordinates in v-direction.

required
max_extent int

Maximal extent.

required

Returns:

Type Description
ndarray

Boolean mask.

Source code in btorch/utils/hex_utils.py
def max_extent_index(u: np.ndarray, v: np.ndarray, max_extent: int) -> np.ndarray:
    """Returns a mask to constrain u and v axial-hex-coordinates by max_extent.

    Args:
        u: Hex-coordinates in u-direction.
        v: Hex-coordinates in v-direction.
        max_extent: Maximal extent.

    Returns:
        Boolean mask.
    """
    return (
        (-max_extent <= u)
        & (u <= max_extent)
        & (-max_extent <= v)
        & (v <= max_extent)
        & (-max_extent <= u + v)
        & (u + v <= max_extent)
    )

pad_to_regular_hex(u, v, values, extent, value=np.nan)

Pad hexals with coordinates to a regular hex lattice.

Parameters:

Name Type Description Default
u ndarray

U-coordinate of hexal.

required
v ndarray

V-coordinate of hexal.

required
values ndarray

Value of hexal with arbitrary shape but last axis must match the hexal dimension.

required
extent int

Extent of regular hex grid to pad to.

required
value float

The pad value.

nan

Returns:

Type Description
Tuple[ndarray, ndarray, ndarray]

A tuple containing: u_padded: Padded u-coordinate. v_padded: Padded v-coordinate. values_padded: Padded value.

Note

The canonical use case here is to pad a filter, receptive field, or postsynaptic current field for visualization.

Example
u = np.array([1, 0, -1, 0, 1, 2])
v = np.array([-2, -1, 0, 0, 0, 0])
values = np.array([0.05, 0.1, 0.3, 0.5, 0.7, 0.9])
hexals = pad_to_regular_hex(u, v, values, 6)
hex_scatter(*hexals, edgecolor='k', cmap=plt.cm.Blues, vmin=0, vmax=1)
Source code in btorch/utils/hex_utils.py
def pad_to_regular_hex(
    u: np.ndarray,
    v: np.ndarray,
    values: np.ndarray,
    extent: int,
    value: float = np.nan,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Pad hexals with coordinates to a regular hex lattice.

    Args:
        u: U-coordinate of hexal.
        v: V-coordinate of hexal.
        values: Value of hexal with arbitrary shape but last axis
            must match the hexal dimension.
        extent: Extent of regular hex grid to pad to.
        value: The pad value.

    Returns:
        A tuple containing:
            u_padded: Padded u-coordinate.
            v_padded: Padded v-coordinate.
            values_padded: Padded value.

    Note:
        The canonical use case here is to pad a filter, receptive field, or
        postsynaptic current field for visualization.

    Example:
        ```python
        u = np.array([1, 0, -1, 0, 1, 2])
        v = np.array([-2, -1, 0, 0, 0, 0])
        values = np.array([0.05, 0.1, 0.3, 0.5, 0.7, 0.9])
        hexals = pad_to_regular_hex(u, v, values, 6)
        hex_scatter(*hexals, edgecolor='k', cmap=plt.cm.Blues, vmin=0, vmax=1)
        ```
    """
    u_padded, v_padded = flyvis.utils.hex_utils.get_hex_coords(extent)
    slices = tuple()
    if len(values.shape) > 1:
        values_padded = np.ones([*values.shape[:-1], len(u_padded)]) * value
        for _ in range(len(values.shape[:-1])):
            slices += (slice(None),)
    else:
        values_padded = np.ones([len(u_padded)]) * value
    index = flyvis.utils.tensor_utils.where_equal_rows(
        np.stack((u, v), axis=1), np.stack((u_padded, v_padded), axis=1)
    )
    slices += (index,)
    values_padded[slices] = values
    return u_padded, v_padded, values_padded

pixel_to_hex(x, y, size=1, mode='default')

Returns hex coordinates from pixel coordinates.

Parameters:

Name Type Description Default
x ndarray

Pixel-coordinates in x-direction.

required
y ndarray

Pixel-coordinates in y-direction.

required
size float

Size of hexagon.

1
mode Literal['default', 'flat', 'pointy']

Coordinate system convention.

'default'

Returns:

Type Description
Tuple[ndarray, ndarray]

A tuple containing: u: Hex-coordinates in u-direction. v: Hex-coordinates in v-direction.

See Also

https://www.redblobgames.com/grids/hexagons/#hex-to-pixel

Source code in btorch/utils/hex_utils.py
def pixel_to_hex(
    x: np.ndarray,
    y: np.ndarray,
    size: float = 1,
    mode: Literal["default", "flat", "pointy"] = "default",
) -> Tuple[np.ndarray, np.ndarray]:
    """Returns hex coordinates from pixel coordinates.

    Args:
        x: Pixel-coordinates in x-direction.
        y: Pixel-coordinates in y-direction.
        size: Size of hexagon.
        mode: Coordinate system convention.

    Returns:
        A tuple containing:
            u: Hex-coordinates in u-direction.
            v: Hex-coordinates in v-direction.

    See Also:
        https://www.redblobgames.com/grids/hexagons/#hex-to-pixel
    """
    if mode == "default":
        return -x / 3 - y / np.sqrt(3), 2 / 3 * x
    elif mode == "flat":
        return (2 / 3 * x) / size, (-1 / 3 * x + np.sqrt(3) / 3 * y) / size
    elif mode == "pointy":
        return (np.sqrt(3) / 3 * x - 1 / 3 * y) / size, (2 / 3 * y) / size
    else:
        raise ValueError(f"{mode} not recognized.")

sort_u_then_v(u, v, values)

Sorts u, v, and values by u and then v.

Parameters:

Name Type Description Default
u ndarray

U-coordinate of hexal.

required
v ndarray

V-coordinate of hexal.

required
values ndarray

Value of hexal.

required

Returns:

Type Description
Tuple[ndarray, ndarray, ndarray]

A tuple containing: u: Sorted u-coordinate of hexal. v: Sorted v-coordinate of hexal. values: Sorted value of hexal.

Source code in btorch/utils/hex_utils.py
def sort_u_then_v(
    u: np.ndarray, v: np.ndarray, values: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Sorts u, v, and values by u and then v.

    Args:
        u: U-coordinate of hexal.
        v: V-coordinate of hexal.
        values: Value of hexal.

    Returns:
        A tuple containing:
            u: Sorted u-coordinate of hexal.
            v: Sorted v-coordinate of hexal.
            values: Sorted value of hexal.
    """
    index = np.lexsort((v, u))
    return u[index], v[index], values[index]

sort_u_then_v_index(u, v)

Index to sort u, v by u and then v.

Parameters:

Name Type Description Default
u ndarray

U-coordinate of hexal.

required
v ndarray

V-coordinate of hexal.

required

Returns:

Type Description
ndarray

Index to sort u and v.

Source code in btorch/utils/hex_utils.py
def sort_u_then_v_index(u: np.ndarray, v: np.ndarray) -> np.ndarray:
    """Index to sort u, v by u and then v.

    Args:
        u: U-coordinate of hexal.
        v: V-coordinate of hexal.

    Returns:
        Index to sort u and v.
    """
    return np.lexsort((v, u))

btorch.utils.pandas_utils

Pandas DataFrame utilities.

Helpers for common DataFrame operations used in connectome analysis and data aggregation workflows.

Functions

groupby_to_dict(df, column_select=None, **groupby_args)

Group DataFrame and return as dictionary mapping keys to subframes.

Parameters:

Name Type Description Default
df DataFrame

Input DataFrame to group.

required
column_select Optional[Sequence[str]]

Optional column subset to include in output values.

None
**groupby_args

Arguments passed to df.groupby().

{}

Returns:

Type Description
dict[Any, DataFrame]

Dictionary mapping group keys to DataFrame slices.

Example

df = pd.DataFrame({"a": [1, 1, 2], "b": [3, 4, 5]}) groupby_to_dict(df, column_select=["b"], by="a") {1: b 0 3 1 4, 2: b 2 5}

Source code in btorch/utils/pandas_utils.py
def groupby_to_dict(
    df: pd.DataFrame, column_select: Optional[Sequence[str]] = None, **groupby_args
) -> dict[Any, pd.DataFrame]:
    """Group DataFrame and return as dictionary mapping keys to subframes.

    Args:
        df: Input DataFrame to group.
        column_select: Optional column subset to include in output values.
        **groupby_args: Arguments passed to ``df.groupby()``.

    Returns:
        Dictionary mapping group keys to DataFrame slices.

    Example:
        >>> df = pd.DataFrame({"a": [1, 1, 2], "b": [3, 4, 5]})
        >>> groupby_to_dict(df, column_select=["b"], by="a")
        {1:    b
         0  3
         1  4, 2:    b
         2  5}
    """
    return {
        key: df.loc[ind, column_select] if column_select is not None else df.loc[ind]
        for key, ind in df.groupby(**groupby_args).groups.items()
    }

btorch.utils.yaml_utils

YAML serialization utilities.

Simple helpers for loading and saving Python objects to YAML files, with automatic directory creation.

Functions

load_yaml(folder_or_file, filename=None)

Load object from YAML file.

Parameters:

Name Type Description Default
folder_or_file

Directory path if filename is provided, otherwise full file path.

required
filename

Optional filename when folder_or_file is a directory.

None

Returns:

Type Description

Deserialized Python object.

Source code in btorch/utils/yaml_utils.py
def load_yaml(folder_or_file, filename=None):
    """Load object from YAML file.

    Args:
        folder_or_file: Directory path if ``filename`` is provided,
            otherwise full file path.
        filename: Optional filename when ``folder_or_file`` is a directory.

    Returns:
        Deserialized Python object.
    """
    file = (
        folder_or_file if filename is None else os.path.join(folder_or_file, filename)
    )
    with open(file, "r") as f:
        return yaml.safe_load(f)

save_yaml(args, folder_or_file, filename=None)

Save object to YAML file.

Parameters:

Name Type Description Default
args

Object to serialize. Tries safe_dump first, falls back to dumping args.__dict__.

required
folder_or_file

Directory path if filename is provided, otherwise full file path.

required
filename

Optional filename when folder_or_file is a directory.

None

Returns:

Type Description

None

Source code in btorch/utils/yaml_utils.py
def save_yaml(args, folder_or_file, filename=None):
    """Save object to YAML file.

    Args:
        args: Object to serialize. Tries ``safe_dump`` first, falls back
            to dumping ``args.__dict__``.
        folder_or_file: Directory path if ``filename`` is provided,
            otherwise full file path.
        filename: Optional filename when ``folder_or_file`` is a directory.

    Returns:
        None
    """
    try:
        args_text = yaml.safe_dump(args)
    except Exception:
        args_text = yaml.dump(args.__dict__)

    folder = os.path.dirname(folder_or_file) if filename is None else folder_or_file
    os.makedirs(folder, exist_ok=True)
    file = (
        folder_or_file if filename is None else os.path.join(folder_or_file, filename)
    )
    with open(file, "w") as f:
        f.write(args_text)