synthcity.plugins.core.models.survival_analysis.surv_deephit module

class DeephitSurvivalAnalysis(num_durations: int = 500, batch_size: int = 100, epochs: int = 2000, lr: float = 0.01, dim_hidden: int = 300, alpha: float = 0.28, sigma: float = 0.38, dropout: float = 0.2, patience: int = 20, batch_norm: bool = False, random_state: int = 0, device: Any = device(type='cpu'), **kwargs: Any)

Bases: synthcity.plugins.core.models.survival_analysis._base.SurvivalAnalysisPlugin

fit(X: pandas.core.frame.DataFrame, T: pandas.core.series.Series, E: pandas.core.series.Series) synthcity.plugins.core.models.survival_analysis._base.SurvivalAnalysisPlugin

Training logic

static hyperparameter_space(*args: Any, **kwargs: Any) List[synthcity.plugins.core.distribution.Distribution]

Returns the hyperparameter space for the derived plugin.

static load(buff: bytes) Any
static load_dict(representation: dict) Any
static name() str

The name of the plugin.

predict(X: pandas.core.frame.DataFrame, time_horizons: List) pandas.core.frame.DataFrame

Predict risk

classmethod sample_hyperparameters(*args: Any, **kwargs: Any) Dict[str, Any]

Sample value from the hyperparameter space for the current plugin.

save() bytes
save_dict() dict
save_to_file(path: pathlib.Path) bytes
static version() str

API version