Skip to content

Separate Waves Classifier

Abstract of Separate Waves (SepWav)

Extracted from "A New Longitudinal Classification Method Based on Stacking Predictions for Separate Time Points" (BCS SGAI AI-2025).

Biomedical research often uses longitudinal data with repeated measurements of variables across time (e.g. cholesterol measured across time), which is challenging for standard machine learning algorithms due to intrinsic temporal dependencies. The Separate Waves (SepWav) data-transformation method trains a base classifier for each time point ("wave") and aggregates their predictions via voting. However, the simplicity of the voting mechanism may not be enough to capture complex patterns of time-dependent interactions involving the base classifiers' predictions. Hence, we propose a novel SepWav method where the simple voting mechanism is replaced by a stacking-based meta-classifier that integrates the base classifiers' wave-specific predictions into a final predicted class label, aiming at improving predictive performance. Experiments with 20 datasets of ageing-related diseases have shown that, overall, the proposed Stacking-based SepWav method achieved significantly better predictive performance than two other methods for longitudinal classification in most cases, when using class-weight adjustment as a class-balancing method.

See More In References

SepWav

Bases: BaseEstimator, ClassifierMixin, DataPreparationMixin

SepWav stands for Separate Waves, a training done wave-by-wave for longitudinal dataset.

The SepWav class implements the Separate Waves strategy, treating each wave (time point) as a separate dataset. A classifier is trained on each wave independently, and their predictions are combined using ensemble methods such as voting or stacking. The workflow supports both binary and multiclass classification. When stacking is selected, the base wave estimators must implement predict_proba, because the meta-learner is trained on wave-level class-probability outputs.

Ensemble Strategies

Supported ensemble methods include:

  • Simple majority voting
  • Weighted voting (e.g., decaying weights for older waves)
  • Stacking with a meta-learner trained on wave-level class probabilities

Refer to LongitudinalVoting and LongitudinalStacking for mathematical details.

Parameters:

Name Type Description Default
estimator Union[ClassifierMixin, CustomClassifierMixinEstimator]

Base classifier for each wave. Defaults to None.

None
features_group List[List[int]]

Temporal matrix where each sublist contains indices of a longitudinal attribute's waves. Defaults to None.

None
non_longitudinal_features List[Union[int, str]]

List of indices or names of non-longitudinal features. Defaults to None.

None
feature_list_names List[str]

List of feature names in the dataset. Defaults to None.

None
voting LongitudinalEnsemblingStrategy

Ensemble strategy. Defaults to LongitudinalEnsemblingStrategy.MAJORITY_VOTING.

MAJORITY_VOTING
stacking_meta_learner Union[CustomClassifierMixinEstimator, ClassifierMixin, None]

Meta-learner for stacking. Defaults to LogisticRegression().

LogisticRegression()
n_jobs int

Number of parallel jobs. Defaults to None.

None
parallel bool

Whether to run wave fitting in parallel. Defaults to False.

False
num_cpus int

Number of CPUs for parallel processing. Defaults to -1 (all available CPUs).

-1
class_weight Any

Class-weight specification to forward to wave estimators when supported.

None

Attributes:

Name Type Description
dataset DataFrame

Training dataset.

estimator BaseEstimator

Base classifier for each wave.

estimators List

List of trained classifiers for each wave.

voting LongitudinalEnsemblingStrategy

Ensemble strategy used.

stacking_meta_learner Union[CustomClassifierMixinEstimator, ClassifierMixin]

Meta-learner for stacking.

clf_ensemble BaseEstimator

Combined ensemble classifier.

n_jobs int

Number of parallel jobs.

parallel bool

Whether parallel processing is enabled.

num_cpus int

Number of CPUs used.

class_weight Any

Requested class-weight configuration applied to compatible estimators.

Examples:

Below are examples using the "stroke.csv" dataset. Replace "stroke.csv" with your actual dataset path.

Basic Usage

from scikit_longitudinal.data_preparation import LongitudinalDataset
from scikit_longitudinal.data_preparation import SepWav
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from scikit_longitudinal.estimators.ensemble.longitudinal_voting.longitudinal_voting import (
    LongitudinalEnsemblingStrategy,
)

# Load dataset
dataset = LongitudinalDataset('./stroke_longitudinal.csv')
dataset.load_data()
dataset.load_target(target_column="stroke_w2")
dataset.setup_features_group("elsa")
dataset.load_train_test_split(test_size=0.2, random_state=42)

# Initialize classifier
classifier = RandomForestClassifier()

# Initialize SepWav
sepwav = SepWav(
    estimator=classifier,
    features_group=dataset.feature_groups(),
    non_longitudinal_features=dataset.non_longitudinal_features(),
    feature_list_names=dataset.data.columns.tolist(),
    voting=LongitudinalEnsemblingStrategy.MAJORITY_VOTING
)

# Fit and predict
sepwav.fit(dataset.X_train, dataset.y_train)
y_pred = sepwav.predict(dataset.X_test)

# Evaluate
accuracy = accuracy_score(dataset.y_test, y_pred)
print(f"Accuracy: {accuracy}")

Advanced: stacking ensemble

from scikit_longitudinal.data_preparation import LongitudinalDataset
from scikit_longitudinal.data_preparation import SepWav
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.linear_model import LogisticRegression
from scikit_longitudinal.estimators.ensemble.longitudinal_voting.longitudinal_voting import (
    LongitudinalEnsemblingStrategy,
)


# Load dataset
dataset = LongitudinalDataset('./stroke_longitudinal.csv')
dataset.load_data()
dataset.load_target(target_column="stroke_w2")
dataset.setup_features_group("elsa")
dataset.load_train_test_split(test_size=0.2, random_state=42)

# Initialize classifier
classifier = RandomForestClassifier()

# Initialize SepWav with stacking
sepwav = SepWav(
    estimator=classifier,
    features_group=dataset.feature_groups(),
    non_longitudinal_features=dataset.non_longitudinal_features(),
    feature_list_names=dataset.data.columns.tolist(),
    voting=LongitudinalEnsemblingStrategy.STACKING,
    stacking_meta_learner=LogisticRegression()
)

# Fit and predict
sepwav.fit(dataset.X_train, dataset.y_train)
y_pred = sepwav.predict(dataset.X_test)

# Evaluate
accuracy = accuracy_score(dataset.y_test, y_pred)
print(f"Accuracy: {accuracy}")

Advanced: parallel processing

# ... Similar to the previous example, but with parallel processing enabled ...

# Initialize SepWav with parallel processing
sepwav = SepWav(
    estimator=classifier,
    features_group=dataset.feature_groups(),
    non_longitudinal_features=dataset.non_longitudinal_features(),
    feature_list_names=dataset.data.columns.tolist(),
    parallel=True, # Enable parallel processing
    num_cpus=4 # Specify number of CPUs to use (or -1 for all available CPUs)
)

