#!/usr/bin/env python
# -*- coding: utf-8 -*-
# This file is part of the
# Scikit-NeuroMSI Project (https://github.com/renatoparedes/scikit-neuromsi).
# Copyright (c) 2021-2025, Renato Paredes; Cabral, Juan
# License: BSD 3-Clause
# Full Text:
# https://github.com/renatoparedes/scikit-neuromsi/blob/main/LICENSE.txt
# =============================================================================
# DOCS
# =============================================================================
"""Public testing utility functions.
This module exposes "assert" functions which facilitate the comparison in a
testing environment of objects created in "scikit-neuromsi".
The functionalities are extensions of those present in "xarray.testing" and
"numpy.testing".
"""
# =============================================================================
# IMPORTS
# =============================================================================
import numpy as np
import xarray as xa
from . import core, ndcollection
from .utils import dict_cmp
# =============================================================================
# ASSERTS
# =============================================================================
def _assert(cond, err_msg):
"""Asserts that a condition is true, otherwise raises an AssertionError \
with a specified error message.
This function exists to prevent asserts from being turned off with a
"python -O."
Parameters
----------
cond : bool
The condition to be evaluated.
err_msg : str
The error message to be raised if the condition is false.
"""
if not cond:
raise AssertionError(err_msg)
[docs]
def assert_ndresult_allclose(
left, right, rtol=1e-05, atol=1e-08, equal_nan=True, decode_bytes=True
):
"""Assert that two NDResult objects are approximately equal.
Parameters
----------
left : NDResult
The first NDResult object to compare.
right : NDResult
The second NDResult object to compare.
rtol : float, optional
The relative tolerance parameter for the `assert_allclose` function
(default is 1e-05).
atol : float, optional
The absolute tolerance parameter for the `assert_allclose` function
(default is 1e-08).
equal_nan : bool, optional
Whether to compare NaN values in the arrays (default is True).
decode_bytes : bool, optional
Whether to decode bytes in the arrays (default is True).
(See `xarray.testing.assert_allclose` for details).
"""
_assert(isinstance(left, core.NDResult), "left is not an NDResult")
if left is right:
return
_assert(isinstance(right, core.NDResult), "right is not an NDResult")
_assert(left.mname == right.mname, "mname mismatch")
_assert(left.mtype == right.mtype, "mtype mismatch")
_assert(left.output_mode == right.output_mode, "output_mode mismatch")
_assert(left.causes_ == right.causes_, "causes mismatch")
_assert(left.time_res == right.time_res, "time_res mismatch")
_assert(left.position_res == right.position_res, "position_res mismatch")
np.testing.assert_allclose(
left.time_range,
right.time_range,
rtol=rtol,
atol=atol,
equal_nan=equal_nan,
)
np.testing.assert_allclose(
left.position_range,
right.position_range,
rtol=rtol,
atol=atol,
equal_nan=equal_nan,
)
# dicts
assert dict_cmp.dict_allclose(
left.nmap_, right.nmap_, rtol=rtol, atol=atol, equal_nan=equal_nan
), "nmap mismatch"
assert dict_cmp.dict_allclose(
left.run_parameters,
right.run_parameters,
rtol=rtol,
atol=atol,
equal_nan=equal_nan,
), "run_parameters mismatch"
assert dict_cmp.dict_allclose(
left.extra_, right.extra_, rtol=rtol, atol=atol, equal_nan=equal_nan
), "extra mismatch"
xa.testing.assert_allclose(
left.to_xarray(),
right.to_xarray(),
rtol=rtol,
atol=atol,
decode_bytes=decode_bytes,
)
[docs]
def assert_ndresult_collection_allclose(
left, right, rtol=1e-05, atol=1e-08, equal_nan=True, decode_bytes=True
):
"""Assert that two NDResultCollection objects are approximately equal.
Parameters
----------
left : NDResultCollection
The first NDResultCollection object to compare.
right : NDResultCollection
The second NDResultCollection object to compare.
rtol : float, optional
The relative tolerance parameter for the `assert_allclose` function
(default is 1e-05).
atol : float, optional
The absolute tolerance parameter for the `assert_allclose` function
(default is 1e-08).
equal_nan : bool, optional
Whether to compare NaN values in the arrays (default is True).
decode_bytes : bool, optional
Whether to decode bytes in the arrays (default is True).
(See `xarray.testing.assert_allclose` for details).
"""
_assert(
isinstance(left, ndcollection.NDResultCollection),
"left is not an NDResultCollection",
)
if left is right:
return
_assert(
isinstance(right, ndcollection.NDResultCollection),
"right is not an NDResultCollection",
)
_assert(len(left) == len(right), "length mismatch")
for idx, (left_ndres, right_ndres) in enumerate(zip(left, right)):
try:
assert_ndresult_allclose(
left_ndres,
right_ndres,
rtol=rtol,
atol=atol,
equal_nan=equal_nan,
decode_bytes=decode_bytes,
)
except AssertionError as e:
raise AssertionError(
f"NDResultCollection[{idx}] mismatch: {e.args[0]}"
)