#!/usr/bin/env python
# -*- coding: utf-8 -*-
# This file is part of the
# Scikit-NeuroMSI Project (https://github.com/renatoparedes/scikit-neuromsi).
# Copyright (c) 2021-2025, Renato Paredes; Cabral, Juan
# License: BSD 3-Clause
# Full Text:
# https://github.com/renatoparedes/scikit-neuromsi/blob/main/LICENSE.txt
# =============================================================================
# DOCS
# =============================================================================
"""Plot helper for the Result object."""
# =============================================================================
# IMPORTS
# =============================================================================
from collections.abc import Iterable
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from ..constants import (
D_MODES,
D_POSITIONS,
D_POSITIONS_COORDINATES,
D_TIMES,
XA_NAME,
)
from ...utils import AccessorABC
# =============================================================================
# PLOTTER OBJECT
# =============================================================================
[docs]
class ResultPlotter(AccessorABC):
"""Make plots of Result.
Kind of plot to produce:
- line_positions
- line_times
Parameters
----------
result : NDResult
The NDResult object for which to create plots.
"""
_default_kind = "line_positions"
def __init__(self, result):
self._result = result
# LINE ====================================================================
def _resolve_axis(self, ax):
"""Resolve the axis for plotting.
If `ax` is None, create a new figure and axis based on the number of
position coordinates in the result. Otherwise, ensure `ax` is an
iterable.
Parameters
----------
ax : matplotlib.axes.Axes or None
The axis to plot on. If None, a new figure and axis will
be created.
Returns
-------
ax : numpy.ndarray
The resolved axis as a numpy array.
"""
if ax is None:
coords_number = len(self._result.pcoords_)
fig, ax = plt.subplots(1, coords_number, sharey=True)
size_x, size_y = fig.get_size_inches()
fig.set_size_inches(size_x * coords_number, size_y)
if not isinstance(ax, Iterable):
ax = [ax]
return np.asarray(ax)
def _complete_index_level(self, *, df, level, center, fill_value):
"""Complete a DataFrame index level with missing values around a \
center point.
This function fills in missing values in a specified index level
of a DataFrame, creating a symmetric range of values around a center
point.
Parameters
----------
df : pandas.DataFrame
The input DataFrame to be completed.
level : str
The name of the index level to be completed.
center : int
The center value around which to complete the index.
fill_value : scalar
The value to use for filling new rows.
Returns
-------
pandas.DataFrame
A new DataFrame with the completed index level.
Notes
-----
The function creates a range of 51 values centered around the given
center value, excluding the center itself. It then uses these values
to complete the specified index level.
"""
# Create a range of values to complete the index,
# centered around the given center value
values_to_complete = np.arange(center - 25, center + 26)
values_to_complete = values_to_complete[values_to_complete != center]
# Prepare new index values for all levels
index_names, new_index_values = [], []
for lname, lvalues in zip(df.index.names, df.index.levels):
lvalues = lvalues.to_numpy()
# Replace the values for the specified level with the new range
if level == lname:
lvalues = values_to_complete
index_names.append(lname)
new_index_values.append(lvalues)
# Create a new DataFrame with the expanded index and fill it
# with the specified fill_value
for_complete_df = pd.DataFrame(
fill_value,
columns=df.columns.copy(),
index=pd.MultiIndex.from_product(
new_index_values, names=index_names
),
)
# Concatenate the original DataFrame with the new one to
# create the completed DataFrame
completed_df = pd.concat((df, for_complete_df))
return completed_df
def _scale_xtickslabels(self, *, limits, ticks, single_value):
"""Scale the x-tick labels based on the provided limits and ticks.
Parameters
----------
limits : tuple
The lower and upper limits for scaling the x-tick labels.
ticks : array-like
The original tick positions.
single_value : bool
Whether there is a single value in the data.
Returns
-------
labels : numpy.ndarray
The scaled x-tick labels.
"""
ll, tl = np.sort(limits)
ticks_array = np.asarray(ticks, dtype=float)
tmin, tmax = np.min(ticks_array), np.max(ticks_array)
new_ticks = np.interp(ticks_array, (tmin, tmax), (ll, tl))
labels = np.array([f"{t:.2f}" for t in new_ticks])
if single_value:
mask = np.ones_like(labels, dtype=bool)
mask[len(mask) // 2] = False
labels[mask] = ""
return labels
# API======================================================================
[docs]
def line_positions(self, time=None, **kwargs):
"""Create a line plot of positions at a specific time.
Parameters
----------
time : float or None, optional
The time at which to plot the positions. If None, the maximum time
from the result will be used. Default is None.
**kwargs
Additional keyword arguments to pass to seaborn.lineplot().
Returns
-------
axes : numpy.ndarray
The plotted axes.
"""
if time is None:
time = self._result.stats.dimmax()[D_TIMES]
axes = self._resolve_axis(kwargs.pop("ax", None)).flatten()
has_single_position = len(self._result.positions_) == 1
position_range = self._result.position_range
xa = self._result.to_xarray()
kwargs.setdefault("alpha", 0.75)
df = xa.sel(times=time).to_dataframe()
if has_single_position:
df = self._complete_index_level(
df=df,
level=D_POSITIONS,
center=self._result.positions_[0],
fill_value={"times": time, "values": 0},
)
for pcoord, ax in zip(self._result.pcoords_, axes, strict=True):
pcoord_df = df.xs(pcoord, level=D_POSITIONS_COORDINATES)
sns.lineplot(
x=D_POSITIONS,
y=XA_NAME,
hue=D_MODES,
data=pcoord_df,
ax=ax,
legend=(ax == axes[-1]), # the last ax has the legend
**kwargs,
)
# rescale the ticks by resolution
ticks = ax.get_xticks()
labels = self._scale_xtickslabels(
limits=position_range,
ticks=ticks,
single_value=has_single_position,
)
ax.set_xticks(ticks) # without this a warning will be raised
ax.set_xticklabels(labels)
# title
ax.set_title(pcoord)
# retrieve the figure
figure = ax.get_figure()
figure.suptitle(f"{self._result.mname} - Time {time}")
return axes
linep = line_positions
[docs]
def line_times(self, position=None, **kwargs):
"""Create a line plot of time series at a specific position.
Parameters
----------
position : float or None, optional
The position at which to plot the time series. If None, the
position with the maximum value from the result will be used.
Default is None.
**kwargs
Additional keyword arguments to pass to seaborn.lineplot().
Returns
-------
ax : numpy.ndarray
The plotted axes.
"""
if position is None:
position = self._result.stats.dimmax()[D_POSITIONS]
axes = self._resolve_axis(kwargs.pop("ax", None))
has_single_time = len(self._result.times_) == 1
time_range = self._result.time_range
xa = self._result.to_xarray()
kwargs.setdefault("alpha", 0.75)
df = xa.sel(positions=position).to_dataframe()
if has_single_time:
df = self._complete_index_level(
df=df,
level=D_TIMES,
center=self._result.times_[0],
fill_value={"positions": position, "values": 0},
)
for pcoord, ax in zip(self._result.pcoords_, axes, strict=True):
pcoord_df = df.xs(pcoord, level=D_POSITIONS_COORDINATES)
sns.lineplot(
x=D_TIMES,
y=XA_NAME,
hue=D_MODES,
data=pcoord_df,
ax=ax,
legend=(ax == axes[-1]), # the last ax has the legend
**kwargs,
)
# rescale the ticks by resolution
ticks = ax.get_xticks()
labels = self._scale_xtickslabels(
limits=time_range, ticks=ticks, single_value=has_single_time
)
ax.set_xticks(ticks) # without this a warning will be raised
ax.set_xticklabels(labels)
# title
ax.set_title(pcoord)
figure = ax.get_figure()
figure.suptitle(f"{self._result.mname} - Position {position}")
return axes
linet = line_times