from __future__ import annotations
import numpy as np
import pandas as pd
from covsirphy.util.validator import Validator
from covsirphy.dynamics.ode import ODEModel
[docs]
class SIRModel(ODEModel):
    """Class of SIR model.
    Args:
        date_range: start date and end date of simulation
        tau: tau value [min]
        initial_dict: initial values
            - Susceptible (int): the number of susceptible cases
            - Infected (int): the number of infected cases
            - Fatal or Recovered (int): the number of fatal or recovered cases
        param_dict: non-dimensional parameter values
            - rho: non-dimensional effective contact rate
            - sigma: non-dimensional recovery plus mortality rate
    """
    # Name of ODE model
    _NAME = "SIR Model"
    # Variables
    _VARIABLES = [ODEModel.S, ODEModel.CI, ODEModel.FR]
    # Non-dimensional parameters
    _PARAMETERS = ["rho", "sigma"]
    # Dimensional parameters
    _DAY_PARAMETERS = ["1/beta [day]", "1/gamma [day]"]
    # Variables that increases monotonically
    _VARS_INCREASE = [ODEModel.FR]
    # Sample data
    _SAMPLE_DICT = {
        "initial_dict": {ODEModel.S: 999_000, ODEModel.CI: 1000, ODEModel.FR: 0},
        "param_dict": {"rho": 0.2, "sigma": 0.075}
    }
    def __init__(self, date_range: tuple[str, str], tau: int, initial_dict: dict[str, int], param_dict: dict[str, float]) -> None:
        super().__init__(date_range, tau, initial_dict, param_dict)
        self._rho = Validator(self._param_dict["rho"], "rho", accept_none=False).float(value_range=(0, 1))
        self._sigma = Validator(self._param_dict["sigma"], "sigma", accept_none=False).float(value_range=(0, 1))
    def _discretize(self, t: int, X: np.ndarray) -> np.ndarray:
        """Discretize the ODE.
        Args:
            t: discrete time-steps
            X: the current values of the model
        Returns:
            numpy.array: the next values of the model
        """
        n = self._population
        s, i, *_ = X
        dsdt = 0 - self._rho * s * i / n
        drdt = self._sigma * i
        didt = 0 - dsdt - drdt
        return np.array([dsdt, didt, drdt])
[docs]
    def r0(self) -> float:
        """Calculate basic reproduction number.
        Raises:
            ZeroDivisionError: Sigma value was over 0
        Returns:
            reproduction number of the ODE model and parameters
        """
        try:
            return round(self._rho / self._sigma, 2)
        except ZeroDivisionError:
            raise ZeroDivisionError(
                f"Sigma must be over 0 to calculate reproduction number with {self._NAME}.") from None 
[docs]
    def dimensional_parameters(self) -> dict[str, int]:
        """Calculate dimensional parameter values.
        Raises:
            ZeroDivisionError: either rho or sigma value was over 0
        Returns:
            dictionary of dimensional parameter values
                - "1/beta [day]" (int): infection period
                - "1/gamma [day]" (int): recovery period
        """
        try:
            return {
                "1/beta [day]": round(self._tau / 24 / 60 / self._rho),
                "1/gamma [day]": round(self._tau / 24 / 60 / self._sigma)
            }
        except ZeroDivisionError:
            raise ZeroDivisionError(
                f"Rho and sigma must be over 0 to calculate dimensional parameters with {self._NAME}.") from None 
    @classmethod
    def _param_quantile(cls, data: pd.DataFrame, q: float | pd.Series = 0.5) -> dict[str, float | pd.Series]:
        """With combinations (X, dX/dt) for X=S, I, R, calculate quantile values of ODE parameters.
        Args:
            data: transformed data with covsirphy.SIRModel.transform(data=data, tau=tau)
            q: the quantile(s) to compute, value(s) between (0, 1)
        Returns:
            parameter values at the quantile(s)
        Note:
            We can get approximate parameter values with difference equations as follows.
            - rho = - n * (dS/dt) / S / I
            - sigma = (dR/dt) / I
        """
        df = data.copy()
        periods = round((df.index.max() - df.index.min()) / len(df))
        # Remove negative values and set variables
        df = df.loc[(df[cls.S] > 0) & (df[cls.CI] > 0)]
        n = df.loc[df.index[0], cls._VARIABLES].sum()
        # Calculate parameter values with non-dimensional difference equation
        rho_series = 0 - n * df[cls.S].diff() / periods / df[cls.S] / df[cls.CI]
        sigma_series = df[cls.FR].diff() / periods / df[cls.CI]
        # Guess representative values
        return {
            "rho": cls._clip(rho_series.quantile(q=q), 0, 1),
            "sigma": cls._clip(sigma_series.quantile(q=q), 0, 1),
        }
[docs]
    @classmethod
    def sr(cls, data: pd.DataFrame) -> pd.DataFrame:
        """Return log10(S) and R of model-specific variables for S-R trend analysis.
        Args:
            data:
                Index
                    Date (pd.Timestamp): Observation date
                Columns
                    Susceptible (int): the number of susceptible cases
                    Infected (int): the number of currently infected cases
                    Recovered (int): the number of recovered cases
                    Fatal (int): the number of fatal cases
        Returns:
            Index
                Date (pandas.Timestamp): date
            Columns
                log10(S) (np.float64): common logarithm of Susceptible
                R (np.int64): Fatal or Recovered
        """
        Validator(data, "data", accept_none=False).dataframe(time_index=True, columns=cls._SIRF)
        df = data.copy()
        df[cls._logS] = np.log10(df[cls.S])
        df[cls._r] = df[cls.F] + df[cls.R]
        return df.loc[:, [cls._logS, cls._r]].astype({cls._logS: np.float64, cls._r: np.int64})