# ... Similar to the previous example, but with parallel processing enabled ...
Source code in scikit_longitudinal/data_preparation/separate_waves.py
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
class SepWav(BaseEstimator, ClassifierMixin, DataPreparationMixin):
    """SepWav stands for Separate Waves, a training done wave-by-wave for longitudinal dataset.

    The `SepWav` class implements the Separate Waves strategy, treating each wave (time point) as a separate dataset.
    A classifier is trained on each wave independently, and their predictions are combined using ensemble methods
    such as voting or stacking. The workflow supports both binary and multiclass classification. When stacking is
    selected, the base wave estimators must implement `predict_proba`, because the meta-learner is trained on
    wave-level class-probability outputs.

    !!! note "Ensemble Strategies"
        Supported ensemble methods include:

        - [x] Simple majority voting
        - [x] Weighted voting (e.g., decaying weights for older waves)
        - [x] Stacking with a meta-learner trained on wave-level class probabilities

        Refer to `LongitudinalVoting` and `LongitudinalStacking` for mathematical details.

    Args:
        estimator (Union[ClassifierMixin, CustomClassifierMixinEstimator], optional):
            Base classifier for each wave. Defaults to None.
        features_group (List[List[int]], optional):
            Temporal matrix where each sublist contains indices of a longitudinal attribute's waves. Defaults to None.
        non_longitudinal_features (List[Union[int, str]], optional):
            List of indices or names of non-longitudinal features. Defaults to None.
        feature_list_names (List[str], optional):
            List of feature names in the dataset. Defaults to None.
        voting (LongitudinalEnsemblingStrategy, optional):
            Ensemble strategy. Defaults to `LongitudinalEnsemblingStrategy.MAJORITY_VOTING`.
        stacking_meta_learner (Union[CustomClassifierMixinEstimator, ClassifierMixin, None], optional):
            Meta-learner for stacking. Defaults to `LogisticRegression()`.
        n_jobs (int, optional): Number of parallel jobs. Defaults to None.
        parallel (bool, optional): Whether to run wave fitting in parallel. Defaults to False.
        num_cpus (int, optional):
            Number of CPUs for parallel processing. Defaults to -1 (all available CPUs).
        class_weight (Any, optional): Class-weight specification to forward to wave estimators when supported.

    Attributes:
        dataset (pd.DataFrame): Training dataset.
        estimator (BaseEstimator): Base classifier for each wave.
        estimators (List): List of trained classifiers for each wave.
        voting (LongitudinalEnsemblingStrategy): Ensemble strategy used.
        stacking_meta_learner (Union[CustomClassifierMixinEstimator, ClassifierMixin]): Meta-learner for stacking.
        clf_ensemble (BaseEstimator): Combined ensemble classifier.
        n_jobs (int): Number of parallel jobs.
        parallel (bool): Whether parallel processing is enabled.
        num_cpus (int): Number of CPUs used.
        class_weight (Any): Requested class-weight configuration applied to compatible estimators.

    Examples:
        Below are examples using the "stroke.csv" dataset. Replace "stroke.csv" with your actual dataset path.

        !!! example "Basic Usage"
            ```python
            from scikit_longitudinal.data_preparation import LongitudinalDataset
            from scikit_longitudinal.data_preparation import SepWav
            from sklearn.ensemble import RandomForestClassifier
            from sklearn.metrics import accuracy_score
            from scikit_longitudinal.estimators.ensemble.longitudinal_voting.longitudinal_voting import (
                LongitudinalEnsemblingStrategy,
            )

            # Load dataset
            dataset = LongitudinalDataset('./stroke_longitudinal.csv')
            dataset.load_data()
            dataset.load_target(target_column="stroke_w2")
            dataset.setup_features_group("elsa")
            dataset.load_train_test_split(test_size=0.2, random_state=42)

            # Initialize classifier
            classifier = RandomForestClassifier()

            # Initialize SepWav
            sepwav = SepWav(
                estimator=classifier,
                features_group=dataset.feature_groups(),
                non_longitudinal_features=dataset.non_longitudinal_features(),
                feature_list_names=dataset.data.columns.tolist(),
                voting=LongitudinalEnsemblingStrategy.MAJORITY_VOTING
            )

            # Fit and predict
            sepwav.fit(dataset.X_train, dataset.y_train)
            y_pred = sepwav.predict(dataset.X_test)

            # Evaluate
            accuracy = accuracy_score(dataset.y_test, y_pred)
            print(f"Accuracy: {accuracy}")
            ```

        !!! example "Advanced: stacking ensemble"
            ```python
            from scikit_longitudinal.data_preparation import LongitudinalDataset
            from scikit_longitudinal.data_preparation import SepWav
            from sklearn.ensemble import RandomForestClassifier
            from sklearn.metrics import accuracy_score
            from sklearn.linear_model import LogisticRegression
            from scikit_longitudinal.estimators.ensemble.longitudinal_voting.longitudinal_voting import (
                LongitudinalEnsemblingStrategy,
            )


            # Load dataset
            dataset = LongitudinalDataset('./stroke_longitudinal.csv')
            dataset.load_data()
            dataset.load_target(target_column="stroke_w2")
            dataset.setup_features_group("elsa")
            dataset.load_train_test_split(test_size=0.2, random_state=42)

            # Initialize classifier
            classifier = RandomForestClassifier()

            # Initialize SepWav with stacking
            sepwav = SepWav(
                estimator=classifier,
                features_group=dataset.feature_groups(),
                non_longitudinal_features=dataset.non_longitudinal_features(),
                feature_list_names=dataset.data.columns.tolist(),
                voting=LongitudinalEnsemblingStrategy.STACKING,
                stacking_meta_learner=LogisticRegression()
            )

            # Fit and predict
            sepwav.fit(dataset.X_train, dataset.y_train)
            y_pred = sepwav.predict(dataset.X_test)

            # Evaluate
            accuracy = accuracy_score(dataset.y_test, y_pred)
            print(f"Accuracy: {accuracy}")
            ```

        !!! example "Advanced: parallel processing"
            ```python
            # ... Similar to the previous example, but with parallel processing enabled ...

            # Initialize SepWav with parallel processing
            sepwav = SepWav(
                estimator=classifier,
                features_group=dataset.feature_groups(),
                non_longitudinal_features=dataset.non_longitudinal_features(),
                feature_list_names=dataset.data.columns.tolist(),
                parallel=True, # Enable parallel processing
                num_cpus=4 # Specify number of CPUs to use (or -1 for all available CPUs)
            )

            # ... Similar to the previous example, but with parallel processing enabled ...
            ```
    """

    def __init__(
        self,
        estimator: Union[ClassifierMixin, CustomClassifierMixinEstimator] = None,
        features_group: List[List[int]] = None,
        non_longitudinal_features: List[Union[int, str]] = None,
        feature_list_names: List[str] = None,
        voting: LongitudinalEnsemblingStrategy = LongitudinalEnsemblingStrategy.MAJORITY_VOTING,
        stacking_meta_learner: Union[
            CustomClassifierMixinEstimator, ClassifierMixin, None
        ] = LogisticRegression(),
        n_jobs: int = None,
        parallel: bool = False,
        num_cpus: int = -1,
        class_weight: Optional[Any] = None,
    ):
        self.features_group = features_group
        self.non_longitudinal_features = non_longitudinal_features
        self.feature_list_names = feature_list_names

        self.estimator = estimator
        self.voting = voting
        self.stacking_meta_learner = stacking_meta_learner

        self.n_jobs = n_jobs
        self.parallel = parallel
        self.num_cpus = num_cpus
        self.class_weight = class_weight

        self.estimators = []
        self.dataset = pd.DataFrame([])
        self.target = np.ndarray([])
        self.clf_ensemble = None

    @override
    def _prepare_data(self, X: np.ndarray, y: np.ndarray = None) -> "SepWav":
        """Prepare the data for transformation.

        In `SepWav`, data preparation is handled within the `fit` method. This method is overridden for compatibility
        with `DataPreparationMixin` but performs no operations.

        Args:
            X (np.ndarray): Input data.
            y (np.ndarray, optional): Target data. Defaults to None.

        Returns:
            SepWav: The instance itself.
        """
        return self

    @property
    def classes_(self):
        if self.clf_ensemble is None:
            raise NotFittedError(
                "This SepWav instance is not fitted yet. Call 'fit' with appropriate arguments."
            )
        return self.clf_ensemble.classes_

    @validate_extract_wave_input
    @validate_extract_wave_output
    def _extract_wave(
        self, wave: int, extract_indices: bool = False
    ) -> Union[
        Tuple[pd.DataFrame, pd.Series], Tuple[pd.DataFrame, pd.Series, List[int]]
    ]:
        """Extract a specific wave from the dataset for training.

        Args:
            wave (int): Wave number to extract (0-based index).
            extract_indices (bool, optional): Whether to return feature indices. Defaults to False.

        Returns:
            tuple: If extract_indices is True, returns (X_wave, y_wave, feature_indices); otherwise, (X_wave, y_wave).

                - [x] X_wave (pd.DataFrame): Input samples for the wave.
                - [x] y_wave (pd.Series): Target values for the wave.
                - [x] feature_indices (list): Indices of extracted features (if extract_indices is True).

        Raises:
            ValueError: If wave number is negative.
        """
        feature_indices = [
            group[wave] for group in self.features_group if wave < len(group)
        ]
        if self.non_longitudinal_features is not None:
            feature_indices.extend(self.non_longitudinal_features)

        if any(idx >= self.dataset.shape[1] or idx < 0 for idx in feature_indices):
            raise IndexError(
                f"Feature index out of bounds for wave {wave}: {feature_indices} (df.shape={self.dataset.shape})"
            )

        X_wave = self.dataset.iloc[:, feature_indices]
        y_wave = self.target

        if extract_indices and feature_indices:
            return X_wave, y_wave, feature_indices
        return X_wave, y_wave

    # pylint: disable=unused-argument,too-many-branches
    @validate_fit_input
    @validate_fit_output
    def fit(
        self,
        X: Union[List[List[float]], "np.ndarray"],
        y: Union[List[float], "np.ndarray"],
        sample_weight: Union[List[float], np.ndarray, None] = None,
    ):
        """Fit the SepWav model to the training data.

        Trains a classifier for each wave and combines them using the specified ensemble strategy.

        Args:
            X (Union[List[List[float]], np.ndarray]): Input samples.
            y (Union[List[float], np.ndarray]): Target values.
            sample_weight (Union[List[float], np.ndarray], optional): Sample weights. Defaults to None.

        Returns:
            SepWav: Fitted instance.

        Raises:
            ValueError: If required parameters (estimator, features_group) are None or ensemble strategy is invalid.
        """
        self.dataset = pd.DataFrame(X, columns=self.feature_list_names)
        self.target = y

        if self.features_group is not None:
            self.features_group = clean_padding(self.features_group)

        n_waves = max(len(group) for group in self.features_group)

        if self.parallel:
            ray = get_ray_for_parallel(self.parallel, self.num_cpus)
            if sample_weight is not None:
                raise ValueError(
                    "Sample weights are not supported in parallel mode. "
                    "Please set parallel=False or remove sample_weight."
                )
            train_classifier = ray.remote(_train_classifier)
            futures = [
                train_classifier.remote(
                    self.estimator, X_train, y_train, wave, self.class_weight
                )
                for wave, (X_train, y_train) in enumerate(
                    self._extract_wave(wave=i) for i in range(n_waves)
                )
            ]
            self.estimators = ray.get(futures)
        else:
            for i in range(n_waves):
                X_wave, y_wave = self._extract_wave(wave=i)
                clf_wave = clone(self.estimator)
                clf_wave = _set_class_weight_if_supported(clf_wave, self.class_weight)
                if hasattr(X_wave, "values") and hasattr(y_wave, "values"):
                    X_wave = X_wave.values
                    y_wave = y_wave.values
                fit_params = {}
                if sample_weight is not None:
                    try:
                        if "sample_weight" in signature(clf_wave.fit).parameters:
                            fit_params["sample_weight"] = sample_weight
                    except (TypeError, ValueError):
                        pass
                clf_wave.fit(X_wave, y_wave, **fit_params)
                self.estimators.append((f"wave_{i}", clf_wave))

        if self.voting == LongitudinalEnsemblingStrategy.STACKING:
            meta_learner = None
            if self.stacking_meta_learner is not None:
                meta_learner = clone(self.stacking_meta_learner)
                meta_learner = _set_class_weight_if_supported(
                    meta_learner, self.class_weight
                )
            self.clf_ensemble = LongitudinalStackingClassifier(
                estimators=self.estimators,
                meta_learner=meta_learner,
                n_jobs=self.n_jobs,
                extract_wave=self._extract_wave,
            )
        else:
            self.clf_ensemble = LongitudinalVotingClassifier(
                estimators=self.estimators,
                voting=self.voting,
                extract_wave=self._extract_wave,
                n_jobs=self.n_jobs,
            )

        X_data = self.dataset.values

        if hasattr(X_data, "flags") and not X_data.flags["C_CONTIGUOUS"]:
            X_data = np.ascontiguousarray(X_data)

        self.clf_ensemble.fit(X_data, self.target)

        return self

    @validate_predict_input
    def predict(
        self, X: Union[List[List[float]], "np.ndarray"]
    ) -> Union[List[float], "np.ndarray"]:
        """Predict class labels for input samples.

        Uses the ensemble classifier to combine predictions from individual wave classifiers.

        Args:
            X (Union[List[List[float]], np.ndarray]): Input samples.

        Returns:
            Union[List[float], np.ndarray]: Predicted class labels.

        Raises:
            NotImplementedError: If the ensemble classifier does not support prediction.
        """
        if hasattr(self.clf_ensemble, "predict"):
            return self.clf_ensemble.predict(X)
        raise NotImplementedError(
            f"predict is not implemented for this classifier: {self.clf_ensemble} / type: {type(self.clf_ensemble)}"
        )

    @validate_predict_input
    def predict_proba(
        self, X: Union[List[List[float]], "np.ndarray"]
    ) -> Union[List[List[float]], "np.ndarray"]:
        """Predict class probabilities for input samples.

        Computes probabilities using the ensemble classifier's `predict_proba` method, if available.

        Args:
            X (Union[List[List[float]], np.ndarray]): Input samples.

        Returns:
            Union[List[List[float]], np.ndarray]: Predicted class probabilities.

        Raises:
            NotImplementedError: If the ensemble classifier does not support probability predictions.
        """
        if hasattr(self.clf_ensemble, "predict_proba"):
            return self.clf_ensemble.predict_proba(X)
        raise NotImplementedError(
            "predict_proba is not implemented for this classifier: "
            f"{self.clf_ensemble} / type: {type(self.clf_ensemble)}"
        )

    @validate_predict_wave_input
    def predict_wave(
        self, wave: int, X: Union[List[List[float]], "np.ndarray"]
    ) -> Union[List[float], "np.ndarray"]:
        """Predict class labels using the classifier for a specific wave.

        Useful for analyzing wave-specific performance or custom ensemble strategies.

        Args:
            wave (int): Wave number (0-based index).
            X (Union[List[List[float]], np.ndarray]): Input samples.

        Returns:
            Union[List[float], np.ndarray]: Predicted class labels for the specified wave.
        """
        return self.estimators[wave][1].predict(X)

