synthcity.metrics.core.metric module

class MetricEvaluator(reduction: str = 'mean', n_histogram_bins: int = 10, n_folds: int = 3, task_type: str = 'classification', random_state: int = 0, workspace: pathlib.Path = PosixPath('workspace'), use_cache: bool = True, default_metric: Optional[str] = None)

Bases: object

Base class for all metrics.

Each derived class must implement the following methods:

evaluate() - compare two datasets and return a dictionary of metrics. direction() - direction of metric (bigger better or smaller better). type() - type of the metric. name() - name of the metric.

If any method implementation is missing, the class constructor will fail.

Constructor Args:
reduction: str

The way to aggregate metrics across folds. Default: ‘mean’.

n_histogram_bins: int

The number of bins used in histogram calculation. Default: 10.

n_folds: int

The number of folds in cross validation. Default: 3.

task_type: str

The type of downstream task. Default: ‘classification’.

workspace: Path

The directory to save intermediate models or results. Default: Path(“workspace”).

use_cache: bool

Whether to use cache. If True, it will try to load saved results in workspace directory where possible.

abstract static direction() str
abstract evaluate(X_gt: synthcity.plugins.core.dataloader.DataLoader, X_syn: synthcity.plugins.core.dataloader.DataLoader) Dict
abstract evaluate_default(X_gt: synthcity.plugins.core.dataloader.DataLoader, X_syn: synthcity.plugins.core.dataloader.DataLoader) float
classmethod fqdn() str
abstract static name() str
reduction() Callable
abstract static type() str
use_cache(path: pathlib.Path) bool