#!/usr/bin/env python
# -*- coding: utf-8 -*-
import sys
import warnings
import matplotlib
if not hasattr(sys, "ps1"):
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
# plt.style.use("seaborn-ticks")
plt.style.use("fast")
plt.rcParams["xtick.direction"] = "in"
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["font.size"] = 11.0
plt.rcParams["figure.figsize"] = (9, 6)
plt.rcParams["legend.frameon"] = False
[docs]def line_plot_multiple(df, x_col, actual_col, predicted_cols,
title, ylabel, xlim=(None, None), v=None, y_logscale=False, filename=None):
"""
show multiple line graph of chronological change with actual plots.
Args:
df (pandas.DataFrame): data
Index
Date (pandas.TimeStamp): Observation date
Columns
- column defined by @x_col, values for x-axis
- column defined by @actual_col, actual values for y-axis
- columns defined by @predicted_cols, predicted values for y-axis
x_col (str): column name for x-axis
actual_col (str): column name for y-axis
predicted_cols (list[str]): list of columns which have predicted values
title (str): title of the figure
y_label (str): label for y-axis
xlim (tuple(int or float, int or float)): limit of x dimain
y_logscale (bool): whether use log-scale in y-axis or not
v (list[int]): list of Recovered values to show vertical lines
filename (str): filename of the figure, or None (show figure)
Note:
When xlim[0] is None and lower x-axis limit determined by matplotlib automatically is lower than 0,
lower x-axis limit will be set to 0.
"""
warnings.simplefilter("ignore", FutureWarning)
warnings.simplefilter("ignore", UserWarning)
x_series = df[x_col]
actual = df[actual_col]
# Plot the actual values
plt.plot(
x_series, actual,
label=actual_col, color="black", marker=".", markeredgewidth=0, linewidth=0)
# Plot the predicted values
for col in predicted_cols:
plt.plot(x_series, df[col], label=col)
# x-axis
plt.xlabel(x_col)
plt.xlim(max(plt.xlim()[0], xlim[0] or 0), xlim[1])
# y-axis
plt.ylabel(ylabel)
try:
plt.yscale("log", base=10)
except Exception:
# Matplotlib version < 3.3
plt.yscale("log", basey=10)
# Delete y-labels of log-scale (minor) axis
plt.setp(plt.gca().get_yticklabels(minor=True), visible=False)
plt.gca().tick_params(left=False, which="minor")
# Set new y-labels of major axis
if y_logscale:
ymin, ymax = plt.ylim()
ydiff_scale = int(np.log10(ymax - ymin))
yticks = np.linspace(
round(ymin, - ydiff_scale),
round(ymax, - ydiff_scale), 5, dtype=np.int64)
plt.gca().set_yticks(yticks)
fmt = matplotlib.ticker.ScalarFormatter(useOffset=False)
fmt.set_scientific(False)
plt.gca().yaxis.set_major_formatter(fmt)
# Title
plt.title(title)
# Vertical lines
if isinstance(v, (list, tuple)):
for value in v:
plt.axvline(x=value, color="black", linestyle=":")
# Legend
plt.legend(
bbox_to_anchor=(1.02, 0), loc="lower left", borderaxespad=0
)
# Save figure or show figure
plt.tight_layout()
if filename is None:
plt.show()
return None
plt.savefig(
filename, bbox_inches="tight", transparent=False, dpi=300
)
plt.clf()