fit(X, y, sample_weight=None)

Fit the SepWav model to the training data.

Trains a classifier for each wave and combines them using the specified ensemble strategy.

Parameters:

Name Type Description Default
X Union[List[List[float]], ndarray]

Input samples.

required
y Union[List[float], ndarray]

Target values.

required
sample_weight Union[List[float], ndarray]

Sample weights. Defaults to None.

None

Returns:

Name Type Description
SepWav

Fitted instance.

Raises:

Type Description
ValueError

If required parameters (estimator, features_group) are None or ensemble strategy is invalid.

Source code in scikit_longitudinal/data_preparation/separate_waves.py
@validate_fit_input
@validate_fit_output
def fit(
    self,
    X: Union[List[List[float]], "np.ndarray"],
    y: Union[List[float], "np.ndarray"],
    sample_weight: Union[List[float], np.ndarray, None] = None,
):
    """Fit the SepWav model to the training data.

    Trains a classifier for each wave and combines them using the specified ensemble strategy.

    Args:
        X (Union[List[List[float]], np.ndarray]): Input samples.
        y (Union[List[float], np.ndarray]): Target values.
        sample_weight (Union[List[float], np.ndarray], optional): Sample weights. Defaults to None.

    Returns:
        SepWav: Fitted instance.

    Raises:
        ValueError: If required parameters (estimator, features_group) are None or ensemble strategy is invalid.
    """
    self.dataset = pd.DataFrame(X, columns=self.feature_list_names)
    self.target = y

    if self.features_group is not None:
        self.features_group = clean_padding(self.features_group)

    n_waves = max(len(group) for group in self.features_group)

    if self.parallel:
        ray = get_ray_for_parallel(self.parallel, self.num_cpus)
        if sample_weight is not None:
            raise ValueError(
                "Sample weights are not supported in parallel mode. "
                "Please set parallel=False or remove sample_weight."
            )
        train_classifier = ray.remote(_train_classifier)
        futures = [
            train_classifier.remote(
                self.estimator, X_train, y_train, wave, self.class_weight
            )
            for wave, (X_train, y_train) in enumerate(
                self._extract_wave(wave=i) for i in range(n_waves)
            )
        ]
        self.estimators = ray.get(futures)
    else:
        for i in range(n_waves):
            X_wave, y_wave = self._extract_wave(wave=i)
            clf_wave = clone(self.estimator)
            clf_wave = _set_class_weight_if_supported(clf_wave, self.class_weight)
            if hasattr(X_wave, "values") and hasattr(y_wave, "values"):
                X_wave = X_wave.values
                y_wave = y_wave.values
            fit_params = {}
            if sample_weight is not None:
                try:
                    if "sample_weight" in signature(clf_wave.fit).parameters:
                        fit_params["sample_weight"] = sample_weight
                except (TypeError, ValueError):
                    pass
            clf_wave.fit(X_wave, y_wave, **fit_params)
            self.estimators.append((f"wave_{i}", clf_wave))

    if self.voting == LongitudinalEnsemblingStrategy.STACKING:
        meta_learner = None
        if self.stacking_meta_learner is not None:
            meta_learner = clone(self.stacking_meta_learner)
            meta_learner = _set_class_weight_if_supported(
                meta_learner, self.class_weight
            )
        self.clf_ensemble = LongitudinalStackingClassifier(
            estimators=self.estimators,
            meta_learner=meta_learner,
            n_jobs=self.n_jobs,
            extract_wave=self._extract_wave,
        )
    else:
        self.clf_ensemble = LongitudinalVotingClassifier(
            estimators=self.estimators,
            voting=self.voting,
            extract_wave=self._extract_wave,
            n_jobs=self.n_jobs,
        )

    X_data = self.dataset.values

    if hasattr(X_data, "flags") and not X_data.flags["C_CONTIGUOUS"]:
        X_data = np.ascontiguousarray(X_data)

    self.clf_ensemble.fit(X_data, self.target)

    return self

