synthcity.plugins.core.models.tabular_arf module

class TabularARF(X: pandas.core.frame.DataFrame, num_trees: int = 30, delta: int = 0, max_iters: int = 10, early_stop: bool = True, verbose: bool = True, min_node_size: int = 5, dist: str = 'truncnorm', oob: bool = False, alpha: float = 0, encoder_max_clusters: int = 20, encoder_whitelist: list = [], device: Union[str, torch.device] = device(type='cpu'), learning_rate: float = 0.005, weight_decay: float = 0.001, batch_size: int = 32, logging_epoch: int = 100, random_state: int = 0, **kwargs: Any)

Bases: object

fit(X: pandas.core.frame.DataFrame, var_threshold: int = 10) Any
generate(count: int) pandas.core.frame.DataFrame
get_categorical_cols(X: pandas.core.frame.DataFrame, var_threshold: int) list

Finds columns with a low number of unique values, and returns them as a list. This is used so that the model can treat them as categorical features even if they are numeric. This is important for the ARF model, as it cannot handle zero variance floats in terminal nodes.

Parameters
  • X (pd.DataFrame) – The dataframe to check for categorical columns

  • var_threshold (int) – The maximum number of unique values a column can have to be considered categorical

Returns

The list of categorical columns

Return type

list