synthcity.plugins.core.models.tabular_aim module
- class TabularAIM(X: pandas.core.frame.DataFrame, epsilon: float = 1.0, delta: float = 1e-09, max_model_size: int = 80, degree: int = 2, num_marginals: Optional[int] = None, max_cells: int = 1000, encoder_max_clusters: int = 20, encoder_whitelist: list = [], device: Union[str, torch.device] = device(type='cpu'), learning_rate: float = 0.005, weight_decay: float = 0.001, logging_epoch: int = 100, random_state: int = 0, **kwargs: Any)
Bases:
object
- Parts
1
- Adaptive and Iterative Mechanism (AIM) implementation, based on:
- Parameters
X (pd.DataFrame) – Reference dataset, used for training the tabular encoder
parameters (# AIM) –
arguments (# core plugin) –
encoder_max_clusters (int = 20) – The max number of clusters to create for continuous columns when encoding with TabularEncoder. Defaults to 20.
encoder_whitelist (list = []) – Ignore columns from encoding with TabularEncoder. Defaults to [].
device – Union[str, torch.device] = DEVICE, # This is not used for this model, as it is built with sklearn, which is cpu only
random_state (int, optional) – _description_. Defaults to 0. # This is not used for this model
**kwargs (Any) – The keyword arguments are passed to a SKLearn RandomForestClassifier - https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html.
- fit(X: pandas.core.frame.DataFrame, **kwargs: Any) Any
- Parameters
data – Pandas DataFrame that contains the tabular data
- Returns
AIMTrainer used for the fine-tuning process
- generate(count: int, start_col: Optional[str] = '', start_col_dist: Optional[Union[dict, list]] = None, temperature: float = 0.7, k: int = 100, max_length: int = 100) pandas.core.frame.DataFrame
Generates tabular data using the trained AIM model.
- Parameters
count (int) – The number of samples to generate
- Returns
n_samples rows of generated data
- Return type
pd.DataFrame