predict(X)

Predict class labels for input samples.

Uses the ensemble classifier to combine predictions from individual wave classifiers.

Parameters:

Name Type Description Default
X Union[List[List[float]], ndarray]

Input samples.

required

Returns:

Type Description
Union[List[float], ndarray]

Union[List[float], np.ndarray]: Predicted class labels.

Raises:

Type Description
NotImplementedError

If the ensemble classifier does not support prediction.

Source code in scikit_longitudinal/data_preparation/separate_waves.py
@validate_predict_input
def predict(
    self, X: Union[List[List[float]], "np.ndarray"]
) -> Union[List[float], "np.ndarray"]:
    """Predict class labels for input samples.

    Uses the ensemble classifier to combine predictions from individual wave classifiers.

    Args:
        X (Union[List[List[float]], np.ndarray]): Input samples.

    Returns:
        Union[List[float], np.ndarray]: Predicted class labels.

    Raises:
        NotImplementedError: If the ensemble classifier does not support prediction.
    """
    if hasattr(self.clf_ensemble, "predict"):
        return self.clf_ensemble.predict(X)
    raise NotImplementedError(
        f"predict is not implemented for this classifier: {self.clf_ensemble} / type: {type(self.clf_ensemble)}"
    )

predict_proba(X)

Predict class probabilities for input samples.

Computes probabilities using the ensemble classifier's predict_proba method, if available.

Parameters:

Name Type Description Default
X Union[List[List[float]], ndarray]

Input samples.

required

Returns:

Type Description
Union[List[List[float]], ndarray]

Union[List[List[float]], np.ndarray]: Predicted class probabilities.

Raises:

Type Description
NotImplementedError

If the ensemble classifier does not support probability predictions.

Source code in scikit_longitudinal/data_preparation/separate_waves.py
@validate_predict_input
def predict_proba(
    self, X: Union[List[List[float]], "np.ndarray"]
) -> Union[List[List[float]], "np.ndarray"]:
    """Predict class probabilities for input samples.

    Computes probabilities using the ensemble classifier's `predict_proba` method, if available.

    Args:
        X (Union[List[List[float]], np.ndarray]): Input samples.

    Returns:
        Union[List[List[float]], np.ndarray]: Predicted class probabilities.

    Raises:
        NotImplementedError: If the ensemble classifier does not support probability predictions.
    """
    if hasattr(self.clf_ensemble, "predict_proba"):
        return self.clf_ensemble.predict_proba(X)
    raise NotImplementedError(
        "predict_proba is not implemented for this classifier: "
        f"{self.clf_ensemble} / type: {type(self.clf_ensemble)}"
    )

predict_wave(wave, X)

Predict class labels using the classifier for a specific wave.

Useful for analyzing wave-specific performance or custom ensemble strategies.

Parameters:

Name Type Description Default
wave int

Wave number (0-based index).

required
X Union[List[List[float]], ndarray]

Input samples.

required

Returns:

Type Description
Union[List[float], ndarray]

Union[List[float], np.ndarray]: Predicted class labels for the specified wave.

Source code in scikit_longitudinal/data_preparation/separate_waves.py
@validate_predict_wave_input
def predict_wave(
    self, wave: int, X: Union[List[List[float]], "np.ndarray"]
) -> Union[List[float], "np.ndarray"]:
    """Predict class labels using the classifier for a specific wave.

    Useful for analyzing wave-specific performance or custom ensemble strategies.

    Args:
        wave (int): Wave number (0-based index).
        X (Union[List[List[float]], np.ndarray]): Input samples.

    Returns:
        Union[List[float], np.ndarray]: Predicted class labels for the specified wave.
    """
    return self.estimators[wave][1].predict(X)

