synthcity.plugins.core.models.tabular_goggle module

class TabularGoggle(X: pandas.core.frame.DataFrame, n_iter: int = 1000, encoder_dim: int = 64, encoder_l: int = 2, het_encoding: bool = True, decoder_dim: int = 64, decoder_l: int = 2, threshold: float = 0.1, decoder_arch: str = 'gcn', graph_prior: Optional[numpy.ndarray] = None, prior_mask: Optional[numpy.ndarray] = None, device: Union[str, torch.device] = device(type='cpu'), alpha: float = 0.1, beta: float = 0.1, iter_opt: bool = True, learning_rate: float = 0.005, weight_decay: float = 0.001, batch_size: int = 32, patience: int = 50, dataloader_sampler: Optional[Any] = None, logging_epoch: int = 100, encoder_nonlin: str = 'relu', decoder_nonlin: str = 'relu', encoder_max_clusters: int = 20, encoder_whitelist: list = [], decoder_nonlin_out_discrete: str = 'softmax', decoder_nonlin_out_continuous: str = 'tanh', random_state: int = 0)

Bases: object

decode(X: pandas.core.frame.DataFrame) pandas.core.frame.DataFrame
encode(X: pandas.core.frame.DataFrame) pandas.core.frame.DataFrame
fit(X: pandas.core.frame.DataFrame, encoded: bool = False, **kwargs: Any) Any
forward(count: int) torch.Tensor
generate(count: int, **kwargs: Any) pandas.core.frame.DataFrame
get_encoder() synthcity.plugins.core.models.tabular_encoder.TabularEncoder