synthcity.plugins.images.plugin_image_adsgan module
- class ImageAdsGANPlugin(n_units_latent: int = 100, n_iter: int = 1000, generator_nonlin: str = 'relu', generator_dropout: float = 0.1, generator_n_residual_units: int = 2, discriminator_nonlin: str = 'leaky_relu', discriminator_n_iter: int = 5, discriminator_dropout: float = 0.1, discriminator_n_residual_units: int = 2, lr: float = 0.0002, weight_decay: float = 0.001, opt_betas: tuple = (0.5, 0.999), batch_size: int = 200, random_state: int = 0, clipping_value: int = 1, lambda_gradient_penalty: float = 10, lambda_identifiability_penalty: float = 0.1, device: Any = device(type='cpu'), patience: int = 5, patience_metric: Optional[synthcity.metrics.weighted_metrics.WeightedMetrics] = None, n_iter_print: int = 50, n_iter_min: int = 100, plot_progress: int = False, early_stopping: bool = True, workspace: pathlib.Path = PosixPath('workspace'), sampling_patience: int = 500, **kwargs: Any)
Bases:
synthcity.plugins.core.plugin.Plugin
Image AdsGAN - Anonymization through Data Synthesis using Generative Adversarial Networks.
- Parameters
n_units_latent – int The noise units size used by the generator.
n_iter – int Maximum number of iterations in the Generator.
generator_nonlin – string, default ‘leaky_relu’ Nonlinearity to use in the generator. Can be ‘elu’, ‘relu’, ‘selu’ or ‘leaky_relu’.
generator_dropout – float Dropout value. If 0, the dropout is not used.
generator_n_residual_units – int The number of convolutions in residual units for the generator, 0 means no residual units
discriminator_nonlin – string, default ‘leaky_relu’ Nonlinearity to use in the discriminator. Can be ‘elu’, ‘relu’, ‘selu’ or ‘leaky_relu’.
discriminator_n_iter – int Maximum number of iterations in the discriminator.
discriminator_dropout – float Dropout value for the discriminator. If 0, the dropout is not used.
discriminator_n_residual_units – int The number of convolutions in residual units for the discriminator, 0 means no residual units
parameters (# training) –
lr – float learning rate for optimizer
weight_decay – float l2 (ridge) penalty for the weights.
batch_size – int Batch size
random_state – int random seed to use
clipping_value – int, default 0 Gradients clipping value. Zero disables the feature
lambda_gradient_penalty – float = 10 Weight for the gradient penalty
lambda_identifiability_penalty – float = 0.1 Weight for the identifiability penalty
device – torch device Device: cpu or cuda
plot_progress – bool Plot some synthetic samples every n_iter_print
stopping (# early) –
n_iter_print – int Number of iterations after which to print updates and check the validation loss.
n_iter_min – int Minimum number of iterations to go through before starting early stopping
early_stopping – bool Evaluate the quality of the synthetic data using patience_metric, and stop after patience iteration with no improvement.
patience – int Max number of iterations without any improvement before training early stopping is trigged.
patience_metric – Optional[WeightedMetrics] If not None, the metric is used for evaluation the criterion for training early stopping.
arguments (# Core Plugin) –
workspace – Path. Optional Path for caching intermediary results.
Example
>>> from torchvision import datasets >>> from synthcity.plugins import Plugins >>> from synthcity.plugins.core.dataloader import ImageDataLoader >>> >>> model = Plugins().get("image_adsgan", n_iter = 10) >>> >>> dataset = datasets.MNIST(".", download=True) >>> dataloader = ImageDataLoader(dataset).sample(100) >>> >>> model.fit(dataloader) >>> >>> X_gen = model.generate(50) >>> assert len(X_gen) == 50
- fit(X: Union[synthcity.plugins.core.dataloader.DataLoader, pandas.core.frame.DataFrame], *args: Any, **kwargs: Any) Any
Training method the synthetic data plugin.
- Parameters
X – DataLoader. The reference dataset.
cond –
Optional, Union[pd.DataFrame, pd.Series, np.ndarray] Optional Training Conditional. The training conditional can be used to control to output of some models, like GANs or VAEs. The content can be anything, as long as it maps to the training dataset X. Usage example:
>>> from sklearn.datasets import load_iris >>> from synthcity.plugins.core.dataloader import GenericDataLoader >>> from synthcity.plugins.core.constraints import Constraints >>> >>> # Load in `test_plugin` the generative model of choice >>> # .... >>> >>> X, y = load_iris(as_frame=True, return_X_y=True) >>> X["target"] = y >>> >>> X = GenericDataLoader(X) >>> test_plugin.fit(X, cond=y) >>> >>> count = 10 >>> X_gen = test_plugin.generate(count, cond=np.ones(count)) >>> >>> # The Conditional only optimizes the output generation >>> # for GANs and VAEs, but does NOT guarantee the samples >>> # are only from that condition. >>> # If you want to guarantee that output contains only >>> # "target" == 1 samples, use Constraints. >>> >>> constraints = Constraints( >>> rules=[ >>> ("target", "==", 1), >>> ] >>> ) >>> X_gen = test_plugin.generate(count, >>> cond=np.ones(count), >>> constraints=constraints >>> ) >>> assert (X_gen["target"] == 1).all()
- Returns
self
- classmethod fqdn() str
The Fully-Qualified name of the plugin.
- generate(count: Optional[int] = None, constraints: Optional[synthcity.plugins.core.constraints.Constraints] = None, random_state: Optional[int] = None, **kwargs: Any) synthcity.plugins.core.dataloader.DataLoader
Synthetic data generation method.
- Parameters
count – optional int. The number of samples to generate. If None, it generated len(reference_dataset) samples.
cond – Optional, Union[pd.DataFrame, pd.Series, np.ndarray]. Optional Generation Conditional. The conditional can be used only if the model was trained using a conditional too. If provided, it must have count length. Not all models support conditionals. The conditionals can be used in VAEs or GANs to speed-up the generation under some constraints. For model agnostic solutions, check out the constraints parameter.
constraints –
optional Constraints. Optional constraints to apply on the generated data. If none, the reference schema constraints are applied. The constraints are model agnostic, and will filter the output of the generative model. The constraints are a list of rules. Each rule is a tuple of the form (<feature>, <operation>, <value>).
- Valid Operations:
”<”, “lt” : less than <value>
”<=”, “le”: less or equal with <value>
”>”, “gt” : greater than <value>
”>=”, “ge”: greater or equal with <value>
”==”, “eq”: equal with <value>
”in”: valid for categorical features, and <value> must be array. for example, (“target”, “in”, [0, 1])
”dtype”: <value> can be a data type. For example, (“target”, “dtype”, “int”)
- Usage example:
>>> from synthcity.plugins.core.constraints import Constraints >>> constraints = Constraints( >>> rules=[ >>> ("InterestingFeature", "==", 0), >>> ] >>> ) >>> >>> syn_data = syn_model.generate( count=count, constraints=constraints ).dataframe() >>> >>> assert (syn_data["InterestingFeature"] == 0).all()
random_state – optional int. Optional random seed to use.
- Returns
<count> synthetic samples
- static hyperparameter_space(**kwargs: Any) List[synthcity.plugins.core.distribution.Distribution]
Returns the hyperparameter space for the derived plugin.
- static load(buff: bytes) Any
- static load_dict(representation: dict) Any
- static name() str
The name of the plugin.
- plot(plt: Any, X: synthcity.plugins.core.dataloader.DataLoader, count: Optional[int] = None, plots: list = ['marginal', 'associations', 'tsne'], **kwargs: Any) Any
Plot the real-synthetic distributions.
- Parameters
plt – output
X – DataLoader. The reference dataset.
- Returns
self
- classmethod sample_hyperparameters(*args: Any, **kwargs: Any) Dict[str, Any]
Sample value from the hyperparameter space for the current plugin.
- classmethod sample_hyperparameters_optuna(trial: Any, *args: Any, **kwargs: Any) Dict[str, Any]
- save() bytes
- save_dict() dict
- save_to_file(path: pathlib.Path) bytes
- schema() synthcity.plugins.core.schema.Schema
The reference schema
- schema_includes(other: Union[synthcity.plugins.core.dataloader.DataLoader, pandas.core.frame.DataFrame]) bool
Helper method to test if the reference schema includes a Dataset
- Parameters
other – DataLoader. The dataset to test
- Returns
bool, if the schema includes the dataset or not.
- training_schema() synthcity.plugins.core.schema.Schema
The internal schema
- static type() str
The type of the plugin.
- static version() str
API version
- plugin
alias of
synthcity.plugins.images.plugin_image_adsgan.ImageAdsGANPlugin