Source code for covsirphy.visualization.bar_plot

from __future__ import annotations
from matplotlib import pyplot as plt
from matplotlib.ticker import ScalarFormatter
import pandas as pd
from typing_extensions import Self
from covsirphy.util.validator import Validator
from covsirphy.visualization.vbase import VisualizeBase, find_args


[docs] class BarPlot(VisualizeBase): """Create a bar plot. Args: filename (str or None): filename to save the figure or None (display) bbox_inches (str): bounding box in inches when creating the figure kwargs: the other arguments of matplotlib.pyplot.savefig() """ def __init__(self, filename=None, bbox_inches="tight", **kwargs): self._filename = filename self._savefig_dict = {"bbox_inches": bbox_inches, **kwargs} # Properties self._title = "" self._variables = [] self._ax = None def __enter__(self) -> Self: return super().__enter__() def __exit__(self, *exc_info): return super().__exit__(*exc_info)
[docs] def plot(self, data, vertical=True, colormap=None, color_dict=None, **kwargs): """Create bar plot. Args: data (pandas.DataFrame or pandas.Series): data to show Index labels of the bars Columns variables to show vertical (bool): whether vertical bar plot (True) or horizontal bar plot (False) colormap (str, matplotlib colormap object or None): colormap, please refer to https://matplotlib.org/examples/color/colormaps_reference.html color_dict (dict[str, str] or None): dictionary of column names (keys) and colors (values) kwargs: keyword arguments of pandas.DataFrame.plot() """ if isinstance(data, pd.Series): data = pd.DataFrame(data) Validator(data, "data").dataframe() self._variables = data.columns.tolist() # Color color_args = self._plot_colors(data.columns, colormap=colormap, color_dict=color_dict) # Set plotting method_dict = {True: data.plot.bar, False: data.plot.barh} try: self._ax = method_dict[vertical](**color_args, **kwargs) except KeyError as e: raise KeyError(e.args[0]) from None # No rotation of xticks self._ax.tick_params(axis="x", rotation=0)
[docs] def x_axis(self, xlabel=None): """Set x axis. Args: xlabel (str or None): x-label """ # Label self._ax.set_xlabel(xlabel)
[docs] def y_axis(self, ylabel="Cases", y_logscale=False, ylim=(0, None), math_scale=True, y_integer=False): """Set x axis. Args: ylabel (str or None): y-label y_logscale (bool): whether use log-scale in y-axis or not ylim (tuple(int or float, int or float)): limit of y domain math_scale (bool): whether use LaTEX or not in y-label y_integer (bool): whether force to show the values as integer or not Note: If None is included in ylim, the values will be automatically determined by Matplotlib """ # Label self._ax.set_ylabel(ylabel) # Math scale if math_scale: self._ax.yaxis.set_major_formatter(ScalarFormatter(useMathText=True)) self._ax.ticklabel_format(style="sci", axis="y", scilimits=(0, 0)) # Integer scale if y_integer: fmt = ScalarFormatter(useOffset=False) fmt.set_scientific(False) self._ax.yaxis.set_major_formatter(fmt) # Log scale if y_logscale: self._ax.set_yscale("log") ylim = (None, None) if ylim[0] == 0 else ylim # limit self._ax.set_ylim(*ylim)
[docs] def line(self, v=None, h=None, color="black", linestyle=":"): """Show vertical/horizontal lines. Args: v (list[int/float] or None): list of x values of vertical lines or None h (list[int/float] or None): list of y values of horizontal lines or None color (str): color of the line linestyle (str): linestyle """ if h is not None: self._ax.axhline(y=h, color=color, linestyle=linestyle) if v is not None: v = v if isinstance(v, list) else [v] for value in v: self._ax.axvline(x=value, color=color, linestyle=linestyle)
[docs] def bar_plot(df, title=None, filename=None, show_legend=True, **kwargs): """Wrapper function: show chronological change of the data. Args: data (pandas.DataFrame or pandas.Series): data to show Index Date (pandas.Timestamp) Columns variables to show title (str): title of the figure filename (str or None): filename to save the figure or None (display) show_legend (bool): whether show legend or not kwargs: keyword arguments of the following classes and methods. - covsirphy.BarPlot() and its methods, - matplotlib.pyplot.savefig(), matplotlib.pyplot.legend(), - pandas.DataFrame.plot() """ with BarPlot(filename=filename, **find_args(plt.savefig, **kwargs)) as bp: bp.title = title bp.plot(data=df, **find_args([BarPlot.plot, pd.DataFrame.plot], **kwargs)) # Axis bp.x_axis(**find_args([BarPlot.x_axis], **kwargs)) bp.y_axis(**find_args([BarPlot.y_axis], **kwargs)) # Vertical/horizontal lines bp.line(**find_args([BarPlot.line], **kwargs)) # Legend if show_legend: bp.legend(**find_args([BarPlot.legend, plt.legend], **kwargs)) else: bp.legend_hide()