synthcity.benchmark.utils module

augment_data(X_train: synthcity.plugins.core.dataloader.DataLoader, augment_generator: Any, strict: bool = False, rule: typing_extensions.Literal[equal, log, ad - hoc] = 'equal', ad_hoc_augment_vals: Optional[Dict[Any, int]] = None, synthetic_constraints: Optional[synthcity.plugins.core.constraints.Constraints] = None, **generate_kwargs: Any) synthcity.plugins.core.dataloader.DataLoader

Augment the real data with generated synthetic data

Parameters
  • X_train (DataLoader) – The ground truth DataLoader to augment with synthetic data.

  • augment_generator (Any) – The synthetic model to be used to generate the synthetic portion of the augmented dataset.

  • strict (bool, optional) – Flag to ensure that the condition for generating synthetic data is strictly met. Defaults to False.

  • rule (Literal["equal", "log", "ad-hoc") – The rule used to achieve the desired proportion records with each value in the fairness column. Defaults to “equal”.

  • ad_hoc_augment_vals (Dict[Union[int, str], int], optional) – A dictionary containing the number of each class to augment the real data with. This is only required if using the rule=”ad-hoc” option. Defaults to None.

  • synthetic_constraints (Optional[Constraints]) – Constraints placed on the generation of the synthetic data. Defaults to None.

Returns

The augmented dataset and labels.

Return type

DataLoader

calculate_fair_aug_sample_size(X_train: pandas.core.frame.DataFrame, fairness_column: Optional[str], rule: typing_extensions.Literal[equal, log, ad - hoc], ad_hoc_augment_vals: Optional[Dict[Any, int]] = None) Dict

Calculate how many samples to augment.

Parameters
  • X_train (pd.DataFrame) – The real dataset to be augmented.

  • fairness_column (str) – The column name of the column to test the fairness of a downstream model with respect to.

  • rule (Literal["equal", "log", "ad-hoc"]) – The rule used to achieve the desired proportion records with each value in the fairness column. Defaults to “equal”.

  • ad_hoc_augment_vals (Dict[ Union[int, str], int ], optional) – A dictionary containing the number of each class to augment the real data with. If using rule=”ad-hoc” this function returns ad_hoc_augment_vals, otherwise this parameter is ignored. Defaults to {}.

Returns

A dictionary containing the number of each class to augment the real data with.

Return type

Dict

get_json_serializable_kwargs(kwargs: Dict) Dict

This function should take the kwargs for Benchmarks.evaluate and makes them serializable with json.dumps. Currently it only handles pathlib.Path -> str.