Source code for covsirphy.visualization.scatter_plot

from __future__ import annotations
from matplotlib import pyplot as plt
import pandas as pd
from typing_extensions import Self
from covsirphy.util.error import UnExecutedError
from covsirphy.util.validator import Validator
from covsirphy.visualization.vbase import find_args
from covsirphy.visualization.line_plot import LinePlot


[docs] class ScatterPlot(LinePlot): """Create a scatter plot. Args: filename (str or None): filename to save the figure or None (display) bbox_inches (str): bounding box in inches when creating the figure kwargs: the other arguments of matplotlib.pyplot.savefig() """ def __init__(self, filename=None, bbox_inches="tight", **kwargs): self._filename = filename self._savefig_dict = {"bbox_inches": bbox_inches, **kwargs} # Properties self._title = "" self._ax = None self._data = pd.DataFrame(columns=["x", "y"]) def __enter__(self) -> Self: return super().__enter__() def __exit__(self, *exc_info): return super().__exit__(*exc_info)
[docs] def plot(self, data, colormap=None, color_dict=None, **kwargs): """Plot chronological change of the data. Args: data (pandas.DataFrame): data to show Index reset index Columns x (int or float): x values y (int or float): y values colormap (str, matplotlib colormap object or None): colormap, please refer to https://matplotlib.org/examples/color/colormaps_reference.html color_dict (dict[str, str] or None): dictionary of column names (keys) and colors (values) kwargs: keyword arguments of pandas.DataFrame.plot() """ self._data = Validator(data, "data").dataframe(columns=["x", "y"]) color_args = self._plot_colors(data.columns, colormap=colormap, color_dict=color_dict) self._ax = data.plot.scatter(x="x", y="y", **color_args, **kwargs)
[docs] def line_straight(self, p1=None, p2=None, color="black", linestyle=":"): """Connect the points with a straight line. Args: p1 (tuple(int or float, int or float) or None): (x, y) of the first point or None (min values) p2 (tuple(int or float, int or float) or None): (x, y) of the second point or None (max values) color (str): color of the line linestyle (str): linestyle Note: The same line will be show when p1 and p2 is reordered. """ if self._data.empty: raise UnExecutedError("ScatterPlot.plot()") x1, y1 = (self._data["x"].min(), self._data["y"].min()) if p1 is None else p1 x2, y2 = (self._data["x"].max(), self._data["y"].max()) if p2 is None else p2 self._ax.plot([x1, x2], [y1, y2], color=color, linestyle=linestyle)
[docs] def legend(self, **kwargs): """ScatterPlot.legend() is not implemented. """ raise NotImplementedError
[docs] def legend_hide(self): """ScatterPlot.legend_hide() is not implemented. """ raise NotImplementedError
[docs] def scatter_plot(df, title=None, filename=None, **kwargs): """Wrapper function: show chronological change of the data. Args: data (pandas.DataFrame): data to show Index reset index Columns x (int or float): x values y (int or float): y values title (str): title of the figure filename (str or None): filename to save the figure or None (display) kwargs: keyword arguments of the following classes and methods. - covsirphy.ScatterPlot() and its methods, - matplotlib.pyplot.savefig(), matplotlib.pyplot.legend(), - pandas.DataFrame.plot() """ with ScatterPlot(filename=filename, **find_args(plt.savefig, **kwargs)) as sp: sp.title = title sp.plot(data=df, **find_args([ScatterPlot.plot, pd.DataFrame.plot], **kwargs)) # Axis sp.x_axis(**find_args([ScatterPlot.x_axis], **kwargs)) sp.y_axis(**find_args([ScatterPlot.y_axis], **kwargs)) # Vertical/horizontal lines sp.line(**find_args([ScatterPlot.line], **kwargs)) # Straight lines sp.line_straight(**find_args([ScatterPlot.line_straight], **kwargs))