synthcity.plugins.core.models.tabular_arf module
- class TabularARF(X: pandas.core.frame.DataFrame, num_trees: int = 30, delta: int = 0, max_iters: int = 10, early_stop: bool = True, verbose: bool = True, min_node_size: int = 5, dist: str = 'truncnorm', oob: bool = False, alpha: float = 0, encoder_max_clusters: int = 20, encoder_whitelist: list = [], device: Union[str, torch.device] = device(type='cpu'), learning_rate: float = 0.005, weight_decay: float = 0.001, batch_size: int = 32, logging_epoch: int = 100, random_state: int = 0, **kwargs: Any)
Bases:
object
- fit(X: pandas.core.frame.DataFrame, var_threshold: int = 10) Any
- generate(count: int) pandas.core.frame.DataFrame
- get_categorical_cols(X: pandas.core.frame.DataFrame, var_threshold: int) list
Finds columns with a low number of unique values, and returns them as a list. This is used so that the model can treat them as categorical features even if they are numeric. This is important for the ARF model, as it cannot handle zero variance floats in terminal nodes.
- Parameters
X (pd.DataFrame) – The dataframe to check for categorical columns
var_threshold (int) – The maximum number of unique values a column can have to be considered categorical
- Returns
The list of categorical columns
- Return type
list