SepWav ensemble back-ends

SepWav delegates the final aggregation of per-wave predictions to one of the two classifiers below.

Longitudinal Voting Classifier

Aggregates per-wave predictions with a configurable voting rule: simple majority, linear or exponential recency decay, or cross-validation-weighted voting.

LongitudinalVotingClassifier

Bases: CustomClassifierMixinEstimator

Aggregates predictions from pre-trained base estimators using the voting rule specified by LongitudinalEnsemblingStrategy (majority, linear or exponential decay, or cross-validation-weighted). Supports both binary and multiclass targets, and wraps scikit-learn's VotingClassifier under the hood.

Parameters:

Name Type Description Default
voting LongitudinalEnsemblingStrategy, default=LongitudinalEnsemblingStrategy.MAJORITY_VOTING

The voting strategy to be used for the ensemble. Refer to the LongitudinalEnsemblingStrategy enum.

MAJORITY_VOTING
estimators List[CustomClassifierMixinEstimator]

A list of classifiers for the ensemble. Note that the classifiers need to be trained before being passed to the LongitudinalVotingClassifier.

required
extract_wave Callable

A function to extract specific wave data for training. Defaults to None. When provided, the order of estimators defines the wave order used for extraction.

None
n_jobs int, default=1

The number of jobs to run in parallel.

1

Attributes:

Name Type Description
clf_ensemble LongitudinalCustomVoting

The underlying custom voting classifier instance.

Raises:

Type Description
ValueError

If no estimators are provided or if an invalid voting strategy is specified.

NotFittedError

If attempting to predict or predict_proba before fitting the model.

Notes
  • predict_proba returns normalised vote shares across classes. These are consistent with the hard-voting decision returned by predict, but they are not calibrated probabilities.
Source code in scikit_longitudinal/estimators/ensemble/longitudinal_voting/longitudinal_voting.py
class LongitudinalVotingClassifier(CustomClassifierMixinEstimator):
    """
    Aggregates predictions from pre-trained base estimators using the voting rule specified by
    `LongitudinalEnsemblingStrategy` (majority, linear or exponential decay, or cross-validation-weighted). Supports
    both binary and multiclass targets, and wraps scikit-learn's `VotingClassifier` under the hood.

    Args:
        voting (LongitudinalEnsemblingStrategy, default=LongitudinalEnsemblingStrategy.MAJORITY_VOTING):
            The voting strategy to be used for the ensemble. Refer to the LongitudinalEnsemblingStrategy enum.
        estimators (List[CustomClassifierMixinEstimator]):
            A list of classifiers for the ensemble. Note that the classifiers need to be trained before being passed to
            the LongitudinalVotingClassifier.
        extract_wave (Callable, optional):
            A function to extract specific wave data for training. Defaults to None. When provided, the order of
            `estimators` defines the wave order used for extraction.
        n_jobs (int, default=1):
            The number of jobs to run in parallel.

    Attributes:
        clf_ensemble (LongitudinalCustomVoting):
            The underlying custom voting classifier instance.

    Raises:
        ValueError: If no estimators are provided or if an invalid voting strategy is specified.
        NotFittedError: If attempting to predict or predict_proba before fitting the model.

    Notes:
        - `predict_proba` returns normalised vote shares across classes. These are consistent with the hard-voting
          decision returned by `predict`, but they are not calibrated probabilities.
    """

    def __init__(
        self,
        estimators: List[CustomClassifierMixinEstimator],
        voting: LongitudinalEnsemblingStrategy = LongitudinalEnsemblingStrategy.MAJORITY_VOTING,
        extract_wave: Callable = None,
        n_jobs: int = 1,
    ) -> None:
        self.estimators = estimators
        self.voting = voting
        self.extract_wave = extract_wave
        self.n_jobs = n_jobs
        self.clf_ensemble = None

    @property
    def classes_(self):
        """
        Property to access the classes of the fitted ensemble model.

        Returns:
            np.ndarray: The class labels.

        Raises:
            NotFittedError: If the model is not fitted yet.
        """
        if self.clf_ensemble is None:
            raise NotFittedError(
                "This LongitudinalVotingClassifier instance is not fitted yet. Call 'fit' with appropriate arguments."
            )
        return self.clf_ensemble.classes_

    @override
    def _fit(
        self, X: np.ndarray, y: np.ndarray, sample_weight: np.ndarray = None
    ) -> "LongitudinalVotingClassifier":
        """
        Fit the ensemble model.

        Trains the ensemble based on the specified voting strategy.

        Args:
            X (np.ndarray):
                The training data.
            y (np.ndarray):
                The target values.

        Returns:
            LongitudinalVotingClassifier:
                The fitted ensemble model.

        Raises:
            ValueError: If no estimators are provided or if an invalid voting strategy is specified.
            NotFittedError: If attempting to predict or predict_proba before fitting the model.

        !!! tip "Estimator Training"
            Ensure all estimators are trained before passing them to the `LongitudinalVotingClassifier`.
        """
        _ = sample_weight

        if not self.estimators:
            raise ValueError("No estimators were provided.")

        if not isinstance(self.voting, LongitudinalEnsemblingStrategy):
            raise ValueError(
                f"Invalid ensemble strategy. It must be a value from {LongitudinalEnsemblingStrategy} enum."
            )

        strategy_method = {
            LongitudinalEnsemblingStrategy.MAJORITY_VOTING: self._fit_majority_voting,
            LongitudinalEnsemblingStrategy.DECAY_LINEAR_VOTING: self._fit_decay_linear_voting,
            LongitudinalEnsemblingStrategy.DECAY_EXPONENTIAL_VOTING: self._fit_decay_exponential_voting,
            LongitudinalEnsemblingStrategy.CV_BASED_VOTING: self._fit_cv_based_voting,
        }.get(self.voting)

        if strategy_method:
            strategy_method(X, y)
        else:
            raise ValueError(f"Invalid ensemble strategy: {self.voting}")

        return self

    @override
    def _predict(self, X: np.ndarray) -> np.ndarray:
        """
        Predict using the ensemble model.

        Generates predictions based on the aggregated votes of the base estimators.

        Args:
            X (np.ndarray):
                The test data.

        Returns:
            np.ndarray:
                The predicted values.

        Raises:
            NotFittedError: If attempting to predict before fitting the model.

        !!! tip "Tie-Breaking"
            In case of a tie, the prediction from the most recent estimator among the tied classes is selected.
        """
        if self.clf_ensemble:
            return self.clf_ensemble.predict(X)
        raise NotFittedError("Ensemble model is not fitted yet.")

    @override
    def _predict_proba(self, X: np.ndarray) -> np.ndarray:
        """
        Predict probabilities using the ensemble model.

        Returns normalised vote shares across classes, rather than averaged base-estimator confidence scores.

        Args:
            X (np.ndarray):
                The test data.

        Returns:
            np.ndarray:
                The predicted probabilities.

        Raises:
            NotFittedError: If attempting to predict before fitting the model.
        """
        if self.clf_ensemble:
            return self.clf_ensemble.predict_proba(X)
        raise NotFittedError("Ensemble model is not fitted yet.")

    def _extract_wave(self, X: np.ndarray, wave: int) -> np.ndarray:
        """
        Extract the data for the given wave.

        Uses the `extract_wave` function to retrieve specific wave data.

        Args:
            X (np.ndarray):
                The training data.
            wave (int):
                The wave number to extract.

        Returns:
            np.ndarray:
                The extracted data.
        """
        if self.extract_wave:
            return X[:, self.extract_wave(wave, extract_indices=True)[2]]
        return X

    def _fit_majority_voting(self, X: np.ndarray, y: np.ndarray) -> None:
        """
        Fit the ensemble model using majority voting strategy.

        Each estimator's vote is equally weighted.

        Args:
            X (np.ndarray):
                The training data.
            y (np.ndarray):
                The target values.
        """
        self.clf_ensemble = LongitudinalCustomVoting(
            self.estimators, extract_wave=self.extract_wave
        )
        self.clf_ensemble.fit(X, y)

    def _fit_decay_linear_voting(self, X: np.ndarray, y: np.ndarray) -> None:
        """
        Fit the ensemble model using linear decay weighted voting strategy.

        Weights are assigned linearly, with more recent waves having higher weights.

        Args:
            X (np.ndarray):
                The training data.
            y (np.ndarray):
                The target values.
        """
        weights = self._calculate_linear_decay_weights(len(self.estimators))
        self.clf_ensemble = LongitudinalCustomVoting(
            self.estimators, weights=weights, extract_wave=self.extract_wave
        )
        self.clf_ensemble.fit(X, y)

    def _fit_decay_exponential_voting(self, X: np.ndarray, y: np.ndarray) -> None:
        """
        Fit the ensemble model using exponential decay weighted voting strategy.

        Weights are assigned exponentially, favouring more recent waves.

        Args:
            X (np.ndarray):
                The training data.
            y (np.ndarray):
                The target values.
        """
        weights = self._calculate_exponential_decay_weights(len(self.estimators))
        self.clf_ensemble = LongitudinalCustomVoting(
            self.estimators, weights=weights, extract_wave=self.extract_wave
        )
        self.clf_ensemble.fit(X, y)

    def _fit_cv_based_voting(self, X: np.ndarray, y: np.ndarray) -> None:
        """
        Fit the ensemble model using cross-validation weighted voting strategy.

        Weights are based on each estimator's cross-validation accuracy.

        Args:
            X (np.ndarray):
                The training data.
            y (np.ndarray):
                The target values.
        """
        weights = self._calculate_cv_weights(X, y, k=5)
        self.clf_ensemble = LongitudinalCustomVoting(
            self.estimators, weights=weights, extract_wave=self.extract_wave
        )
        self.clf_ensemble.fit(X, y)

    def _calculate_cv_weights(
        self, X: np.ndarray, y: np.ndarray, k: int
    ) -> List[float]:
        """
        Calculate the weights based on cross-validation accuracy.

        Args:
            X (np.ndarray):
                The training data.
            y (np.ndarray):
                The target values.
            k (int):
                The number of folds for cross-validation.

        Returns:
            List[float]: Weights for each estimator.
        """
        accuracies = [
            cross_val_score(
                estimator, self._extract_wave(X, estimator_index), y, cv=k
            ).mean()
            for estimator_index, (_, estimator) in enumerate(self.estimators)
        ]
        total_accuracy = sum(accuracies)
        if np.isclose(total_accuracy, 0.0):
            return [1.0 / len(accuracies)] * len(accuracies)
        return [acc / total_accuracy for acc in accuracies]

    @staticmethod
    def _calculate_linear_decay_weights(N: int) -> List[float]:
        """
        Calculate the weights based on linear decay.

        Args:
            N (int):
                The number of waves.

        Returns:
            List[float]: Linear decay weights.
        """
        return [i / sum(range(1, N + 1)) for i in range(1, N + 1)]

    @staticmethod
    def _calculate_exponential_decay_weights(N: int) -> List[float]:
        """
        Calculate the weights based on exponential decay.

        Args:
            N (int):
                The number of waves.

        Returns:
            List[float]: Exponential decay weights.
        """
        return [
            np.exp(i) / sum(np.exp(j) for j in range(1, N + 1)) for i in range(1, N + 1)
        ]
