synthcity.plugins package
- class Plugin(sampling_patience: int = 500, strict: bool = True, device: Any = device(type='cpu'), random_state: int = 0, workspace: pathlib.Path = PosixPath('workspace'), compress_dataset: bool = False, sampling_strategy: str = 'marginal')
Bases:
synthcity.plugins.core.serializable.Serializable
Base class for all plugins.
- Each derived class must implement the following methods:
type() - a static method that returns the type of the plugin. e.g., debug, generative, bayesian, etc. name() - a static method that returns the name of the plugin. e.g., ctgan, random_noise, etc. hyperparameter_space() - a static method that returns the hyperparameters that can be tuned during AutoML. _fit() - internal method, called by fit on each training set. _generate() - internal method, called by generate.
If any method implementation is missing, the class constructor will fail.
- Parameters
strict – bool. Default = True If True, is raises an exception if the generated data is not following the requested constraints. If False, it returns only the rows that match the constraints.
workspace – Path Path for caching intermediary results
compress_dataset – bool. Default = False Drop redundant features before training the generator.
device – PyTorch device: cpu or cuda.
random_state – int Random seed
sampling_patience – int. Max inference iterations to wait for the generated data to match the training schema.
sampling_strategy – str Internal parameter for schema. marginal or uniform.
- fit(X: Union[synthcity.plugins.core.dataloader.DataLoader, pandas.core.frame.DataFrame], *args: Any, **kwargs: Any) Any
Training method the synthetic data plugin.
- Parameters
X – DataLoader. The reference dataset.
cond –
Optional, Union[pd.DataFrame, pd.Series, np.ndarray] Optional Training Conditional. The training conditional can be used to control to output of some models, like GANs or VAEs. The content can be anything, as long as it maps to the training dataset X. Usage example:
>>> from sklearn.datasets import load_iris >>> from synthcity.plugins.core.dataloader import GenericDataLoader >>> from synthcity.plugins.core.constraints import Constraints >>> >>> # Load in `test_plugin` the generative model of choice >>> # .... >>> >>> X, y = load_iris(as_frame=True, return_X_y=True) >>> X["target"] = y >>> >>> X = GenericDataLoader(X) >>> test_plugin.fit(X, cond=y) >>> >>> count = 10 >>> X_gen = test_plugin.generate(count, cond=np.ones(count)) >>> >>> # The Conditional only optimizes the output generation >>> # for GANs and VAEs, but does NOT guarantee the samples >>> # are only from that condition. >>> # If you want to guarantee that output contains only >>> # "target" == 1 samples, use Constraints. >>> >>> constraints = Constraints( >>> rules=[ >>> ("target", "==", 1), >>> ] >>> ) >>> X_gen = test_plugin.generate(count, >>> cond=np.ones(count), >>> constraints=constraints >>> ) >>> assert (X_gen["target"] == 1).all()
- Returns
self
- classmethod fqdn() str
The Fully-Qualified name of the plugin.
- generate(count: Optional[int] = None, constraints: Optional[synthcity.plugins.core.constraints.Constraints] = None, random_state: Optional[int] = None, **kwargs: Any) synthcity.plugins.core.dataloader.DataLoader
Synthetic data generation method.
- Parameters
count – optional int. The number of samples to generate. If None, it generated len(reference_dataset) samples.
cond – Optional, Union[pd.DataFrame, pd.Series, np.ndarray]. Optional Generation Conditional. The conditional can be used only if the model was trained using a conditional too. If provided, it must have count length. Not all models support conditionals. The conditionals can be used in VAEs or GANs to speed-up the generation under some constraints. For model agnostic solutions, check out the constraints parameter.
constraints –
optional Constraints. Optional constraints to apply on the generated data. If none, the reference schema constraints are applied. The constraints are model agnostic, and will filter the output of the generative model. The constraints are a list of rules. Each rule is a tuple of the form (<feature>, <operation>, <value>).
- Valid Operations:
”<”, “lt” : less than <value>
”<=”, “le”: less or equal with <value>
”>”, “gt” : greater than <value>
”>=”, “ge”: greater or equal with <value>
”==”, “eq”: equal with <value>
”in”: valid for categorical features, and <value> must be array. for example, (“target”, “in”, [0, 1])
”dtype”: <value> can be a data type. For example, (“target”, “dtype”, “int”)
- Usage example:
>>> from synthcity.plugins.core.constraints import Constraints >>> constraints = Constraints( >>> rules=[ >>> ("InterestingFeature", "==", 0), >>> ] >>> ) >>> >>> syn_data = syn_model.generate( count=count, constraints=constraints ).dataframe() >>> >>> assert (syn_data["InterestingFeature"] == 0).all()
random_state – optional int. Optional random seed to use.
- Returns
<count> synthetic samples
- abstract static hyperparameter_space(**kwargs: Any) List[synthcity.plugins.core.distribution.Distribution]
Returns the hyperparameter space for the derived plugin.
- static load(buff: bytes) Any
- static load_dict(representation: dict) Any
- abstract static name() str
The name of the plugin.
- plot(plt: Any, X: synthcity.plugins.core.dataloader.DataLoader, count: Optional[int] = None, plots: list = ['marginal', 'associations', 'tsne'], **kwargs: Any) Any
Plot the real-synthetic distributions.
- Parameters
plt – output
X – DataLoader. The reference dataset.
- Returns
self
- classmethod sample_hyperparameters(*args: Any, **kwargs: Any) Dict[str, Any]
Sample value from the hyperparameter space for the current plugin.
- classmethod sample_hyperparameters_optuna(trial: Any, *args: Any, **kwargs: Any) Dict[str, Any]
- save() bytes
- save_dict() dict
- save_to_file(path: pathlib.Path) bytes
- schema() synthcity.plugins.core.schema.Schema
The reference schema
- schema_includes(other: Union[synthcity.plugins.core.dataloader.DataLoader, pandas.core.frame.DataFrame]) bool
Helper method to test if the reference schema includes a Dataset
- Parameters
other – DataLoader. The dataset to test
- Returns
bool, if the schema includes the dataset or not.
- training_schema() synthcity.plugins.core.schema.Schema
The internal schema
- abstract static type() str
The type of the plugin.
- static version() str
API version
- class Plugins(categories: list = ['generic', 'privacy', 'survival_analysis', 'time_series', 'domain_adaptation', 'images', 'debug'])
Bases:
synthcity.plugins.core.plugin.PluginLoader
- add(name: str, cls: Type) synthcity.plugins.core.plugin.PluginLoader
Add a new plugin
- get(name: str, *args: Any, **kwargs: Any) Any
Create a new object from a plugin. :param name: str. The name of the plugin :param &args: :param **kwargs. Plugin specific arguments:
- Returns
The new object
- get_type(name: str) Type
Get the class type of a plugin. :param name: str. The name of the plugin
- Returns
The class of the plugin
- list() List[str]
Get all the available plugins.
- load(buff: bytes) Any
Load serialized plugin
- types() List[Type]
Get the loaded plugins types
Subpackages
- synthcity.plugins.core package
- synthcity.plugins.domain_adaptation package
- synthcity.plugins.generic package
- Submodules
- synthcity.plugins.generic.plugin_arf module
- synthcity.plugins.generic.plugin_bayesian_network module
- synthcity.plugins.generic.plugin_ctgan module
- synthcity.plugins.generic.plugin_ddpm module
- synthcity.plugins.generic.plugin_dummy_sampler module
- synthcity.plugins.generic.plugin_goggle module
- synthcity.plugins.generic.plugin_great module
- synthcity.plugins.generic.plugin_marginal_distributions module
- synthcity.plugins.generic.plugin_nflow module
- synthcity.plugins.generic.plugin_rtvae module
- synthcity.plugins.generic.plugin_tvae module
- synthcity.plugins.generic.plugin_uniform_sampler module
- Submodules
- synthcity.plugins.images package
- synthcity.plugins.privacy package
- synthcity.plugins.survival_analysis package
- Submodules
- synthcity.plugins.survival_analysis._survival_pipeline module
- synthcity.plugins.survival_analysis.plugin_survae module
- synthcity.plugins.survival_analysis.plugin_survival_ctgan module
- synthcity.plugins.survival_analysis.plugin_survival_gan module
- synthcity.plugins.survival_analysis.plugin_survival_nflow module
- Submodules
- synthcity.plugins.time_series package