Source code for covsirphy.visualization.compare_plot

from __future__ import annotations
from matplotlib import pyplot as plt
from matplotlib.ticker import ScalarFormatter
from typing_extensions import Self
from covsirphy.util.validator import Validator
from covsirphy.visualization.vbase import VisualizeBase, find_args


[docs] class ComparePlot(VisualizeBase): """Compare two groups with specified variables. Args: filename (str or None): filename to save the figure or None (display) bbox_inches (str): bounding box in inches when creating the figure kwargs: the other arguments of matplotlib.pyplot.savefig() """ def __init__(self, filename=None, bbox_inches="tight", **kwargs): self._filename = filename self._savefig_dict = {"bbox_inches": bbox_inches, **kwargs} # Properties self._title = "" self._variables = [] self._ax = None def __enter__(self) -> Self: return super().__enter__() def __exit__(self, *exc_info): return super().__exit__(*exc_info)
[docs] def plot(self, data, variables, groups): """Compare two groups with specified variables. Args: data (pandas.DataFrame): data to show Index x values Columns y variables to show, "{variable}_{group}" for all combinations of variables and groups variables (list[str]): variables to compare groups (list[str]): the first group name and the second group name """ Validator(variables, "variables").sequence() group1, group2 = Validator(groups, "groups").sequence() col_nest = [[f"{variable}_{group}" for group in groups] for variable in variables] Validator(data, "data").dataframe(columns=sum(col_nest, [])) # Prepare figure object fig_len = len(variables) + 1 _, self._ax = plt.subplots(ncols=1, nrows=fig_len, figsize=(9, 6 * fig_len / 2)) # Compare each variable for (ax, v, columns) in zip(self._ax.ravel()[1:], variables, col_nest): data[columns].plot.line( ax=ax, ylim=(None, None), sharex=True, title=f"Comparison regarding {v}(t)") ax.yaxis.set_major_formatter(ScalarFormatter(useMathText=True)) ax.ticklabel_format(style="sci", axis="y", scilimits=(0, 0)) ax.legend(bbox_to_anchor=(1.02, 0), loc="lower left", borderaxespad=0) # Show residuals for (v, columns) in zip(variables, col_nest): data[f"{v}_diff"] = data[columns[0]] - data[columns[1]] data[f"{v}_diff"].plot.line( ax=self._ax.ravel()[0], sharex=True, title=f"{group1.capitalize()} - {group2.capitalize()}") self._ax.ravel()[0].axhline(y=0, color="black", linestyle="--") self._ax.ravel()[0].yaxis.set_major_formatter(ScalarFormatter(useMathText=True)) self._ax.ravel()[0].ticklabel_format(style="sci", axis="y", scilimits=(0, 0)) self._ax.ravel()[0].legend(bbox_to_anchor=(1.02, 0), loc="lower left", borderaxespad=0)
[docs] def compare_plot(df, variables, groups, filename=None, **kwargs): """Wrapper function: show chronological change of the data. Args: df (pandas.DataFrame): data to show Index x values Columns y variables to show, "{variable}_{group}" for all combinations of variables and groups variables (list[str]): variables to compare groups (list[str]): the first group name and the second group name filename (str or None): filename to save the figure or None (display) kwargs: keyword arguments of the following classes and methods. - matplotlib.pyplot.savefig() - matplotlib.pyplot.legend() """ with ComparePlot(filename=filename, **find_args(plt.savefig, **kwargs)) as cp: cp.plot(data=df, variables=variables, groups=groups)