synthcity.plugins.privacy.plugin_decaf module
Reference: Boris van Breugel, Trent Kyono, Jeroen Berrevoets, Mihaela van der Schaar “DECAF: Generating Fair Synthetic Data Using Causally-Aware Generative Networks”(2021).
- class DECAFPlugin(n_iter: int = 100, n_iter_baseline: int = 1000, generator_n_layers_hidden: int = 2, generator_n_units_hidden: int = 500, generator_nonlin: str = 'relu', generator_dropout: float = 0.1, generator_opt_betas: tuple = (0.5, 0.999), discriminator_n_layers_hidden: int = 2, discriminator_n_units_hidden: int = 500, discriminator_nonlin: str = 'leaky_relu', discriminator_n_iter: int = 1, discriminator_dropout: float = 0.1, discriminator_opt_betas: tuple = (0.5, 0.999), lr: float = 0.001, batch_size: int = 200, random_state: int = 0, clipping_value: int = 1, lambda_gradient_penalty: float = 10, lambda_privacy: float = 1, eps: float = 1e-08, alpha: float = 1, rho: float = 1, weight_decay: float = 0.01, l1_g: float = 0, l1_W: float = 1, grad_dag_loss: bool = False, struct_learning_enabled: bool = True, struct_learning_n_iter: int = 1000, struct_learning_search_method: str = 'tree_search', struct_learning_score: str = 'k2', struct_max_indegree: int = 4, encoder_max_clusters: int = 10, device: Any = device(type='cpu'), workspace: pathlib.Path = PosixPath('workspace'), compress_dataset: bool = False, sampling_patience: int = 500, **kwargs: Any)
Bases:
synthcity.plugins.core.plugin.Plugin
DECAF (DEbiasing CAusal Fairness) plugin.
- Parameters
n_iter – int Number of training iterations.
generator_n_layers_hidden – int Number of hidden layers in the generator.
generator_n_units_hidden – Number of neurons in the hidden layers of the generator.
generator_nonlin – str Nonlinearity used by the generator for the hidden layers: leaky_relu, relu, gelu etc.
generator_dropout – float Generator dropout.
generator_opt_betas – tuple Generator initial decay rates for the Adam optimizer
discriminator_n_layers_hidden – int Number of hidden layers in the discriminator.
discriminator_n_units_hidden – int Number of neurons in the hidden layers of the discriminator.
discriminator_nonlin – str Nonlinearity used by the discriminator for the hidden layers: leaky_relu, relu, gelu etc.
discriminator_n_iter – int Discriminator number of iterations(default = 1)
discriminator_dropout – float Discriminator dropout
discriminator_opt_betas – tuple Discriminator initial decay rates for the Adam optimizer
lr – float Learning rate
weight_decay – float Optimizer weight decay
batch_size – int Batch size
random_state – int Random seed
clipping_value – int Gradient clipping value
lambda_gradient_penalty – float Gradient penalty factor used for training the GAN.
lambda_privacy – float Privacy factor used the AdsGAN loss.
eps – float = 1e-8, Noise added to the privacy loss
alpha – float Gradient penalty weight for real samples.
rho – float DAG loss factor
l1_g – float = 0 l1 regularization loss for the generator
l1_W – float = 1 l1 regularization factor for l1_g
struct_learning_enabled – bool Enable DAG learning outside DECAF.
struct_learning_n_iter – int Number of iterations for the DAG search.
struct_learning_search_method – str DAG search strategy: hillclimb, pc, tree_search, mmhc, exhaustive, d-struct
struct_learning_score – str DAG search scoring strategy: k2, bdeu, bic, bds
struct_max_indegree – int Max parents in the DAG.
encoder_max_clusters – int Number of clusters used for tabular encoding
device – Any = DEVICE torch device used for training.
arguments (# Core Plugin) –
workspace – Path. Optional Path for caching intermediary results.
compress_dataset – bool. Default = False. Drop redundant features before training the generator.
sampling_patience – int. Max inference iterations to wait for the generated data to match the training schema.
Example
>>> from sklearn.datasets import load_iris >>> from synthcity.plugins import Plugins >>> >>> X, y = load_iris(as_frame = True, return_X_y = True) >>> X["target"] = y >>> >>> plugin = Plugins().get("decaf", n_iter = 100) >>> plugin.fit(X) >>> >>> plugin.generate(50)
- fit(X: Union[synthcity.plugins.core.dataloader.DataLoader, pandas.core.frame.DataFrame], *args: Any, **kwargs: Any) Any
Training method the synthetic data plugin.
- Parameters
X – DataLoader. The reference dataset.
cond –
Optional, Union[pd.DataFrame, pd.Series, np.ndarray] Optional Training Conditional. The training conditional can be used to control to output of some models, like GANs or VAEs. The content can be anything, as long as it maps to the training dataset X. Usage example:
>>> from sklearn.datasets import load_iris >>> from synthcity.plugins.core.dataloader import GenericDataLoader >>> from synthcity.plugins.core.constraints import Constraints >>> >>> # Load in `test_plugin` the generative model of choice >>> # .... >>> >>> X, y = load_iris(as_frame=True, return_X_y=True) >>> X["target"] = y >>> >>> X = GenericDataLoader(X) >>> test_plugin.fit(X, cond=y) >>> >>> count = 10 >>> X_gen = test_plugin.generate(count, cond=np.ones(count)) >>> >>> # The Conditional only optimizes the output generation >>> # for GANs and VAEs, but does NOT guarantee the samples >>> # are only from that condition. >>> # If you want to guarantee that output contains only >>> # "target" == 1 samples, use Constraints. >>> >>> constraints = Constraints( >>> rules=[ >>> ("target", "==", 1), >>> ] >>> ) >>> X_gen = test_plugin.generate(count, >>> cond=np.ones(count), >>> constraints=constraints >>> ) >>> assert (X_gen["target"] == 1).all()
- Returns
self
- classmethod fqdn() str
The Fully-Qualified name of the plugin.
- generate(count: Optional[int] = None, constraints: Optional[synthcity.plugins.core.constraints.Constraints] = None, random_state: Optional[int] = None, **kwargs: Any) synthcity.plugins.core.dataloader.DataLoader
Synthetic data generation method.
- Parameters
count – optional int. The number of samples to generate. If None, it generated len(reference_dataset) samples.
cond – Optional, Union[pd.DataFrame, pd.Series, np.ndarray]. Optional Generation Conditional. The conditional can be used only if the model was trained using a conditional too. If provided, it must have count length. Not all models support conditionals. The conditionals can be used in VAEs or GANs to speed-up the generation under some constraints. For model agnostic solutions, check out the constraints parameter.
constraints –
optional Constraints. Optional constraints to apply on the generated data. If none, the reference schema constraints are applied. The constraints are model agnostic, and will filter the output of the generative model. The constraints are a list of rules. Each rule is a tuple of the form (<feature>, <operation>, <value>).
- Valid Operations:
”<”, “lt” : less than <value>
”<=”, “le”: less or equal with <value>
”>”, “gt” : greater than <value>
”>=”, “ge”: greater or equal with <value>
”==”, “eq”: equal with <value>
”in”: valid for categorical features, and <value> must be array. for example, (“target”, “in”, [0, 1])
”dtype”: <value> can be a data type. For example, (“target”, “dtype”, “int”)
- Usage example:
>>> from synthcity.plugins.core.constraints import Constraints >>> constraints = Constraints( >>> rules=[ >>> ("InterestingFeature", "==", 0), >>> ] >>> ) >>> >>> syn_data = syn_model.generate( count=count, constraints=constraints ).dataframe() >>> >>> assert (syn_data["InterestingFeature"] == 0).all()
random_state – optional int. Optional random seed to use.
- Returns
<count> synthetic samples
- get_dag(X: pandas.core.frame.DataFrame, struct_learning_search_method: Optional[str] = None, as_index: bool = False) Any
- static hyperparameter_space(**kwargs: Any) List[synthcity.plugins.core.distribution.Distribution]
Returns the hyperparameter space for the derived plugin.
- static load(buff: bytes) Any
- static load_dict(representation: dict) Any
- static name() str
The name of the plugin.
- plot(plt: Any, X: synthcity.plugins.core.dataloader.DataLoader, count: Optional[int] = None, plots: list = ['marginal', 'associations', 'tsne'], **kwargs: Any) Any
Plot the real-synthetic distributions.
- Parameters
plt – output
X – DataLoader. The reference dataset.
- Returns
self
- classmethod sample_hyperparameters(*args: Any, **kwargs: Any) Dict[str, Any]
Sample value from the hyperparameter space for the current plugin.
- classmethod sample_hyperparameters_optuna(trial: Any, *args: Any, **kwargs: Any) Dict[str, Any]
- save() bytes
- save_dict() dict
- save_to_file(path: pathlib.Path) bytes
- schema() synthcity.plugins.core.schema.Schema
The reference schema
- schema_includes(other: Union[synthcity.plugins.core.dataloader.DataLoader, pandas.core.frame.DataFrame]) bool
Helper method to test if the reference schema includes a Dataset
- Parameters
other – DataLoader. The dataset to test
- Returns
bool, if the schema includes the dataset or not.
- training_schema() synthcity.plugins.core.schema.Schema
The internal schema
- static type() str
The type of the plugin.
- static version() str
API version
- plugin