fit(X, y=None, sample_weight=None)

Fit the classifier to the training data.

Validates X (and y when provided) with scikit-learn's check_X_y / check_array and then delegates to the subclass implementation in _fit. sample_weight is forwarded only when the subclass's _fit declares it.

Parameters:

Name Type Description Default
X ndarray

Training input samples of shape (n_samples, n_features).

required
y ndarray

Target class labels of shape (n_samples,).

None
sample_weight ndarray

Per-sample weights of shape (n_samples,). Forwarded to _fit only when supported.

None

Returns:

Name Type Description
CustomClassifierMixinEstimator CustomClassifierMixinEstimator

The fitted estimator (self).

Source code in scikit_longitudinal/templates/custom_classifier_mixin_estimator.py
@final
def fit(
    self, X: np.ndarray, y: np.ndarray = None, sample_weight: np.ndarray = None
) -> "CustomClassifierMixinEstimator":
    """Fit the classifier to the training data.

    Validates ``X`` (and ``y`` when provided) with scikit-learn's
    ``check_X_y`` / ``check_array`` and then delegates to the subclass
    implementation in ``_fit``. ``sample_weight`` is forwarded only when
    the subclass's ``_fit`` declares it.

    Args:
        X (np.ndarray):
            Training input samples of shape ``(n_samples, n_features)``.
        y (np.ndarray, optional):
            Target class labels of shape ``(n_samples,)``.
        sample_weight (np.ndarray, optional):
            Per-sample weights of shape ``(n_samples,)``. Forwarded to
            ``_fit`` only when supported.

    Returns:
        CustomClassifierMixinEstimator: The fitted estimator (``self``).
    """
    if y is None:
        return self._check_array_decorator(self._fit)(X)
    _fit_sig = inspect.signature(self._fit)
    if "sample_weight" in _fit_sig.parameters:
        return self._check_X_y_decorator(self._fit)(
            X, y, sample_weight=sample_weight
        )
    else:
        return self._check_X_y_decorator(self._fit)(X, y)
predict(X)

Predict class labels for the input samples.

Validates X with scikit-learn's check_array and delegates to the subclass implementation in _predict.

Parameters:

Name Type Description Default
X ndarray

Input samples of shape (n_samples, n_features).

required

Returns:

Type Description
ndarray

np.ndarray: Predicted class labels of shape (n_samples,).

Source code in scikit_longitudinal/templates/custom_classifier_mixin_estimator.py
@final
def predict(self, X: np.ndarray) -> np.ndarray:
    """Predict class labels for the input samples.

    Validates ``X`` with scikit-learn's ``check_array`` and delegates to
    the subclass implementation in ``_predict``.

    Args:
        X (np.ndarray):
            Input samples of shape ``(n_samples, n_features)``.

    Returns:
        np.ndarray: Predicted class labels of shape ``(n_samples,)``.
    """
    return self._check_array_decorator(self._predict)(X)
predict_proba(X)

Predict class probabilities for the input samples.

Validates X with scikit-learn's check_array and delegates to the subclass implementation in _predict_proba.

Parameters:

Name Type Description Default
X ndarray

Input samples of shape (n_samples, n_features).

required

Returns:

Type Description
ndarray

np.ndarray: Class probabilities of shape (n_samples, n_classes),

ndarray

with columns ordered as in self.classes_.

Source code in scikit_longitudinal/templates/custom_classifier_mixin_estimator.py
@final
def predict_proba(self, X: np.ndarray) -> np.ndarray:
    """Predict class probabilities for the input samples.

    Validates ``X`` with scikit-learn's ``check_array`` and delegates to
    the subclass implementation in ``_predict_proba``.

    Args:
        X (np.ndarray):
            Input samples of shape ``(n_samples, n_features)``.

    Returns:
        np.ndarray: Class probabilities of shape ``(n_samples, n_classes)``,
        with columns ordered as in ``self.classes_``.
    """
    return self._check_array_decorator(self._predict_proba)(X)

