from __future__ import annotations
from matplotlib import pyplot as plt
from matplotlib.ticker import ScalarFormatter
import pandas as pd
from typing_extensions import Self
from covsirphy.util.validator import Validator
from covsirphy.visualization.vbase import VisualizeBase, find_args
[docs]
class LinePlot(VisualizeBase):
"""Create a line plot.
Args:
filename (str or None): filename to save the figure or None (display)
bbox_inches (str): bounding box in inches when creating the figure
kwargs: the other arguments of matplotlib.pyplot.savefig()
"""
def __init__(self, filename=None, bbox_inches="tight", **kwargs):
self._filename = filename
self._savefig_dict = {"bbox_inches": bbox_inches, **kwargs}
# Properties
self._title = ""
self._variables = []
self._ax = None
def __enter__(self) -> Self:
return super().__enter__()
def __exit__(self, *exc_info):
return super().__exit__(*exc_info)
[docs]
def plot(self, data, colormap=None, color_dict=None, **kwargs):
"""Plot chronological change of the data.
Args:
data (pandas.DataFrame or pandas.Series): data to show
Index
x values
Columns
y variables to show
colormap (str, matplotlib colormap object or None): colormap, please refer to https://matplotlib.org/examples/color/colormaps_reference.html
color_dict (dict[str, str] or None): dictionary of column names (keys) and colors (values)
kwargs: keyword arguments of pandas.DataFrame.plot()
"""
if isinstance(data, pd.Series):
data = pd.DataFrame(data)
Validator(data, "data").dataframe()
self._variables = data.columns.tolist()
# Color
color_args = self._plot_colors(data.columns, colormap=colormap, color_dict=color_dict)
# Set plotting
try:
self._ax = data.plot(**color_args, **kwargs)
except KeyError as e:
raise KeyError(e.args[0]) from None
[docs]
def x_axis(self, xlabel=None, x_logscale=False, xlim=(None, None)):
"""Set x axis.
Args:
xlabel (str or None): x-label
x_logscale (bool): whether use log-scale in x-axis or not
xlim (tuple(int or float, int or float)): limit of x domain
Note:
If None is included in xlim, the values will be automatically determined by Matplotlib
"""
# Label
self._ax.set_xlabel(xlabel)
# Log scale
if x_logscale:
self._ax.set_xscale("log")
xlim = (None, None) if xlim[0] == 0 else xlim
# limit
self._ax.set_xlim(*xlim)
[docs]
def y_axis(self, ylabel="Cases", y_logscale=False, ylim=(0, None), math_scale=True, y_integer=False):
"""Set x axis.
Args:
ylabel (str or None): y-label
y_logscale (bool): whether use log-scale in y-axis or not
ylim (tuple(int or float, int or float)): limit of y domain
math_scale (bool): whether use LaTEX or not in y-label
y_integer (bool): whether force to show the values as integer or not
Note:
If None is included in ylim, the values will be automatically determined by Matplotlib
"""
# Label
self._ax.set_ylabel(ylabel)
# Math scale
if math_scale:
self._ax.yaxis.set_major_formatter(ScalarFormatter(useMathText=True))
self._ax.ticklabel_format(style="sci", axis="y", scilimits=(0, 0))
# Integer scale
if y_integer:
fmt = ScalarFormatter(useOffset=False)
fmt.set_scientific(False)
self._ax.yaxis.set_major_formatter(fmt)
# Log scale
if y_logscale:
self._ax.set_yscale("log")
ylim = (None, None) if ylim[0] == 0 else ylim
# limit
self._ax.set_ylim(*ylim)
@staticmethod
def _convert_to_list(x):
"""Convert None to empty list, str/float/int etc. to a list.
Args:
x (list/tuple[str/int/float] or None): value(s)
Returns:
list or tuple
"""
return x if isinstance(x, (list, tuple)) else [] if x is None else [x]
[docs]
def line(self, v=None, h=None, color="black", linestyle=":"):
"""Show vertical/horizontal lines.
Args:
v (list/tuple[int/float] or None): list of x values of vertical lines or None
h (list/tuple[int/float] or None): list of y values of horizontal lines or None
color (str): color of the line
linestyle (str): linestyle
"""
# Horizontal
h = self._convert_to_list(h)
for value in h:
self._ax.axhline(y=value, color=color, linestyle=linestyle)
# Vertical
v = self._convert_to_list(v)
for value in v:
self._ax.axvline(x=value, color=color, linestyle=linestyle)
[docs]
def line_plot(df, title=None, filename=None, show_legend=True, **kwargs):
"""Wrapper function: show chronological change of the data.
Args:
data (pandas.DataFrame or pandas.Series): data to show
Index
Date (pandas.Timestamp)
Columns
variables to show
title (str): title of the figure
filename (str or None): filename to save the figure or None (display)
show_legend (bool): whether show legend or not
kwargs: keyword arguments of the following classes and methods.
- covsirphy.LinePlot() and its methods,
- matplotlib.pyplot.savefig(), matplotlib.pyplot.legend(),
- pandas.DataFrame.plot()
"""
with LinePlot(filename=filename, **find_args(plt.savefig, **kwargs)) as lp:
lp.title = title
lp.plot(data=df, **find_args([LinePlot.plot, pd.DataFrame.plot], **kwargs))
# Axis
lp.x_axis(**find_args([LinePlot.x_axis], **kwargs))
lp.y_axis(**find_args([LinePlot.y_axis], **kwargs))
# Vertical/horizontal lines
lp.line(**find_args([LinePlot.line], **kwargs))
# Legend
if show_legend:
lp.legend(**find_args([LinePlot.legend, plt.legend], **kwargs))
else:
lp.legend_hide()