arkas
Advanced tools
| r"""Implement an analyzer that plots the content of each column.""" | ||
| from __future__ import annotations | ||
| __all__ = ["PlotColumnAnalyzer"] | ||
| import logging | ||
| from typing import TYPE_CHECKING | ||
| from grizz.utils.format import str_shape_diff | ||
| from arkas.analyzer.lazy import BaseInNLazyAnalyzer | ||
| from arkas.output.plot_column import PlotColumnOutput | ||
| from arkas.state.dataframe import DataFrameState | ||
| if TYPE_CHECKING: | ||
| from collections.abc import Sequence | ||
| import polars as pl | ||
| from arkas.figure import BaseFigureConfig | ||
| logger = logging.getLogger(__name__) | ||
| class PlotColumnAnalyzer(BaseInNLazyAnalyzer): | ||
| r"""Implement an analyzer that plots the content of each column. | ||
| Args: | ||
| columns: The columns to analyze. If ``None``, it analyzes all | ||
| the columns. | ||
| exclude_columns: The columns to exclude from the input | ||
| ``columns``. If any column is not found, it will be ignored | ||
| during the filtering process. | ||
| missing_policy: The policy on how to handle missing columns. | ||
| The following options are available: ``'ignore'``, | ||
| ``'warn'``, and ``'raise'``. If ``'raise'``, an exception | ||
| is raised if at least one column is missing. | ||
| If ``'warn'``, a warning is raised if at least one column | ||
| is missing and the missing columns are ignored. | ||
| If ``'ignore'``, the missing columns are ignored and | ||
| no warning message appears. | ||
| figure_config: The figure configuration. | ||
| Example usage: | ||
| ```pycon | ||
| >>> import polars as pl | ||
| >>> from arkas.analyzer import PlotColumnAnalyzer | ||
| >>> analyzer = PlotColumnAnalyzer() | ||
| >>> analyzer | ||
| PlotColumnAnalyzer(columns=None, exclude_columns=(), missing_policy='raise', figure_config=None) | ||
| >>> frame = pl.DataFrame( | ||
| ... { | ||
| ... "col1": [0, 1, 0, 1], | ||
| ... "col2": [1, 0, 1, 0], | ||
| ... "col3": [1, 1, 1, 1], | ||
| ... }, | ||
| ... schema={"col1": pl.Int64, "col2": pl.Int64, "col3": pl.Int64}, | ||
| ... ) | ||
| >>> output = analyzer.analyze(frame) | ||
| >>> output | ||
| PlotColumnOutput( | ||
| (state): DataFrameState(dataframe=(4, 3), figure_config=MatplotlibFigureConfig()) | ||
| ) | ||
| ``` | ||
| """ | ||
| def __init__( | ||
| self, | ||
| columns: Sequence[str] | None = None, | ||
| exclude_columns: Sequence[str] = (), | ||
| missing_policy: str = "raise", | ||
| figure_config: BaseFigureConfig | None = None, | ||
| ) -> None: | ||
| super().__init__( | ||
| columns=columns, | ||
| exclude_columns=exclude_columns, | ||
| missing_policy=missing_policy, | ||
| ) | ||
| self._figure_config = figure_config | ||
| def get_args(self) -> dict: | ||
| return super().get_args() | { | ||
| "figure_config": self._figure_config, | ||
| } | ||
| def _analyze(self, frame: pl.DataFrame) -> PlotColumnOutput: | ||
| logger.info(f"Plotting the content of {len(self.find_columns(frame)):,} columns...") | ||
| columns = self.find_common_columns(frame) | ||
| dataframe = frame.select(columns) | ||
| logger.info(str_shape_diff(orig=frame.shape, final=dataframe.shape)) | ||
| return PlotColumnOutput( | ||
| state=DataFrameState(dataframe=dataframe, figure_config=self._figure_config) | ||
| ) |
| r"""Implement an analyzer that plots the content of each column.""" | ||
| from __future__ import annotations | ||
| __all__ = ["ScatterColumnAnalyzer"] | ||
| import logging | ||
| from typing import TYPE_CHECKING, Any | ||
| from coola import objects_are_equal | ||
| from coola.utils.format import repr_mapping_line | ||
| from grizz.utils.format import str_shape_diff | ||
| from arkas.analyzer.lazy import BaseLazyAnalyzer | ||
| from arkas.output.scatter_column import ScatterColumnOutput | ||
| from arkas.state.scatter_dataframe import ScatterDataFrameState | ||
| if TYPE_CHECKING: | ||
| import polars as pl | ||
| from arkas.figure import BaseFigureConfig | ||
| logger = logging.getLogger(__name__) | ||
| class ScatterColumnAnalyzer(BaseLazyAnalyzer): | ||
| r"""Implement an analyzer that plots the content of each column. | ||
| Args: | ||
| x: The x-axis data column. | ||
| y: The y-axis data column. | ||
| color: An optional color axis data column. | ||
| figure_config: The figure configuration. | ||
| Example usage: | ||
| ```pycon | ||
| >>> import polars as pl | ||
| >>> from arkas.analyzer import ScatterColumnAnalyzer | ||
| >>> analyzer = ScatterColumnAnalyzer(x="col1", y="col2") | ||
| >>> analyzer | ||
| ScatterColumnAnalyzer(x='col1', y='col2', color=None, figure_config=None) | ||
| >>> frame = pl.DataFrame( | ||
| ... { | ||
| ... "col1": [0, 1, 0, 1], | ||
| ... "col2": [1, 0, 1, 0], | ||
| ... "col3": [1, 1, 1, 1], | ||
| ... }, | ||
| ... schema={"col1": pl.Int64, "col2": pl.Int64, "col3": pl.Int64}, | ||
| ... ) | ||
| >>> output = analyzer.analyze(frame) | ||
| >>> output | ||
| ScatterColumnOutput( | ||
| (state): ScatterDataFrameState(dataframe=(4, 2), x='col1', y='col2', color=None, figure_config=MatplotlibFigureConfig()) | ||
| ) | ||
| ``` | ||
| """ | ||
| def __init__( | ||
| self, | ||
| x: str, | ||
| y: str, | ||
| color: str | None = None, | ||
| figure_config: BaseFigureConfig | None = None, | ||
| ) -> None: | ||
| self._x = x | ||
| self._y = y | ||
| self._color = color | ||
| self._figure_config = figure_config | ||
| def __repr__(self) -> str: | ||
| args = repr_mapping_line(self.get_args()) | ||
| return f"{self.__class__.__qualname__}({args})" | ||
| def equal(self, other: Any, equal_nan: bool = False) -> bool: | ||
| if not isinstance(other, self.__class__): | ||
| return False | ||
| return objects_are_equal(self.get_args(), other.get_args(), equal_nan=equal_nan) | ||
| def get_args(self) -> dict: | ||
| return { | ||
| "x": self._x, | ||
| "y": self._y, | ||
| "color": self._color, | ||
| "figure_config": self._figure_config, | ||
| } | ||
| def _analyze(self, frame: pl.DataFrame) -> ScatterColumnOutput: | ||
| logger.info(f"Plotting the content of {self._x!r}, {self._y!r}, and {self._color!r}...") | ||
| dataframe = frame.select([self._x, self._y] + ([self._color] if self._color else [])) | ||
| logger.info(str_shape_diff(orig=frame.shape, final=dataframe.shape)) | ||
| return ScatterColumnOutput( | ||
| state=ScatterDataFrameState( | ||
| dataframe=dataframe, | ||
| x=self._x, | ||
| y=self._y, | ||
| color=self._color, | ||
| figure_config=self._figure_config, | ||
| ) | ||
| ) |
| r"""Implement an analyzer that plots the content of each column.""" | ||
| from __future__ import annotations | ||
| __all__ = ["TemporalPlotColumnAnalyzer"] | ||
| import logging | ||
| from typing import TYPE_CHECKING | ||
| from grizz.utils.format import str_shape_diff | ||
| from arkas.analyzer.lazy import BaseInNLazyAnalyzer | ||
| from arkas.output.temporal_plot_column import TemporalPlotColumnOutput | ||
| from arkas.state.temporal_dataframe import TemporalDataFrameState | ||
| if TYPE_CHECKING: | ||
| from collections.abc import Sequence | ||
| import polars as pl | ||
| from arkas.figure import BaseFigureConfig | ||
| logger = logging.getLogger(__name__) | ||
| class TemporalPlotColumnAnalyzer(BaseInNLazyAnalyzer): | ||
| r"""Implement an analyzer that plots the content of each column. | ||
| Args: | ||
| temporal_column: The temporal column in the DataFrame. | ||
| period: An optional temporal period e.g. monthly or daily. | ||
| columns: The columns to analyze. If ``None``, it analyzes all | ||
| the columns. | ||
| exclude_columns: The columns to exclude from the input | ||
| ``columns``. If any column is not found, it will be ignored | ||
| during the filtering process. | ||
| missing_policy: The policy on how to handle missing columns. | ||
| The following options are available: ``'ignore'``, | ||
| ``'warn'``, and ``'raise'``. If ``'raise'``, an exception | ||
| is raised if at least one column is missing. | ||
| If ``'warn'``, a warning is raised if at least one column | ||
| is missing and the missing columns are ignored. | ||
| If ``'ignore'``, the missing columns are ignored and | ||
| no warning message appears. | ||
| figure_config: The figure configuration. | ||
| Example usage: | ||
| ```pycon | ||
| >>> from datetime import datetime, timezone | ||
| >>> import polars as pl | ||
| >>> from arkas.analyzer import TemporalPlotColumnAnalyzer | ||
| >>> analyzer = TemporalPlotColumnAnalyzer(temporal_column="datetime") | ||
| >>> analyzer | ||
| TemporalPlotColumnAnalyzer(columns=None, exclude_columns=(), missing_policy='raise', temporal_column='datetime', period=None, figure_config=None) | ||
| >>> frame = pl.DataFrame( | ||
| ... { | ||
| ... "col1": [0, 1, 1, 0], | ||
| ... "col2": [0, 1, 0, 1], | ||
| ... "col3": [1, 0, 0, 0], | ||
| ... "datetime": [ | ||
| ... datetime(year=2020, month=1, day=3, tzinfo=timezone.utc), | ||
| ... datetime(year=2020, month=2, day=3, tzinfo=timezone.utc), | ||
| ... datetime(year=2020, month=3, day=3, tzinfo=timezone.utc), | ||
| ... datetime(year=2020, month=4, day=3, tzinfo=timezone.utc), | ||
| ... ], | ||
| ... }, | ||
| ... schema={ | ||
| ... "col1": pl.Int64, | ||
| ... "col2": pl.Int64, | ||
| ... "col3": pl.Int64, | ||
| ... "datetime": pl.Datetime(time_unit="us", time_zone="UTC"), | ||
| ... }, | ||
| ... ) | ||
| >>> output = analyzer.analyze(frame) | ||
| >>> output | ||
| TemporalPlotColumnOutput( | ||
| (state): TemporalDataFrameState(dataframe=(4, 4), temporal_column='datetime', period=None, figure_config=MatplotlibFigureConfig()) | ||
| ) | ||
| ``` | ||
| """ | ||
| def __init__( | ||
| self, | ||
| temporal_column: str, | ||
| period: str | None = None, | ||
| columns: Sequence[str] | None = None, | ||
| exclude_columns: Sequence[str] = (), | ||
| missing_policy: str = "raise", | ||
| figure_config: BaseFigureConfig | None = None, | ||
| ) -> None: | ||
| super().__init__( | ||
| columns=columns, | ||
| exclude_columns=exclude_columns, | ||
| missing_policy=missing_policy, | ||
| ) | ||
| self._temporal_column = temporal_column | ||
| self._period = period | ||
| self._figure_config = figure_config | ||
| def get_args(self) -> dict: | ||
| return super().get_args() | { | ||
| "temporal_column": self._temporal_column, | ||
| "period": self._period, | ||
| "figure_config": self._figure_config, | ||
| } | ||
| def _analyze(self, frame: pl.DataFrame) -> TemporalPlotColumnOutput: | ||
| logger.info( | ||
| f"Plotting the content of {len(self.find_columns(frame)):,} columns " | ||
| f"using the temporal column {self._temporal_column!r} and period {self._period!r}..." | ||
| ) | ||
| columns = list(self.find_common_columns(frame)) | ||
| if self._temporal_column not in columns: | ||
| columns.append(self._temporal_column) | ||
| dataframe = frame.select(columns) | ||
| logger.info(str_shape_diff(orig=frame.shape, final=dataframe.shape)) | ||
| return TemporalPlotColumnOutput( | ||
| state=TemporalDataFrameState( | ||
| dataframe=dataframe, | ||
| temporal_column=self._temporal_column, | ||
| period=self._period, | ||
| figure_config=self._figure_config, | ||
| ) | ||
| ) |
| r"""Contain the implementation of a HTML content generator that plots | ||
| the content of each column.""" | ||
| from __future__ import annotations | ||
| __all__ = ["PlotColumnContentGenerator", "create_template"] | ||
| import logging | ||
| from typing import TYPE_CHECKING, Any | ||
| from coola.utils import repr_indent, repr_mapping, str_indent, str_mapping | ||
| from jinja2 import Template | ||
| from arkas.content.section import BaseSectionContentGenerator | ||
| from arkas.figure.utils import figure2html | ||
| from arkas.plotter.plot_column import PlotColumnPlotter | ||
| if TYPE_CHECKING: | ||
| from arkas.state.dataframe import DataFrameState | ||
| logger = logging.getLogger(__name__) | ||
| class PlotColumnContentGenerator(BaseSectionContentGenerator): | ||
| r"""Implement a content generator that plots the content of each | ||
| column. | ||
| Args: | ||
| state: The state containing the DataFrame to analyze. | ||
| Example usage: | ||
| ```pycon | ||
| >>> import polars as pl | ||
| >>> from arkas.content import PlotColumnContentGenerator | ||
| >>> from arkas.state import DataFrameState | ||
| >>> dataframe = pl.DataFrame( | ||
| ... { | ||
| ... "col1": [0, 1, 1, 0, 0, 1, 0], | ||
| ... "col2": [0, 1, 0, 1, 0, 1, 0], | ||
| ... "col3": [0, 0, 0, 0, 1, 1, 1], | ||
| ... } | ||
| ... ) | ||
| >>> content = PlotColumnContentGenerator(DataFrameState(dataframe)) | ||
| >>> content | ||
| PlotColumnContentGenerator( | ||
| (state): DataFrameState(dataframe=(7, 3), figure_config=MatplotlibFigureConfig()) | ||
| ) | ||
| ``` | ||
| """ | ||
| def __init__(self, state: DataFrameState) -> None: | ||
| self._state = state | ||
| def __repr__(self) -> str: | ||
| args = repr_indent(repr_mapping({"state": self._state})) | ||
| return f"{self.__class__.__qualname__}(\n {args}\n)" | ||
| def __str__(self) -> str: | ||
| args = str_indent(str_mapping({"state": self._state})) | ||
| return f"{self.__class__.__qualname__}(\n {args}\n)" | ||
| def equal(self, other: Any, equal_nan: bool = False) -> bool: | ||
| if not isinstance(other, self.__class__): | ||
| return False | ||
| return self._state.equal(other._state, equal_nan=equal_nan) | ||
| def generate_content(self) -> str: | ||
| nrows, ncols = self._state.dataframe.shape | ||
| logger.info(f"Generating the plot of {ncols:,} columns...") | ||
| figures = PlotColumnPlotter(state=self._state).plot() | ||
| return Template(create_template()).render( | ||
| { | ||
| "nrows": f"{nrows:,}", | ||
| "ncols": f"{ncols:,}", | ||
| "columns": ", ".join(self._state.dataframe.columns), | ||
| "figure": figure2html(figures["plot_column"], close_fig=True), | ||
| } | ||
| ) | ||
| def create_template() -> str: | ||
| r"""Return the template of the content. | ||
| Returns: | ||
| The content template. | ||
| Example usage: | ||
| ```pycon | ||
| >>> from arkas.content.plot_column import create_template | ||
| >>> template = create_template() | ||
| ``` | ||
| """ | ||
| return """This section plots the content of some columns. | ||
| The x-axis is the row index and the y-axis shows the value. | ||
| <ul> | ||
| <li> {{ncols}} columns: {{columns}} </li> | ||
| <li> number of rows: {{nrows}}</li> | ||
| </ul> | ||
| {{figure}} | ||
| """ |
| r"""Contain the implementation of a HTML content generator that plots | ||
| the content of each column.""" | ||
| from __future__ import annotations | ||
| __all__ = ["ScatterColumnContentGenerator", "create_template"] | ||
| import logging | ||
| from typing import TYPE_CHECKING, Any | ||
| from coola.utils import repr_indent, repr_mapping, str_indent, str_mapping | ||
| from jinja2 import Template | ||
| from arkas.content.section import BaseSectionContentGenerator | ||
| from arkas.figure.utils import figure2html | ||
| from arkas.plotter.scatter_column import ScatterColumnPlotter | ||
| if TYPE_CHECKING: | ||
| from arkas.state.scatter_dataframe import ScatterDataFrameState | ||
| logger = logging.getLogger(__name__) | ||
| class ScatterColumnContentGenerator(BaseSectionContentGenerator): | ||
| r"""Implement a content generator that plots the content of each | ||
| column. | ||
| Args: | ||
| state: The state containing the DataFrame to analyze. | ||
| Example usage: | ||
| ```pycon | ||
| >>> import polars as pl | ||
| >>> from arkas.content import ScatterColumnContentGenerator | ||
| >>> from arkas.state import ScatterDataFrameState | ||
| >>> dataframe = pl.DataFrame( | ||
| ... { | ||
| ... "col1": [0, 1, 1, 0, 0, 1, 0], | ||
| ... "col2": [0, 1, 0, 1, 0, 1, 0], | ||
| ... "col3": [0, 0, 0, 0, 1, 1, 1], | ||
| ... } | ||
| ... ) | ||
| >>> content = ScatterColumnContentGenerator( | ||
| ... ScatterDataFrameState(dataframe, x="col1", y="col2") | ||
| ... ) | ||
| >>> content | ||
| ScatterColumnContentGenerator( | ||
| (state): ScatterDataFrameState(dataframe=(7, 3), x='col1', y='col2', color=None, figure_config=MatplotlibFigureConfig()) | ||
| ) | ||
| ``` | ||
| """ | ||
| def __init__(self, state: ScatterDataFrameState) -> None: | ||
| self._state = state | ||
| def __repr__(self) -> str: | ||
| args = repr_indent(repr_mapping({"state": self._state})) | ||
| return f"{self.__class__.__qualname__}(\n {args}\n)" | ||
| def __str__(self) -> str: | ||
| args = str_indent(str_mapping({"state": self._state})) | ||
| return f"{self.__class__.__qualname__}(\n {args}\n)" | ||
| def equal(self, other: Any, equal_nan: bool = False) -> bool: | ||
| if not isinstance(other, self.__class__): | ||
| return False | ||
| return self._state.equal(other._state, equal_nan=equal_nan) | ||
| def generate_content(self) -> str: | ||
| nrows, ncols = self._state.dataframe.shape | ||
| logger.info(f"Generating the plot of {ncols:,} columns...") | ||
| figures = ScatterColumnPlotter(state=self._state).plot() | ||
| return Template(create_template()).render( | ||
| { | ||
| "color": self._state.color, | ||
| "figure": figure2html(figures["scatter_column"], close_fig=True), | ||
| "n_samples": f"{nrows:,}", | ||
| "x": self._state.x, | ||
| "y": self._state.y, | ||
| } | ||
| ) | ||
| def create_template() -> str: | ||
| r"""Return the template of the content. | ||
| Returns: | ||
| The content template. | ||
| Example usage: | ||
| ```pycon | ||
| >>> from arkas.content.plot_column import create_template | ||
| >>> template = create_template() | ||
| ``` | ||
| """ | ||
| return """This section plots a scatter plot for the following columns. | ||
| <ul> | ||
| <li> x: {{x}} </li> | ||
| <li> y: {{y}} </li> | ||
| <li> color: {{color}} </li> | ||
| <li> number of samples: {{n_samples}} </li> | ||
| </ul> | ||
| {{figure}} | ||
| """ |
| r"""Contain the implementation of a HTML content generator that plots | ||
| the content of each column.""" | ||
| from __future__ import annotations | ||
| __all__ = ["TemporalPlotColumnContentGenerator", "create_template"] | ||
| import logging | ||
| from typing import TYPE_CHECKING, Any | ||
| from coola.utils import repr_indent, repr_mapping, str_indent, str_mapping | ||
| from jinja2 import Template | ||
| from arkas.content.section import BaseSectionContentGenerator | ||
| from arkas.figure.utils import figure2html | ||
| from arkas.plotter.temporal_plot_column import TemporalPlotColumnPlotter | ||
| if TYPE_CHECKING: | ||
| from arkas.state.temporal_dataframe import TemporalDataFrameState | ||
| logger = logging.getLogger(__name__) | ||
| class TemporalPlotColumnContentGenerator(BaseSectionContentGenerator): | ||
| r"""Implement a content generator that plots the content of each | ||
| column. | ||
| Args: | ||
| state: The state containing the DataFrame to analyze. | ||
| Example usage: | ||
| ```pycon | ||
| >>> from datetime import datetime, timezone | ||
| >>> import polars as pl | ||
| >>> from arkas.content import TemporalPlotColumnContentGenerator | ||
| >>> from arkas.state import TemporalDataFrameState | ||
| >>> dataframe = pl.DataFrame( | ||
| ... { | ||
| ... "col1": [0, 1, 1, 0], | ||
| ... "col2": [0, 1, 0, 1], | ||
| ... "col3": [1, 0, 0, 0], | ||
| ... "datetime": [ | ||
| ... datetime(year=2020, month=1, day=3, tzinfo=timezone.utc), | ||
| ... datetime(year=2020, month=2, day=3, tzinfo=timezone.utc), | ||
| ... datetime(year=2020, month=3, day=3, tzinfo=timezone.utc), | ||
| ... datetime(year=2020, month=4, day=3, tzinfo=timezone.utc), | ||
| ... ], | ||
| ... }, | ||
| ... schema={ | ||
| ... "col1": pl.Int64, | ||
| ... "col2": pl.Int64, | ||
| ... "col3": pl.Int64, | ||
| ... "datetime": pl.Datetime(time_unit="us", time_zone="UTC"), | ||
| ... }, | ||
| ... ) | ||
| >>> content = TemporalPlotColumnContentGenerator( | ||
| ... TemporalDataFrameState(dataframe, temporal_column="datetime") | ||
| ... ) | ||
| >>> content | ||
| TemporalPlotColumnContentGenerator( | ||
| (state): TemporalDataFrameState(dataframe=(4, 4), temporal_column='datetime', period=None, figure_config=MatplotlibFigureConfig()) | ||
| ) | ||
| ``` | ||
| """ | ||
| def __init__(self, state: TemporalDataFrameState) -> None: | ||
| self._state = state | ||
| def __repr__(self) -> str: | ||
| args = repr_indent(repr_mapping({"state": self._state})) | ||
| return f"{self.__class__.__qualname__}(\n {args}\n)" | ||
| def __str__(self) -> str: | ||
| args = str_indent(str_mapping({"state": self._state})) | ||
| return f"{self.__class__.__qualname__}(\n {args}\n)" | ||
| def equal(self, other: Any, equal_nan: bool = False) -> bool: | ||
| if not isinstance(other, self.__class__): | ||
| return False | ||
| return self._state.equal(other._state, equal_nan=equal_nan) | ||
| def generate_content(self) -> str: | ||
| nrows, ncols = self._state.dataframe.shape | ||
| logger.info( | ||
| f"Generating the temporal plot of {ncols} columns using the " | ||
| f"temporal column {self._state.temporal_column!r}..." | ||
| ) | ||
| figures = TemporalPlotColumnPlotter(state=self._state).plot() | ||
| return Template(create_template()).render( | ||
| { | ||
| "nrows": f"{nrows:,}", | ||
| "ncols": f"{ncols:,}", | ||
| "columns": ", ".join(self._state.dataframe.columns), | ||
| "figure": figure2html(figures["temporal_plot_column"], close_fig=True), | ||
| } | ||
| ) | ||
| def create_template() -> str: | ||
| r"""Return the template of the content. | ||
| Returns: | ||
| The content template. | ||
| Example usage: | ||
| ```pycon | ||
| >>> from arkas.content.temporal_plot_column import create_template | ||
| >>> template = create_template() | ||
| ``` | ||
| """ | ||
| return """This section plots the content of some columns. | ||
| The x-axis is the row index and the y-axis shows the value. | ||
| <ul> | ||
| <li> {{ncols}} columns: {{columns}} </li> | ||
| <li> number of rows: {{nrows}}</li> | ||
| </ul> | ||
| {{figure}} | ||
| """ |
| r"""Implement the pairwise column co-occurrence evaluator.""" | ||
| from __future__ import annotations | ||
| __all__ = ["ColumnCooccurrenceEvaluator"] | ||
| from typing import TYPE_CHECKING, Any | ||
| from coola.utils import repr_indent, repr_mapping, str_indent, str_mapping | ||
| from arkas.evaluator2.base import BaseEvaluator | ||
| from arkas.evaluator2.vanilla import Evaluator | ||
| if TYPE_CHECKING: | ||
| import numpy as np | ||
| from arkas.state.column_cooccurrence import ColumnCooccurrenceState | ||
| class ColumnCooccurrenceEvaluator(BaseEvaluator): | ||
| r"""Implement the pairwise column co-occurrence evaluator. | ||
| Args: | ||
| state: The state with the co-occurrence matrix. | ||
| Example usage: | ||
| ```pycon | ||
| >>> import numpy as np | ||
| >>> from arkas.evaluator2 import ColumnCooccurrenceEvaluator | ||
| >>> from arkas.state import ColumnCooccurrenceState | ||
| >>> evaluator = ColumnCooccurrenceEvaluator( | ||
| ... ColumnCooccurrenceState(matrix=np.ones((3, 3)), columns=["a", "b", "c"]) | ||
| ... ) | ||
| >>> evaluator | ||
| ColumnCooccurrenceEvaluator( | ||
| (state): ColumnCooccurrenceState(matrix=(3, 3), figure_config=MatplotlibFigureConfig()) | ||
| ) | ||
| >>> evaluator.evaluate() | ||
| {'column_cooccurrence': array([[1., 1., 1.], | ||
| [1., 1., 1.], | ||
| [1., 1., 1.]])} | ||
| ``` | ||
| """ | ||
| def __init__(self, state: ColumnCooccurrenceState) -> None: | ||
| self._state = state | ||
| def __repr__(self) -> str: | ||
| args = repr_indent(repr_mapping({"state": self._state})) | ||
| return f"{self.__class__.__qualname__}(\n {args}\n)" | ||
| def __str__(self) -> str: | ||
| args = str_indent(str_mapping({"state": self._state})) | ||
| return f"{self.__class__.__qualname__}(\n {args}\n)" | ||
| def compute(self) -> Evaluator: | ||
| return Evaluator(metrics=self.evaluate()) | ||
| def equal(self, other: Any, equal_nan: bool = False) -> bool: | ||
| if not isinstance(other, self.__class__): | ||
| return False | ||
| return self._state.equal(other._state, equal_nan=equal_nan) | ||
| def evaluate(self, prefix: str = "", suffix: str = "") -> dict[str, np.ndarray]: | ||
| return {f"{prefix}column_cooccurrence{suffix}": self._state.matrix} |
| r"""Implement an output to plot each column of a DataFrame.""" | ||
| from __future__ import annotations | ||
| __all__ = ["PlotColumnOutput"] | ||
| from typing import TYPE_CHECKING, Any | ||
| from coola.utils import repr_indent, repr_mapping, str_indent, str_mapping | ||
| from arkas.content.plot_column import PlotColumnContentGenerator | ||
| from arkas.evaluator2.vanilla import Evaluator | ||
| from arkas.output.lazy import BaseLazyOutput | ||
| from arkas.plotter.plot_column import PlotColumnPlotter | ||
| if TYPE_CHECKING: | ||
| from arkas.state.dataframe import DataFrameState | ||
| class PlotColumnOutput(BaseLazyOutput): | ||
| r"""Implement an output to plot each column of a DataFrame. | ||
| Args: | ||
| state: The state containing the DataFrame to analyze. | ||
| Example usage: | ||
| ```pycon | ||
| >>> import polars as pl | ||
| >>> from arkas.output import PlotColumnOutput | ||
| >>> from arkas.state import DataFrameState | ||
| >>> frame = pl.DataFrame( | ||
| ... { | ||
| ... "col1": [1.2, 4.2, 4.2, 2.2], | ||
| ... "col2": [1, 1, 1, 1], | ||
| ... "col3": [1, 2, 2, 2], | ||
| ... }, | ||
| ... schema={"col1": pl.Float64, "col2": pl.Int64, "col3": pl.Int64}, | ||
| ... ) | ||
| >>> output = PlotColumnOutput(DataFrameState(frame)) | ||
| >>> output | ||
| PlotColumnOutput( | ||
| (state): DataFrameState(dataframe=(4, 3), figure_config=MatplotlibFigureConfig()) | ||
| ) | ||
| >>> output.get_content_generator() | ||
| PlotColumnContentGenerator( | ||
| (state): DataFrameState(dataframe=(4, 3), figure_config=MatplotlibFigureConfig()) | ||
| ) | ||
| >>> output.get_evaluator() | ||
| Evaluator(count=0) | ||
| >>> output.get_plotter() | ||
| PlotColumnPlotter( | ||
| (state): DataFrameState(dataframe=(4, 3), figure_config=MatplotlibFigureConfig()) | ||
| ) | ||
| ``` | ||
| """ | ||
| def __init__(self, state: DataFrameState) -> None: | ||
| self._state = state | ||
| def __repr__(self) -> str: | ||
| args = repr_indent(repr_mapping({"state": self._state})) | ||
| return f"{self.__class__.__qualname__}(\n {args}\n)" | ||
| def __str__(self) -> str: | ||
| args = str_indent(str_mapping({"state": self._state})) | ||
| return f"{self.__class__.__qualname__}(\n {args}\n)" | ||
| def equal(self, other: Any, equal_nan: bool = False) -> bool: | ||
| if not isinstance(other, self.__class__): | ||
| return False | ||
| return self._state.equal(other._state, equal_nan=equal_nan) | ||
| def _get_content_generator(self) -> PlotColumnContentGenerator: | ||
| return PlotColumnContentGenerator(self._state) | ||
| def _get_evaluator(self) -> Evaluator: | ||
| return Evaluator() | ||
| def _get_plotter(self) -> PlotColumnPlotter: | ||
| return PlotColumnPlotter(self._state) |
| r"""Implement an output to scatter plot some columns.""" | ||
| from __future__ import annotations | ||
| __all__ = ["ScatterColumnOutput"] | ||
| from typing import TYPE_CHECKING, Any | ||
| from coola.utils import repr_indent, repr_mapping, str_indent, str_mapping | ||
| from arkas.content.scatter_column import ScatterColumnContentGenerator | ||
| from arkas.evaluator2.vanilla import Evaluator | ||
| from arkas.output.lazy import BaseLazyOutput | ||
| from arkas.plotter.scatter_column import ScatterColumnPlotter | ||
| if TYPE_CHECKING: | ||
| from arkas.state.temporal_dataframe import ScatterDataFrameState | ||
| class ScatterColumnOutput(BaseLazyOutput): | ||
| r"""Implement an output to scatter plot some columns. | ||
| Args: | ||
| state: The state containing the DataFrame to analyze. | ||
| Example usage: | ||
| ```pycon | ||
| >>> import polars as pl | ||
| >>> from arkas.output import ScatterColumnOutput | ||
| >>> from arkas.state import ScatterDataFrameState | ||
| >>> frame = pl.DataFrame( | ||
| ... { | ||
| ... "col1": [0, 1, 1, 0], | ||
| ... "col2": [0, 1, 0, 1], | ||
| ... "col3": [1, 0, 0, 0], | ||
| ... }, | ||
| ... schema={"col1": pl.Int64, "col2": pl.Int64, "col3": pl.Int64}, | ||
| ... ) | ||
| >>> output = ScatterColumnOutput(ScatterDataFrameState(frame, x="col1", y="col2")) | ||
| >>> output | ||
| ScatterColumnOutput( | ||
| (state): ScatterDataFrameState(dataframe=(4, 3), x='col1', y='col2', color=None, figure_config=MatplotlibFigureConfig()) | ||
| ) | ||
| >>> output.get_content_generator() | ||
| ScatterColumnContentGenerator( | ||
| (state): ScatterDataFrameState(dataframe=(4, 3), x='col1', y='col2', color=None, figure_config=MatplotlibFigureConfig()) | ||
| ) | ||
| >>> output.get_evaluator() | ||
| Evaluator(count=0) | ||
| >>> output.get_plotter() | ||
| ScatterColumnPlotter( | ||
| (state): ScatterDataFrameState(dataframe=(4, 3), x='col1', y='col2', color=None, figure_config=MatplotlibFigureConfig()) | ||
| ) | ||
| ``` | ||
| """ | ||
| def __init__(self, state: ScatterDataFrameState) -> None: | ||
| self._state = state | ||
| def __repr__(self) -> str: | ||
| args = repr_indent(repr_mapping({"state": self._state})) | ||
| return f"{self.__class__.__qualname__}(\n {args}\n)" | ||
| def __str__(self) -> str: | ||
| args = str_indent(str_mapping({"state": self._state})) | ||
| return f"{self.__class__.__qualname__}(\n {args}\n)" | ||
| def equal(self, other: Any, equal_nan: bool = False) -> bool: | ||
| if not isinstance(other, self.__class__): | ||
| return False | ||
| return self._state.equal(other._state, equal_nan=equal_nan) | ||
| def _get_content_generator(self) -> ScatterColumnContentGenerator: | ||
| return ScatterColumnContentGenerator(self._state) | ||
| def _get_evaluator(self) -> Evaluator: | ||
| return Evaluator() | ||
| def _get_plotter(self) -> ScatterColumnPlotter: | ||
| return ScatterColumnPlotter(self._state) |
| r"""Implement an output to plot each column of a DataFrame along a | ||
| temporal dimension.""" | ||
| from __future__ import annotations | ||
| __all__ = ["TemporalPlotColumnOutput"] | ||
| from typing import TYPE_CHECKING, Any | ||
| from coola.utils import repr_indent, repr_mapping, str_indent, str_mapping | ||
| from arkas.content.temporal_plot_column import TemporalPlotColumnContentGenerator | ||
| from arkas.evaluator2.vanilla import Evaluator | ||
| from arkas.output.lazy import BaseLazyOutput | ||
| from arkas.plotter.temporal_plot_column import TemporalPlotColumnPlotter | ||
| if TYPE_CHECKING: | ||
| from arkas.state.temporal_dataframe import TemporalDataFrameState | ||
| class TemporalPlotColumnOutput(BaseLazyOutput): | ||
| r"""Implement an output to plot each column of a DataFrame along a | ||
| temporal dimension. | ||
| Args: | ||
| state: The state containing the DataFrame to analyze. | ||
| Example usage: | ||
| ```pycon | ||
| >>> from datetime import datetime, timezone | ||
| >>> import polars as pl | ||
| >>> from arkas.output import TemporalPlotColumnOutput | ||
| >>> from arkas.state import TemporalDataFrameState | ||
| >>> frame = pl.DataFrame( | ||
| ... { | ||
| ... "col1": [0, 1, 1, 0], | ||
| ... "col2": [0, 1, 0, 1], | ||
| ... "col3": [1, 0, 0, 0], | ||
| ... "datetime": [ | ||
| ... datetime(year=2020, month=1, day=3, tzinfo=timezone.utc), | ||
| ... datetime(year=2020, month=2, day=3, tzinfo=timezone.utc), | ||
| ... datetime(year=2020, month=3, day=3, tzinfo=timezone.utc), | ||
| ... datetime(year=2020, month=4, day=3, tzinfo=timezone.utc), | ||
| ... ], | ||
| ... }, | ||
| ... schema={ | ||
| ... "col1": pl.Int64, | ||
| ... "col2": pl.Int64, | ||
| ... "col3": pl.Int64, | ||
| ... "datetime": pl.Datetime(time_unit="us", time_zone="UTC"), | ||
| ... }, | ||
| ... ) | ||
| >>> output = TemporalPlotColumnOutput( | ||
| ... TemporalDataFrameState(frame, temporal_column="datetime") | ||
| ... ) | ||
| >>> output | ||
| TemporalPlotColumnOutput( | ||
| (state): TemporalDataFrameState(dataframe=(4, 4), temporal_column='datetime', period=None, figure_config=MatplotlibFigureConfig()) | ||
| ) | ||
| >>> output.get_content_generator() | ||
| TemporalPlotColumnContentGenerator( | ||
| (state): TemporalDataFrameState(dataframe=(4, 4), temporal_column='datetime', period=None, figure_config=MatplotlibFigureConfig()) | ||
| ) | ||
| >>> output.get_evaluator() | ||
| Evaluator(count=0) | ||
| >>> output.get_plotter() | ||
| TemporalPlotColumnPlotter( | ||
| (state): TemporalDataFrameState(dataframe=(4, 4), temporal_column='datetime', period=None, figure_config=MatplotlibFigureConfig()) | ||
| ) | ||
| ``` | ||
| """ | ||
| def __init__(self, state: TemporalDataFrameState) -> None: | ||
| self._state = state | ||
| def __repr__(self) -> str: | ||
| args = repr_indent(repr_mapping({"state": self._state})) | ||
| return f"{self.__class__.__qualname__}(\n {args}\n)" | ||
| def __str__(self) -> str: | ||
| args = str_indent(str_mapping({"state": self._state})) | ||
| return f"{self.__class__.__qualname__}(\n {args}\n)" | ||
| def equal(self, other: Any, equal_nan: bool = False) -> bool: | ||
| if not isinstance(other, self.__class__): | ||
| return False | ||
| return self._state.equal(other._state, equal_nan=equal_nan) | ||
| def _get_content_generator(self) -> TemporalPlotColumnContentGenerator: | ||
| return TemporalPlotColumnContentGenerator(self._state) | ||
| def _get_evaluator(self) -> Evaluator: | ||
| return Evaluator() | ||
| def _get_plotter(self) -> TemporalPlotColumnPlotter: | ||
| return TemporalPlotColumnPlotter(self._state) |
| r"""Contain the implementation of a DataFrame column plotter.""" | ||
| from __future__ import annotations | ||
| __all__ = ["BaseFigureCreator", "MatplotlibFigureCreator", "PlotColumnPlotter"] | ||
| from abc import ABC, abstractmethod | ||
| from typing import TYPE_CHECKING, Any | ||
| import matplotlib.pyplot as plt | ||
| from coola.utils import repr_indent, repr_mapping, str_indent, str_mapping | ||
| from arkas.figure.creator import FigureCreatorRegistry | ||
| from arkas.figure.html import HtmlFigure | ||
| from arkas.figure.matplotlib import MatplotlibFigure, MatplotlibFigureConfig | ||
| from arkas.figure.utils import MISSING_FIGURE_MESSAGE | ||
| from arkas.plotter.base import BasePlotter | ||
| from arkas.plotter.vanilla import Plotter | ||
| if TYPE_CHECKING: | ||
| from arkas.figure.base import BaseFigure | ||
| from arkas.state.dataframe import DataFrameState | ||
| class BaseFigureCreator(ABC): | ||
| r"""Define the base class to create a figure with the content of | ||
| each column.""" | ||
| @abstractmethod | ||
| def create(self, state: DataFrameState) -> BaseFigure: | ||
| r"""Create a figure with the content of each column. | ||
| Args: | ||
| state: The state containing the DataFrame to analyze. | ||
| Returns: | ||
| The generated figure. | ||
| Example usage: | ||
| ```pycon | ||
| >>> import polars as pl | ||
| >>> from arkas.figure import MatplotlibFigureConfig | ||
| >>> from arkas.state import DataFrameState | ||
| >>> creator = MatplotlibFigureCreator() | ||
| >>> frame = pl.DataFrame( | ||
| ... { | ||
| ... "col1": [1.2, 4.2, 4.2, 2.2], | ||
| ... "col2": [1, 1, 1, 1], | ||
| ... "col3": [1, 2, 2, 2], | ||
| ... }, | ||
| ... schema={"col1": pl.Float64, "col2": pl.Int64, "col3": pl.Int64}, | ||
| ... ) | ||
| >>> fig = creator.create(DataFrameState(frame)) | ||
| ``` | ||
| """ | ||
| class MatplotlibFigureCreator(BaseFigureCreator): | ||
| r"""Create a matplotlib figure with the content of each column. | ||
| Example usage: | ||
| ```pycon | ||
| >>> import polars as pl | ||
| >>> from arkas.figure import MatplotlibFigureConfig | ||
| >>> from arkas.state import DataFrameState | ||
| >>> creator = MatplotlibFigureCreator() | ||
| >>> frame = pl.DataFrame( | ||
| ... { | ||
| ... "col1": [1.2, 4.2, 4.2, 2.2], | ||
| ... "col2": [1, 1, 1, 1], | ||
| ... "col3": [1, 2, 2, 2], | ||
| ... }, | ||
| ... schema={"col1": pl.Float64, "col2": pl.Int64, "col3": pl.Int64}, | ||
| ... ) | ||
| >>> fig = creator.create(DataFrameState(frame)) | ||
| ``` | ||
| """ | ||
| def __repr__(self) -> str: | ||
| return f"{self.__class__.__qualname__}()" | ||
| def create(self, state: DataFrameState) -> BaseFigure: | ||
| if state.dataframe.shape[0] == 0: | ||
| return HtmlFigure(MISSING_FIGURE_MESSAGE) | ||
| fig, ax = plt.subplots(**state.figure_config.get_arg("init", {})) | ||
| for col in state.dataframe: | ||
| ax.plot(col.to_numpy(), label=col.name) | ||
| if yscale := state.figure_config.get_arg("yscale"): | ||
| ax.set_yscale(yscale) | ||
| ax.legend() | ||
| fig.tight_layout() | ||
| return MatplotlibFigure(fig) | ||
| class PlotColumnPlotter(BasePlotter): | ||
| r"""Implement a DataFrame column plotter. | ||
| Args: | ||
| state: The state containing the DataFrame to analyze. | ||
| Example usage: | ||
| ```pycon | ||
| >>> import polars as pl | ||
| >>> from arkas.plotter import PlotColumnPlotter | ||
| >>> from arkas.state import DataFrameState | ||
| >>> frame = pl.DataFrame( | ||
| ... { | ||
| ... "col1": [1.2, 4.2, 4.2, 2.2], | ||
| ... "col2": [1, 1, 1, 1], | ||
| ... "col3": [1, 2, 2, 2], | ||
| ... }, | ||
| ... schema={"col1": pl.Float64, "col2": pl.Int64, "col3": pl.Int64}, | ||
| ... ) | ||
| >>> plotter = PlotColumnPlotter(DataFrameState(frame)) | ||
| >>> plotter | ||
| PlotColumnPlotter( | ||
| (state): DataFrameState(dataframe=(4, 3), figure_config=MatplotlibFigureConfig()) | ||
| ) | ||
| ``` | ||
| """ | ||
| registry = FigureCreatorRegistry[BaseFigureCreator]( | ||
| {MatplotlibFigureConfig.backend(): MatplotlibFigureCreator()} | ||
| ) | ||
| def __init__(self, state: DataFrameState) -> None: | ||
| self._state = state | ||
| def __repr__(self) -> str: | ||
| args = repr_indent(repr_mapping({"state": self._state})) | ||
| return f"{self.__class__.__qualname__}(\n {args}\n)" | ||
| def __str__(self) -> str: | ||
| args = str_indent(str_mapping({"state": self._state})) | ||
| return f"{self.__class__.__qualname__}(\n {args}\n)" | ||
| def compute(self) -> Plotter: | ||
| return Plotter(self.plot()) | ||
| def equal(self, other: Any, equal_nan: bool = False) -> bool: | ||
| if not isinstance(other, self.__class__): | ||
| return False | ||
| return self._state.equal(other._state, equal_nan=equal_nan) | ||
| def plot(self, prefix: str = "", suffix: str = "") -> dict: | ||
| figure = self.registry.find_creator(self._state.figure_config.backend()).create(self._state) | ||
| return {f"{prefix}plot_column{suffix}": figure} |
| r"""Contain the implementation of a DataFrame column plotter.""" | ||
| from __future__ import annotations | ||
| __all__ = ["BaseFigureCreator", "MatplotlibFigureCreator", "ScatterColumnPlotter"] | ||
| from abc import ABC, abstractmethod | ||
| from typing import TYPE_CHECKING, Any | ||
| import matplotlib.pyplot as plt | ||
| from coola.utils import repr_indent, repr_mapping, str_indent, str_mapping | ||
| from arkas.figure.creator import FigureCreatorRegistry | ||
| from arkas.figure.html import HtmlFigure | ||
| from arkas.figure.matplotlib import MatplotlibFigure, MatplotlibFigureConfig | ||
| from arkas.figure.utils import MISSING_FIGURE_MESSAGE | ||
| from arkas.plotter.base import BasePlotter | ||
| from arkas.plotter.vanilla import Plotter | ||
| if TYPE_CHECKING: | ||
| from arkas.figure.base import BaseFigure | ||
| from arkas.state.scatter_dataframe import ScatterDataFrameState | ||
| class BaseFigureCreator(ABC): | ||
| r"""Define the base class to create a figure with the content of | ||
| each column.""" | ||
| @abstractmethod | ||
| def create(self, state: ScatterDataFrameState) -> BaseFigure: | ||
| r"""Create a figure with the content of each column. | ||
| Args: | ||
| state: The state containing the DataFrame to analyze. | ||
| Returns: | ||
| The generated figure. | ||
| Example usage: | ||
| ```pycon | ||
| >>> import polars as pl | ||
| >>> from arkas.figure import MatplotlibFigureConfig | ||
| >>> from arkas.state import ScatterDataFrameState | ||
| >>> creator = MatplotlibFigureCreator() | ||
| >>> frame = pl.DataFrame( | ||
| ... { | ||
| ... "col1": [1.2, 4.2, 4.2, 2.2], | ||
| ... "col2": [1, 1, 1, 1], | ||
| ... "col3": [1, 2, 2, 2], | ||
| ... }, | ||
| ... schema={"col1": pl.Float64, "col2": pl.Int64, "col3": pl.Int64}, | ||
| ... ) | ||
| >>> fig = creator.create(ScatterDataFrameState(frame, x="col1", y="col2", color="col3")) | ||
| ``` | ||
| """ | ||
| class MatplotlibFigureCreator(BaseFigureCreator): | ||
| r"""Create a matplotlib figure with the content of each column. | ||
| Example usage: | ||
| ```pycon | ||
| >>> import polars as pl | ||
| >>> from arkas.figure import MatplotlibFigureConfig | ||
| >>> from arkas.state import ScatterDataFrameState | ||
| >>> creator = MatplotlibFigureCreator() | ||
| >>> frame = pl.DataFrame( | ||
| ... { | ||
| ... "col1": [1.2, 4.2, 4.2, 2.2], | ||
| ... "col2": [1, 1, 1, 1], | ||
| ... "col3": [1, 2, 2, 2], | ||
| ... }, | ||
| ... schema={"col1": pl.Float64, "col2": pl.Int64, "col3": pl.Int64}, | ||
| ... ) | ||
| >>> fig = creator.create(ScatterDataFrameState(frame, x="col1", y="col2", color="col3")) | ||
| ``` | ||
| """ | ||
| def __repr__(self) -> str: | ||
| return f"{self.__class__.__qualname__}()" | ||
| def create(self, state: ScatterDataFrameState) -> BaseFigure: | ||
| if state.dataframe.shape[0] == 0: | ||
| return HtmlFigure(MISSING_FIGURE_MESSAGE) | ||
| fig, ax = plt.subplots(**state.figure_config.get_arg("init", {})) | ||
| color = state.dataframe[state.color].to_numpy() if state.color else None | ||
| s = ax.scatter( | ||
| state.dataframe[state.x].to_numpy(), | ||
| state.dataframe[state.y].to_numpy(), | ||
| c=color, | ||
| label=state.color, | ||
| ) | ||
| if color is not None: | ||
| fig.colorbar(s) | ||
| ax.legend() | ||
| ax.set_xlabel(state.x) | ||
| ax.set_ylabel(state.y) | ||
| if xscale := state.figure_config.get_arg("xscale"): | ||
| ax.set_xscale(xscale) | ||
| if yscale := state.figure_config.get_arg("yscale"): | ||
| ax.set_yscale(yscale) | ||
| fig.tight_layout() | ||
| return MatplotlibFigure(fig) | ||
| class ScatterColumnPlotter(BasePlotter): | ||
| r"""Implement a DataFrame column plotter. | ||
| Args: | ||
| state: The state containing the DataFrame to analyze. | ||
| Example usage: | ||
| ```pycon | ||
| >>> import polars as pl | ||
| >>> from arkas.plotter import ScatterColumnPlotter | ||
| >>> from arkas.state import ScatterDataFrameState | ||
| >>> frame = pl.DataFrame( | ||
| ... { | ||
| ... "col1": [1.2, 4.2, 4.2, 2.2], | ||
| ... "col2": [1, 1, 1, 1], | ||
| ... "col3": [1, 2, 2, 2], | ||
| ... }, | ||
| ... schema={"col1": pl.Float64, "col2": pl.Int64, "col3": pl.Int64}, | ||
| ... ) | ||
| >>> plotter = ScatterColumnPlotter( | ||
| ... ScatterDataFrameState(frame, x="col1", y="col2", color="col3") | ||
| ... ) | ||
| >>> plotter | ||
| ScatterColumnPlotter( | ||
| (state): ScatterDataFrameState(dataframe=(4, 3), x='col1', y='col2', color='col3', figure_config=MatplotlibFigureConfig()) | ||
| ) | ||
| ``` | ||
| """ | ||
| registry = FigureCreatorRegistry[BaseFigureCreator]( | ||
| {MatplotlibFigureConfig.backend(): MatplotlibFigureCreator()} | ||
| ) | ||
| def __init__(self, state: ScatterDataFrameState) -> None: | ||
| self._state = state | ||
| def __repr__(self) -> str: | ||
| args = repr_indent(repr_mapping({"state": self._state})) | ||
| return f"{self.__class__.__qualname__}(\n {args}\n)" | ||
| def __str__(self) -> str: | ||
| args = str_indent(str_mapping({"state": self._state})) | ||
| return f"{self.__class__.__qualname__}(\n {args}\n)" | ||
| def compute(self) -> Plotter: | ||
| return Plotter(self.plot()) | ||
| def equal(self, other: Any, equal_nan: bool = False) -> bool: | ||
| if not isinstance(other, self.__class__): | ||
| return False | ||
| return self._state.equal(other._state, equal_nan=equal_nan) | ||
| def plot(self, prefix: str = "", suffix: str = "") -> dict: | ||
| figure = self.registry.find_creator(self._state.figure_config.backend()).create(self._state) | ||
| return {f"{prefix}scatter_column{suffix}": figure} |
| r"""Contain the implementation of a DataFrame column plotter.""" | ||
| from __future__ import annotations | ||
| __all__ = [ | ||
| "BaseFigureCreator", | ||
| "MatplotlibFigureCreator", | ||
| "TemporalPlotColumnPlotter", | ||
| "prepare_data", | ||
| ] | ||
| from abc import ABC, abstractmethod | ||
| from typing import TYPE_CHECKING, Any | ||
| import matplotlib.pyplot as plt | ||
| import polars as pl | ||
| from coola.utils import repr_indent, repr_mapping, str_indent, str_mapping | ||
| from arkas.figure.creator import FigureCreatorRegistry | ||
| from arkas.figure.html import HtmlFigure | ||
| from arkas.figure.matplotlib import MatplotlibFigure, MatplotlibFigureConfig | ||
| from arkas.figure.utils import MISSING_FIGURE_MESSAGE | ||
| from arkas.plotter.base import BasePlotter | ||
| from arkas.plotter.vanilla import Plotter | ||
| if TYPE_CHECKING: | ||
| import numpy as np | ||
| from arkas.figure.base import BaseFigure | ||
| from arkas.state.temporal_dataframe import TemporalDataFrameState | ||
| class BaseFigureCreator(ABC): | ||
| r"""Define the base class to create a figure with the content of | ||
| each column.""" | ||
| @abstractmethod | ||
| def create(self, state: TemporalDataFrameState) -> BaseFigure: | ||
| r"""Create a figure with the content of each column. | ||
| Args: | ||
| state: The state containing the DataFrame to analyze. | ||
| Returns: | ||
| The generated figure. | ||
| Example usage: | ||
| ```pycon | ||
| >>> from datetime import datetime, timezone | ||
| >>> import polars as pl | ||
| >>> from arkas.plotter.temporal_plot_column import MatplotlibFigureCreator | ||
| >>> from arkas.state import TemporalDataFrameState | ||
| >>> creator = MatplotlibFigureCreator() | ||
| >>> frame = pl.DataFrame( | ||
| ... { | ||
| ... "col1": [0, 1, 1, 0], | ||
| ... "col2": [0, 1, 0, 1], | ||
| ... "col3": [1, 0, 0, 0], | ||
| ... "datetime": [ | ||
| ... datetime(year=2020, month=1, day=3, tzinfo=timezone.utc), | ||
| ... datetime(year=2020, month=2, day=3, tzinfo=timezone.utc), | ||
| ... datetime(year=2020, month=3, day=3, tzinfo=timezone.utc), | ||
| ... datetime(year=2020, month=4, day=3, tzinfo=timezone.utc), | ||
| ... ], | ||
| ... }, | ||
| ... schema={ | ||
| ... "col1": pl.Int64, | ||
| ... "col2": pl.Int64, | ||
| ... "col3": pl.Int64, | ||
| ... "datetime": pl.Datetime(time_unit="us", time_zone="UTC"), | ||
| ... }, | ||
| ... ) | ||
| >>> fig = creator.create(TemporalDataFrameState(frame, temporal_column="datetime")) | ||
| ``` | ||
| """ | ||
| class MatplotlibFigureCreator(BaseFigureCreator): | ||
| r"""Create a matplotlib figure with the content of each column. | ||
| Example usage: | ||
| ```pycon | ||
| >>> from datetime import datetime, timezone | ||
| >>> import polars as pl | ||
| >>> from arkas.plotter.temporal_plot_column import MatplotlibFigureCreator | ||
| >>> from arkas.state import TemporalDataFrameState | ||
| >>> creator = MatplotlibFigureCreator() | ||
| >>> frame = pl.DataFrame( | ||
| ... { | ||
| ... "col1": [0, 1, 1, 0], | ||
| ... "col2": [0, 1, 0, 1], | ||
| ... "col3": [1, 0, 0, 0], | ||
| ... "datetime": [ | ||
| ... datetime(year=2020, month=1, day=3, tzinfo=timezone.utc), | ||
| ... datetime(year=2020, month=2, day=3, tzinfo=timezone.utc), | ||
| ... datetime(year=2020, month=3, day=3, tzinfo=timezone.utc), | ||
| ... datetime(year=2020, month=4, day=3, tzinfo=timezone.utc), | ||
| ... ], | ||
| ... }, | ||
| ... schema={ | ||
| ... "col1": pl.Int64, | ||
| ... "col2": pl.Int64, | ||
| ... "col3": pl.Int64, | ||
| ... "datetime": pl.Datetime(time_unit="us", time_zone="UTC"), | ||
| ... }, | ||
| ... ) | ||
| >>> fig = creator.create(TemporalDataFrameState(frame, temporal_column="datetime")) | ||
| ``` | ||
| """ | ||
| def __repr__(self) -> str: | ||
| return f"{self.__class__.__qualname__}()" | ||
| def create(self, state: TemporalDataFrameState) -> BaseFigure: | ||
| if state.dataframe.shape[0] == 0: | ||
| return HtmlFigure(MISSING_FIGURE_MESSAGE) | ||
| data, time = prepare_data( | ||
| dataframe=state.dataframe, temporal_column=state.temporal_column, period=state.period | ||
| ) | ||
| fig, ax = plt.subplots(**state.figure_config.get_arg("init", {})) | ||
| for col in data: | ||
| ax.plot(time, col.to_numpy(), label=col.name) | ||
| if yscale := state.figure_config.get_arg("yscale"): | ||
| ax.set_yscale(yscale) | ||
| ax.legend() | ||
| fig.tight_layout() | ||
| return MatplotlibFigure(fig) | ||
| class TemporalPlotColumnPlotter(BasePlotter): | ||
| r"""Implement a DataFrame column plotter. | ||
| Args: | ||
| state: The state containing the DataFrame to analyze. | ||
| Example usage: | ||
| ```pycon | ||
| >>> from datetime import datetime, timezone | ||
| >>> import polars as pl | ||
| >>> from arkas.plotter import TemporalPlotColumnPlotter | ||
| >>> from arkas.state import TemporalDataFrameState | ||
| >>> frame = pl.DataFrame( | ||
| ... { | ||
| ... "col1": [0, 1, 1, 0], | ||
| ... "col2": [0, 1, 0, 1], | ||
| ... "col3": [1, 0, 0, 0], | ||
| ... "datetime": [ | ||
| ... datetime(year=2020, month=1, day=3, tzinfo=timezone.utc), | ||
| ... datetime(year=2020, month=2, day=3, tzinfo=timezone.utc), | ||
| ... datetime(year=2020, month=3, day=3, tzinfo=timezone.utc), | ||
| ... datetime(year=2020, month=4, day=3, tzinfo=timezone.utc), | ||
| ... ], | ||
| ... }, | ||
| ... schema={ | ||
| ... "col1": pl.Int64, | ||
| ... "col2": pl.Int64, | ||
| ... "col3": pl.Int64, | ||
| ... "datetime": pl.Datetime(time_unit="us", time_zone="UTC"), | ||
| ... }, | ||
| ... ) | ||
| >>> plotter = TemporalPlotColumnPlotter( | ||
| ... TemporalDataFrameState(frame, temporal_column="datetime") | ||
| ... ) | ||
| >>> plotter | ||
| TemporalPlotColumnPlotter( | ||
| (state): TemporalDataFrameState(dataframe=(4, 4), temporal_column='datetime', period=None, figure_config=MatplotlibFigureConfig()) | ||
| ) | ||
| ``` | ||
| """ | ||
| registry = FigureCreatorRegistry[BaseFigureCreator]( | ||
| {MatplotlibFigureConfig.backend(): MatplotlibFigureCreator()} | ||
| ) | ||
| def __init__(self, state: TemporalDataFrameState) -> None: | ||
| self._state = state | ||
| def __repr__(self) -> str: | ||
| args = repr_indent(repr_mapping({"state": self._state})) | ||
| return f"{self.__class__.__qualname__}(\n {args}\n)" | ||
| def __str__(self) -> str: | ||
| args = str_indent(str_mapping({"state": self._state})) | ||
| return f"{self.__class__.__qualname__}(\n {args}\n)" | ||
| def compute(self) -> Plotter: | ||
| return Plotter(self.plot()) | ||
| def equal(self, other: Any, equal_nan: bool = False) -> bool: | ||
| if not isinstance(other, self.__class__): | ||
| return False | ||
| return self._state.equal(other._state, equal_nan=equal_nan) | ||
| def plot(self, prefix: str = "", suffix: str = "") -> dict: | ||
| figure = self.registry.find_creator(self._state.figure_config.backend()).create(self._state) | ||
| return {f"{prefix}temporal_plot_column{suffix}": figure} | ||
| def prepare_data( | ||
| dataframe: pl.DataFrame, temporal_column: str, period: str | None | ||
| ) -> tuple[pl.DataFrame, np.ndarray]: | ||
| """Prepare the data before to plot them. | ||
| Args: | ||
| dataframe: The DataFrame. | ||
| temporal_column: The temporal column in the DataFrame. | ||
| period: An optional temporal period e.g. monthly or daily. | ||
| Returns: | ||
| The DataFrame to plot and the array with the time steps. | ||
| ```pycon | ||
| >>> from datetime import datetime, timezone | ||
| >>> import polars as pl | ||
| >>> from arkas.plotter.temporal_plot_column import prepare_data | ||
| >>> from arkas.state import TemporalDataFrameState | ||
| >>> frame = pl.DataFrame( | ||
| ... { | ||
| ... "col1": [0, 1, 1, 0], | ||
| ... "col2": [0, 1, 0, 1], | ||
| ... "col3": [1, 0, 0, 0], | ||
| ... "datetime": [ | ||
| ... datetime(year=2020, month=1, day=3, tzinfo=timezone.utc), | ||
| ... datetime(year=2020, month=2, day=3, tzinfo=timezone.utc), | ||
| ... datetime(year=2020, month=3, day=3, tzinfo=timezone.utc), | ||
| ... datetime(year=2020, month=4, day=3, tzinfo=timezone.utc), | ||
| ... ], | ||
| ... }, | ||
| ... schema={ | ||
| ... "col1": pl.Int64, | ||
| ... "col2": pl.Int64, | ||
| ... "col3": pl.Int64, | ||
| ... "datetime": pl.Datetime(time_unit="us", time_zone="UTC"), | ||
| ... }, | ||
| ... ) | ||
| >>> data, time = prepare_data(frame, temporal_column="datetime", period=None) | ||
| >>> data | ||
| shape: (4, 3) | ||
| ┌──────┬──────┬──────┐ | ||
| │ col1 ┆ col2 ┆ col3 │ | ||
| │ --- ┆ --- ┆ --- │ | ||
| │ i64 ┆ i64 ┆ i64 │ | ||
| ╞══════╪══════╪══════╡ | ||
| │ 0 ┆ 0 ┆ 1 │ | ||
| │ 1 ┆ 1 ┆ 0 │ | ||
| │ 1 ┆ 0 ┆ 0 │ | ||
| │ 0 ┆ 1 ┆ 0 │ | ||
| └──────┴──────┴──────┘ | ||
| >>> time | ||
| array(['2020-01-03T00:00:00.000000', '2020-02-03T00:00:00.000000', | ||
| '2020-03-03T00:00:00.000000', '2020-04-03T00:00:00.000000'], | ||
| dtype='datetime64[us]') | ||
| ``` | ||
| """ | ||
| dataframe = dataframe.sort(temporal_column) | ||
| if period: | ||
| dataframe = dataframe.group_by_dynamic(temporal_column, every=period).agg(pl.all().mean()) | ||
| time = dataframe[temporal_column].to_numpy() | ||
| return dataframe.drop(temporal_column), time |
| r"""Implement the column co-occurrence state.""" | ||
| from __future__ import annotations | ||
| __all__ = ["ColumnCooccurrenceState"] | ||
| import sys | ||
| from typing import TYPE_CHECKING, Any | ||
| from coola import objects_are_equal | ||
| from coola.utils.format import repr_mapping_line | ||
| from grizz.utils.cooccurrence import compute_pairwise_cooccurrence | ||
| from arkas.figure.utils import get_default_config | ||
| from arkas.state.base import BaseState | ||
| from arkas.utils.array import check_square_matrix | ||
| if sys.version_info >= (3, 11): | ||
| from typing import Self | ||
| else: # pragma: no cover | ||
| from typing_extensions import ( | ||
| Self, # use backport because it was added in python 3.11 | ||
| ) | ||
| if TYPE_CHECKING: | ||
| from collections.abc import Sequence | ||
| import numpy as np | ||
| import polars as pl | ||
| from arkas.figure.base import BaseFigureConfig | ||
| class ColumnCooccurrenceState(BaseState): | ||
| r"""Implement the column co-occurrence state. | ||
| Args: | ||
| matrix: The co-occurrence matrix. | ||
| columns: The column names. | ||
| figure_config: An optional figure configuration. | ||
| Example usage: | ||
| ```pycon | ||
| >>> import numpy as np | ||
| >>> from arkas.state import ColumnCooccurrenceState | ||
| >>> state = ColumnCooccurrenceState(matrix=np.ones((3, 3)), columns=["a", "b", "c"]) | ||
| >>> state | ||
| ColumnCooccurrenceState(matrix=(3, 3), figure_config=MatplotlibFigureConfig()) | ||
| ``` | ||
| """ | ||
| def __init__( | ||
| self, | ||
| matrix: np.ndarray, | ||
| columns: Sequence[str], | ||
| figure_config: BaseFigureConfig | None = None, | ||
| ) -> None: | ||
| check_square_matrix(name="matrix", array=matrix) | ||
| if matrix.shape[0] != len(columns): | ||
| msg = ( | ||
| f"The number of columns does not match the matrix shape: {len(columns)} " | ||
| f"vs {matrix.shape[0]}" | ||
| ) | ||
| raise ValueError(msg) | ||
| self._matrix = matrix | ||
| self._columns = tuple(columns) | ||
| self._figure_config = figure_config or get_default_config() | ||
| def __repr__(self) -> str: | ||
| args = repr_mapping_line( | ||
| { | ||
| "matrix": self._matrix.shape, | ||
| "figure_config": self._figure_config, | ||
| } | ||
| ) | ||
| return f"{self.__class__.__qualname__}({args})" | ||
| @property | ||
| def matrix(self) -> np.ndarray: | ||
| return self._matrix | ||
| @property | ||
| def columns(self) -> tuple[str, ...]: | ||
| return self._columns | ||
| @property | ||
| def figure_config(self) -> BaseFigureConfig | None: | ||
| return self._figure_config | ||
| def clone(self, deep: bool = True) -> Self: | ||
| return self.__class__( | ||
| matrix=self._matrix.copy() if deep else self._matrix, | ||
| columns=self._columns, | ||
| figure_config=self._figure_config.clone() if deep else self._figure_config, | ||
| ) | ||
| def equal(self, other: Any, equal_nan: bool = False) -> bool: | ||
| if not isinstance(other, self.__class__): | ||
| return False | ||
| return ( | ||
| objects_are_equal(self.matrix, other.matrix, equal_nan=equal_nan) | ||
| and objects_are_equal(self.columns, other.columns, equal_nan=equal_nan) | ||
| and objects_are_equal(self.figure_config, other.figure_config, equal_nan=equal_nan) | ||
| ) | ||
| @classmethod | ||
| def from_dataframe( | ||
| cls, | ||
| frame: pl.DataFrame, | ||
| ignore_self: bool = False, | ||
| figure_config: BaseFigureConfig | None = None, | ||
| ) -> ColumnCooccurrenceState: | ||
| r"""Instantiate a ``ColumnCooccurrenceState`` object from a | ||
| DataFrame. | ||
| Args: | ||
| frame: The DataFrame to analyze. | ||
| ignore_self: If ``True``, the diagonal of the co-occurrence | ||
| matrix (a.k.a. self-co-occurrence) is set to 0. | ||
| figure_config: An optional figure configuration. | ||
| Returns: | ||
| The instantiate state. | ||
| Example usage: | ||
| ```pycon | ||
| >>> import polars as pl | ||
| >>> from arkas.state import ColumnCooccurrenceState | ||
| >>> frame = pl.DataFrame( | ||
| ... { | ||
| ... "col1": [0, 1, 1, 0, 0, 1, 0], | ||
| ... "col2": [0, 1, 0, 1, 0, 1, 0], | ||
| ... "col3": [0, 0, 0, 0, 1, 1, 1], | ||
| ... } | ||
| ... ) | ||
| >>> state = ColumnCooccurrenceState.from_dataframe(frame) | ||
| >>> state | ||
| ColumnCooccurrenceState(matrix=(3, 3), figure_config=MatplotlibFigureConfig()) | ||
| ``` | ||
| """ | ||
| matrix = compute_pairwise_cooccurrence(frame=frame, ignore_self=ignore_self) | ||
| return cls(matrix=matrix, columns=frame.columns, figure_config=figure_config) |
| r"""Implement the DataFrame state.""" | ||
| from __future__ import annotations | ||
| __all__ = ["DataFrameState"] | ||
| import sys | ||
| from typing import TYPE_CHECKING, Any | ||
| from coola import objects_are_equal | ||
| from coola.utils.format import repr_mapping_line, str_indent, str_mapping | ||
| from arkas.figure.utils import get_default_config | ||
| from arkas.state.base import BaseState | ||
| if sys.version_info >= (3, 11): | ||
| from typing import Self | ||
| else: # pragma: no cover | ||
| from typing_extensions import ( | ||
| Self, # use backport because it was added in python 3.11 | ||
| ) | ||
| if TYPE_CHECKING: | ||
| import polars as pl | ||
| from arkas.figure.base import BaseFigureConfig | ||
| class DataFrameState(BaseState): | ||
| r"""Implement the DataFrame state. | ||
| Args: | ||
| dataframe: The DataFrame. | ||
| figure_config: An optional figure configuration. | ||
| Example usage: | ||
| ```pycon | ||
| >>> import polars as pl | ||
| >>> from arkas.state import DataFrameState | ||
| >>> frame = pl.DataFrame( | ||
| ... { | ||
| ... "col1": [0, 1, 1, 0, 0, 1, 0], | ||
| ... "col2": [0, 1, 0, 1, 0, 1, 0], | ||
| ... "col3": [0, 0, 0, 0, 1, 1, 1], | ||
| ... } | ||
| ... ) | ||
| >>> state = DataFrameState(frame) | ||
| >>> state | ||
| DataFrameState(dataframe=(7, 3), figure_config=MatplotlibFigureConfig()) | ||
| ``` | ||
| """ | ||
| def __init__( | ||
| self, | ||
| dataframe: pl.DataFrame, | ||
| figure_config: BaseFigureConfig | None = None, | ||
| ) -> None: | ||
| self._dataframe = dataframe | ||
| self._figure_config = figure_config or get_default_config() | ||
| def __repr__(self) -> str: | ||
| args = repr_mapping_line( | ||
| { | ||
| "dataframe": self._dataframe.shape, | ||
| "figure_config": self._figure_config, | ||
| } | ||
| ) | ||
| return f"{self.__class__.__qualname__}({args})" | ||
| def __str__(self) -> str: | ||
| args = str_indent( | ||
| str_mapping( | ||
| { | ||
| "dataframe": self._dataframe.shape, | ||
| "figure_config": self._figure_config, | ||
| } | ||
| ) | ||
| ) | ||
| return f"{self.__class__.__qualname__}({args})" | ||
| @property | ||
| def dataframe(self) -> pl.DataFrame: | ||
| return self._dataframe | ||
| @property | ||
| def figure_config(self) -> BaseFigureConfig | None: | ||
| return self._figure_config | ||
| def clone(self, deep: bool = True) -> Self: | ||
| return self.__class__( | ||
| dataframe=self._dataframe.clone() if deep else self._dataframe, | ||
| figure_config=self._figure_config.clone() if deep else self._figure_config, | ||
| ) | ||
| def equal(self, other: Any, equal_nan: bool = False) -> bool: | ||
| if not isinstance(other, self.__class__): | ||
| return False | ||
| return objects_are_equal(self.get_args(), other.get_args(), equal_nan=equal_nan) | ||
| def get_args(self) -> dict: | ||
| return { | ||
| "dataframe": self._dataframe, | ||
| "figure_config": self._figure_config, | ||
| } |
| r"""Implement the DataFrame state for scatter plots.""" | ||
| from __future__ import annotations | ||
| __all__ = ["ScatterDataFrameState"] | ||
| import sys | ||
| from typing import TYPE_CHECKING | ||
| from coola.utils.format import repr_mapping_line, str_indent, str_mapping | ||
| from arkas.state.dataframe import DataFrameState | ||
| from arkas.utils.dataframe import check_column_exist | ||
| if sys.version_info >= (3, 11): | ||
| from typing import Self | ||
| else: # pragma: no cover | ||
| from typing_extensions import ( | ||
| Self, # use backport because it was added in python 3.11 | ||
| ) | ||
| if TYPE_CHECKING: | ||
| import polars as pl | ||
| from arkas.figure.base import BaseFigureConfig | ||
| class ScatterDataFrameState(DataFrameState): | ||
| r"""Implement the DataFrame state for scatter plots. | ||
| Args: | ||
| dataframe: The DataFrame. | ||
| x: The x-axis data column. | ||
| y: The y-axis data column. | ||
| color: An optional color axis data column. | ||
| figure_config: An optional figure configuration. | ||
| Example usage: | ||
| ```pycon | ||
| >>> import polars as pl | ||
| >>> from arkas.state import ScatterDataFrameState | ||
| >>> frame = pl.DataFrame( | ||
| ... { | ||
| ... "col1": [0, 1, 1, 0, 0, 1, 0], | ||
| ... "col2": [0, 1, 0, 1, 0, 1, 0], | ||
| ... "col3": [0, 0, 0, 0, 1, 1, 1], | ||
| ... } | ||
| ... ) | ||
| >>> state = ScatterDataFrameState(frame, x="col1", y="col2") | ||
| >>> state | ||
| ScatterDataFrameState(dataframe=(7, 3), x='col1', y='col2', color=None, figure_config=MatplotlibFigureConfig()) | ||
| ``` | ||
| """ | ||
| def __init__( | ||
| self, | ||
| dataframe: pl.DataFrame, | ||
| x: str, | ||
| y: str, | ||
| color: str | None = None, | ||
| figure_config: BaseFigureConfig | None = None, | ||
| ) -> None: | ||
| super().__init__(dataframe=dataframe, figure_config=figure_config) | ||
| check_column_exist(dataframe, x) | ||
| check_column_exist(dataframe, y) | ||
| if color is not None: | ||
| check_column_exist(dataframe, color) | ||
| self._x = x | ||
| self._y = y | ||
| self._color = color | ||
| def __repr__(self) -> str: | ||
| args = repr_mapping_line( | ||
| { | ||
| "dataframe": self._dataframe.shape, | ||
| "x": self._x, | ||
| "y": self._y, | ||
| "color": self._color, | ||
| "figure_config": self._figure_config, | ||
| } | ||
| ) | ||
| return f"{self.__class__.__qualname__}({args})" | ||
| def __str__(self) -> str: | ||
| args = str_indent( | ||
| str_mapping( | ||
| { | ||
| "dataframe": self._dataframe.shape, | ||
| "x": self._x, | ||
| "y": self._y, | ||
| "color": self._color, | ||
| "figure_config": self._figure_config, | ||
| } | ||
| ) | ||
| ) | ||
| return f"{self.__class__.__qualname__}({args})" | ||
| @property | ||
| def x(self) -> str: | ||
| return self._x | ||
| @property | ||
| def y(self) -> str: | ||
| return self._y | ||
| @property | ||
| def color(self) -> str | None: | ||
| return self._color | ||
| def clone(self, deep: bool = True) -> Self: | ||
| return self.__class__( | ||
| dataframe=self._dataframe.clone() if deep else self._dataframe, | ||
| x=self._x, | ||
| y=self._y, | ||
| color=self._color, | ||
| figure_config=self._figure_config.clone() if deep else self._figure_config, | ||
| ) | ||
| def get_args(self) -> dict: | ||
| return super().get_args() | {"x": self._x, "y": self._y, "color": self._color} |
| r"""Implement the temporal DataFrame state.""" | ||
| from __future__ import annotations | ||
| __all__ = ["TemporalDataFrameState"] | ||
| import sys | ||
| from typing import TYPE_CHECKING | ||
| from coola.utils.format import repr_mapping_line, str_indent, str_mapping | ||
| from arkas.state.dataframe import DataFrameState | ||
| from arkas.utils.dataframe import check_column_exist | ||
| if sys.version_info >= (3, 11): | ||
| from typing import Self | ||
| else: # pragma: no cover | ||
| from typing_extensions import ( | ||
| Self, # use backport because it was added in python 3.11 | ||
| ) | ||
| if TYPE_CHECKING: | ||
| import polars as pl | ||
| from arkas.figure.base import BaseFigureConfig | ||
| class TemporalDataFrameState(DataFrameState): | ||
| r"""Implement the temporal DataFrame state. | ||
| Args: | ||
| dataframe: The DataFrame. | ||
| temporal_column: The temporal column in the DataFrame. | ||
| period: An optional temporal period e.g. monthly or daily. | ||
| figure_config: An optional figure configuration. | ||
| Example usage: | ||
| ```pycon | ||
| >>> from datetime import datetime, timezone | ||
| >>> import polars as pl | ||
| >>> from arkas.state import TemporalDataFrameState | ||
| >>> frame = pl.DataFrame( | ||
| ... { | ||
| ... "col1": [0, 1, 1, 0], | ||
| ... "col2": [0, 1, 0, 1], | ||
| ... "col3": [1, 0, 0, 0], | ||
| ... "datetime": [ | ||
| ... datetime(year=2020, month=1, day=3, tzinfo=timezone.utc), | ||
| ... datetime(year=2020, month=2, day=3, tzinfo=timezone.utc), | ||
| ... datetime(year=2020, month=3, day=3, tzinfo=timezone.utc), | ||
| ... datetime(year=2020, month=4, day=3, tzinfo=timezone.utc), | ||
| ... ], | ||
| ... }, | ||
| ... schema={ | ||
| ... "col1": pl.Int64, | ||
| ... "col2": pl.Int64, | ||
| ... "col3": pl.Int64, | ||
| ... "datetime": pl.Datetime(time_unit="us", time_zone="UTC"), | ||
| ... }, | ||
| ... ) | ||
| >>> state = TemporalDataFrameState(frame, temporal_column="datetime") | ||
| >>> state | ||
| TemporalDataFrameState(dataframe=(4, 4), temporal_column='datetime', period=None, figure_config=MatplotlibFigureConfig()) | ||
| ``` | ||
| """ | ||
| def __init__( | ||
| self, | ||
| dataframe: pl.DataFrame, | ||
| temporal_column: str, | ||
| period: str | None = None, | ||
| figure_config: BaseFigureConfig | None = None, | ||
| ) -> None: | ||
| super().__init__(dataframe=dataframe, figure_config=figure_config) | ||
| check_column_exist(dataframe, temporal_column) | ||
| self._temporal_column = temporal_column | ||
| self._period = period | ||
| def __repr__(self) -> str: | ||
| args = repr_mapping_line( | ||
| { | ||
| "dataframe": self._dataframe.shape, | ||
| "temporal_column": self._temporal_column, | ||
| "period": self._period, | ||
| "figure_config": self._figure_config, | ||
| } | ||
| ) | ||
| return f"{self.__class__.__qualname__}({args})" | ||
| def __str__(self) -> str: | ||
| args = str_indent( | ||
| str_mapping( | ||
| { | ||
| "dataframe": self._dataframe.shape, | ||
| "temporal_column": self._temporal_column, | ||
| "period": self._period, | ||
| "figure_config": self._figure_config, | ||
| } | ||
| ) | ||
| ) | ||
| return f"{self.__class__.__qualname__}({args})" | ||
| @property | ||
| def period(self) -> str | None: | ||
| return self._period | ||
| @property | ||
| def temporal_column(self) -> str: | ||
| return self._temporal_column | ||
| def clone(self, deep: bool = True) -> Self: | ||
| return self.__class__( | ||
| dataframe=self._dataframe.clone() if deep else self._dataframe, | ||
| temporal_column=self._temporal_column, | ||
| period=self._period, | ||
| figure_config=self._figure_config.clone() if deep else self._figure_config, | ||
| ) | ||
| def get_args(self) -> dict: | ||
| return super().get_args() | { | ||
| "temporal_column": self._temporal_column, | ||
| "period": self._period, | ||
| } |
+2
-2
@@ -1,4 +0,4 @@ | ||
| Metadata-Version: 2.1 | ||
| Metadata-Version: 2.3 | ||
| Name: arkas | ||
| Version: 0.0.1a8 | ||
| Version: 0.0.1a9 | ||
| Summary: Library to evaluate ML model performances | ||
@@ -5,0 +5,0 @@ Home-page: https://github.com/durandtibo/arkas |
+1
-1
| [tool.poetry] | ||
| name = "arkas" | ||
| version = "0.0.1a8" | ||
| version = "0.0.1a9" | ||
| description = "Library to evaluate ML model performances" | ||
@@ -5,0 +5,0 @@ readme = "README.md" |
@@ -16,2 +16,5 @@ r"""Contain DataFrame analyzers.""" | ||
| "MappingAnalyzer", | ||
| "PlotColumnAnalyzer", | ||
| "ScatterColumnAnalyzer", | ||
| "TemporalPlotColumnAnalyzer", | ||
| "TransformAnalyzer", | ||
@@ -31,2 +34,5 @@ "is_analyzer_config", | ||
| from arkas.analyzer.mapping import MappingAnalyzer | ||
| from arkas.analyzer.plot_column import PlotColumnAnalyzer | ||
| from arkas.analyzer.scatter_column import ScatterColumnAnalyzer | ||
| from arkas.analyzer.temporal_plot_column import TemporalPlotColumnAnalyzer | ||
| from arkas.analyzer.transform import TransformAnalyzer |
@@ -14,2 +14,3 @@ r"""Implement a pairwise column co-occurrence analyzer.""" | ||
| from arkas.output.column_cooccurrence import ColumnCooccurrenceOutput | ||
| from arkas.state.column_cooccurrence import ColumnCooccurrenceState | ||
@@ -65,3 +66,5 @@ if TYPE_CHECKING: | ||
| >>> output | ||
| ColumnCooccurrenceOutput(shape=(7, 3), ignore_self=False) | ||
| ColumnCooccurrenceOutput( | ||
| (state): ColumnCooccurrenceState(matrix=(3, 3), figure_config=MatplotlibFigureConfig()) | ||
| ) | ||
@@ -102,3 +105,5 @@ ``` | ||
| return ColumnCooccurrenceOutput( | ||
| frame=out, ignore_self=self._ignore_self, figure_config=self._figure_config | ||
| state=ColumnCooccurrenceState.from_dataframe( | ||
| frame=out, ignore_self=self._ignore_self, figure_config=self._figure_config | ||
| ) | ||
| ) |
@@ -114,3 +114,5 @@ r"""Define a base class to implement lazy analyzers.""" | ||
| >>> output | ||
| ColumnCooccurrenceOutput(shape=(7, 3), ignore_self=False) | ||
| ColumnCooccurrenceOutput( | ||
| (state): ColumnCooccurrenceState(matrix=(3, 3), figure_config=MatplotlibFigureConfig()) | ||
| ) | ||
@@ -117,0 +119,0 @@ ``` |
@@ -13,2 +13,5 @@ r"""Contain HTML content generators.""" | ||
| "DataFrameSummaryContentGenerator", | ||
| "PlotColumnContentGenerator", | ||
| "ScatterColumnContentGenerator", | ||
| "TemporalPlotColumnContentGenerator", | ||
| ] | ||
@@ -22,2 +25,5 @@ | ||
| from arkas.content.mapping import ContentGeneratorDict | ||
| from arkas.content.plot_column import PlotColumnContentGenerator | ||
| from arkas.content.scatter_column import ScatterColumnContentGenerator | ||
| from arkas.content.temporal_plot_column import TemporalPlotColumnContentGenerator | ||
| from arkas.content.vanilla import ContentGenerator |
@@ -6,3 +6,9 @@ r"""Contain the implementation of a HTML content generator that returns | ||
| __all__ = ["ColumnCooccurrenceContentGenerator", "create_table", "create_template"] | ||
| __all__ = [ | ||
| "ColumnCooccurrenceContentGenerator", | ||
| "create_table", | ||
| "create_table_row", | ||
| "create_table_section", | ||
| "create_template", | ||
| ] | ||
@@ -13,5 +19,3 @@ import logging | ||
| import numpy as np | ||
| from coola import objects_are_equal | ||
| from coola.utils import str_indent | ||
| from grizz.utils.cooccurrence import compute_pairwise_cooccurrence | ||
| from coola.utils import repr_indent, repr_mapping, str_indent, str_mapping | ||
| from jinja2 import Template | ||
@@ -26,5 +30,4 @@ | ||
| import polars as pl | ||
| from arkas.state.column_cooccurrence import ColumnCooccurrenceState | ||
| from arkas.figure.base import BaseFigureConfig | ||
@@ -39,6 +42,3 @@ logger = logging.getLogger(__name__) | ||
| Args: | ||
| frame: The DataFrame to analyze. | ||
| ignore_self: If ``True``, the diagonal of the co-occurrence | ||
| matrix (a.k.a. self-co-occurrence) is set to 0. | ||
| figure_config: The figure configuration. | ||
| state: The state with the co-occurrence matrix. | ||
@@ -49,14 +49,12 @@ Example usage: | ||
| >>> import polars as pl | ||
| >>> import numpy as np | ||
| >>> from arkas.content import ColumnCooccurrenceContentGenerator | ||
| >>> frame = pl.DataFrame( | ||
| ... { | ||
| ... "col1": [0, 1, 1, 0, 0, 1, 0], | ||
| ... "col2": [0, 1, 0, 1, 0, 1, 0], | ||
| ... "col3": [0, 0, 0, 0, 1, 1, 1], | ||
| ... } | ||
| >>> from arkas.state import ColumnCooccurrenceState | ||
| >>> content = ColumnCooccurrenceContentGenerator( | ||
| ... ColumnCooccurrenceState(matrix=np.ones((3, 3)), columns=["a", "b", "c"]) | ||
| ... ) | ||
| >>> content = ColumnCooccurrenceContentGenerator(frame) | ||
| >>> content | ||
| ColumnCooccurrenceContentGenerator(shape=(7, 3), ignore_self=False) | ||
| ColumnCooccurrenceContentGenerator( | ||
| (state): ColumnCooccurrenceState(matrix=(3, 3), figure_config=MatplotlibFigureConfig()) | ||
| ) | ||
@@ -66,52 +64,29 @@ ``` | ||
| def __init__( | ||
| self, | ||
| frame: pl.DataFrame, | ||
| ignore_self: bool = False, | ||
| figure_config: BaseFigureConfig | None = None, | ||
| ) -> None: | ||
| self._frame = frame | ||
| self._ignore_self = bool(ignore_self) | ||
| self._figure_config = figure_config | ||
| def __init__(self, state: ColumnCooccurrenceState) -> None: | ||
| self._state = state | ||
| def __repr__(self) -> str: | ||
| return ( | ||
| f"{self.__class__.__qualname__}(shape={self._frame.shape}, " | ||
| f"ignore_self={self._ignore_self})" | ||
| ) | ||
| args = repr_indent(repr_mapping({"state": self._state})) | ||
| return f"{self.__class__.__qualname__}(\n {args}\n)" | ||
| @property | ||
| def frame(self) -> pl.DataFrame: | ||
| r"""The DataFrame to analyze.""" | ||
| return self._frame | ||
| def __str__(self) -> str: | ||
| args = str_indent(str_mapping({"state": self._state})) | ||
| return f"{self.__class__.__qualname__}(\n {args}\n)" | ||
| @property | ||
| def ignore_self(self) -> bool: | ||
| return self._ignore_self | ||
| def equal(self, other: Any, equal_nan: bool = False) -> bool: | ||
| if not isinstance(other, self.__class__): | ||
| return False | ||
| return ( | ||
| self.ignore_self == other.ignore_self | ||
| and objects_are_equal(self.frame, other.frame, equal_nan=equal_nan) | ||
| and objects_are_equal(self._figure_config, other._figure_config, equal_nan=equal_nan) | ||
| ) | ||
| return self._state.equal(other._state, equal_nan=equal_nan) | ||
| def generate_content(self) -> str: | ||
| logger.info("Generating the DataFrame summary content...") | ||
| figures = ColumnCooccurrencePlotter( | ||
| frame=self._frame, ignore_self=self._ignore_self, figure_config=self._figure_config | ||
| ).plot() | ||
| columns = list(self._frame.columns) | ||
| figures = ColumnCooccurrencePlotter(self._state).plot() | ||
| columns = self._state.columns | ||
| return Template(create_template()).render( | ||
| { | ||
| "nrows": f"{self._frame.shape[0]:,}", | ||
| "ncols": f"{self._frame.shape[1]:,}", | ||
| "columns": ", ".join([f"{x!r}" for x in columns]), | ||
| "ncols": f"{len(columns):,}", | ||
| "figure": figure2html(figures["column_cooccurrence"], close_fig=True), | ||
| "table": create_table_section( | ||
| matrix=compute_pairwise_cooccurrence( | ||
| frame=self._frame, ignore_self=self._ignore_self | ||
| ), | ||
| matrix=self._state.matrix, | ||
| columns=columns, | ||
@@ -139,9 +114,5 @@ ), | ||
| return """This section shows an analysis of the pairwise column co-occurrence. | ||
| <ul> | ||
| <li> number of columns: {{ncols}} </li> | ||
| <li> number of rows: {{nrows}}</li> | ||
| </ul> | ||
| {{figure}} | ||
| <details> | ||
| <summary>[show columns]</summary> | ||
| <summary>[show {{ncols}} columns]</summary> | ||
| {{columns}} | ||
@@ -148,0 +119,0 @@ </details> |
@@ -9,2 +9,3 @@ r"""Contain data evaluators.""" | ||
| "BaseEvaluator", | ||
| "ColumnCooccurrenceEvaluator", | ||
| "Evaluator", | ||
@@ -17,3 +18,4 @@ "EvaluatorDict", | ||
| from arkas.evaluator2.base import BaseEvaluator | ||
| from arkas.evaluator2.column_cooccurrence import ColumnCooccurrenceEvaluator | ||
| from arkas.evaluator2.mapping import EvaluatorDict | ||
| from arkas.evaluator2.vanilla import Evaluator |
@@ -8,3 +8,2 @@ r"""Contain figures.""" | ||
| "BaseFigureConfig", | ||
| "DefaultFigureConfig", | ||
| "HtmlFigure", | ||
@@ -20,3 +19,2 @@ "MatplotlibFigure", | ||
| from arkas.figure.base import BaseFigure, BaseFigureConfig | ||
| from arkas.figure.default import DefaultFigureConfig | ||
| from arkas.figure.html import HtmlFigure | ||
@@ -23,0 +21,0 @@ from arkas.figure.matplotlib import MatplotlibFigure, MatplotlibFigureConfig |
@@ -144,3 +144,3 @@ """Contain the base class to implement a figure.""" | ||
| >>> config | ||
| MatplotlibFigureConfig(color_norm=None) | ||
| MatplotlibFigureConfig() | ||
@@ -168,2 +168,20 @@ ``` | ||
| @abstractmethod | ||
| def clone(self) -> Self: | ||
| r"""Return a copy of the config. | ||
| Returns: | ||
| A copy of the config. | ||
| Example usage: | ||
| ```pycon | ||
| >>> from arkas.figure import MatplotlibFigureConfig | ||
| >>> config = MatplotlibFigureConfig() | ||
| >>> cloned_config = config.clone() | ||
| ``` | ||
| """ | ||
| @abstractmethod | ||
| def equal(self, other: Any, equal_nan: bool = False) -> bool: | ||
@@ -197,7 +215,11 @@ r"""Indicate if two configs are equal or not. | ||
| @abstractmethod | ||
| def get_args(self) -> dict: | ||
| r"""Get the config arguments. | ||
| def get_arg(self, name: str, default: Any = None) -> Any: | ||
| r"""Get a given argument from the config. | ||
| Args: | ||
| name: The argument name to get. | ||
| default: The default value to return if the argument is missing. | ||
| Returns: | ||
| The config arguments. | ||
| The argument value or the default value. | ||
@@ -209,5 +231,5 @@ Example usage: | ||
| >>> from arkas.figure import MatplotlibFigureConfig | ||
| >>> config = MatplotlibFigureConfig(dpi=300) | ||
| >>> config.get_args() | ||
| {'dpi': 300} | ||
| >>> config = MatplotlibFigureConfig(dpi=42) | ||
| >>> config.get_arg("dpi") | ||
| 42 | ||
@@ -214,0 +236,0 @@ ``` |
@@ -8,5 +8,6 @@ r"""Contain the implementation for matplotlib figures.""" | ||
| import base64 | ||
| import copy | ||
| import io | ||
| import sys | ||
| from typing import TYPE_CHECKING, Any | ||
| from typing import Any | ||
@@ -19,5 +20,2 @@ import matplotlib.pyplot as plt | ||
| if TYPE_CHECKING: | ||
| from matplotlib.colors import Normalize | ||
| if sys.version_info >= (3, 11): | ||
@@ -103,3 +101,3 @@ from typing import Self | ||
| >>> config | ||
| MatplotlibFigureConfig(color_norm=None) | ||
| MatplotlibFigureConfig() | ||
@@ -109,8 +107,7 @@ ``` | ||
| def __init__(self, color_norm: Normalize | None = None, **kwargs: Any) -> None: | ||
| self._color_norm = color_norm | ||
| def __init__(self, **kwargs: Any) -> None: | ||
| self._kwargs = kwargs | ||
| def __repr__(self) -> str: | ||
| args = repr_mapping_line({"color_norm": self._color_norm} | self.get_args()) | ||
| args = repr_mapping_line(self._kwargs) | ||
| return f"{self.__class__.__qualname__}({args})" | ||
@@ -122,30 +119,11 @@ | ||
| def clone(self) -> Self: | ||
| return self.__class__(**self._kwargs) | ||
| def equal(self, other: Any, equal_nan: bool = False) -> bool: | ||
| if not isinstance(other, self.__class__): | ||
| return False | ||
| # color_norm is excluded from the comparison as it is not straightforward | ||
| # to compare to Normalize objects. | ||
| return objects_are_equal(self.get_args(), other.get_args(), equal_nan=equal_nan) | ||
| return objects_are_equal(self._kwargs, other._kwargs, equal_nan=equal_nan) | ||
| def get_args(self) -> dict: | ||
| return self._kwargs | ||
| def get_color_norm(self) -> Normalize | None: | ||
| r"""Get the color normalization. | ||
| Returns: | ||
| The color normalization. | ||
| Example usage: | ||
| ```pycon | ||
| >>> from matplotlib.colors import LogNorm | ||
| >>> from arkas.figure import MatplotlibFigureConfig | ||
| >>> config = MatplotlibFigureConfig(color_norm=LogNorm()) | ||
| >>> config.get_color_norm() | ||
| <matplotlib.colors.LogNorm object at 0x...> | ||
| ``` | ||
| """ | ||
| return self._color_norm | ||
| def get_arg(self, name: str, default: Any = None) -> Any: | ||
| return copy.copy(self._kwargs.get(name, default)) |
@@ -7,2 +7,3 @@ r"""Contain the implementation for plotly figures.""" | ||
| import copy | ||
| import sys | ||
@@ -109,3 +110,3 @@ from typing import Any | ||
| def __repr__(self) -> str: | ||
| args = repr_mapping_line(self.get_args()) | ||
| args = repr_mapping_line(self._kwargs) | ||
| return f"{self.__class__.__qualname__}({args})" | ||
@@ -117,8 +118,11 @@ | ||
| def clone(self) -> Self: | ||
| return self.__class__(**self._kwargs) | ||
| def equal(self, other: Any, equal_nan: bool = False) -> bool: | ||
| if not isinstance(other, self.__class__): | ||
| return False | ||
| return objects_are_equal(self.get_args(), other.get_args(), equal_nan=equal_nan) | ||
| return objects_are_equal(self._kwargs, other._kwargs, equal_nan=equal_nan) | ||
| def get_args(self) -> dict: | ||
| return self._kwargs | ||
| def get_arg(self, name: str, default: Any = None) -> Any: | ||
| return copy.copy(self._kwargs.get(name, default)) |
@@ -62,3 +62,3 @@ r"""Contain utility functions to manage figures.""" | ||
| >>> config | ||
| MatplotlibFigureConfig(color_norm=None) | ||
| MatplotlibFigureConfig() | ||
@@ -65,0 +65,0 @@ ``` |
@@ -16,2 +16,5 @@ r"""Contain data outputs.""" | ||
| "OutputDict", | ||
| "PlotColumnOutput", | ||
| "ScatterColumnOutput", | ||
| "TemporalPlotColumnOutput", | ||
| ] | ||
@@ -28,2 +31,5 @@ | ||
| from arkas.output.mapping import OutputDict | ||
| from arkas.output.plot_column import PlotColumnOutput | ||
| from arkas.output.scatter_column import ScatterColumnOutput | ||
| from arkas.output.temporal_plot_column import TemporalPlotColumnOutput | ||
| from arkas.output.vanilla import Output |
@@ -9,6 +9,6 @@ r"""Implement the pairwise column co-occurrence output.""" | ||
| from coola import objects_are_equal | ||
| from coola.utils import repr_indent, repr_mapping, str_indent, str_mapping | ||
| from arkas.content.column_cooccurrence import ColumnCooccurrenceContentGenerator | ||
| from arkas.evaluator2.vanilla import Evaluator | ||
| from arkas.evaluator2.column_cooccurrence import ColumnCooccurrenceEvaluator | ||
| from arkas.output.lazy import BaseLazyOutput | ||
@@ -18,7 +18,5 @@ from arkas.plotter.column_cooccurrence import ColumnCooccurrencePlotter | ||
| if TYPE_CHECKING: | ||
| import polars as pl | ||
| from arkas.state.column_cooccurrence import ColumnCooccurrenceState | ||
| from arkas.figure.base import BaseFigureConfig | ||
| class ColumnCooccurrenceOutput(BaseLazyOutput): | ||
@@ -28,6 +26,3 @@ r"""Implement the pairwise column co-occurrence output. | ||
| Args: | ||
| frame: The DataFrame to analyze. | ||
| ignore_self: If ``True``, the diagonal of the co-occurrence | ||
| matrix (a.k.a. self-co-occurrence) is set to 0. | ||
| figure_config: The figure configuration. | ||
| state: The state with the co-occurrence matrix. | ||
@@ -38,20 +33,24 @@ Example usage: | ||
| >>> import polars as pl | ||
| >>> import numpy as np | ||
| >>> from arkas.output import ColumnCooccurrenceOutput | ||
| >>> frame = pl.DataFrame( | ||
| ... { | ||
| ... "col1": [0, 1, 1, 0, 0, 1, 0], | ||
| ... "col2": [0, 1, 0, 1, 0, 1, 0], | ||
| ... "col3": [0, 0, 0, 0, 1, 1, 1], | ||
| ... } | ||
| >>> from arkas.state import ColumnCooccurrenceState | ||
| >>> output = ColumnCooccurrenceOutput( | ||
| ... ColumnCooccurrenceState(matrix=np.ones((3, 3)), columns=["a", "b", "c"]) | ||
| ... ) | ||
| >>> output = ColumnCooccurrenceOutput(frame) | ||
| >>> output | ||
| ColumnCooccurrenceOutput(shape=(7, 3), ignore_self=False) | ||
| ColumnCooccurrenceOutput( | ||
| (state): ColumnCooccurrenceState(matrix=(3, 3), figure_config=MatplotlibFigureConfig()) | ||
| ) | ||
| >>> output.get_content_generator() | ||
| ColumnCooccurrenceContentGenerator(shape=(7, 3), ignore_self=False) | ||
| ColumnCooccurrenceContentGenerator( | ||
| (state): ColumnCooccurrenceState(matrix=(3, 3), figure_config=MatplotlibFigureConfig()) | ||
| ) | ||
| >>> output.get_evaluator() | ||
| Evaluator(count=0) | ||
| ColumnCooccurrenceEvaluator( | ||
| (state): ColumnCooccurrenceState(matrix=(3, 3), figure_config=MatplotlibFigureConfig()) | ||
| ) | ||
| >>> output.get_plotter() | ||
| ColumnCooccurrencePlotter(shape=(7, 3), ignore_self=False) | ||
| ColumnCooccurrencePlotter( | ||
| (state): ColumnCooccurrenceState(matrix=(3, 3), figure_config=MatplotlibFigureConfig()) | ||
| ) | ||
@@ -61,38 +60,25 @@ ``` | ||
| def __init__( | ||
| self, | ||
| frame: pl.DataFrame, | ||
| ignore_self: bool = False, | ||
| figure_config: BaseFigureConfig | None = None, | ||
| ) -> None: | ||
| self._frame = frame | ||
| self._ignore_self = bool(ignore_self) | ||
| self._figure_config = figure_config | ||
| def __init__(self, state: ColumnCooccurrenceState) -> None: | ||
| self._state = state | ||
| def __repr__(self) -> str: | ||
| return ( | ||
| f"{self.__class__.__qualname__}(shape={self._frame.shape}, " | ||
| f"ignore_self={self._ignore_self})" | ||
| ) | ||
| args = repr_indent(repr_mapping({"state": self._state})) | ||
| return f"{self.__class__.__qualname__}(\n {args}\n)" | ||
| def __str__(self) -> str: | ||
| args = str_indent(str_mapping({"state": self._state})) | ||
| return f"{self.__class__.__qualname__}(\n {args}\n)" | ||
| def equal(self, other: Any, equal_nan: bool = False) -> bool: | ||
| if not isinstance(other, self.__class__): | ||
| return False | ||
| return ( | ||
| self._ignore_self == other._ignore_self | ||
| and objects_are_equal(self._frame, other._frame, equal_nan=equal_nan) | ||
| and objects_are_equal(self._figure_config, other._figure_config, equal_nan=equal_nan) | ||
| ) | ||
| return self._state.equal(other._state, equal_nan=equal_nan) | ||
| def _get_content_generator(self) -> ColumnCooccurrenceContentGenerator: | ||
| return ColumnCooccurrenceContentGenerator( | ||
| frame=self._frame, ignore_self=self._ignore_self, figure_config=self._figure_config | ||
| ) | ||
| return ColumnCooccurrenceContentGenerator(state=self._state) | ||
| def _get_evaluator(self) -> Evaluator: | ||
| return Evaluator() | ||
| def _get_evaluator(self) -> ColumnCooccurrenceEvaluator: | ||
| return ColumnCooccurrenceEvaluator(state=self._state) | ||
| def _get_plotter(self) -> ColumnCooccurrencePlotter: | ||
| return ColumnCooccurrencePlotter( | ||
| frame=self._frame, ignore_self=self._ignore_self, figure_config=self._figure_config | ||
| ) | ||
| return ColumnCooccurrencePlotter(state=self._state) |
@@ -5,3 +5,11 @@ r"""Contain data plotters.""" | ||
| __all__ = ["BasePlotter", "ColumnCooccurrencePlotter", "Plotter", "PlotterDict"] | ||
| __all__ = [ | ||
| "BasePlotter", | ||
| "ColumnCooccurrencePlotter", | ||
| "PlotColumnPlotter", | ||
| "Plotter", | ||
| "PlotterDict", | ||
| "ScatterColumnPlotter", | ||
| "TemporalPlotColumnPlotter", | ||
| ] | ||
@@ -11,2 +19,5 @@ from arkas.plotter.base import BasePlotter | ||
| from arkas.plotter.mapping import PlotterDict | ||
| from arkas.plotter.plot_column import PlotColumnPlotter | ||
| from arkas.plotter.scatter_column import ScatterColumnPlotter | ||
| from arkas.plotter.temporal_plot_column import TemporalPlotColumnPlotter | ||
| from arkas.plotter.vanilla import Plotter |
@@ -12,4 +12,3 @@ r"""Contain the implementation of a pairwise column co-occurrence | ||
| import matplotlib.pyplot as plt | ||
| from coola import objects_are_equal | ||
| from grizz.utils.cooccurrence import compute_pairwise_cooccurrence | ||
| from coola.utils import repr_indent, repr_mapping, str_indent, str_mapping | ||
@@ -19,3 +18,3 @@ from arkas.figure.creator import FigureCreatorRegistry | ||
| from arkas.figure.matplotlib import MatplotlibFigure, MatplotlibFigureConfig | ||
| from arkas.figure.utils import MISSING_FIGURE_MESSAGE, get_default_config | ||
| from arkas.figure.utils import MISSING_FIGURE_MESSAGE | ||
| from arkas.plot.utils import readable_xticklabels, readable_yticklabels | ||
@@ -26,10 +25,6 @@ from arkas.plotter.base import BasePlotter | ||
| if TYPE_CHECKING: | ||
| from collections.abc import Sequence | ||
| from arkas.figure.base import BaseFigure | ||
| from arkas.state.column_cooccurrence import ColumnCooccurrenceState | ||
| import numpy as np | ||
| import polars as pl | ||
| from arkas.figure.base import BaseFigure, BaseFigureConfig | ||
| class BaseFigureCreator(ABC): | ||
@@ -46,7 +41,7 @@ r"""Define the base class to create a figure of the pairwise column | ||
| >>> from arkas.plotter.column_cooccurrence import MatplotlibFigureCreator | ||
| >>> from arkas.state import ColumnCooccurrenceState | ||
| >>> creator = MatplotlibFigureCreator() | ||
| >>> creator | ||
| MatplotlibFigureCreator() | ||
| >>> config = MatplotlibFigureConfig() | ||
| >>> fig = creator.create(matrix=np.ones((3, 3)), columns=["a", "b", "c"], config=config) | ||
| >>> fig = creator.create( | ||
| ... ColumnCooccurrenceState(matrix=np.ones((3, 3)), columns=["a", "b", "c"]) | ||
| ... ) | ||
@@ -57,11 +52,7 @@ ``` | ||
| @abstractmethod | ||
| def create( | ||
| self, matrix: np.ndarray, columns: Sequence[str], config: BaseFigureConfig | ||
| ) -> BaseFigure: | ||
| def create(self, state: ColumnCooccurrenceState) -> BaseFigure: | ||
| r"""Create a figure of the pairwise column co-occurrence matrix. | ||
| Args: | ||
| matrix: The co-occurrence matrix. | ||
| columns: The column names. | ||
| config: The figure config. | ||
| state: The state with the co-occurrence matrix. | ||
@@ -78,5 +69,7 @@ Returns: | ||
| >>> from arkas.plotter.column_cooccurrence import MatplotlibFigureCreator | ||
| >>> from arkas.state import ColumnCooccurrenceState | ||
| >>> creator = MatplotlibFigureCreator() | ||
| >>> config = MatplotlibFigureConfig() | ||
| >>> fig = creator.create(matrix=np.ones((3, 3)), columns=["a", "b", "c"], config=config) | ||
| >>> fig = creator.create( | ||
| ... ColumnCooccurrenceState(matrix=np.ones((3, 3)), columns=["a", "b", "c"]) | ||
| ... ) | ||
@@ -98,7 +91,7 @@ ``` | ||
| >>> from arkas.plotter.column_cooccurrence import MatplotlibFigureCreator | ||
| >>> from arkas.state import ColumnCooccurrenceState | ||
| >>> creator = MatplotlibFigureCreator() | ||
| >>> creator | ||
| MatplotlibFigureCreator() | ||
| >>> config = MatplotlibFigureConfig() | ||
| >>> fig = creator.create(matrix=np.ones((3, 3)), columns=["a", "b", "c"], config=config) | ||
| >>> fig = creator.create( | ||
| ... ColumnCooccurrenceState(matrix=np.ones((3, 3)), columns=["a", "b", "c"]) | ||
| ... ) | ||
@@ -111,14 +104,12 @@ ``` | ||
| def create( | ||
| self, matrix: np.ndarray, columns: Sequence[str], config: MatplotlibFigureConfig | ||
| ) -> BaseFigure: | ||
| if matrix.shape[0] == 0: | ||
| def create(self, state: ColumnCooccurrenceState) -> BaseFigure: | ||
| if state.matrix.shape[0] == 0: | ||
| return HtmlFigure(MISSING_FIGURE_MESSAGE) | ||
| fig, ax = plt.subplots(**config.get_args()) | ||
| im = ax.imshow(matrix, norm=config.get_color_norm()) | ||
| fig, ax = plt.subplots(**state.figure_config.get_arg("init", {})) | ||
| im = ax.imshow(state.matrix, norm=state.figure_config.get_arg("color_norm")) | ||
| fig.colorbar(im) | ||
| ax.set_xticks( | ||
| range(len(columns)), | ||
| labels=columns, | ||
| range(len(state.columns)), | ||
| labels=state.columns, | ||
| rotation=45, | ||
@@ -128,6 +119,3 @@ ha="right", | ||
| ) | ||
| ax.set_yticks( | ||
| range(len(columns)), | ||
| labels=columns, | ||
| ) | ||
| ax.set_yticks(range(len(state.columns)), labels=state.columns) | ||
| readable_xticklabels(ax, max_num_xticks=50) | ||
@@ -137,7 +125,13 @@ readable_yticklabels(ax, max_num_yticks=50) | ||
| if matrix.shape[0] < 16: | ||
| for i in range(len(columns)): | ||
| for j in range(len(columns)): | ||
| if state.matrix.shape[0] < 16: | ||
| for i in range(len(state.columns)): | ||
| for j in range(len(state.columns)): | ||
| ax.text( | ||
| j, i, matrix[i, j], ha="center", va="center", color="w", fontsize="xx-small" | ||
| j, | ||
| i, | ||
| state.matrix[i, j], | ||
| ha="center", | ||
| va="center", | ||
| color="w", | ||
| fontsize="xx-small", | ||
| ) | ||
@@ -153,6 +147,3 @@ | ||
| Args: | ||
| frame: The DataFrame to analyze. | ||
| ignore_self: If ``True``, the diagonal of the co-occurrence | ||
| matrix (a.k.a. self-co-occurrence) is set to 0. | ||
| figure_config: The figure configuration. | ||
| state: The state with the co-occurrence matrix. | ||
@@ -163,14 +154,12 @@ Example usage: | ||
| >>> import polars as pl | ||
| >>> import numpy as np | ||
| >>> from arkas.plotter import ColumnCooccurrencePlotter | ||
| >>> frame = pl.DataFrame( | ||
| ... { | ||
| ... "col1": [0, 1, 1, 0, 0, 1, 0], | ||
| ... "col2": [0, 1, 0, 1, 0, 1, 0], | ||
| ... "col3": [0, 0, 0, 0, 1, 1, 1], | ||
| ... } | ||
| >>> from arkas.state import ColumnCooccurrenceState | ||
| >>> plotter = ColumnCooccurrencePlotter( | ||
| ... ColumnCooccurrenceState(matrix=np.ones((3, 3)), columns=["a", "b", "c"]) | ||
| ... ) | ||
| >>> plotter = ColumnCooccurrencePlotter(frame) | ||
| >>> plotter | ||
| ColumnCooccurrencePlotter(shape=(7, 3), ignore_self=False) | ||
| ColumnCooccurrencePlotter( | ||
| (state): ColumnCooccurrenceState(matrix=(3, 3), figure_config=MatplotlibFigureConfig()) | ||
| ) | ||
@@ -180,22 +169,17 @@ ``` | ||
| registry: FigureCreatorRegistry = FigureCreatorRegistry( | ||
| registry = FigureCreatorRegistry[BaseFigureCreator]( | ||
| {MatplotlibFigureConfig.backend(): MatplotlibFigureCreator()} | ||
| ) | ||
| def __init__( | ||
| self, | ||
| frame: pl.DataFrame, | ||
| ignore_self: bool = False, | ||
| figure_config: BaseFigureConfig | None = None, | ||
| ) -> None: | ||
| self._frame = frame | ||
| self._ignore_self = bool(ignore_self) | ||
| self._figure_config = figure_config or get_default_config() | ||
| def __init__(self, state: ColumnCooccurrenceState) -> None: | ||
| self._state = state | ||
| def __repr__(self) -> str: | ||
| return ( | ||
| f"{self.__class__.__qualname__}(shape={self._frame.shape}, " | ||
| f"ignore_self={self._ignore_self})" | ||
| ) | ||
| args = repr_indent(repr_mapping({"state": self._state})) | ||
| return f"{self.__class__.__qualname__}(\n {args}\n)" | ||
| def __str__(self) -> str: | ||
| args = str_indent(str_mapping({"state": self._state})) | ||
| return f"{self.__class__.__qualname__}(\n {args}\n)" | ||
| def compute(self) -> Plotter: | ||
@@ -207,22 +191,6 @@ return Plotter(self.plot()) | ||
| return False | ||
| return ( | ||
| self._ignore_self == other._ignore_self | ||
| and objects_are_equal(self._frame, other._frame, equal_nan=equal_nan) | ||
| and self._figure_config.equal(other._figure_config) | ||
| ) | ||
| return self._state.equal(other._state, equal_nan=equal_nan) | ||
| def plot(self, prefix: str = "", suffix: str = "") -> dict: | ||
| figure = self.registry.find_creator(self._figure_config.backend()).create( | ||
| matrix=self.cooccurrence_matrix(), | ||
| columns=self._frame.columns, | ||
| config=self._figure_config, | ||
| ) | ||
| figure = self.registry.find_creator(self._state.figure_config.backend()).create(self._state) | ||
| return {f"{prefix}column_cooccurrence{suffix}": figure} | ||
| def cooccurrence_matrix(self) -> np.ndarray: | ||
| r"""Return the pairwise column co-occurrence matrix. | ||
| Returns: | ||
| The pairwise column co-occurrence. | ||
| """ | ||
| return compute_pairwise_cooccurrence(frame=self._frame, ignore_self=self._ignore_self) |
@@ -5,6 +5,18 @@ r"""Contain states.""" | ||
| __all__ = ["AccuracyState", "BaseState", "PrecisionRecallState"] | ||
| __all__ = [ | ||
| "AccuracyState", | ||
| "BaseState", | ||
| "ColumnCooccurrenceState", | ||
| "DataFrameState", | ||
| "PrecisionRecallState", | ||
| "ScatterDataFrameState", | ||
| "TemporalDataFrameState", | ||
| ] | ||
| from arkas.state.accuracy import AccuracyState | ||
| from arkas.state.base import BaseState | ||
| from arkas.state.column_cooccurrence import ColumnCooccurrenceState | ||
| from arkas.state.dataframe import DataFrameState | ||
| from arkas.state.precision_recall import PrecisionRecallState | ||
| from arkas.state.scatter_dataframe import ScatterDataFrameState | ||
| from arkas.state.temporal_dataframe import TemporalDataFrameState |
@@ -5,3 +5,3 @@ r"""Implement some utility functions for ``numpy.ndarray``s.""" | ||
| __all__ = ["filter_range", "nonnan", "rand_replace", "to_array"] | ||
| __all__ = ["check_square_matrix", "filter_range", "nonnan", "rand_replace", "to_array"] | ||
@@ -15,2 +15,30 @@ from typing import Any | ||
| def check_square_matrix(name: str, array: np.ndarray) -> None: | ||
| r"""Check if the input array is a square matrix. | ||
| Args: | ||
| name: The name of the variable. | ||
| array: The array to check. | ||
| Raises: | ||
| ValueError: if the array is not a square matrix. | ||
| Example usage: | ||
| ```pycon | ||
| >>> import numpy as np | ||
| >>> from arkas.utils.array import check_square_matrix | ||
| >>> check_square_matrix("var", np.ones((3, 3))) | ||
| ``` | ||
| """ | ||
| if array.ndim != 2 or array.shape[0] != array.shape[1]: | ||
| msg = ( | ||
| f"Incorrect {name!r}. The array must be a square matrix but received an array of " | ||
| f"shape {array.shape}" | ||
| ) | ||
| raise ValueError(msg) | ||
| def filter_range(array: np.ndarray, xmin: float, xmax: float) -> np.ndarray: | ||
@@ -17,0 +45,0 @@ r"""Filter in the values in a given range. |
@@ -47,1 +47,34 @@ r"""Contain DataFrame utility functions.""" | ||
| return {s.name: s.to_numpy() for s in frame.iter_columns()} | ||
| def check_column_exist(frame: pl.DataFrame, col: str) -> None: | ||
| r"""Check if a column exists in the DataFrame. | ||
| Args: | ||
| frame: The DataFrame. | ||
| col: The column to check. | ||
| Raises: | ||
| ValueError: if the column is missing. | ||
| Example usage: | ||
| ```pycon | ||
| >>> import polars as pl | ||
| >>> from arkas.utils.dataframe import check_column_exist | ||
| >>> frame = pl.DataFrame( | ||
| ... { | ||
| ... "int": [1, 2, 3, 4, 5], | ||
| ... "float": [5.0, 4.0, 3.0, 2.0, 1.0], | ||
| ... "str": ["a", "b", "c", "d", "e"], | ||
| ... }, | ||
| ... schema={"int": pl.Int64, "float": pl.Float64, "str": pl.String}, | ||
| ... ) | ||
| >>> check_column_exist(frame, "int") | ||
| ``` | ||
| """ | ||
| if col not in frame: | ||
| msg = f"The column {col!r} is not in the DataFrame: {sorted(frame.columns)}" | ||
| raise ValueError(msg) |
| r"""Contain the default figure config.""" | ||
| from __future__ import annotations | ||
| __all__ = ["DefaultFigureConfig"] | ||
| from typing import Any | ||
| from arkas.figure.base import BaseFigureConfig | ||
| class DefaultFigureConfig(BaseFigureConfig): | ||
| r"""Implement the default figure config. | ||
| Example usage: | ||
| ```pycon | ||
| >>> from arkas.figure import DefaultFigureConfig | ||
| >>> config = DefaultFigureConfig() | ||
| >>> config | ||
| DefaultFigureConfig() | ||
| ``` | ||
| """ | ||
| def __repr__(self) -> str: | ||
| return f"{self.__class__.__qualname__}()" | ||
| @classmethod | ||
| def backend(cls) -> str: | ||
| return "default" | ||
| def equal(self, other: Any, equal_nan: bool = False) -> bool: # noqa: ARG002 | ||
| return isinstance(other, self.__class__) | ||
| def get_args(self) -> dict: | ||
| return {} |
Alert delta unavailable
Currently unable to show alert delta for PyPI packages.
813973
9.14%230
7.48%20408
8.88%