Source code for skneuromsi.ndcollection.cplot_acc

#!/usr/bin/env python
# -*- coding: utf-8 -*-

# This file is part of the
#   Scikit-NeuroMSI Project (https://github.com/renatoparedes/scikit-neuromsi).
# Copyright (c) 2021-2025, Renato Paredes; Cabral, Juan
# License: BSD 3-Clause
# Full Text:
#     https://github.com/renatoparedes/scikit-neuromsi/blob/main/LICENSE.txt

# =============================================================================
# DOCS
# =============================================================================

"""Utilities for plotting NDResultCollection data.

NDResultCollectionPlotter provides methods for plotting NDResultCollection
data, including line plots for reports and biases.

"""

# =============================================================================
# IMPORTS
# =============================================================================


import seaborn as sns

from ..utils import AccessorABC

# =============================================================================
# PLOTTER
# =============================================================================


[docs] class NDResultCollectionPlotter(AccessorABC): """NDResultCollection plotting utilities. Parameters ---------- ndcollection : NDResultCollection The NDResultCollection object to be plotted. """ _default_kind = "unity_report" def __init__(self, ndcollection): self._nd_collection = ndcollection def _line_report(self, report, ax, kws): """Generate a line plot for a given report. Parameters ---------- report : pandas.DataFrame The report data to be plotted. ax : matplotlib.axes.Axes The axes object to plot on. kws : dict Additional keyword arguments for the plot. Returns ------- ax : matplotlib.axes.Axes The plotted axes object. """ parameter = report.columns[0] x = report.index y = report[parameter] kws.setdefault("label", parameter) ax = sns.lineplot(x=x, y=y, ax=ax, **kws) return ax
[docs] def n_report(self, n, *, parameter=None, ax=None, **kws): """Line plot of the N-report for a given number of causes. Parameters ---------- n : int The number of causes for the N-report. parameter : str, optional The parameter to plot, by default None. ax : matplotlib.axes.Axes, optional The axes object to plot on, by default None. **kws Additional keyword arguments for the plot. Returns ------- ax : matplotlib.axes.Axes The plotted axes object. """ causes_acc = self._nd_collection.causes the_report = causes_acc.n_report(n, parameter=parameter) ax = self._line_report(the_report, ax, kws) ax.set_ylabel(f"Proportion of {n} causes") return ax
[docs] def unity_report(self, *, parameter=None, ax=None, **kws): """Line plot of the unity report. Parameters ---------- parameter : str, optional The parameter to plot, by default None. ax : matplotlib.axes.Axes, optional The axes object to plot on, by default None. **kws Additional keyword arguments for the plot. Returns ------- ax : matplotlib.axes.Axes The plotted axes object. """ causes_acc = self._nd_collection.causes the_report = causes_acc.unity_report(parameter=parameter) ax = self._line_report(the_report, ax, kws) ax.set_ylabel("Proportion of unique causes") return ax
[docs] def mean_report(self, *, parameter=None, ax=None, **kws): """Line plot of the mean report. Parameters ---------- parameter : str, optional The parameter to plot, by default None. ax : matplotlib.axes.Axes, optional The axes object to plot on, by default None. **kws Additional keyword arguments for the plot. Returns ------- ax : matplotlib.axes.Axes The plotted axes object. """ causes_acc = self._nd_collection.causes the_report = causes_acc.mean_report(parameter=parameter) ax = self._line_report(the_report, ax, kws) ax.set_ylabel("Mean of causes") return ax
[docs] def bias( self, influence_parameter, *, changing_parameter=None, dim=None, mode=None, quiet=False, show_iterations=True, show_mean=True, ax=None, legend=True, it_linestyle="--", it_alpha=0.25, mean_color="black", **kws, ): """Line plot of bias. Parameters ---------- influence_parameter : object The parameter that influences the bias calculation. changing_parameter : object, optional The parameter that changes across the bias calculation. dim : object, optional The dimension to use for the bias calculation. mode : object, optional The mode to use for the bias calculation. quiet : bool, default False If True, suppress output messages. show_iterations : bool, default True If True, plot individual iterations of bias data. show_mean : bool, default True If True, plot the mean bias data. ax : matplotlib.axes.Axes, optional The matplotlib axes to plot on. If None, a new figure and axes will be created. legend : bool, default True If True, add a legend to the plot. it_linestyle : str, default '--' The line style to use for iteration plots. it_alpha : float, default 0.25 The alpha (transparency) value for iteration plots. mean_color : str, default 'black' The color to use for the mean bias plot. **kws : dict Additional keyword arguments to pass to seaborn's lineplot. Returns ------- matplotlib.axes.Axes The matplotlib axes containing the plot. Raises ------ ValueError If both show_iterations and show_mean are False. """ if not (show_iterations or show_mean): raise ValueError( "If 'show_iterations' and 'show_mean' are False, " "yout plot are empty" ) coll = self._nd_collection changing_parameter = coll.coerce_parameter(changing_parameter) dim = coll.coerce_dimension(dim) mode = coll.coerce_mode(mode) bias_acc = coll.bias kws.setdefault("estimator", "mean") if show_iterations: the_bias = bias_acc.bias( influence_parameter, changing_parameter=changing_parameter, dim=dim, mode=mode, quiet=quiet, ) x = the_bias.index for column in the_bias.columns: y = the_bias[column] label = ( "Iterations" if (column == the_bias.columns[0]) else None ) ax = sns.lineplot( x=x, y=y, ax=ax, linestyle=it_linestyle, alpha=it_alpha, legend=False, label=label, **kws, ) if show_mean: the_bias_mean = bias_acc.bias_mean( influence_parameter, changing_parameter=changing_parameter, dim=dim, mode=mode, quiet=quiet, ) x_mean = the_bias_mean.index y_mean = the_bias_mean[the_bias_mean.columns[0]] ax = sns.lineplot( x=x_mean, y=y_mean, ax=ax, legend=False, label="Bias Mean", color=mean_color, **kws, ) ax.set_ylabel("Bias") ax.set_title( f"Fixed={influence_parameter}, Target={changing_parameter}\n" f"Mode={mode}, Dimension={dim}" ) if legend: ax.legend() leg = ax.get_legend() leg.legend_handles[0].set_color(mean_color) return ax