from __future__ import annotations
from matplotlib import pyplot as plt
from matplotlib.ticker import ScalarFormatter
from typing_extensions import Self
from covsirphy.util.validator import Validator
from covsirphy.visualization.vbase import VisualizeBase, find_args
[docs]
class ComparePlot(VisualizeBase):
"""Compare two groups with specified variables.
Args:
filename (str or None): filename to save the figure or None (display)
bbox_inches (str): bounding box in inches when creating the figure
kwargs: the other arguments of matplotlib.pyplot.savefig()
"""
def __init__(self, filename=None, bbox_inches="tight", **kwargs):
self._filename = filename
self._savefig_dict = {"bbox_inches": bbox_inches, **kwargs}
# Properties
self._title = ""
self._variables = []
self._ax = None
def __enter__(self) -> Self:
return super().__enter__()
def __exit__(self, *exc_info):
return super().__exit__(*exc_info)
[docs]
def plot(self, data, variables, groups):
"""Compare two groups with specified variables.
Args:
data (pandas.DataFrame): data to show
Index
x values
Columns
y variables to show, "{variable}_{group}" for all combinations of variables and groups
variables (list[str]): variables to compare
groups (list[str]): the first group name and the second group name
"""
Validator(variables, "variables").sequence()
group1, group2 = Validator(groups, "groups").sequence()
col_nest = [[f"{variable}_{group}" for group in groups] for variable in variables]
Validator(data, "data").dataframe(columns=sum(col_nest, []))
# Prepare figure object
fig_len = len(variables) + 1
_, self._ax = plt.subplots(ncols=1, nrows=fig_len, figsize=(9, 6 * fig_len / 2))
# Compare each variable
for (ax, v, columns) in zip(self._ax.ravel()[1:], variables, col_nest):
data[columns].plot.line(
ax=ax, ylim=(None, None), sharex=True, title=f"Comparison regarding {v}(t)")
ax.yaxis.set_major_formatter(ScalarFormatter(useMathText=True))
ax.ticklabel_format(style="sci", axis="y", scilimits=(0, 0))
ax.legend(bbox_to_anchor=(1.02, 0), loc="lower left", borderaxespad=0)
# Show residuals
for (v, columns) in zip(variables, col_nest):
data[f"{v}_diff"] = data[columns[0]] - data[columns[1]]
data[f"{v}_diff"].plot.line(
ax=self._ax.ravel()[0], sharex=True,
title=f"{group1.capitalize()} - {group2.capitalize()}")
self._ax.ravel()[0].axhline(y=0, color="black", linestyle="--")
self._ax.ravel()[0].yaxis.set_major_formatter(ScalarFormatter(useMathText=True))
self._ax.ravel()[0].ticklabel_format(style="sci", axis="y", scilimits=(0, 0))
self._ax.ravel()[0].legend(bbox_to_anchor=(1.02, 0), loc="lower left", borderaxespad=0)
[docs]
def compare_plot(df, variables, groups, filename=None, **kwargs):
"""Wrapper function: show chronological change of the data.
Args:
df (pandas.DataFrame): data to show
Index
x values
Columns
y variables to show, "{variable}_{group}" for all combinations of variables and groups
variables (list[str]): variables to compare
groups (list[str]): the first group name and the second group name
filename (str or None): filename to save the figure or None (display)
kwargs: keyword arguments of the following classes and methods.
- matplotlib.pyplot.savefig()
- matplotlib.pyplot.legend()
"""
with ComparePlot(filename=filename, **find_args(plt.savefig, **kwargs)) as cp:
cp.plot(data=df, variables=variables, groups=groups)