LongitudinalEnsemblingStrategy

Bases: Enum

An enum for the different longitudinal voting strategies.

Attributes:

Name Type Description
MAJORITY_VOTING int

Simple consensus voting where the most frequent prediction is selected.

DECAY_LINEAR_VOTING int

Weights each classifier's vote based on the recency of its wave using a linear decay. Weight formula:

\[w_i = \frac{i}{\sum_{j=1}^{N} j}\]
DECAY_EXPONENTIAL_VOTING int

Weights each classifier's vote based on the recency of its wave using an exponential decay. Weight formula:

\[w_i = \frac{e^{i}}{\sum_{j=1}^{N} e^{j}}\]
CV_BASED_VOTING int

Weights each classifier based on its cross-validation accuracy on the training data. Weight formula:

\[w_i = \frac{A_i}{\sum_{j=1}^{N} A_j}\]
STACKING int

Stacking ensemble strategy uses a meta-learner to combine predictions of base classifiers. The meta-learner is trained on meta-features formed from the base classifiers' predicted class probabilities. This approach is suitable when the cardinality of meta-features is smaller than the original feature set.

In stacking, for each wave \(i\) (\(i \in \{1, 2, \ldots, N\}\)), a base classifier \(C_i\) is trained on \((X_i, T_N)\). The class-probability output from \(C_i\) is denoted as \(V_i\), forming the meta-features \(\mathbf{V} = [V_1, V_2, ..., V_N]\). The meta-learner \(M\) is then trained on \((\mathbf{V}, T_N)\), and for a new instance \(x\), the final prediction is \(P(x) = M(\mathbf{V}(x))\).

Source code in scikit_longitudinal/estimators/ensemble/longitudinal_voting/longitudinal_voting.py
class LongitudinalEnsemblingStrategy(Enum):
    """
    An enum for the different longitudinal voting strategies.

    Attributes:
        MAJORITY_VOTING (int):
            Simple consensus voting where the most frequent prediction is selected.
        DECAY_LINEAR_VOTING (int):
            Weights each classifier's vote based on the recency of its wave using a linear decay.
            Weight formula:

            $$w_i = \\frac{i}{\\sum_{j=1}^{N} j}$$

        DECAY_EXPONENTIAL_VOTING (int):
            Weights each classifier's vote based on the recency of its wave using an exponential decay.
            Weight formula:

            $$w_i = \\frac{e^{i}}{\\sum_{j=1}^{N} e^{j}}$$

        CV_BASED_VOTING (int):
            Weights each classifier based on its cross-validation accuracy on the training data.
            Weight formula:

            $$w_i = \\frac{A_i}{\\sum_{j=1}^{N} A_j}$$

        STACKING (int):
            Stacking ensemble strategy uses a meta-learner to combine predictions of base classifiers.
            The meta-learner is trained on meta-features formed from the base classifiers' predicted class
            probabilities.
            This approach is suitable when the cardinality of meta-features is smaller than the original feature set.

            In stacking, for each wave $i$ ($i \\in \\{1, 2, \\ldots, N\\}$), a base classifier $C_i$
            is trained on $(X_i, T_N)$. The class-probability output from $C_i$ is denoted as $V_i$,
            forming the meta-features $\\mathbf{V} = [V_1, V_2, ..., V_N]$. The meta-learner $M$ is then
            trained on $(\\mathbf{V}, T_N)$, and for a new instance $x$, the final prediction is
            $P(x) = M(\\mathbf{V}(x))$.

    """

    MAJORITY_VOTING = auto()
    DECAY_LINEAR_VOTING = auto()
    DECAY_EXPONENTIAL_VOTING = auto()
    CV_BASED_VOTING = auto()
    STACKING = auto()

Longitudinal Stacking Classifier

Trains a meta-learner on the class probabilities emitted by the per-wave classifiers fitted by SepWav.

LongitudinalStackingClassifier

Bases: CustomClassifierMixinEstimator

Trains a meta-learner on the class-probability outputs of the pre-trained base estimators. Each base estimator must implement predict_proba; the meta-learner is then fitted on the stacked probabilities to produce the final prediction. Supports both binary and multiclass targets, and wraps scikit-learn's StackingClassifier under the hood. When extract_wave is provided, internal refits remain wave-specific.

Parameters:

Name Type Description Default
estimators List[CustomClassifierMixinEstimator]

The base estimators for the ensemble. These can be passed directly, or as estimators prepared by SepWav. Each estimator must implement predict_proba.

required
meta_learner Optional[Union[CustomClassifierMixinEstimator, ClassifierMixin]], default=LogisticRegression()

The meta-learner to be used in stacking. Can be any scikit-learn compliant classifier.

LogisticRegression()
n_jobs int, default=1

The number of jobs to run in parallel for fitting base estimators.

1
extract_wave Callable

Optional wave extractor used when estimators should remain wave-specific inside stacking, such as the SepWav workflow.

None

Attributes:

Name Type Description
clf_ensemble StackingClassifier

The underlying scikit-learn StackingClassifier instance.

Raises:

Type Description
ValueError

If no base estimators are provided, if a base estimator does not implement predict_proba, or if the meta-learner is not suitable.

NotFittedError

If attempting to predict or predict_proba before fitting the model.

