arkas
Advanced tools
| r"""Implement the precision evaluator.""" | ||
| from __future__ import annotations | ||
| __all__ = ["PrecisionEvaluator"] | ||
| from typing import TYPE_CHECKING, Any | ||
| from coola.utils.format import repr_indent, repr_mapping, str_indent, str_mapping | ||
| from arkas.evaluator2.base import BaseEvaluator | ||
| from arkas.evaluator2.vanilla import Evaluator | ||
| from arkas.metric import precision | ||
| if TYPE_CHECKING: | ||
| from arkas.state.precision_recall import PrecisionRecallState | ||
| class PrecisionEvaluator(BaseEvaluator): | ||
| r"""Implement the precision evaluator. | ||
| This evaluator can be used in 3 different settings: | ||
| - binary: ``y_true`` must be an array of shape ``(n_samples,)`` | ||
| with ``0`` and ``1`` values, and ``y_pred`` must be an array | ||
| of shape ``(n_samples,)``. | ||
| - multiclass: ``y_true`` must be an array of shape ``(n_samples,)`` | ||
| with values in ``{0, ..., n_classes-1}``, and ``y_pred`` must | ||
| be an array of shape ``(n_samples,)``. | ||
| - multilabel: ``y_true`` must be an array of shape | ||
| ``(n_samples, n_classes)`` with ``0`` and ``1`` values, and | ||
| ``y_pred`` must be an array of shape | ||
| ``(n_samples, n_classes)``. | ||
| Args: | ||
| state: The state containing the ground truth and predicted | ||
| labels. | ||
| Example usage: | ||
| ```pycon | ||
| >>> import numpy as np | ||
| >>> from arkas.evaluator2 import PrecisionEvaluator | ||
| >>> from arkas.state import PrecisionRecallState | ||
| >>> # binary | ||
| >>> evaluator = PrecisionEvaluator( | ||
| ... PrecisionRecallState( | ||
| ... y_true=np.array([1, 0, 0, 1, 1]), | ||
| ... y_pred=np.array([1, 0, 0, 1, 1]), | ||
| ... y_true_name="target", | ||
| ... y_pred_name="pred", | ||
| ... label_type="binary", | ||
| ... ), | ||
| ... ) | ||
| >>> evaluator | ||
| PrecisionEvaluator( | ||
| (state): PrecisionRecallState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred', label_type='binary', nan_policy='propagate') | ||
| ) | ||
| >>> evaluator.evaluate() | ||
| {'count': 5, 'precision': 1.0} | ||
| >>> # multilabel | ||
| >>> evaluator = PrecisionEvaluator( | ||
| ... PrecisionRecallState( | ||
| ... y_true=np.array([[1, 0, 1], [0, 1, 0], [0, 1, 0], [1, 0, 1], [1, 0, 1]]), | ||
| ... y_pred=np.array([[1, 0, 0], [0, 1, 1], [0, 1, 1], [1, 0, 0], [1, 0, 0]]), | ||
| ... y_true_name="target", | ||
| ... y_pred_name="pred", | ||
| ... label_type="multilabel", | ||
| ... ) | ||
| ... ) | ||
| >>> evaluator | ||
| PrecisionEvaluator( | ||
| (state): PrecisionRecallState(y_true=(5, 3), y_pred=(5, 3), y_true_name='target', y_pred_name='pred', label_type='multilabel', nan_policy='propagate') | ||
| ) | ||
| >>> evaluator.evaluate() | ||
| {'count': 5, | ||
| 'macro_precision': 0.666..., | ||
| 'micro_precision': 0.714..., | ||
| 'precision': array([1., 1., 0.]), | ||
| 'weighted_precision': 0.625} | ||
| >>> # multiclass | ||
| >>> evaluator = PrecisionEvaluator( | ||
| ... PrecisionRecallState( | ||
| ... y_true=np.array([0, 0, 1, 1, 2, 2]), | ||
| ... y_pred=np.array([0, 0, 1, 1, 2, 2]), | ||
| ... y_true_name="target", | ||
| ... y_pred_name="pred", | ||
| ... label_type="multiclass", | ||
| ... ), | ||
| ... ) | ||
| >>> evaluator | ||
| PrecisionEvaluator( | ||
| (state): PrecisionRecallState(y_true=(6,), y_pred=(6,), y_true_name='target', y_pred_name='pred', label_type='multiclass', nan_policy='propagate') | ||
| ) | ||
| >>> evaluator.evaluate() | ||
| {'count': 6, | ||
| 'macro_precision': 1.0, | ||
| 'micro_precision': 1.0, | ||
| 'precision': array([1., 1., 1.]), | ||
| 'weighted_precision': 1.0} | ||
| >>> # auto | ||
| >>> evaluator = PrecisionEvaluator( | ||
| ... PrecisionRecallState( | ||
| ... y_true=np.array([1, 0, 0, 1, 1]), | ||
| ... y_pred=np.array([1, 0, 0, 1, 1]), | ||
| ... y_true_name="target", | ||
| ... y_pred_name="pred", | ||
| ... ) | ||
| ... ) | ||
| >>> evaluator | ||
| PrecisionEvaluator( | ||
| (state): PrecisionRecallState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred', label_type='binary', nan_policy='propagate') | ||
| ) | ||
| >>> evaluator.evaluate() | ||
| {'count': 5, 'precision': 1.0} | ||
| ``` | ||
| """ | ||
| def __init__(self, state: PrecisionRecallState) -> None: | ||
| self._state = state | ||
| def __repr__(self) -> str: | ||
| args = repr_indent(repr_mapping({"state": self._state})) | ||
| return f"{self.__class__.__qualname__}(\n {args}\n)" | ||
| def __str__(self) -> str: | ||
| args = str_indent(str_mapping({"state": self._state})) | ||
| return f"{self.__class__.__qualname__}(\n {args}\n)" | ||
| def compute(self) -> Evaluator: | ||
| return Evaluator(metrics=self.evaluate()) | ||
| def equal(self, other: Any, equal_nan: bool = False) -> bool: | ||
| if not isinstance(other, self.__class__): | ||
| return False | ||
| return self._state.equal(other._state, equal_nan=equal_nan) | ||
| def evaluate(self, prefix: str = "", suffix: str = "") -> dict[str, float]: | ||
| return precision( | ||
| y_true=self._state.y_true, | ||
| y_pred=self._state.y_pred, | ||
| prefix=prefix, | ||
| suffix=suffix, | ||
| label_type=self._state.label_type, | ||
| nan_policy=self._state.nan_policy, | ||
| ) |
| r"""Contain an abstract state to more easily manage arbitrary keyword | ||
| arguments.""" | ||
| from __future__ import annotations | ||
| __all__ = ["BaseArgState"] | ||
| import copy | ||
| import sys | ||
| from typing import Any | ||
| import numpy as np | ||
| import polars as pl | ||
| from coola import objects_are_equal | ||
| from coola.utils import str_indent, str_mapping | ||
| from coola.utils.format import repr_mapping_line | ||
| from arkas.state.base import BaseState | ||
| if sys.version_info >= (3, 11): | ||
| from typing import Self | ||
| else: # pragma: no cover | ||
| from typing_extensions import ( | ||
| Self, # use backport because it was added in python 3.11 | ||
| ) | ||
| class BaseArgState(BaseState): | ||
| r"""Define a base class to manage arbitrary keyword arguments. | ||
| Args: | ||
| **kwargs: Additional keyword arguments. | ||
| """ | ||
| def __init__(self, **kwargs: Any) -> None: | ||
| self._kwargs = kwargs | ||
| def __repr__(self) -> str: | ||
| args = repr_mapping_line( | ||
| { | ||
| key: val.shape if isinstance(val, (pl.DataFrame, np.ndarray)) else val | ||
| for key, val in self.get_args().items() | ||
| } | ||
| ) | ||
| return f"{self.__class__.__qualname__}({args})" | ||
| def __str__(self) -> str: | ||
| args = str_indent( | ||
| str_mapping( | ||
| { | ||
| key: val.shape if isinstance(val, (pl.DataFrame, np.ndarray)) else val | ||
| for key, val in self.get_args().items() | ||
| } | ||
| ) | ||
| ) | ||
| return f"{self.__class__.__qualname__}({args})" | ||
| def clone(self, deep: bool = True) -> Self: | ||
| args = self.get_args() | ||
| if deep: | ||
| args = copy.deepcopy(args) | ||
| return self.__class__(**args) | ||
| def equal(self, other: Any, equal_nan: bool = False) -> bool: | ||
| if not isinstance(other, self.__class__): | ||
| return False | ||
| return objects_are_equal(self.get_args(), other.get_args(), equal_nan=equal_nan) | ||
| def get_arg(self, name: str, default: Any = None) -> Any: | ||
| r"""Get a given argument from the state. | ||
| Args: | ||
| name: The argument name to get. | ||
| default: The default value to return if the argument is missing. | ||
| Returns: | ||
| The argument value or the default value. | ||
| Example usage: | ||
| ```pycon | ||
| >>> from datetime import datetime, timezone | ||
| >>> import polars as pl | ||
| >>> from arkas.state import DataFrameState | ||
| >>> frame = pl.DataFrame( | ||
| ... { | ||
| ... "col1": [0, 1, 1, 0, 0, 1, 0], | ||
| ... "col2": [0, 1, 0, 1, 0, 1, 0], | ||
| ... "col3": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], | ||
| ... }, | ||
| ... schema={"col1": pl.Int64, "col2": pl.Int32, "col3": pl.Float64}, | ||
| ... ) | ||
| >>> state = DataFrameState(frame, column="col3") | ||
| >>> state.get_arg("column") | ||
| col3 | ||
| ``` | ||
| """ | ||
| return self._kwargs.get(name, default) | ||
| def get_args(self) -> dict: | ||
| r"""Get a dictionary with all the arguments of the state. | ||
| Returns: | ||
| The dictionary with all the arguments. | ||
| Example usage: | ||
| ```pycon | ||
| >>> from datetime import datetime, timezone | ||
| >>> import polars as pl | ||
| >>> from arkas.state import DataFrameState | ||
| >>> frame = pl.DataFrame( | ||
| ... { | ||
| ... "col1": [0, 1, 1, 0, 0, 1, 0], | ||
| ... "col2": [0, 1, 0, 1, 0, 1, 0], | ||
| ... "col3": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], | ||
| ... }, | ||
| ... schema={"col1": pl.Int64, "col2": pl.Int32, "col3": pl.Float64}, | ||
| ... ) | ||
| >>> state = DataFrameState(frame, column="col3") | ||
| >>> args = state.get_args() | ||
| ``` | ||
| """ | ||
| return self._kwargs |
| r"""Contain default styles.""" | ||
| from __future__ import annotations | ||
| __all__ = ["get_tab_number_style"] | ||
| def get_tab_number_style() -> str: | ||
| r"""Get the default style for numbers in a HTML table. | ||
| Returns: | ||
| The default style for numbers in a HTML table. | ||
| Example usage: | ||
| ```pycon | ||
| >>> from arkas.utils.style import get_tab_number_style | ||
| >>> style = get_tab_number_style() | ||
| ``` | ||
| """ | ||
| return "text-align: right; font-variant-numeric: tabular-nums;" |
+1
-1
| Metadata-Version: 2.1 | ||
| Name: arkas | ||
| Version: 0.0.1a11 | ||
| Version: 0.0.1a12 | ||
| Summary: Library to evaluate ML model performances | ||
@@ -5,0 +5,0 @@ Home-page: https://github.com/durandtibo/arkas |
+2
-2
| [tool.poetry] | ||
| name = "arkas" | ||
| version = "0.0.1a11" | ||
| version = "0.0.1a12" | ||
| description = "Library to evaluate ML model performances" | ||
@@ -84,3 +84,3 @@ readme = "README.md" | ||
| pytest-timeout = "^2.3" | ||
| ruff = ">=0.8,<1.0" | ||
| ruff = ">=0.9,<1.0" | ||
| xdoctest = "^1.2" | ||
@@ -87,0 +87,0 @@ |
@@ -59,4 +59,3 @@ r"""Contain the accuracy analyzer.""" | ||
| AccuracyOutput( | ||
| (state): AccuracyState(y_true=(6,), y_pred=(6,), y_true_name='target', y_pred_name='pred') | ||
| (nan_policy): propagate | ||
| (state): AccuracyState(y_true=(6,), y_pred=(6,), y_true_name='target', y_pred_name='pred', nan_policy='propagate') | ||
| ) | ||
@@ -104,4 +103,4 @@ | ||
| y_pred_name=self._y_pred, | ||
| nan_policy=self._nan_policy, | ||
| ), | ||
| nan_policy=self._nan_policy, | ||
| ) |
@@ -59,4 +59,3 @@ r"""Contain the balanced accuracy analyzer.""" | ||
| BalancedAccuracyOutput( | ||
| (state): AccuracyState(y_true=(6,), y_pred=(6,), y_true_name='target', y_pred_name='pred') | ||
| (nan_policy): propagate | ||
| (state): AccuracyState(y_true=(6,), y_pred=(6,), y_true_name='target', y_pred_name='pred', nan_policy='propagate') | ||
| ) | ||
@@ -105,4 +104,4 @@ | ||
| y_pred_name=self._y_pred, | ||
| nan_policy=self._nan_policy, | ||
| ), | ||
| nan_policy=self._nan_policy, | ||
| ) |
@@ -38,4 +38,3 @@ r"""Contain the base class to implement an analyzer.""" | ||
| AccuracyOutput( | ||
| (state): AccuracyState(y_true=(6,), y_pred=(6,), y_true_name='target', y_pred_name='pred') | ||
| (nan_policy): propagate | ||
| (state): AccuracyState(y_true=(6,), y_pred=(6,), y_true_name='target', y_pred_name='pred', nan_policy='propagate') | ||
| ) | ||
@@ -70,4 +69,3 @@ | ||
| AccuracyOutput( | ||
| (state): AccuracyState(y_true=(6,), y_pred=(6,), y_true_name='target', y_pred_name='pred') | ||
| (nan_policy): propagate | ||
| (state): AccuracyState(y_true=(6,), y_pred=(6,), y_true_name='target', y_pred_name='pred', nan_policy='propagate') | ||
| ) | ||
@@ -74,0 +72,0 @@ |
@@ -45,2 +45,3 @@ r"""Implement an analyzer that analyzes the correlation between numeric | ||
| no warning message appears. | ||
| sork_key: The key used to sort the correlation table. | ||
@@ -55,3 +56,3 @@ Example usage: | ||
| >>> analyzer | ||
| ColumnCorrelationAnalyzer(target_column='col3', columns=None, exclude_columns=(), missing_policy='raise') | ||
| ColumnCorrelationAnalyzer(target_column='col3', sork_key='spearman_coeff', columns=None, exclude_columns=(), missing_policy='raise') | ||
| >>> frame = pl.DataFrame( | ||
@@ -67,3 +68,3 @@ ... { | ||
| ColumnCorrelationOutput( | ||
| (state): TargetDataFrameState(dataframe=(7, 3), target_column='col3', nan_policy='propagate', figure_config=MatplotlibFigureConfig()) | ||
| (state): TargetDataFrameState(dataframe=(7, 3), target_column='col3', nan_policy='propagate', figure_config=MatplotlibFigureConfig(), sork_key='spearman_coeff') | ||
| ) | ||
@@ -80,2 +81,3 @@ | ||
| missing_policy: str = "raise", | ||
| sork_key: str = "spearman_coeff", | ||
| ) -> None: | ||
@@ -86,2 +88,3 @@ super().__init__( | ||
| self._target_column = target_column | ||
| self._sork_key = sork_key | ||
@@ -95,3 +98,6 @@ def find_columns(self, frame: pl.DataFrame) -> tuple[str, ...]: | ||
| def get_args(self) -> dict: | ||
| return {"target_column": self._target_column} | super().get_args() | ||
| return { | ||
| "target_column": self._target_column, | ||
| "sork_key": self._sork_key, | ||
| } | super().get_args() | ||
@@ -107,3 +113,4 @@ def _analyze(self, frame: pl.DataFrame) -> ColumnCorrelationOutput | EmptyOutput: | ||
| logger.info( | ||
| f"Analyzing the correlation between {self._target_column} and {self._columns}..." | ||
| f"Analyzing the correlation between {self._target_column} and {self._columns} | " | ||
| f"sort_key={self._sork_key!r} ..." | ||
| ) | ||
@@ -114,3 +121,5 @@ columns = list(self.find_common_columns(frame)) | ||
| return ColumnCorrelationOutput( | ||
| state=TargetDataFrameState(dataframe=out, target_column=self._target_column) | ||
| state=TargetDataFrameState( | ||
| dataframe=out, target_column=self._target_column, sork_key=self._sork_key | ||
| ) | ||
| ) |
@@ -72,3 +72,3 @@ r"""Implement an analyzer that analyzes the correlation between two | ||
| CorrelationOutput( | ||
| (state): DataFrameState(dataframe=(7, 2), figure_config=MatplotlibFigureConfig()) | ||
| (state): DataFrameState(dataframe=(7, 2), nan_policy='propagate', figure_config=MatplotlibFigureConfig()) | ||
| ) | ||
@@ -75,0 +75,0 @@ |
@@ -19,3 +19,2 @@ r"""Implement an analyzer that plots the content of each column.""" | ||
| if TYPE_CHECKING: | ||
| import polars as pl | ||
@@ -22,0 +21,0 @@ |
@@ -54,4 +54,3 @@ r"""Contain an analyzer that transforms the data before to analyze | ||
| AccuracyOutput( | ||
| (state): AccuracyState(y_true=(6,), y_pred=(6,), y_true_name='target', y_pred_name='pred') | ||
| (nan_policy): propagate | ||
| (state): AccuracyState(y_true=(6,), y_pred=(6,), y_true_name='target', y_pred_name='pred', nan_policy='propagate') | ||
| ) | ||
@@ -58,0 +57,0 @@ |
@@ -16,3 +16,2 @@ r"""Contain the implementation of a HTML content generator that analyzes | ||
| from arkas.evaluator2.accuracy import AccuracyEvaluator | ||
| from arkas.metric.utils import check_nan_policy | ||
@@ -33,5 +32,2 @@ if TYPE_CHECKING: | ||
| labels. | ||
| nan_policy: The policy on how to handle NaN values in the input | ||
| arrays. The following options are available: ``'omit'``, | ||
| ``'propagate'``, and ``'raise'``. | ||
@@ -55,4 +51,3 @@ Example usage: | ||
| AccuracyContentGenerator( | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred') | ||
| (nan_policy): propagate | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred', nan_policy='propagate') | ||
| ) | ||
@@ -63,14 +58,11 @@ | ||
| def __init__(self, state: AccuracyState, nan_policy: str = "propagate") -> None: | ||
| def __init__(self, state: AccuracyState) -> None: | ||
| self._state = state | ||
| check_nan_policy(nan_policy) | ||
| self._nan_policy = nan_policy | ||
| def __repr__(self) -> str: | ||
| args = repr_indent(repr_mapping({"state": self._state, "nan_policy": self._nan_policy})) | ||
| args = repr_indent(repr_mapping({"state": self._state})) | ||
| return f"{self.__class__.__qualname__}(\n {args}\n)" | ||
| def __str__(self) -> str: | ||
| args = str_indent(str_mapping({"state": self._state, "nan_policy": self._nan_policy})) | ||
| args = str_indent(str_mapping({"state": self._state})) | ||
| return f"{self.__class__.__qualname__}(\n {args}\n)" | ||
@@ -81,10 +73,7 @@ | ||
| return False | ||
| return ( | ||
| self._state.equal(other._state, equal_nan=equal_nan) | ||
| and self._nan_policy == other._nan_policy | ||
| ) | ||
| return self._state.equal(other._state, equal_nan=equal_nan) | ||
| def generate_content(self) -> str: | ||
| logger.info("Generating the accuracy content...") | ||
| metrics = AccuracyEvaluator(self._state, nan_policy=self._nan_policy).evaluate() | ||
| metrics = AccuracyEvaluator(self._state).evaluate() | ||
| return Template(create_template()).render( | ||
@@ -91,0 +80,0 @@ { |
@@ -16,3 +16,2 @@ r"""Contain the implementation of a HTML content generator that analyzes | ||
| from arkas.evaluator2.balanced_accuracy import BalancedAccuracyEvaluator | ||
| from arkas.metric.utils import check_nan_policy | ||
@@ -50,4 +49,3 @@ if TYPE_CHECKING: | ||
| BalancedAccuracyContentGenerator( | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred') | ||
| (nan_policy): propagate | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred', nan_policy='propagate') | ||
| ) | ||
@@ -58,14 +56,11 @@ | ||
| def __init__(self, state: AccuracyState, nan_policy: str = "propagate") -> None: | ||
| def __init__(self, state: AccuracyState) -> None: | ||
| self._state = state | ||
| check_nan_policy(nan_policy) | ||
| self._nan_policy = nan_policy | ||
| def __repr__(self) -> str: | ||
| args = repr_indent(repr_mapping({"state": self._state, "nan_policy": self._nan_policy})) | ||
| args = repr_indent(repr_mapping({"state": self._state})) | ||
| return f"{self.__class__.__qualname__}(\n {args}\n)" | ||
| def __str__(self) -> str: | ||
| args = str_indent(str_mapping({"state": self._state, "nan_policy": self._nan_policy})) | ||
| args = str_indent(str_mapping({"state": self._state})) | ||
| return f"{self.__class__.__qualname__}(\n {args}\n)" | ||
@@ -76,10 +71,7 @@ | ||
| return False | ||
| return ( | ||
| self._state.equal(other._state, equal_nan=equal_nan) | ||
| and self._nan_policy == other._nan_policy | ||
| ) | ||
| return self._state.equal(other._state, equal_nan=equal_nan) | ||
| def generate_content(self) -> str: | ||
| logger.info("Generating the balance accuracy content...") | ||
| metrics = BalancedAccuracyEvaluator(self._state, nan_policy=self._nan_policy).evaluate() | ||
| metrics = BalancedAccuracyEvaluator(self._state).evaluate() | ||
| return Template(create_template()).render( | ||
@@ -86,0 +78,0 @@ { |
@@ -40,4 +40,3 @@ r"""Contain the base class to implement a HTML Content Generator.""" | ||
| AccuracyContentGenerator( | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred') | ||
| (nan_policy): propagate | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred', nan_policy='propagate') | ||
| ) | ||
@@ -72,4 +71,3 @@ | ||
| AccuracyContentGenerator( | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred') | ||
| (nan_policy): propagate | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred', nan_policy='propagate') | ||
| ) | ||
@@ -76,0 +74,0 @@ >>> generator2 = generator.compute() |
@@ -24,2 +24,3 @@ r"""Contain the implementation of a HTML content generator that returns | ||
| from arkas.plotter import ColumnCooccurrencePlotter | ||
| from arkas.utils.style import get_tab_number_style | ||
@@ -241,8 +242,8 @@ if TYPE_CHECKING: | ||
| "<td>{{col2}}</td>" | ||
| '<td style="text-align: right;">{{count}}</td>' | ||
| '<td style="text-align: right;">{{pct}}</td>' | ||
| "<td {{num_style}}>{{count}}</td>" | ||
| "<td {{num_style}}>{{pct}}</td>" | ||
| "</tr>" | ||
| ).render( | ||
| { | ||
| "num_style": 'style="text-align: right;"', | ||
| "num_style": f'style="{get_tab_number_style()}"', | ||
| "rank": rank, | ||
@@ -249,0 +250,0 @@ "col1": col1, |
@@ -11,5 +11,7 @@ r"""Contain the implementation of a HTML content generator that analyzes | ||
| "create_template", | ||
| "sort_metrics", | ||
| ] | ||
| import logging | ||
| import math | ||
| from typing import TYPE_CHECKING, Any | ||
@@ -22,6 +24,5 @@ | ||
| from arkas.evaluator2.column_correlation import ColumnCorrelationEvaluator | ||
| from arkas.utils.style import get_tab_number_style | ||
| if TYPE_CHECKING: | ||
| from collections.abc import Sequence | ||
| from arkas.state.target_dataframe import TargetDataFrameState | ||
@@ -86,2 +87,6 @@ | ||
| metrics = ColumnCorrelationEvaluator(self._state).evaluate() | ||
| metrics = sort_metrics( | ||
| {key.split("_", maxsplit=1)[1]: val for key, val in metrics.items()}, | ||
| key=self._state.get_arg("sort_metric", "spearman_coeff"), | ||
| ) | ||
| columns = list(self._state.dataframe.columns) | ||
@@ -95,3 +100,3 @@ columns.remove(self._state.target_column) | ||
| "columns": ", ".join(self._state.dataframe.columns), | ||
| "table": create_table(metrics, columns=columns), | ||
| "table": create_table(metrics), | ||
| "target_column": f"{self._state.target_column}", | ||
@@ -135,3 +140,3 @@ } | ||
| def create_table(metrics: dict[str, dict], columns: Sequence[str]) -> str: | ||
| def create_table(metrics: dict[str, dict]) -> str: | ||
| r"""Return a HTML representation of a table with some statisticts | ||
@@ -142,3 +147,2 @@ about each column. | ||
| metrics: The dictionary of metrics. | ||
| columns: The columns to show in the table. | ||
@@ -156,3 +160,3 @@ Returns: | ||
| ... metrics={ | ||
| ... "correlation_col1": { | ||
| ... "col1": { | ||
| ... "count": 7, | ||
@@ -164,3 +168,3 @@ ... "pearson_coeff": 1.0, | ||
| ... }, | ||
| ... "correlation_col2": { | ||
| ... "col2": { | ||
| ... "count": 7, | ||
@@ -173,3 +177,2 @@ ... "pearson_coeff": -1.0, | ||
| ... }, | ||
| ... columns=["col1", "col2"], | ||
| ... ) | ||
@@ -180,3 +183,3 @@ | ||
| rows = "\n".join( | ||
| [create_table_row(column=col, metrics=metrics[f"correlation_{col}"]) for col in columns] | ||
| [create_table_row(column=col, metrics=values) for col, values in metrics.items()] | ||
| ) | ||
@@ -189,6 +192,4 @@ return Template( | ||
| <th>num samples</th> | ||
| <th>pearson coefficient</th> | ||
| <th>pearson p-value</th> | ||
| <th>spearman coefficient</th> | ||
| <th>spearman p-value</th> | ||
| <th>pearson coefficient (p-value)</th> | ||
| <th>spearman coefficient (p-value)</th> | ||
| </tr> | ||
@@ -238,10 +239,8 @@ </thead> | ||
| <td {{num_style}}>{{count}}</td> | ||
| <td {{num_style}}>{{pearson_coeff}}</td> | ||
| <td {{num_style}}>{{pearson_pvalue}}</td> | ||
| <td {{num_style}}>{{spearman_coeff}}</td> | ||
| <td {{num_style}}>{{spearman_pvalue}}</td> | ||
| <td {{num_style}}>{{pearson_coeff}} ({{pearson_pvalue}})</td> | ||
| <td {{num_style}}>{{spearman_coeff}} ({{spearman_pvalue}})</td> | ||
| </tr>""" | ||
| ).render( | ||
| { | ||
| "num_style": 'style="text-align: right;"', | ||
| "num_style": f'style="{get_tab_number_style()}"', | ||
| "column": column, | ||
@@ -255,1 +254,22 @@ "count": f'{metrics.get("count", 0):,}', | ||
| ) | ||
| def sort_metrics( | ||
| metrics: dict[str, dict[str, float]], key: str = "spearman_coeff" | ||
| ) -> dict[str, dict[str, float]]: | ||
| r"""Sort the dictionary of metrics by a given key. | ||
| Args: | ||
| metrics: The dictionary of metrics to sort. | ||
| key: The key to use to sort the metrics. | ||
| Returns: | ||
| The sorted dictionary of metrics. | ||
| """ | ||
| def get_metric(item: Any) -> float: | ||
| val = item[1][key] | ||
| if math.isnan(val): | ||
| val = float('-inf') | ||
| return val | ||
| return dict(sorted(metrics.items(), key=get_metric, reverse=True)) |
@@ -19,2 +19,3 @@ r"""Contain the implementation of a HTML content generator that analyzes | ||
| from arkas.utils.stats import compute_statistics_continuous | ||
| from arkas.utils.style import get_tab_number_style | ||
@@ -214,3 +215,3 @@ if TYPE_CHECKING: | ||
| { | ||
| "num_style": 'style="text-align: right;"', | ||
| "num_style": f'style="{get_tab_number_style()}"', | ||
| "count": f"{stats['count']:,}", | ||
@@ -217,0 +218,0 @@ "mean": f"{stats['mean']:,.4f}", |
@@ -24,2 +24,3 @@ r"""Contain the implementation of a HTML content generator that | ||
| from arkas.utils.stats import compute_statistics_continuous | ||
| from arkas.utils.style import get_tab_number_style | ||
@@ -244,3 +245,3 @@ if TYPE_CHECKING: | ||
| { | ||
| "num_style": 'style="text-align: right;"', | ||
| "num_style": f'style="{get_tab_number_style()}"', | ||
| "column": series.name, | ||
@@ -363,3 +364,3 @@ "dtype": series.dtype, | ||
| { | ||
| "num_style": 'style="text-align: right;"', | ||
| "num_style": f'style="{get_tab_number_style()}"', | ||
| "column": series.name, | ||
@@ -366,0 +367,0 @@ "min": float_to_str(stats["min"]), |
@@ -23,2 +23,4 @@ r"""Contain the implementation of a HTML content generator that returns | ||
| from arkas.content.section import BaseSectionContentGenerator | ||
| from arkas.content.utils import to_str | ||
| from arkas.utils.style import get_tab_number_style | ||
| from arkas.utils.validation import check_positive | ||
@@ -283,3 +285,3 @@ | ||
| most_frequent_values = ", ".join( | ||
| [f"{val} ({100 * c / total:.2f}%)" for val, c in most_frequent_values] | ||
| [f"{to_str(val)} ({100 * c / total:.2f}%)" for val, c in most_frequent_values] | ||
| ) | ||
@@ -296,3 +298,3 @@ return Template( | ||
| { | ||
| "num_style": 'style="text-align: right;"', | ||
| "num_style": f'style="{get_tab_number_style()}"', | ||
| "column": column, | ||
@@ -299,0 +301,0 @@ "null": null, |
@@ -18,2 +18,3 @@ r"""Contain the implementation of a HTML content generator that analyzes | ||
| from arkas.plotter.temporal_null_value import TemporalNullValuePlotter | ||
| from arkas.utils.style import get_tab_number_style | ||
@@ -244,3 +245,3 @@ if TYPE_CHECKING: | ||
| { | ||
| "num_style": 'style="text-align: right;"', | ||
| "num_style": f'style="{get_tab_number_style()}"', | ||
| "label": label, | ||
@@ -247,0 +248,0 @@ "num_nulls": f"{num_nulls:,}", |
@@ -18,3 +18,2 @@ r"""Contain the average precision evaluator for multiclass labels.""" | ||
| if TYPE_CHECKING: | ||
| import polars as pl | ||
@@ -21,0 +20,0 @@ |
@@ -14,2 +14,3 @@ r"""Contain data evaluators.""" | ||
| "EvaluatorDict", | ||
| "PrecisionEvaluator", | ||
| ] | ||
@@ -24,2 +25,3 @@ | ||
| from arkas.evaluator2.mapping import EvaluatorDict | ||
| from arkas.evaluator2.precision import PrecisionEvaluator | ||
| from arkas.evaluator2.vanilla import Evaluator |
@@ -15,3 +15,2 @@ r"""Implement the accuracy evaluator.""" | ||
| from arkas.metric import accuracy | ||
| from arkas.metric.utils import check_nan_policy | ||
@@ -28,5 +27,2 @@ if TYPE_CHECKING: | ||
| labels. | ||
| nan_policy: The policy on how to handle NaN values in the input | ||
| arrays. The following options are available: ``'omit'``, | ||
| ``'propagate'``, and ``'raise'``. | ||
@@ -50,4 +46,3 @@ Example usage: | ||
| AccuracyEvaluator( | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred') | ||
| (nan_policy): propagate | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred', nan_policy='propagate') | ||
| ) | ||
@@ -60,9 +55,7 @@ >>> evaluator.evaluate() | ||
| def __init__(self, state: AccuracyState, nan_policy: str = "propagate") -> None: | ||
| def __init__(self, state: AccuracyState) -> None: | ||
| self._state = state | ||
| check_nan_policy(nan_policy) | ||
| self._nan_policy = nan_policy | ||
| def __repr__(self) -> str: | ||
| args = repr_indent(repr_mapping({"state": self._state, "nan_policy": self._nan_policy})) | ||
| args = repr_indent(repr_mapping({"state": self._state})) | ||
| return f"{self.__class__.__qualname__}(\n {args}\n)" | ||
@@ -76,6 +69,3 @@ | ||
| return False | ||
| return ( | ||
| self._state.equal(other._state, equal_nan=equal_nan) | ||
| and self._nan_policy == other._nan_policy | ||
| ) | ||
| return self._state.equal(other._state, equal_nan=equal_nan) | ||
@@ -88,3 +78,3 @@ def evaluate(self, prefix: str = "", suffix: str = "") -> dict[str, float]: | ||
| suffix=suffix, | ||
| nan_policy=self._nan_policy, | ||
| nan_policy=self._state.nan_policy, | ||
| ) |
@@ -15,3 +15,2 @@ r"""Implement the balanced accuracy evaluator.""" | ||
| from arkas.metric import balanced_accuracy | ||
| from arkas.metric.utils import check_nan_policy | ||
@@ -49,4 +48,3 @@ if TYPE_CHECKING: | ||
| BalancedAccuracyEvaluator( | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred') | ||
| (nan_policy): propagate | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred', nan_policy='propagate') | ||
| ) | ||
@@ -59,9 +57,7 @@ >>> evaluator.evaluate() | ||
| def __init__(self, state: AccuracyState, nan_policy: str = "propagate") -> None: | ||
| def __init__(self, state: AccuracyState) -> None: | ||
| self._state = state | ||
| check_nan_policy(nan_policy) | ||
| self._nan_policy = nan_policy | ||
| def __repr__(self) -> str: | ||
| args = repr_indent(repr_mapping({"state": self._state, "nan_policy": self._nan_policy})) | ||
| args = repr_indent(repr_mapping({"state": self._state})) | ||
| return f"{self.__class__.__qualname__}(\n {args}\n)" | ||
@@ -75,6 +71,3 @@ | ||
| return False | ||
| return ( | ||
| self._state.equal(other._state, equal_nan=equal_nan) | ||
| and self._nan_policy == other._nan_policy | ||
| ) | ||
| return self._state.equal(other._state, equal_nan=equal_nan) | ||
@@ -87,3 +80,3 @@ def evaluate(self, prefix: str = "", suffix: str = "") -> dict[str, float]: | ||
| suffix=suffix, | ||
| nan_policy=self._nan_policy, | ||
| nan_policy=self._state.nan_policy, | ||
| ) |
@@ -38,4 +38,3 @@ r"""Contain the base class to implement an evaluator.""" | ||
| AccuracyEvaluator( | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred') | ||
| (nan_policy): propagate | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred', nan_policy='propagate') | ||
| ) | ||
@@ -42,0 +41,0 @@ |
@@ -16,3 +16,2 @@ r"""Implement the pairwise column correlation evaluator.""" | ||
| if TYPE_CHECKING: | ||
| from arkas.state.target_dataframe import TargetDataFrameState | ||
@@ -74,3 +73,3 @@ | ||
| def evaluate(self, prefix: str = "", suffix: str = "") -> dict[str, dict]: | ||
| def evaluate(self, prefix: str = "", suffix: str = "") -> dict[str, dict[str, float]]: | ||
| target_column = self._state.target_column | ||
@@ -77,0 +76,0 @@ columns = list(self._state.dataframe.columns) |
@@ -17,3 +17,2 @@ r"""Implement the pairwise column correlation evaluator.""" | ||
| if TYPE_CHECKING: | ||
| from arkas.state.target_dataframe import DataFrameState | ||
@@ -75,7 +74,7 @@ | ||
| def evaluate(self, prefix: str = "", suffix: str = "") -> dict[str, float]: | ||
| frame = self._state.dataframe.drop_nulls().drop_nans() | ||
| frame = self._state.dataframe | ||
| x = frame[frame.columns[0]].to_numpy() | ||
| y = frame[frame.columns[1]].to_numpy() | ||
| return pearsonr(x=x, y=y, prefix=prefix, suffix=suffix) | spearmanr( | ||
| x=x, y=y, prefix=prefix, suffix=suffix | ||
| ) | ||
| return pearsonr( | ||
| x=x, y=y, prefix=prefix, suffix=suffix, nan_policy=self._state.nan_policy | ||
| ) | spearmanr(x=x, y=y, prefix=prefix, suffix=suffix, nan_policy=self._state.nan_policy) |
@@ -52,4 +52,3 @@ r"""Contain an evaluator that evaluates a mapping of evaluators.""" | ||
| (one): AccuracyEvaluator( | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred') | ||
| (nan_policy): propagate | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred', nan_policy='propagate') | ||
| ) | ||
@@ -56,0 +55,0 @@ (two): Evaluator(count=2) |
@@ -0,1 +1,2 @@ | ||
| # noqa: A005 | ||
| r"""Contain the implementation for matplotlib figures.""" | ||
@@ -2,0 +3,0 @@ |
@@ -69,3 +69,3 @@ r"""Implement the Pearson correlation metrics.""" | ||
| coeff, pvalue = float("nan"), float("nan") | ||
| if count > 0 and not x_nan and not y_nan: | ||
| if count > 1 and not x_nan and not y_nan: | ||
| result = stats.pearsonr(x=x, y=y, alternative=alternative) | ||
@@ -72,0 +72,0 @@ coeff, pvalue = float(result.statistic), float(result.pvalue) |
@@ -72,3 +72,3 @@ r"""Implement the Spearman correlation metrics.""" | ||
| coeff, pvalue = float("nan"), float("nan") | ||
| if count > 0 and not x_nan and not y_nan: | ||
| if count > 1 and not x_nan and not y_nan: | ||
| result = stats.spearmanr(x, y, alternative=alternative) | ||
@@ -75,0 +75,0 @@ coeff, pvalue = float(result.statistic), float(result.pvalue) |
@@ -13,3 +13,2 @@ r"""Implement the accuracy output.""" | ||
| from arkas.evaluator2.accuracy import AccuracyEvaluator | ||
| from arkas.metric.utils import check_nan_policy | ||
| from arkas.output.lazy import BaseLazyOutput | ||
@@ -28,5 +27,2 @@ from arkas.plotter.vanilla import Plotter | ||
| labels. | ||
| nan_policy: The policy on how to handle NaN values in the input | ||
| arrays. The following options are available: ``'omit'``, | ||
| ``'propagate'``, and ``'raise'``. | ||
@@ -50,14 +46,11 @@ Example usage: | ||
| AccuracyOutput( | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred') | ||
| (nan_policy): propagate | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred', nan_policy='propagate') | ||
| ) | ||
| >>> output.get_content_generator() | ||
| AccuracyContentGenerator( | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred') | ||
| (nan_policy): propagate | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred', nan_policy='propagate') | ||
| ) | ||
| >>> output.get_evaluator() | ||
| AccuracyEvaluator( | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred') | ||
| (nan_policy): propagate | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred', nan_policy='propagate') | ||
| ) | ||
@@ -70,9 +63,7 @@ >>> output.get_plotter() | ||
| def __init__(self, state: AccuracyState, nan_policy: str = "propagate") -> None: | ||
| def __init__(self, state: AccuracyState) -> None: | ||
| self._state = state | ||
| check_nan_policy(nan_policy) | ||
| self._nan_policy = nan_policy | ||
| def __repr__(self) -> str: | ||
| args = repr_indent(repr_mapping({"state": self._state, "nan_policy": self._nan_policy})) | ||
| args = repr_indent(repr_mapping({"state": self._state})) | ||
| return f"{self.__class__.__qualname__}(\n {args}\n)" | ||
@@ -83,14 +74,11 @@ | ||
| return False | ||
| return ( | ||
| self._state.equal(other._state, equal_nan=equal_nan) | ||
| and self._nan_policy == other._nan_policy | ||
| ) | ||
| return self._state.equal(other._state, equal_nan=equal_nan) | ||
| def _get_content_generator(self) -> AccuracyContentGenerator: | ||
| return AccuracyContentGenerator(state=self._state, nan_policy=self._nan_policy) | ||
| return AccuracyContentGenerator(state=self._state) | ||
| def _get_evaluator(self) -> AccuracyEvaluator: | ||
| return AccuracyEvaluator(state=self._state, nan_policy=self._nan_policy) | ||
| return AccuracyEvaluator(state=self._state) | ||
| def _get_plotter(self) -> Plotter: | ||
| return Plotter() |
@@ -13,3 +13,2 @@ r"""Implement the balanced accuracy output.""" | ||
| from arkas.evaluator2.balanced_accuracy import BalancedAccuracyEvaluator | ||
| from arkas.metric.utils import check_nan_policy | ||
| from arkas.output.lazy import BaseLazyOutput | ||
@@ -28,5 +27,2 @@ from arkas.plotter.vanilla import Plotter | ||
| labels. | ||
| nan_policy: The policy on how to handle NaN values in the input | ||
| arrays. The following options are available: ``'omit'``, | ||
| ``'propagate'``, and ``'raise'``. | ||
@@ -50,14 +46,11 @@ Example usage: | ||
| BalancedAccuracyOutput( | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred') | ||
| (nan_policy): propagate | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred', nan_policy='propagate') | ||
| ) | ||
| >>> output.get_content_generator() | ||
| BalancedAccuracyContentGenerator( | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred') | ||
| (nan_policy): propagate | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred', nan_policy='propagate') | ||
| ) | ||
| >>> output.get_evaluator() | ||
| BalancedAccuracyEvaluator( | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred') | ||
| (nan_policy): propagate | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred', nan_policy='propagate') | ||
| ) | ||
@@ -70,9 +63,7 @@ >>> output.get_plotter() | ||
| def __init__(self, state: AccuracyState, nan_policy: str = "propagate") -> None: | ||
| def __init__(self, state: AccuracyState) -> None: | ||
| self._state = state | ||
| check_nan_policy(nan_policy) | ||
| self._nan_policy = nan_policy | ||
| def __repr__(self) -> str: | ||
| args = repr_indent(repr_mapping({"state": self._state, "nan_policy": self._nan_policy})) | ||
| args = repr_indent(repr_mapping({"state": self._state})) | ||
| return f"{self.__class__.__qualname__}(\n {args}\n)" | ||
@@ -83,14 +74,11 @@ | ||
| return False | ||
| return ( | ||
| self._state.equal(other._state, equal_nan=equal_nan) | ||
| and self._nan_policy == other._nan_policy | ||
| ) | ||
| return self._state.equal(other._state, equal_nan=equal_nan) | ||
| def _get_content_generator(self) -> BalancedAccuracyContentGenerator: | ||
| return BalancedAccuracyContentGenerator(state=self._state, nan_policy=self._nan_policy) | ||
| return BalancedAccuracyContentGenerator(state=self._state) | ||
| def _get_evaluator(self) -> BalancedAccuracyEvaluator: | ||
| return BalancedAccuracyEvaluator(state=self._state, nan_policy=self._nan_policy) | ||
| return BalancedAccuracyEvaluator(state=self._state) | ||
| def _get_plotter(self) -> Plotter: | ||
| return Plotter() |
@@ -42,4 +42,3 @@ r"""Contain the base class to implement an output.""" | ||
| AccuracyOutput( | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred') | ||
| (nan_policy): propagate | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred', nan_policy='propagate') | ||
| ) | ||
@@ -74,4 +73,3 @@ | ||
| AccuracyOutput( | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred') | ||
| (nan_policy): propagate | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred', nan_policy='propagate') | ||
| ) | ||
@@ -169,4 +167,3 @@ >>> out = output.compute() | ||
| AccuracyContentGenerator( | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred') | ||
| (nan_policy): propagate | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred', nan_policy='propagate') | ||
| ) | ||
@@ -206,4 +203,3 @@ | ||
| AccuracyEvaluator( | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred') | ||
| (nan_policy): propagate | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred', nan_policy='propagate') | ||
| ) | ||
@@ -210,0 +206,0 @@ |
@@ -62,4 +62,3 @@ r"""Implement a output that combines a mapping of output objects into a | ||
| (two): AccuracyContentGenerator( | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred') | ||
| (nan_policy): propagate | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred', nan_policy='propagate') | ||
| ) | ||
@@ -71,4 +70,3 @@ ) | ||
| (two): AccuracyEvaluator( | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred') | ||
| (nan_policy): propagate | ||
| (state): AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred', nan_policy='propagate') | ||
| ) | ||
@@ -75,0 +73,0 @@ ) |
@@ -21,3 +21,2 @@ r"""Contain the implementation of a DataFrame column plotter.""" | ||
| if TYPE_CHECKING: | ||
| from arkas.figure.base import BaseFigure | ||
@@ -24,0 +23,0 @@ from arkas.state.dataframe import DataFrameState |
@@ -24,3 +24,2 @@ r"""Contain the implementation of a DataFrame column plotter.""" | ||
| if TYPE_CHECKING: | ||
| from arkas.figure.base import BaseFigure | ||
@@ -27,0 +26,0 @@ from arkas.state.temporal_dataframe import TemporalDataFrameState |
@@ -7,2 +7,3 @@ r"""Contain states.""" | ||
| "AccuracyState", | ||
| "BaseArgState", | ||
| "BaseState", | ||
@@ -20,2 +21,3 @@ "ColumnCooccurrenceState", | ||
| from arkas.state.accuracy import AccuracyState | ||
| from arkas.state.arg import BaseArgState | ||
| from arkas.state.base import BaseState | ||
@@ -22,0 +24,0 @@ from arkas.state.column_cooccurrence import ColumnCooccurrenceState |
@@ -7,18 +7,7 @@ r"""Implement the accuracy state.""" | ||
| import sys | ||
| from typing import TYPE_CHECKING, Any | ||
| from coola import objects_are_equal | ||
| from coola.utils.format import repr_mapping_line | ||
| from arkas.metric.utils import check_nan_policy, check_same_shape_pred | ||
| from arkas.state.arg import BaseArgState | ||
| from arkas.metric.utils import check_same_shape_pred | ||
| from arkas.state.base import BaseState | ||
| if sys.version_info >= (3, 11): | ||
| from typing import Self | ||
| else: # pragma: no cover | ||
| from typing_extensions import ( | ||
| Self, # use backport because it was added in python 3.11 | ||
| ) | ||
| if TYPE_CHECKING: | ||
@@ -28,3 +17,3 @@ import numpy as np | ||
| class AccuracyState(BaseState): | ||
| class AccuracyState(BaseArgState): | ||
| r"""Implement the accuracy state. | ||
@@ -42,2 +31,5 @@ | ||
| y_pred_name: The name associated to the predicted labels. | ||
| nan_policy: The policy on how to handle NaN values in the input | ||
| arrays. The following options are available: ``'omit'``, | ||
| ``'propagate'``, and ``'raise'``. | ||
@@ -57,3 +49,3 @@ Example usage: | ||
| >>> state | ||
| AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred') | ||
| AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred', nan_policy='propagate') | ||
@@ -69,3 +61,6 @@ ``` | ||
| y_pred_name: str, | ||
| nan_policy: str = "propagate", | ||
| **kwargs: Any, | ||
| ) -> None: | ||
| super().__init__(**kwargs) | ||
| self._y_true = y_true.ravel() | ||
@@ -78,12 +73,4 @@ self._y_pred = y_pred.ravel() | ||
| def __repr__(self) -> str: | ||
| args = repr_mapping_line( | ||
| { | ||
| "y_true": self._y_true.shape, | ||
| "y_pred": self._y_pred.shape, | ||
| "y_true_name": self._y_true_name, | ||
| "y_pred_name": self._y_pred_name, | ||
| } | ||
| ) | ||
| return f"{self.__class__.__qualname__}({args})" | ||
| check_nan_policy(nan_policy) | ||
| self._nan_policy = nan_policy | ||
@@ -106,18 +93,13 @@ @property | ||
| def clone(self, deep: bool = True) -> Self: | ||
| return self.__class__( | ||
| y_true=self._y_true.copy() if deep else self._y_true, | ||
| y_pred=self._y_pred.copy() if deep else self._y_pred, | ||
| y_true_name=self._y_true_name, | ||
| y_pred_name=self._y_pred_name, | ||
| ) | ||
| @property | ||
| def nan_policy(self) -> str: | ||
| return self._nan_policy | ||
| def equal(self, other: Any, equal_nan: bool = False) -> bool: | ||
| if not isinstance(other, self.__class__): | ||
| return False | ||
| return ( | ||
| objects_are_equal(self.y_true, other.y_true, equal_nan=equal_nan) | ||
| and objects_are_equal(self.y_pred, other.y_pred, equal_nan=equal_nan) | ||
| and self.y_true_name == other.y_true_name | ||
| and self.y_pred_name == other.y_pred_name | ||
| ) | ||
| def get_args(self) -> dict: | ||
| return { | ||
| "y_true": self._y_true, | ||
| "y_pred": self._y_pred, | ||
| "y_true_name": self._y_true_name, | ||
| "y_pred_name": self._y_pred_name, | ||
| "nan_policy": self._nan_policy, | ||
| } | super().get_args() |
@@ -42,3 +42,3 @@ r"""Contain the base class to implement a state.""" | ||
| >>> state | ||
| AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred') | ||
| AccuracyState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred', nan_policy='propagate') | ||
@@ -45,0 +45,0 @@ ``` |
@@ -7,19 +7,8 @@ r"""Implement the DataFrame state.""" | ||
| import sys | ||
| from typing import TYPE_CHECKING, Any | ||
| from coola import objects_are_equal | ||
| from coola.utils.format import repr_mapping_line, str_indent, str_mapping | ||
| from arkas.figure.utils import get_default_config | ||
| from arkas.metric.utils import check_nan_policy | ||
| from arkas.state.base import BaseState | ||
| from arkas.state.arg import BaseArgState | ||
| if sys.version_info >= (3, 11): | ||
| from typing import Self | ||
| else: # pragma: no cover | ||
| from typing_extensions import ( | ||
| Self, # use backport because it was added in python 3.11 | ||
| ) | ||
| if TYPE_CHECKING: | ||
@@ -31,3 +20,3 @@ import polars as pl | ||
| class DataFrameState(BaseState): | ||
| class DataFrameState(BaseArgState): | ||
| r"""Implement the DataFrame state. | ||
@@ -41,2 +30,3 @@ | ||
| figure_config: An optional figure configuration. | ||
| **kwargs: Additional keyword arguments. | ||
@@ -68,3 +58,5 @@ Example usage: | ||
| figure_config: BaseFigureConfig | None = None, | ||
| **kwargs: Any, | ||
| ) -> None: | ||
| super().__init__(**kwargs) | ||
| self._dataframe = dataframe | ||
@@ -75,24 +67,2 @@ check_nan_policy(nan_policy) | ||
| def __repr__(self) -> str: | ||
| args = repr_mapping_line( | ||
| { | ||
| "dataframe": self._dataframe.shape, | ||
| "nan_policy": self._nan_policy, | ||
| "figure_config": self._figure_config, | ||
| } | ||
| ) | ||
| return f"{self.__class__.__qualname__}({args})" | ||
| def __str__(self) -> str: | ||
| args = str_indent( | ||
| str_mapping( | ||
| { | ||
| "dataframe": self._dataframe.shape, | ||
| "nan_policy": self._nan_policy, | ||
| "figure_config": self._figure_config, | ||
| } | ||
| ) | ||
| ) | ||
| return f"{self.__class__.__qualname__}({args})" | ||
| @property | ||
@@ -110,14 +80,2 @@ def dataframe(self) -> pl.DataFrame: | ||
| def clone(self, deep: bool = True) -> Self: | ||
| return self.__class__( | ||
| dataframe=self._dataframe.clone() if deep else self._dataframe, | ||
| nan_policy=self._nan_policy, | ||
| figure_config=self._figure_config.clone() if deep else self._figure_config, | ||
| ) | ||
| def equal(self, other: Any, equal_nan: bool = False) -> bool: | ||
| if not isinstance(other, self.__class__): | ||
| return False | ||
| return objects_are_equal(self.get_args(), other.get_args(), equal_nan=equal_nan) | ||
| def get_args(self) -> dict: | ||
@@ -128,2 +86,2 @@ return { | ||
| "figure_config": self._figure_config, | ||
| } | ||
| } | super().get_args() |
@@ -14,3 +14,3 @@ r"""Implement the accuracy state.""" | ||
| from arkas.metric.classification.precision import find_label_type | ||
| from arkas.metric.utils import check_label_type, check_same_shape_pred | ||
| from arkas.metric.utils import check_label_type, check_nan_policy, check_same_shape_pred | ||
| from arkas.state.base import BaseState | ||
@@ -55,2 +55,9 @@ | ||
| y_pred_name: The name associated to the predicted labels. | ||
| label_type: The type of labels used to evaluate the metrics. | ||
| The valid values are: ``'binary'``, ``'multiclass'``, | ||
| and ``'multilabel'``. If ``'binary'`` or ``'multilabel'``, | ||
| ``y_true`` values must be ``0`` and ``1``. | ||
| nan_policy: The policy on how to handle NaN values in the input | ||
| arrays. The following options are available: ``'omit'``, | ||
| ``'propagate'``, and ``'raise'``. | ||
@@ -70,3 +77,3 @@ Example usage: | ||
| >>> state | ||
| PrecisionRecallState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred', label_type='binary') | ||
| PrecisionRecallState(y_true=(5,), y_pred=(5,), y_true_name='target', y_pred_name='pred', label_type='binary', nan_policy='propagate') | ||
@@ -83,2 +90,3 @@ ``` | ||
| label_type: str = "auto", | ||
| nan_policy: str = "propagate", | ||
| ) -> None: | ||
@@ -92,2 +100,3 @@ self._y_true = y_true | ||
| ) | ||
| self._nan_policy = nan_policy | ||
| self._check_args() | ||
@@ -103,2 +112,3 @@ | ||
| "label_type": self._label_type, | ||
| "nan_policy": self._nan_policy, | ||
| } | ||
@@ -128,2 +138,6 @@ ) | ||
| @property | ||
| def nan_policy(self) -> str: | ||
| return self._nan_policy | ||
| def clone(self, deep: bool = True) -> Self: | ||
@@ -136,2 +150,3 @@ return self.__class__( | ||
| label_type=self._label_type, | ||
| nan_policy=self._nan_policy, | ||
| ) | ||
@@ -148,2 +163,3 @@ | ||
| and self.label_type == other.label_type | ||
| and self.nan_policy == other.nan_policy | ||
| ) | ||
@@ -166,1 +182,2 @@ | ||
| check_label_type(self._label_type) | ||
| check_nan_policy(self._nan_policy) |
@@ -8,6 +8,4 @@ r"""Implement the DataFrame state for scatter plots.""" | ||
| import sys | ||
| from typing import TYPE_CHECKING | ||
| from typing import TYPE_CHECKING, Any | ||
| from coola.utils.format import repr_mapping_line, str_indent, str_mapping | ||
| from arkas.state.dataframe import DataFrameState | ||
@@ -17,7 +15,5 @@ from arkas.utils.dataframe import check_column_exist | ||
| if sys.version_info >= (3, 11): | ||
| from typing import Self | ||
| pass | ||
| else: # pragma: no cover | ||
| from typing_extensions import ( | ||
| Self, # use backport because it was added in python 3.11 | ||
| ) | ||
| pass | ||
@@ -42,2 +38,3 @@ if TYPE_CHECKING: | ||
| figure_config: An optional figure configuration. | ||
| **kwargs: Additional keyword arguments. | ||
@@ -72,4 +69,7 @@ Example usage: | ||
| figure_config: BaseFigureConfig | None = None, | ||
| **kwargs: Any, | ||
| ) -> None: | ||
| super().__init__(dataframe=dataframe, nan_policy=nan_policy, figure_config=figure_config) | ||
| super().__init__( | ||
| dataframe=dataframe, nan_policy=nan_policy, figure_config=figure_config, **kwargs | ||
| ) | ||
@@ -84,30 +84,2 @@ check_column_exist(dataframe, x) | ||
| def __repr__(self) -> str: | ||
| args = repr_mapping_line( | ||
| { | ||
| "dataframe": self._dataframe.shape, | ||
| "x": self._x, | ||
| "y": self._y, | ||
| "color": self._color, | ||
| "nan_policy": self._nan_policy, | ||
| "figure_config": self._figure_config, | ||
| } | ||
| ) | ||
| return f"{self.__class__.__qualname__}({args})" | ||
| def __str__(self) -> str: | ||
| args = str_indent( | ||
| str_mapping( | ||
| { | ||
| "dataframe": self._dataframe.shape, | ||
| "x": self._x, | ||
| "y": self._y, | ||
| "color": self._color, | ||
| "nan_policy": self._nan_policy, | ||
| "figure_config": self._figure_config, | ||
| } | ||
| ) | ||
| ) | ||
| return f"{self.__class__.__qualname__}({args})" | ||
| @property | ||
@@ -125,13 +97,9 @@ def x(self) -> str: | ||
| def clone(self, deep: bool = True) -> Self: | ||
| return self.__class__( | ||
| dataframe=self._dataframe.clone() if deep else self._dataframe, | ||
| x=self._x, | ||
| y=self._y, | ||
| color=self._color, | ||
| nan_policy=self._nan_policy, | ||
| figure_config=self._figure_config.clone() if deep else self._figure_config, | ||
| ) | ||
| def get_args(self) -> dict: | ||
| return super().get_args() | {"x": self._x, "y": self._y, "color": self._color} | ||
| args = super().get_args() | ||
| return { | ||
| "dataframe": args.pop("dataframe"), | ||
| "x": self._x, | ||
| "y": self._y, | ||
| "color": self._color, | ||
| } | args |
@@ -8,6 +8,4 @@ r"""Implement DataFrame state with a target column.""" | ||
| import sys | ||
| from typing import TYPE_CHECKING | ||
| from typing import TYPE_CHECKING, Any | ||
| from coola.utils.format import repr_mapping_line, str_indent, str_mapping | ||
| from arkas.state.dataframe import DataFrameState | ||
@@ -17,7 +15,5 @@ from arkas.utils.dataframe import check_column_exist | ||
| if sys.version_info >= (3, 11): | ||
| from typing import Self | ||
| pass | ||
| else: # pragma: no cover | ||
| from typing_extensions import ( | ||
| Self, # use backport because it was added in python 3.11 | ||
| ) | ||
| pass | ||
@@ -40,2 +36,3 @@ if TYPE_CHECKING: | ||
| figure_config: An optional figure configuration. | ||
| **kwargs: Additional keyword arguments. | ||
@@ -70,4 +67,7 @@ Example usage: | ||
| figure_config: BaseFigureConfig | None = None, | ||
| **kwargs: Any, | ||
| ) -> None: | ||
| super().__init__(dataframe=dataframe, nan_policy=nan_policy, figure_config=figure_config) | ||
| super().__init__( | ||
| dataframe=dataframe, nan_policy=nan_policy, figure_config=figure_config, **kwargs | ||
| ) | ||
@@ -77,26 +77,2 @@ check_column_exist(dataframe, target_column) | ||
| def __repr__(self) -> str: | ||
| args = repr_mapping_line( | ||
| { | ||
| "dataframe": self._dataframe.shape, | ||
| "target_column": self._target_column, | ||
| "nan_policy": self._nan_policy, | ||
| "figure_config": self._figure_config, | ||
| } | ||
| ) | ||
| return f"{self.__class__.__qualname__}({args})" | ||
| def __str__(self) -> str: | ||
| args = str_indent( | ||
| str_mapping( | ||
| { | ||
| "dataframe": self._dataframe.shape, | ||
| "target_column": self._target_column, | ||
| "nan_policy": self._nan_policy, | ||
| "figure_config": self._figure_config, | ||
| } | ||
| ) | ||
| ) | ||
| return f"{self.__class__.__qualname__}({args})" | ||
| @property | ||
@@ -106,11 +82,7 @@ def target_column(self) -> str: | ||
| def clone(self, deep: bool = True) -> Self: | ||
| return self.__class__( | ||
| dataframe=self._dataframe.clone() if deep else self._dataframe, | ||
| target_column=self._target_column, | ||
| nan_policy=self._nan_policy, | ||
| figure_config=self._figure_config.clone() if deep else self._figure_config, | ||
| ) | ||
| def get_args(self) -> dict: | ||
| return super().get_args() | {"target_column": self._target_column} | ||
| args = super().get_args() | ||
| return { | ||
| "dataframe": args.pop("dataframe"), | ||
| "target_column": self._target_column, | ||
| } | args |
@@ -8,6 +8,4 @@ r"""Implement the temporal DataFrame state.""" | ||
| import sys | ||
| from typing import TYPE_CHECKING | ||
| from typing import TYPE_CHECKING, Any | ||
| from coola.utils.format import repr_mapping_line, str_indent, str_mapping | ||
| from arkas.state.dataframe import DataFrameState | ||
@@ -17,7 +15,5 @@ from arkas.utils.dataframe import check_column_exist | ||
| if sys.version_info >= (3, 11): | ||
| from typing import Self | ||
| pass | ||
| else: # pragma: no cover | ||
| from typing_extensions import ( | ||
| Self, # use backport because it was added in python 3.11 | ||
| ) | ||
| pass | ||
@@ -41,2 +37,3 @@ if TYPE_CHECKING: | ||
| figure_config: An optional figure configuration. | ||
| **kwargs: Additional keyword arguments. | ||
@@ -83,4 +80,7 @@ Example usage: | ||
| figure_config: BaseFigureConfig | None = None, | ||
| **kwargs: Any, | ||
| ) -> None: | ||
| super().__init__(dataframe=dataframe, nan_policy=nan_policy, figure_config=figure_config) | ||
| super().__init__( | ||
| dataframe=dataframe, nan_policy=nan_policy, figure_config=figure_config, **kwargs | ||
| ) | ||
@@ -91,28 +91,2 @@ check_column_exist(dataframe, temporal_column) | ||
| def __repr__(self) -> str: | ||
| args = repr_mapping_line( | ||
| { | ||
| "dataframe": self._dataframe.shape, | ||
| "temporal_column": self._temporal_column, | ||
| "period": self._period, | ||
| "nan_policy": self._nan_policy, | ||
| "figure_config": self._figure_config, | ||
| } | ||
| ) | ||
| return f"{self.__class__.__qualname__}({args})" | ||
| def __str__(self) -> str: | ||
| args = str_indent( | ||
| str_mapping( | ||
| { | ||
| "dataframe": self._dataframe.shape, | ||
| "temporal_column": self._temporal_column, | ||
| "period": self._period, | ||
| "nan_policy": self._nan_policy, | ||
| "figure_config": self._figure_config, | ||
| } | ||
| ) | ||
| ) | ||
| return f"{self.__class__.__qualname__}({args})" | ||
| @property | ||
@@ -126,15 +100,8 @@ def period(self) -> str | None: | ||
| def clone(self, deep: bool = True) -> Self: | ||
| return self.__class__( | ||
| dataframe=self._dataframe.clone() if deep else self._dataframe, | ||
| temporal_column=self._temporal_column, | ||
| period=self._period, | ||
| nan_policy=self._nan_policy, | ||
| figure_config=self._figure_config.clone() if deep else self._figure_config, | ||
| ) | ||
| def get_args(self) -> dict: | ||
| return super().get_args() | { | ||
| args = super().get_args() | ||
| return { | ||
| "dataframe": args.pop("dataframe"), | ||
| "temporal_column": self._temporal_column, | ||
| "period": self._period, | ||
| } | ||
| } | args |
@@ -0,1 +1,2 @@ | ||
| # noqa: A005 | ||
| r"""Implement some utility functions for ``numpy.ndarray``s.""" | ||
@@ -2,0 +3,0 @@ |
@@ -0,1 +1,2 @@ | ||
| # noqa: A005 | ||
| r"""Contain utility functions to generate sections.""" | ||
@@ -2,0 +3,0 @@ |
@@ -0,1 +1,2 @@ | ||
| # noqa: A005 | ||
| r"""Contain utility functions to configure the standard logging | ||
@@ -2,0 +3,0 @@ library.""" |
Alert delta unavailable
Currently unable to show alert delta for PyPI packages.
980960
0.34%267
1.14%24539
0.37%