Source code in scikit_longitudinal/estimators/ensemble/longitudinal_stacking/longitudinal_stacking.py
class LongitudinalStackingClassifier(CustomClassifierMixinEstimator):
    """
    Trains a meta-learner on the class-probability outputs of the pre-trained base estimators. Each base estimator
    must implement `predict_proba`; the meta-learner is then fitted on the stacked probabilities to produce the
    final prediction. Supports both binary and multiclass targets, and wraps scikit-learn's `StackingClassifier`
    under the hood. When `extract_wave` is provided, internal refits remain wave-specific.

    Args:
        estimators (List[CustomClassifierMixinEstimator]):
            The base estimators for the ensemble. These can be passed directly, or as estimators prepared by `SepWav`.
            Each estimator must implement `predict_proba`.
        meta_learner (Optional[Union[CustomClassifierMixinEstimator, ClassifierMixin]], default=LogisticRegression()):
            The meta-learner to be used in stacking. Can be any scikit-learn compliant classifier.
        n_jobs (int, default=1):
            The number of jobs to run in parallel for fitting base estimators.
        extract_wave (Callable, optional):
            Optional wave extractor used when estimators should remain wave-specific inside stacking, such as the
            `SepWav` workflow.

    Attributes:
        clf_ensemble (StackingClassifier):
            The underlying scikit-learn StackingClassifier instance.

    Raises:
        ValueError: If no base estimators are provided, if a base estimator does not implement `predict_proba`, or if
            the meta-learner is not suitable.
        NotFittedError: If attempting to predict or predict_proba before fitting the model.
    """

    def __init__(
        self,
        estimators: List[CustomClassifierMixinEstimator],
        meta_learner: Optional[
            Union[CustomClassifierMixinEstimator, ClassifierMixin]
        ] = LogisticRegression(),
        n_jobs: int = 1,
        extract_wave: Callable = None,
    ) -> None:
        self.estimators = estimators
        self.meta_learner = meta_learner
        self.n_jobs = n_jobs
        self.extract_wave = extract_wave
        self.clf_ensemble = None

    @property
    def classes_(self):
        if self.clf_ensemble is None:
            raise NotFittedError(
                "This LongitudinalStackingClassifier instance is not fitted yet. Call 'fit' with appropriate arguments."
            )
        return self.clf_ensemble.classes_

    @override
    def _fit(
        self, X: np.ndarray, y: np.ndarray, sample_weight: Optional[np.ndarray] = None
    ) -> "LongitudinalStackingClassifier":
        """
        Fit the ensemble model.

        Trains the stacking ensemble by combining out-of-fold base-estimator probability predictions and fitting the
        meta-learner.

        Args:
            X (np.ndarray):
                The input data.
            y (np.ndarray):
                The target data.

        Returns:
            LongitudinalStackingClassifier: The fitted model.

        Raises:
            ValueError: If no base estimators are provided, if a base estimator does not implement `predict_proba`, or
                if the meta-learner is not suitable.

        !!! tip "Meta-Learner Selection"
            Choose a meta-learner that complements your base estimators. For example, use Logistic Regression for linear
            decision boundaries or a Decision Tree for more complex interactions.
        """
        if not self.estimators:
            raise ValueError("No estimators were provided.")

        if not hasattr(self.meta_learner, "fit") or not hasattr(
            self.meta_learner, "predict"
        ):
            raise ValueError(
                "The meta learner must be a classifier with a fit and predict scikit-compliant format."
            )
        if any(
            not hasattr(estimator, "predict_proba") for _, estimator in self.estimators
        ):
            raise ValueError(
                "All base estimators must implement predict_proba for LongitudinalStackingClassifier."
            )

        estimators = self.estimators
        stack_method = "predict_proba"
        if self.extract_wave is not None:
            estimators = [
                (
                    name,
                    _WaveAwareEstimator(
                        estimator=estimator, wave=wave, extract_wave=self.extract_wave
                    ),
                )
                for wave, (name, estimator) in enumerate(self.estimators)
            ]

        self.clf_ensemble = StackingClassifier(
            estimators=estimators,
            final_estimator=self.meta_learner,
            n_jobs=self.n_jobs,
            stack_method=stack_method,
        )

        fit_params = {}
        if sample_weight is not None:
            fit_params["sample_weight"] = sample_weight

        self.clf_ensemble.fit(X, y, **fit_params)
        return self

    @override
    def _predict(self, X: np.ndarray) -> np.ndarray:
        """
        Predict using the ensemble model.

        Generates predictions by passing stacked base-estimator probability outputs to the meta-learner.

        Args:
            X (np.ndarray):
                The input data.

        Returns:
            np.ndarray: The predicted target data.

        Raises:
            NotFittedError: If attempting to predict before fitting the model.
        """
        if self.clf_ensemble:
            return self.clf_ensemble.predict(X)
        raise NotFittedError("Ensemble model is not fitted yet.")

    @override
    def _predict_proba(self, X: np.ndarray) -> np.ndarray:
        """
        Predict probabilities using the ensemble model.

        Generates probability estimates from the meta-learner based on stacked base-estimator probability outputs.

        Args:
            X (np.ndarray):
                The input data.

        Returns:
            np.ndarray: The predicted target data probabilities.

        Raises:
            NotFittedError: If attempting to predict before fitting the model.

        !!! tip "Probability Calibration"
            If your meta-learner supports probability calibration (e.g., Logistic Regression), consider calibrating
            probabilities for better confidence estimates.
        """
        if self.clf_ensemble:
            return self.clf_ensemble.predict_proba(X)
        raise NotFittedError("Ensemble model is not fitted yet.")
fit(X, y=None, sample_weight=None)

Fit the classifier to the training data.

Validates X (and y when provided) with scikit-learn's check_X_y / check_array and then delegates to the subclass implementation in _fit. sample_weight is forwarded only when the subclass's _fit declares it.

Parameters:

Name Type Description Default
X ndarray

Training input samples of shape (n_samples, n_features).

required
y ndarray

Target class labels of shape (n_samples,).

None
sample_weight ndarray

Per-sample weights of shape (n_samples,). Forwarded to _fit only when supported.

None

Returns:

Name Type Description
CustomClassifierMixinEstimator CustomClassifierMixinEstimator

The fitted estimator (self).

Source code in scikit_longitudinal/templates/custom_classifier_mixin_estimator.py
@final
def fit(
    self, X: np.ndarray, y: np.ndarray = None, sample_weight: np.ndarray = None
) -> "CustomClassifierMixinEstimator":
    """Fit the classifier to the training data.

    Validates ``X`` (and ``y`` when provided) with scikit-learn's
    ``check_X_y`` / ``check_array`` and then delegates to the subclass
    implementation in ``_fit``. ``sample_weight`` is forwarded only when
    the subclass's ``_fit`` declares it.

    Args:
        X (np.ndarray):
            Training input samples of shape ``(n_samples, n_features)``.
        y (np.ndarray, optional):
            Target class labels of shape ``(n_samples,)``.
        sample_weight (np.ndarray, optional):
            Per-sample weights of shape ``(n_samples,)``. Forwarded to
            ``_fit`` only when supported.

    Returns:
        CustomClassifierMixinEstimator: The fitted estimator (``self``).
    """
    if y is None:
        return self._check_array_decorator(self._fit)(X)
    _fit_sig = inspect.signature(self._fit)
    if "sample_weight" in _fit_sig.parameters:
        return self._check_X_y_decorator(self._fit)(
            X, y, sample_weight=sample_weight
        )
    else:
        return self._check_X_y_decorator(self._fit)(X, y)
predict(X)

Predict class labels for the input samples.

Validates X with scikit-learn's check_array and delegates to the subclass implementation in _predict.

Parameters:

Name Type Description Default
X ndarray

Input samples of shape (n_samples, n_features).

required

Returns:

Type Description
ndarray

np.ndarray: Predicted class labels of shape (n_samples,).

Source code in scikit_longitudinal/templates/custom_classifier_mixin_estimator.py
@final
def predict(self, X: np.ndarray) -> np.ndarray:
    """Predict class labels for the input samples.

    Validates ``X`` with scikit-learn's ``check_array`` and delegates to
    the subclass implementation in ``_predict``.

    Args:
        X (np.ndarray):
            Input samples of shape ``(n_samples, n_features)``.

    Returns:
        np.ndarray: Predicted class labels of shape ``(n_samples,)``.
    """
    return self._check_array_decorator(self._predict)(X)
predict_proba(X)

Predict class probabilities for the input samples.

Validates X with scikit-learn's check_array and delegates to the subclass implementation in _predict_proba.

Parameters:

Name Type Description Default
X ndarray

Input samples of shape (n_samples, n_features).

required

Returns:

Type Description
ndarray

np.ndarray: Class probabilities of shape (n_samples, n_classes),

ndarray

with columns ordered as in self.classes_.

Source code in scikit_longitudinal/templates/custom_classifier_mixin_estimator.py
@final
def predict_proba(self, X: np.ndarray) -> np.ndarray:
    """Predict class probabilities for the input samples.

    Validates ``X`` with scikit-learn's ``check_array`` and delegates to
    the subclass implementation in ``_predict_proba``.

    Args:
        X (np.ndarray):
            Input samples of shape ``(n_samples, n_features)``.

    Returns:
        np.ndarray: Class probabilities of shape ``(n_samples, n_classes)``,
        with columns ordered as in ``self.classes_``.
    """
    return self._check_array_decorator(self._predict_proba)(X)