Skip to content

API Reference

This section provides the complete API documentation for seroepi, parsed directly from the docstrings.

seroepi.accessors

Module to handle epidemiological, geospatial and genotypic operations on isolate datasets in the form of Pandas DataFrames.

EpiAccessor

Pandas accessor for epidemiological analysis on isolate datasets.

Provides methods for generating epidemic curves, calculating prevalence, diversity, and incidence, and identifying transmission clusters.

Source code in src/seroepi/accessors.py
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
@pd.api.extensions.register_dataframe_accessor("epi")
class EpiAccessor:
    """
    Pandas accessor for epidemiological analysis on isolate datasets.

    Provides methods for generating epidemic curves, calculating prevalence,
    diversity, and incidence, and identifying transmission clusters.
    """
    def __init__(self, pandas_obj: pd.DataFrame):
        self._obj = pandas_obj

    # --- Shiny UI State Checkers ---

    @property
    def has_temporal(self) -> bool:
        """Checks if the dataset contains valid temporal data (any 'temporal_' column)."""
        cols = self._obj.filter(regex=f"^{Domain.TEMPORAL.value}_(?!res_)").columns
        return bool(len(cols) > 0 and self._obj[cols[0]].notna().any())

    @property
    def has_spatial(self) -> bool:
        """Checks if the dataset contains valid spatial coordinates (latitude and longitude)."""
        return 'latitude' in self._obj.columns and 'longitude' in self._obj.columns

    # --- Spatiotemporal Helpers ---

    @property
    def temporal(self) -> pd.DataFrame:
        """
        Returns a DataFrame of all temporal columns, with the prefix removed.
        """
        return self._obj.filter(regex=f"^{Domain.TEMPORAL.value}_(?!res_)").rename(columns=lambda c: c.replace(f'{Domain.TEMPORAL.value}_', '', 1))

    @property
    def temporal_resolution(self) -> pd.DataFrame:
        """
        Returns a DataFrame of all temporal resolution columns.
        """
        return self._obj.filter(regex=f"^{Domain.TEMPORAL_RES.value}_").rename(columns=lambda c: c.replace(f'{Domain.TEMPORAL_RES.value}_', '', 1))

    @property
    def spatial(self) -> pd.DataFrame:
        """
        Returns the core spatial coordinates (latitude, longitude).

        Raises:
            ValueError: If spatial columns are missing.
        """
        if not self.has_spatial:
            raise ValueError("Spatial columns ('latitude', 'longitude') are missing.")
        return self._obj[['latitude', 'longitude']].astype("Float64")

    # --- Time Series / Epidemic Curve Methods ---

    def _get_spatiotemporal_arrays(self, temporal_col: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
        """Helper to extract and format coordinates and dates for spatial clustering."""
        df = self._obj

        if not pd.api.types.is_datetime64_any_dtype(df[temporal_col]):
            raise TypeError(f"Temporal column '{temporal_col}' must be a datetime type. Ensure data is parsed via seroepi.io.")

        # Schema guarantees datetime64, so we just safely strip timezones if present
        date_series = df[temporal_col].dt.tz_localize(None)
        # Coerce coordinates to standard numpy floats
        coords = np.radians(df[['latitude', 'longitude']].astype(float).values)

        # Convert dates to raw days (use float to allow NaNs for missing dates)
        raw_dates = np.full(len(df), np.nan)
        date_mask = date_series.notna().values
        raw_dates[date_mask] = date_series[date_mask].values.astype('datetime64[D]').astype(float)

        # Create a boolean mask of rows that have all required spatiotemporal data
        valid_mask = ~(np.isnan(coords[:, 0]) | np.isnan(coords[:, 1]) | np.isnan(raw_dates))
        return coords, raw_dates, valid_mask

    def epidemic_curve(self, freq: Union[str, TemporalResolution] = TemporalResolution.WEEK,
                       stratify_by: str = None, temporal_col: str = None) -> pd.DataFrame:
        """
        Generates a time-series DataFrame for plotting epidemic curves.

        Args:
            freq: Time frequency for resampling (e.g., TimeResolution.MONTH, 'ME', 'YE').
                Defaults to TimeResolution.WEEK.
            stratify_by: Column name to group by before resampling.

        Returns:
            A DataFrame with counts of isolates per time interval.

        Raises:
            ValueError: If temporal data is not available.
        """
        if not self.has_temporal:
            raise ValueError("Cannot generate epi curve: No temporal data available.")

        df = self._obj.copy()

        if isinstance(freq, TemporalResolution):
            freq_val = freq.pandas_offset
        else:
            freq_val = freq

        if temporal_col is None:
            temporal_col = df.filter(regex=f"^{Domain.TEMPORAL.value}_(?!res_)").columns[0]
        elif not temporal_col.startswith(f"{Domain.TEMPORAL.value}_"):
            temporal_col = f"{Domain.TEMPORAL.value}_{temporal_col}"

        if not pd.api.types.is_datetime64_any_dtype(df[temporal_col]):
            raise TypeError(f"Temporal column '{temporal_col}' must be a datetime type. Ensure data is parsed via seroepi.io.")

        # Set time as index for resampling
        df = df.set_index(temporal_col)

        if stratify_by:
            curve = df.groupby(stratify_by, observed=True).resample(freq_val).size().unstack(level=0, fill_value=0)
        else:
            curve = df.resample(freq_val).size().to_frame(name='count')

        return curve.reset_index()

    @property
    def metadata_columns(self) -> list[str]:
        """Returns the raw names of user-uploaded clinical/metadata columns."""
        return self._obj.filter(regex="^meta_").columns.tolist()

    @property
    def ui_metadata_columns(self) -> list[str]:
        """Returns clean metadata names (without 'meta_' prefix) for UI display."""
        return [c.replace('meta_', '', 1) for c in self.metadata_columns]

    @property
    def genotypes(self) -> list[str]:
        """Compiles a list of all genetic variables (Core + Accessory traits)."""
        # Grab all the dynamically prefixed traits (genotype, phenotype, amr, virulence)
        return self._obj.filter(regex=f"^({Domain.GENOTYPE.value}|{Domain.PHENOTYPE.value}|{Domain.AMR.value}|{Domain.VIRULENCE.value})_").columns.tolist()

    @property
    def stratify_cols(self) -> list[str]:
        """Returns columns suitable for stratification (excluding QC, metadata, and high-cardinality/internal cols)."""
        exclude_strat = ['sample_id', 'latitude', 'longitude']
        return [c for c in self._obj.columns if not c.startswith((f'{Domain.QC.value}_', 'meta_', f'{Domain.SPATIAL_RES.value}_', f'{Domain.TEMPORAL_RES.value}_')) and c not in exclude_strat]

    @property
    def cluster_cols(self) -> list[str]:
        """Returns columns suitable for cluster adjustment (e.g., transmission clusters and ST)."""
        return [c for c in self._obj.columns if c.startswith(f"{Domain.CLUSTER.value}_") or c.endswith('_ST')]

    @staticmethod
    def _calculate_events_and_denoms(
        df: pd.DataFrame,
        trait_col: str,
        denom_cols: list[str],
        trait_strata: list[str],
        cluster_col: str,
        negative_indicator: Union[str, list[str]],
        event_name: str,
        denom_name: str
    ) -> tuple[Union[pd.Series, int], Union[pd.Series, int]]:
        """Helper to calculate numerators and denominators for aggregated metrics."""
        if trait_col:
            valid_df = df.dropna(subset=[trait_col]).copy()

            if pd.api.types.is_bool_dtype(valid_df[trait_col]):
                valid_df['_trait_bool'] = valid_df[trait_col]
            else:
                neg_list = [negative_indicator] if isinstance(negative_indicator, str) else negative_indicator
                valid_df['_trait_bool'] = ~valid_df[trait_col].isin(neg_list)

            if cluster_col:
                denoms = valid_df.groupby(denom_cols, observed=True)[cluster_col].nunique().rename(denom_name) if denom_cols else valid_df[cluster_col].nunique()
                events = valid_df[valid_df['_trait_bool']].groupby(trait_strata, observed=True)[cluster_col].nunique().rename(event_name) if trait_strata else valid_df[valid_df['_trait_bool']][cluster_col].nunique()
            else:
                denoms = valid_df.groupby(denom_cols, observed=True).size().rename(denom_name) if denom_cols else len(valid_df)
                events = valid_df.groupby(trait_strata, observed=True)['_trait_bool'].sum().rename(event_name) if trait_strata else valid_df['_trait_bool'].sum()

        else:
            valid_df = df.dropna(subset=[trait_strata[-1]]).copy()

            if cluster_col:
                denoms = valid_df.groupby(denom_cols, observed=True)[cluster_col].nunique().rename(denom_name) if denom_cols else valid_df[cluster_col].nunique()
                events = valid_df.groupby(trait_strata, observed=True)[cluster_col].nunique().rename(event_name)
            else:
                denoms = valid_df.groupby(denom_cols, observed=True).size().rename(denom_name) if denom_cols else len(valid_df)
                events = valid_df.groupby(trait_strata, observed=True).size().rename(event_name)

        return events, denoms

    def aggregate_prevalence(self, stratify_by: list[str], trait_col: str = None,
                             cluster_col: str = None, negative_indicator: Union[str, list[str]] = '-',
                             pad_zeros: bool = False) -> pd.DataFrame:
        """
        Aggregates data to calculate event counts and denominators for prevalence.

        Supports both trait prevalence (presence/absence of a marker) and
        compositional prevalence (distribution of variants within a locus).

        Args:
            stratify_by: Columns to group by (e.g., ['spatial']).
            trait_col: The column containing the trait/marker to measure.
                If None, compositional prevalence is calculated for the last
                column in `stratify_by`.
            cluster_col: Column containing cluster IDs to adjust for (e.g., nosocomial outbreaks).
            negative_indicator: Value(s) representing the absence of a trait.
                Defaults to '-'.
            pad_zeros: If True, pads missing combinations of strata with zero counts.
                Essential for spatial/hierarchical models. If False, only includes
                observed combinations for efficiency. Defaults to False.

        Returns:
            An aggregated DataFrame with 'event' and 'n' columns.

        Raises:
            ValueError: If compositional prevalence is requested without enough strata.
        """
        df = self._obj.copy()

        if trait_col:
            denom_cols = stratify_by
            trait_strata = stratify_by
        else:
            if len(stratify_by) < 1:
                raise ValueError("Compositional prevalence requires at least 1 stratify_by column.")
            denom_cols = stratify_by[:-1]
            trait_strata = stratify_by

        events, denoms = self._calculate_events_and_denoms(
            df=df,
            trait_col=trait_col,
            denom_cols=denom_cols,
            trait_strata=trait_strata,
            cluster_col=cluster_col,
            negative_indicator=negative_indicator,
            event_name='event',
            denom_name='n'
        )

        # Expand Grid
        if len(stratify_by) > 0:
            if pad_zeros:
                # Full Cartesian expansion (padding zeroes)
                unique_levels = [df[col].dropna().unique() for col in stratify_by]
                base_index = pd.MultiIndex.from_product(unique_levels, names=stratify_by)
            else:
                # Efficiently use only the observed strata combinations
                base_index = denoms.index if trait_col else events.index

            agg_df = pd.DataFrame(index=base_index).join(events, how='left').fillna({'event': 0}).reset_index()

            # Safely map denominators
            if len(denom_cols) == 0:
                agg_df['n'] = denoms
            else:
                agg_df = agg_df.set_index(denom_cols)
                agg_df['n'] = denoms
                agg_df = agg_df.reset_index().fillna({'n': 0})

            if not pad_zeros:
                agg_df = agg_df[agg_df['n'] > 0].copy()
        else:
            agg_df = pd.DataFrame({'event': [events], 'n': [denoms]})

        # Standardize to 'target' for uniform downstream handling
        if trait_col:
            agg_df['target'] = trait_col
        else:
            agg_df = agg_df.rename(columns={stratify_by[-1]: 'target'})

        agg_df.attrs = self._obj.attrs.copy()
        agg_df.attrs['metric_meta'] = {
            "metric_type": MetricType.PREVALENCE,
            "stratified_by": denom_cols,
            "trait": trait_col if trait_col else stratify_by[-1],
            "aggregation_type": AggregationType.TRAIT if trait_col else AggregationType.COMPOSITIONAL,
            "adjusted_for": cluster_col,
            "is_zero_padded": pad_zeros
        }

        return agg_df

    def aggregate_diversity(self, stratify_by: list[str], trait_col: str = None,
                            cluster_col: str = None, negative_indicator: Union[str, list[str]] = '-',
                            pad_zeros: bool = False) -> pd.DataFrame:
        """
        Aggregates data to calculate counts for diversity analysis (e.g., Shannon index).

        Args:
            stratify_by: Columns to group by.
            trait_col: The locus or trait to measure diversity for.
            cluster_col: Optional cluster column to adjust for.
            negative_indicator: Value(s) to exclude from diversity counts.
            pad_zeros: If True, pads missing combinations of strata with zero counts.
                Defaults to False.

        Returns:
            A DataFrame with 'variant_count' and 'n_total'.

        Raises:
            ValueError: If compositional diversity is requested without enough strata.
        """
        df = self._obj.copy()

        is_trait = True if trait_col else False
        if trait_col:
            groupers = stratify_by
            trait_strata = stratify_by + [trait_col]
            valid_df = df.dropna(subset=[trait_col]).copy()

            # For trait diversity, we often want to strip out the "absence" indicators
            # so they don't count as a diversity variant mathematically
            if not pd.api.types.is_bool_dtype(valid_df[trait_col]):
                neg_list = [negative_indicator] if isinstance(negative_indicator, str) else negative_indicator
                valid_df = valid_df[~valid_df[trait_col].isin(neg_list)]
        else:
            if not stratify_by:
                raise ValueError("Compositional diversity requires at least 1 stratify_by column.")
            groupers = stratify_by[:-1]
            trait_col = stratify_by[-1]
            trait_strata = stratify_by
            valid_df = df.dropna(subset=[trait_col]).copy()

        if cluster_col:
            div_df = valid_df.groupby(trait_strata, observed=True)[cluster_col].nunique().reset_index()
            div_df = div_df.rename(columns={cluster_col: 'variant_count'})
        else:
            div_df = valid_df.groupby(trait_strata, observed=True).size().reset_index(name='variant_count')

        if pad_zeros and trait_strata:
            unique_levels = [df[col].dropna().unique() for col in trait_strata]
            expanded_index = pd.MultiIndex.from_product(unique_levels, names=trait_strata)
            div_df = div_df.set_index(trait_strata).reindex(expanded_index, fill_value=0).reset_index()

        if groupers:
            div_df['n_total'] = div_df.groupby(groupers, observed=True)['variant_count'].transform('sum')
        else:
            div_df['n_total'] = div_df['variant_count'].sum()

        # Standardize to 'target'
        div_df = div_df.rename(columns={trait_col: 'target'})
        if is_trait:
            div_df['target'] = trait_col

        div_df.attrs = self._obj.attrs.copy()
        div_df.attrs["metric_meta"] = {
            "metric_type": MetricType.DIVERSITY,
            "stratified_by": groupers,
            "trait": trait_col,
            "aggregation_type": AggregationType.TRAIT if trait_col else AggregationType.COMPOSITIONAL,
            "adjusted_for": cluster_col,
            "is_zero_padded": pad_zeros
        }

        return div_df

    def aggregate_incidence(self, stratify_by: list[str], trait_col: str = None, freq: Union[str, TemporalResolution] = TemporalResolution.MONTH,
                            cluster_col: str = None, negative_indicator: Union[str, list[str]] = '-',
                            pad_zeros: bool = False, temporal_col: str = None) -> pd.DataFrame:
        """
        Aggregates data for time-series incidence analysis.

        Args:
            stratify_by: Columns to group by.
            trait_col: The marker to measure incidence for.
            freq: Time frequency for binning (e.g., TimeResolution.MONTH, 'ME'). Defaults to TimeResolution.MONTH.
            cluster_col: Optional cluster column to adjust for.
            negative_indicator: Value(s) representing absence.
            pad_zeros: If True, pads missing combinations of strata. If False,
                maintains unbroken time grids only for observed strata combinations.

        Returns:
            A DataFrame with 'variant_count', 'total_sequenced', and binned dates.

        Raises:
            ValueError: If temporal data is missing or inappropriate strata provided.
        """
        df = self._obj.copy()

        if temporal_col is None:
            temporal_cols = df.filter(regex=f"^{Domain.TEMPORAL.value}_(?!res_)").columns
            if not len(temporal_cols):
                raise ValueError("Incidence aggregation requires a valid temporal column.")
            temporal_col = temporal_cols[0]
        elif not temporal_col.startswith(f"{Domain.TEMPORAL.value}_"):
            temporal_col = f"{Domain.TEMPORAL.value}_{temporal_col}"

        if temporal_col not in df.columns or not pd.api.types.is_datetime64_any_dtype(df[temporal_col]):
            raise ValueError(f"Incidence aggregation requires a valid datetime64 temporal column. '{temporal_col}' invalid.")

        # Translate Pandas 2.2+ point offsets ('ME') to period spans ('M')
        if isinstance(freq, TemporalResolution):
            period_freq = freq.pandas_period
            stored_freq = freq.value
        else:
            period_freq = freq.replace('ME', 'M').replace('YE', 'Y')
            stored_freq = freq

        # Snap dates to the requested frequency bin using the safe string
        df['date_bin'] = df[temporal_col].dt.to_period(period_freq).dt.to_timestamp()

        if trait_col:
            denom_cols = ['date_bin'] + stratify_by
            trait_strata = ['date_bin'] + stratify_by
        else:
            if not stratify_by:
                raise ValueError("Compositional incidence requires at least 1 stratify_by column.")

            denom_cols = ['date_bin'] + stratify_by[:-1]
            trait_strata = ['date_bin'] + stratify_by

        events, denoms = self._calculate_events_and_denoms(
            df=df,
            trait_col=trait_col,
            denom_cols=denom_cols,
            trait_strata=trait_strata,
            cluster_col=cluster_col,
            negative_indicator=negative_indicator,
            event_name='variant_count',
            denom_name='total_sequenced'
        )

        # --- TIME GRID EXPANSION ---
        # Generate an unbroken sequence of dates from the earliest to the latest record
        min_date, max_date = df['date_bin'].min(), df['date_bin'].max()
        all_dates = pd.period_range(min_date, max_date, freq=period_freq).to_timestamp().tolist() if not pd.isna(
            min_date) else []

        if pad_zeros:
            unique_levels = [all_dates] + [df[col].dropna().unique() for col in trait_strata[1:]]
            expanded_index = pd.MultiIndex.from_product(unique_levels, names=trait_strata)
        else:
            if len(trait_strata) > 1:
                # Only use strata combinations that actually appear in the data
                observed_strata = df[trait_strata[1:]].dropna().drop_duplicates()
                dates_df = pd.DataFrame({trait_strata[0]: all_dates})
                expanded_df = dates_df.merge(observed_strata, how='cross')
                expanded_index = pd.MultiIndex.from_frame(expanded_df[trait_strata])
            else:
                expanded_index = pd.Index(all_dates, name=trait_strata[0])

        inc_df = pd.DataFrame(index=expanded_index).join(events, how='left').fillna({'variant_count': 0}).reset_index()

        inc_df = inc_df.set_index(denom_cols)
        inc_df['total_sequenced'] = denoms
        inc_df = inc_df.reset_index().fillna({'total_sequenced': 0})

        # Unlike Prevalence, we do NOT drop rows where total_sequenced == 0.
        # A true 0 sequence volume is critical information for an epicurve gap.
        inc_df = inc_df.rename(columns={'date_bin': 'date'})

        # Standardize to 'target'
        if trait_col:
            inc_df['target'] = trait_col
        else:
            inc_df = inc_df.rename(columns={stratify_by[-1]: 'target'})

        inc_df.attrs = self._obj.attrs.copy()
        inc_df.attrs['metric_meta'] = {
            "metric_type": MetricType.INCIDENCE,
            "stratified_by": stratify_by if trait_col else stratify_by[:-1],
            "trait": trait_col if trait_col else stratify_by[-1],
            "aggregation_type": AggregationType.TRAIT if trait_col else AggregationType.COMPOSITIONAL,
            "freq": stored_freq,
            "adjusted_for": cluster_col,
            "is_zero_padded": pad_zeros
        }

        return inc_df

    def transmission_network(
            self,
            clone_col: str,
            spatial_threshold_km: float = 10.0,
            temporal_threshold_days: int = 20,
            temporal_col: str = None
    ) -> TransmissionDistances:
        """
        Builds a sparse adjacency graph of transmission links.

        Args:
            clone_col: Column containing clone IDs (e.g., 'ST' or a custom cluster).
            spatial_threshold_km: Maximum distance in kilometers. Defaults to 10.0.
            temporal_threshold_days: Maximum time difference in days. Defaults to 20.

        Returns:
            A TransmissionDistances object representing the outbreak network.

        Raises:
            KeyError: If required columns ('latitude', 'longitude', 'date') are missing.
        """
        df = self._obj

        # 1. Validation Checks
        if clone_col not in df.columns:
            raise KeyError(f"Clone column '{clone_col}' not found in DataFrame.")

        # Intelligently check for your geo accessor/columns
        if 'latitude' not in df.columns or 'longitude' not in df.columns:
            raise KeyError("Spatial clustering requires 'latitude' and 'longitude' columns. "
                           "Ensure geo accessors have parsed coordinates.")

        if temporal_col is None:
            temporal_cols = df.filter(regex=f"^{Domain.TEMPORAL.value}_(?!res_)").columns
            if not len(temporal_cols):
                raise KeyError("A temporal column is required for temporal clustering.")
            temporal_col = temporal_cols[0]
        elif not temporal_col.startswith(f"{Domain.TEMPORAL.value}_"):
            temporal_col = f"{Domain.TEMPORAL.value}_{temporal_col}"

        if temporal_col not in df.columns:
            raise KeyError(f"Temporal column '{temporal_col}' not found.")

        df = self._obj
        coords, raw_dates, _ = self._get_spatiotemporal_arrays(temporal_col)

        return TransmissionDistances.from_spatiotemporal(
            sample_ids=df['sample_id'],
            coords=coords,
            dates=raw_dates,
            clones=df[clone_col].values,
            spatial_threshold_km=spatial_threshold_km,
            temporal_threshold_days=temporal_threshold_days
        )

    def transmission_clusters(
            self,
            clone_col: str,
            spatial_threshold_km: float = 10.0,
            temporal_threshold_days: int = 20,
            temporal_col: str = None,
            network: TransmissionDistances = None
    ) -> pd.Series:
        """Extracts and formats categorical cluster labels from the transmission network."""
        df = self._obj

        if temporal_col is None:
            temporal_cols = df.filter(regex=f"^{Domain.TEMPORAL.value}_(?!res_)").columns
            if not len(temporal_cols):
                raise KeyError("A temporal column is required for temporal clustering.")
            temporal_col = temporal_cols[0]
        elif not temporal_col.startswith(f"{Domain.TEMPORAL.value}_"):
            temporal_col = f"{Domain.TEMPORAL.value}_{temporal_col}"

        if network is None:
            network = self.transmission_network(clone_col, spatial_threshold_km, temporal_threshold_days, temporal_col)

        labels = network.get_clusters()

        _, _, valid_mask = self._get_spatiotemporal_arrays(temporal_col)
        clone_mask = df[clone_col].notna().values

        labels_array = labels.astype(float).to_numpy(copy=True)
        labels_array[~valid_mask] = np.nan
        labels_array[~clone_mask] = np.nan

        res = pd.Series(labels_array, index=df.index, dtype="Int64",
                        name=f'{Domain.CLUSTER.value}_transmission_{spatial_threshold_km}km_{temporal_threshold_days}days')
        return res.astype("category").cat.as_ordered()

cluster_cols property

cluster_cols: list[str]

Returns columns suitable for cluster adjustment (e.g., transmission clusters and ST).

genotypes property

genotypes: list[str]

Compiles a list of all genetic variables (Core + Accessory traits).

has_spatial property

has_spatial: bool

Checks if the dataset contains valid spatial coordinates (latitude and longitude).

has_temporal property

has_temporal: bool

Checks if the dataset contains valid temporal data (any 'temporal_' column).

metadata_columns property

metadata_columns: list[str]

Returns the raw names of user-uploaded clinical/metadata columns.

spatial property

spatial: DataFrame

Returns the core spatial coordinates (latitude, longitude).

Raises:

Type Description
ValueError

If spatial columns are missing.

stratify_cols property

stratify_cols: list[str]

Returns columns suitable for stratification (excluding QC, metadata, and high-cardinality/internal cols).

temporal property

temporal: DataFrame

Returns a DataFrame of all temporal columns, with the prefix removed.

temporal_resolution property

temporal_resolution: DataFrame

Returns a DataFrame of all temporal resolution columns.

ui_metadata_columns property

ui_metadata_columns: list[str]

Returns clean metadata names (without 'meta_' prefix) for UI display.

aggregate_diversity

aggregate_diversity(stratify_by: list[str], trait_col: str = None, cluster_col: str = None, negative_indicator: Union[str, list[str]] = '-', pad_zeros: bool = False) -> pd.DataFrame

Aggregates data to calculate counts for diversity analysis (e.g., Shannon index).

Parameters:

Name Type Description Default
stratify_by list[str]

Columns to group by.

required
trait_col str

The locus or trait to measure diversity for.

None
cluster_col str

Optional cluster column to adjust for.

None
negative_indicator Union[str, list[str]]

Value(s) to exclude from diversity counts.

'-'
pad_zeros bool

If True, pads missing combinations of strata with zero counts. Defaults to False.

False

Returns:

Type Description
DataFrame

A DataFrame with 'variant_count' and 'n_total'.

Raises:

Type Description
ValueError

If compositional diversity is requested without enough strata.

Source code in src/seroepi/accessors.py
def aggregate_diversity(self, stratify_by: list[str], trait_col: str = None,
                        cluster_col: str = None, negative_indicator: Union[str, list[str]] = '-',
                        pad_zeros: bool = False) -> pd.DataFrame:
    """
    Aggregates data to calculate counts for diversity analysis (e.g., Shannon index).

    Args:
        stratify_by: Columns to group by.
        trait_col: The locus or trait to measure diversity for.
        cluster_col: Optional cluster column to adjust for.
        negative_indicator: Value(s) to exclude from diversity counts.
        pad_zeros: If True, pads missing combinations of strata with zero counts.
            Defaults to False.

    Returns:
        A DataFrame with 'variant_count' and 'n_total'.

    Raises:
        ValueError: If compositional diversity is requested without enough strata.
    """
    df = self._obj.copy()

    is_trait = True if trait_col else False
    if trait_col:
        groupers = stratify_by
        trait_strata = stratify_by + [trait_col]
        valid_df = df.dropna(subset=[trait_col]).copy()

        # For trait diversity, we often want to strip out the "absence" indicators
        # so they don't count as a diversity variant mathematically
        if not pd.api.types.is_bool_dtype(valid_df[trait_col]):
            neg_list = [negative_indicator] if isinstance(negative_indicator, str) else negative_indicator
            valid_df = valid_df[~valid_df[trait_col].isin(neg_list)]
    else:
        if not stratify_by:
            raise ValueError("Compositional diversity requires at least 1 stratify_by column.")
        groupers = stratify_by[:-1]
        trait_col = stratify_by[-1]
        trait_strata = stratify_by
        valid_df = df.dropna(subset=[trait_col]).copy()

    if cluster_col:
        div_df = valid_df.groupby(trait_strata, observed=True)[cluster_col].nunique().reset_index()
        div_df = div_df.rename(columns={cluster_col: 'variant_count'})
    else:
        div_df = valid_df.groupby(trait_strata, observed=True).size().reset_index(name='variant_count')

    if pad_zeros and trait_strata:
        unique_levels = [df[col].dropna().unique() for col in trait_strata]
        expanded_index = pd.MultiIndex.from_product(unique_levels, names=trait_strata)
        div_df = div_df.set_index(trait_strata).reindex(expanded_index, fill_value=0).reset_index()

    if groupers:
        div_df['n_total'] = div_df.groupby(groupers, observed=True)['variant_count'].transform('sum')
    else:
        div_df['n_total'] = div_df['variant_count'].sum()

    # Standardize to 'target'
    div_df = div_df.rename(columns={trait_col: 'target'})
    if is_trait:
        div_df['target'] = trait_col

    div_df.attrs = self._obj.attrs.copy()
    div_df.attrs["metric_meta"] = {
        "metric_type": MetricType.DIVERSITY,
        "stratified_by": groupers,
        "trait": trait_col,
        "aggregation_type": AggregationType.TRAIT if trait_col else AggregationType.COMPOSITIONAL,
        "adjusted_for": cluster_col,
        "is_zero_padded": pad_zeros
    }

    return div_df

aggregate_incidence

aggregate_incidence(stratify_by: list[str], trait_col: str = None, freq: Union[str, TemporalResolution] = TemporalResolution.MONTH, cluster_col: str = None, negative_indicator: Union[str, list[str]] = '-', pad_zeros: bool = False, temporal_col: str = None) -> pd.DataFrame

Aggregates data for time-series incidence analysis.

Parameters:

Name Type Description Default
stratify_by list[str]

Columns to group by.

required
trait_col str

The marker to measure incidence for.

None
freq Union[str, TemporalResolution]

Time frequency for binning (e.g., TimeResolution.MONTH, 'ME'). Defaults to TimeResolution.MONTH.

MONTH
cluster_col str

Optional cluster column to adjust for.

None
negative_indicator Union[str, list[str]]

Value(s) representing absence.

'-'
pad_zeros bool

If True, pads missing combinations of strata. If False, maintains unbroken time grids only for observed strata combinations.

False

Returns:

Type Description
DataFrame

A DataFrame with 'variant_count', 'total_sequenced', and binned dates.

Raises:

Type Description
ValueError

If temporal data is missing or inappropriate strata provided.

Source code in src/seroepi/accessors.py
def aggregate_incidence(self, stratify_by: list[str], trait_col: str = None, freq: Union[str, TemporalResolution] = TemporalResolution.MONTH,
                        cluster_col: str = None, negative_indicator: Union[str, list[str]] = '-',
                        pad_zeros: bool = False, temporal_col: str = None) -> pd.DataFrame:
    """
    Aggregates data for time-series incidence analysis.

    Args:
        stratify_by: Columns to group by.
        trait_col: The marker to measure incidence for.
        freq: Time frequency for binning (e.g., TimeResolution.MONTH, 'ME'). Defaults to TimeResolution.MONTH.
        cluster_col: Optional cluster column to adjust for.
        negative_indicator: Value(s) representing absence.
        pad_zeros: If True, pads missing combinations of strata. If False,
            maintains unbroken time grids only for observed strata combinations.

    Returns:
        A DataFrame with 'variant_count', 'total_sequenced', and binned dates.

    Raises:
        ValueError: If temporal data is missing or inappropriate strata provided.
    """
    df = self._obj.copy()

    if temporal_col is None:
        temporal_cols = df.filter(regex=f"^{Domain.TEMPORAL.value}_(?!res_)").columns
        if not len(temporal_cols):
            raise ValueError("Incidence aggregation requires a valid temporal column.")
        temporal_col = temporal_cols[0]
    elif not temporal_col.startswith(f"{Domain.TEMPORAL.value}_"):
        temporal_col = f"{Domain.TEMPORAL.value}_{temporal_col}"

    if temporal_col not in df.columns or not pd.api.types.is_datetime64_any_dtype(df[temporal_col]):
        raise ValueError(f"Incidence aggregation requires a valid datetime64 temporal column. '{temporal_col}' invalid.")

    # Translate Pandas 2.2+ point offsets ('ME') to period spans ('M')
    if isinstance(freq, TemporalResolution):
        period_freq = freq.pandas_period
        stored_freq = freq.value
    else:
        period_freq = freq.replace('ME', 'M').replace('YE', 'Y')
        stored_freq = freq

    # Snap dates to the requested frequency bin using the safe string
    df['date_bin'] = df[temporal_col].dt.to_period(period_freq).dt.to_timestamp()

    if trait_col:
        denom_cols = ['date_bin'] + stratify_by
        trait_strata = ['date_bin'] + stratify_by
    else:
        if not stratify_by:
            raise ValueError("Compositional incidence requires at least 1 stratify_by column.")

        denom_cols = ['date_bin'] + stratify_by[:-1]
        trait_strata = ['date_bin'] + stratify_by

    events, denoms = self._calculate_events_and_denoms(
        df=df,
        trait_col=trait_col,
        denom_cols=denom_cols,
        trait_strata=trait_strata,
        cluster_col=cluster_col,
        negative_indicator=negative_indicator,
        event_name='variant_count',
        denom_name='total_sequenced'
    )

    # --- TIME GRID EXPANSION ---
    # Generate an unbroken sequence of dates from the earliest to the latest record
    min_date, max_date = df['date_bin'].min(), df['date_bin'].max()
    all_dates = pd.period_range(min_date, max_date, freq=period_freq).to_timestamp().tolist() if not pd.isna(
        min_date) else []

    if pad_zeros:
        unique_levels = [all_dates] + [df[col].dropna().unique() for col in trait_strata[1:]]
        expanded_index = pd.MultiIndex.from_product(unique_levels, names=trait_strata)
    else:
        if len(trait_strata) > 1:
            # Only use strata combinations that actually appear in the data
            observed_strata = df[trait_strata[1:]].dropna().drop_duplicates()
            dates_df = pd.DataFrame({trait_strata[0]: all_dates})
            expanded_df = dates_df.merge(observed_strata, how='cross')
            expanded_index = pd.MultiIndex.from_frame(expanded_df[trait_strata])
        else:
            expanded_index = pd.Index(all_dates, name=trait_strata[0])

    inc_df = pd.DataFrame(index=expanded_index).join(events, how='left').fillna({'variant_count': 0}).reset_index()

    inc_df = inc_df.set_index(denom_cols)
    inc_df['total_sequenced'] = denoms
    inc_df = inc_df.reset_index().fillna({'total_sequenced': 0})

    # Unlike Prevalence, we do NOT drop rows where total_sequenced == 0.
    # A true 0 sequence volume is critical information for an epicurve gap.
    inc_df = inc_df.rename(columns={'date_bin': 'date'})

    # Standardize to 'target'
    if trait_col:
        inc_df['target'] = trait_col
    else:
        inc_df = inc_df.rename(columns={stratify_by[-1]: 'target'})

    inc_df.attrs = self._obj.attrs.copy()
    inc_df.attrs['metric_meta'] = {
        "metric_type": MetricType.INCIDENCE,
        "stratified_by": stratify_by if trait_col else stratify_by[:-1],
        "trait": trait_col if trait_col else stratify_by[-1],
        "aggregation_type": AggregationType.TRAIT if trait_col else AggregationType.COMPOSITIONAL,
        "freq": stored_freq,
        "adjusted_for": cluster_col,
        "is_zero_padded": pad_zeros
    }

    return inc_df

aggregate_prevalence

aggregate_prevalence(stratify_by: list[str], trait_col: str = None, cluster_col: str = None, negative_indicator: Union[str, list[str]] = '-', pad_zeros: bool = False) -> pd.DataFrame

Aggregates data to calculate event counts and denominators for prevalence.

Supports both trait prevalence (presence/absence of a marker) and compositional prevalence (distribution of variants within a locus).

Parameters:

Name Type Description Default
stratify_by list[str]

Columns to group by (e.g., ['spatial']).

required
trait_col str

The column containing the trait/marker to measure. If None, compositional prevalence is calculated for the last column in stratify_by.

None
cluster_col str

Column containing cluster IDs to adjust for (e.g., nosocomial outbreaks).

None
negative_indicator Union[str, list[str]]

Value(s) representing the absence of a trait. Defaults to '-'.

'-'
pad_zeros bool

If True, pads missing combinations of strata with zero counts. Essential for spatial/hierarchical models. If False, only includes observed combinations for efficiency. Defaults to False.

False

Returns:

Type Description
DataFrame

An aggregated DataFrame with 'event' and 'n' columns.

Raises:

Type Description
ValueError

If compositional prevalence is requested without enough strata.

Source code in src/seroepi/accessors.py
def aggregate_prevalence(self, stratify_by: list[str], trait_col: str = None,
                         cluster_col: str = None, negative_indicator: Union[str, list[str]] = '-',
                         pad_zeros: bool = False) -> pd.DataFrame:
    """
    Aggregates data to calculate event counts and denominators for prevalence.

    Supports both trait prevalence (presence/absence of a marker) and
    compositional prevalence (distribution of variants within a locus).

    Args:
        stratify_by: Columns to group by (e.g., ['spatial']).
        trait_col: The column containing the trait/marker to measure.
            If None, compositional prevalence is calculated for the last
            column in `stratify_by`.
        cluster_col: Column containing cluster IDs to adjust for (e.g., nosocomial outbreaks).
        negative_indicator: Value(s) representing the absence of a trait.
            Defaults to '-'.
        pad_zeros: If True, pads missing combinations of strata with zero counts.
            Essential for spatial/hierarchical models. If False, only includes
            observed combinations for efficiency. Defaults to False.

    Returns:
        An aggregated DataFrame with 'event' and 'n' columns.

    Raises:
        ValueError: If compositional prevalence is requested without enough strata.
    """
    df = self._obj.copy()

    if trait_col:
        denom_cols = stratify_by
        trait_strata = stratify_by
    else:
        if len(stratify_by) < 1:
            raise ValueError("Compositional prevalence requires at least 1 stratify_by column.")
        denom_cols = stratify_by[:-1]
        trait_strata = stratify_by

    events, denoms = self._calculate_events_and_denoms(
        df=df,
        trait_col=trait_col,
        denom_cols=denom_cols,
        trait_strata=trait_strata,
        cluster_col=cluster_col,
        negative_indicator=negative_indicator,
        event_name='event',
        denom_name='n'
    )

    # Expand Grid
    if len(stratify_by) > 0:
        if pad_zeros:
            # Full Cartesian expansion (padding zeroes)
            unique_levels = [df[col].dropna().unique() for col in stratify_by]
            base_index = pd.MultiIndex.from_product(unique_levels, names=stratify_by)
        else:
            # Efficiently use only the observed strata combinations
            base_index = denoms.index if trait_col else events.index

        agg_df = pd.DataFrame(index=base_index).join(events, how='left').fillna({'event': 0}).reset_index()

        # Safely map denominators
        if len(denom_cols) == 0:
            agg_df['n'] = denoms
        else:
            agg_df = agg_df.set_index(denom_cols)
            agg_df['n'] = denoms
            agg_df = agg_df.reset_index().fillna({'n': 0})

        if not pad_zeros:
            agg_df = agg_df[agg_df['n'] > 0].copy()
    else:
        agg_df = pd.DataFrame({'event': [events], 'n': [denoms]})

    # Standardize to 'target' for uniform downstream handling
    if trait_col:
        agg_df['target'] = trait_col
    else:
        agg_df = agg_df.rename(columns={stratify_by[-1]: 'target'})

    agg_df.attrs = self._obj.attrs.copy()
    agg_df.attrs['metric_meta'] = {
        "metric_type": MetricType.PREVALENCE,
        "stratified_by": denom_cols,
        "trait": trait_col if trait_col else stratify_by[-1],
        "aggregation_type": AggregationType.TRAIT if trait_col else AggregationType.COMPOSITIONAL,
        "adjusted_for": cluster_col,
        "is_zero_padded": pad_zeros
    }

    return agg_df

epidemic_curve

epidemic_curve(freq: Union[str, TemporalResolution] = TemporalResolution.WEEK, stratify_by: str = None, temporal_col: str = None) -> pd.DataFrame

Generates a time-series DataFrame for plotting epidemic curves.

Parameters:

Name Type Description Default
freq Union[str, TemporalResolution]

Time frequency for resampling (e.g., TimeResolution.MONTH, 'ME', 'YE'). Defaults to TimeResolution.WEEK.

WEEK
stratify_by str

Column name to group by before resampling.

None

Returns:

Type Description
DataFrame

A DataFrame with counts of isolates per time interval.

Raises:

Type Description
ValueError

If temporal data is not available.

Source code in src/seroepi/accessors.py
def epidemic_curve(self, freq: Union[str, TemporalResolution] = TemporalResolution.WEEK,
                   stratify_by: str = None, temporal_col: str = None) -> pd.DataFrame:
    """
    Generates a time-series DataFrame for plotting epidemic curves.

    Args:
        freq: Time frequency for resampling (e.g., TimeResolution.MONTH, 'ME', 'YE').
            Defaults to TimeResolution.WEEK.
        stratify_by: Column name to group by before resampling.

    Returns:
        A DataFrame with counts of isolates per time interval.

    Raises:
        ValueError: If temporal data is not available.
    """
    if not self.has_temporal:
        raise ValueError("Cannot generate epi curve: No temporal data available.")

    df = self._obj.copy()

    if isinstance(freq, TemporalResolution):
        freq_val = freq.pandas_offset
    else:
        freq_val = freq

    if temporal_col is None:
        temporal_col = df.filter(regex=f"^{Domain.TEMPORAL.value}_(?!res_)").columns[0]
    elif not temporal_col.startswith(f"{Domain.TEMPORAL.value}_"):
        temporal_col = f"{Domain.TEMPORAL.value}_{temporal_col}"

    if not pd.api.types.is_datetime64_any_dtype(df[temporal_col]):
        raise TypeError(f"Temporal column '{temporal_col}' must be a datetime type. Ensure data is parsed via seroepi.io.")

    # Set time as index for resampling
    df = df.set_index(temporal_col)

    if stratify_by:
        curve = df.groupby(stratify_by, observed=True).resample(freq_val).size().unstack(level=0, fill_value=0)
    else:
        curve = df.resample(freq_val).size().to_frame(name='count')

    return curve.reset_index()

transmission_clusters

transmission_clusters(clone_col: str, spatial_threshold_km: float = 10.0, temporal_threshold_days: int = 20, temporal_col: str = None, network: TransmissionDistances = None) -> pd.Series

Extracts and formats categorical cluster labels from the transmission network.

Source code in src/seroepi/accessors.py
def transmission_clusters(
        self,
        clone_col: str,
        spatial_threshold_km: float = 10.0,
        temporal_threshold_days: int = 20,
        temporal_col: str = None,
        network: TransmissionDistances = None
) -> pd.Series:
    """Extracts and formats categorical cluster labels from the transmission network."""
    df = self._obj

    if temporal_col is None:
        temporal_cols = df.filter(regex=f"^{Domain.TEMPORAL.value}_(?!res_)").columns
        if not len(temporal_cols):
            raise KeyError("A temporal column is required for temporal clustering.")
        temporal_col = temporal_cols[0]
    elif not temporal_col.startswith(f"{Domain.TEMPORAL.value}_"):
        temporal_col = f"{Domain.TEMPORAL.value}_{temporal_col}"

    if network is None:
        network = self.transmission_network(clone_col, spatial_threshold_km, temporal_threshold_days, temporal_col)

    labels = network.get_clusters()

    _, _, valid_mask = self._get_spatiotemporal_arrays(temporal_col)
    clone_mask = df[clone_col].notna().values

    labels_array = labels.astype(float).to_numpy(copy=True)
    labels_array[~valid_mask] = np.nan
    labels_array[~clone_mask] = np.nan

    res = pd.Series(labels_array, index=df.index, dtype="Int64",
                    name=f'{Domain.CLUSTER.value}_transmission_{spatial_threshold_km}km_{temporal_threshold_days}days')
    return res.astype("category").cat.as_ordered()

transmission_network

transmission_network(clone_col: str, spatial_threshold_km: float = 10.0, temporal_threshold_days: int = 20, temporal_col: str = None) -> TransmissionDistances

Builds a sparse adjacency graph of transmission links.

Parameters:

Name Type Description Default
clone_col str

Column containing clone IDs (e.g., 'ST' or a custom cluster).

required
spatial_threshold_km float

Maximum distance in kilometers. Defaults to 10.0.

10.0
temporal_threshold_days int

Maximum time difference in days. Defaults to 20.

20

Returns:

Type Description
TransmissionDistances

A TransmissionDistances object representing the outbreak network.

Raises:

Type Description
KeyError

If required columns ('latitude', 'longitude', 'date') are missing.

Source code in src/seroepi/accessors.py
def transmission_network(
        self,
        clone_col: str,
        spatial_threshold_km: float = 10.0,
        temporal_threshold_days: int = 20,
        temporal_col: str = None
) -> TransmissionDistances:
    """
    Builds a sparse adjacency graph of transmission links.

    Args:
        clone_col: Column containing clone IDs (e.g., 'ST' or a custom cluster).
        spatial_threshold_km: Maximum distance in kilometers. Defaults to 10.0.
        temporal_threshold_days: Maximum time difference in days. Defaults to 20.

    Returns:
        A TransmissionDistances object representing the outbreak network.

    Raises:
        KeyError: If required columns ('latitude', 'longitude', 'date') are missing.
    """
    df = self._obj

    # 1. Validation Checks
    if clone_col not in df.columns:
        raise KeyError(f"Clone column '{clone_col}' not found in DataFrame.")

    # Intelligently check for your geo accessor/columns
    if 'latitude' not in df.columns or 'longitude' not in df.columns:
        raise KeyError("Spatial clustering requires 'latitude' and 'longitude' columns. "
                       "Ensure geo accessors have parsed coordinates.")

    if temporal_col is None:
        temporal_cols = df.filter(regex=f"^{Domain.TEMPORAL.value}_(?!res_)").columns
        if not len(temporal_cols):
            raise KeyError("A temporal column is required for temporal clustering.")
        temporal_col = temporal_cols[0]
    elif not temporal_col.startswith(f"{Domain.TEMPORAL.value}_"):
        temporal_col = f"{Domain.TEMPORAL.value}_{temporal_col}"

    if temporal_col not in df.columns:
        raise KeyError(f"Temporal column '{temporal_col}' not found.")

    df = self._obj
    coords, raw_dates, _ = self._get_spatiotemporal_arrays(temporal_col)

    return TransmissionDistances.from_spatiotemporal(
        sample_ids=df['sample_id'],
        coords=coords,
        dates=raw_dates,
        clones=df[clone_col].values,
        spatial_threshold_km=spatial_threshold_km,
        temporal_threshold_days=temporal_threshold_days
    )

GenoAccessor

Pandas accessor for genetic and trait-based operations.

Provides methods for filtering determinants, checking for trait patterns, and sorting loci.

Examples:

>>> import pandas as pd
>>> import seroepi.accessors
>>> df = pd.DataFrame({'amr_blaKPC': [True, False], 'vir_ybt': [True, True]})
>>> amr_matrix = df.geno.amr
Source code in src/seroepi/accessors.py
@pd.api.extensions.register_dataframe_accessor("geno")
class GenoAccessor:
    """
    Pandas accessor for genetic and trait-based operations.

    Provides methods for filtering determinants, checking for trait patterns,
    and sorting loci.

    Examples:
        >>> import pandas as pd
        >>> import seroepi.accessors
        >>> df = pd.DataFrame({'amr_blaKPC': [True, False], 'vir_ybt': [True, True]})
        >>> amr_matrix = df.geno.amr
    """
    def __init__(self, pandas_obj: pd.DataFrame):
        self._obj = pandas_obj

    @property
    def genotype(self) -> pd.DataFrame:
        """Returns the Core Genotype matrix with the prefix removed from names."""
        return self._obj.filter(regex=f"^{Domain.GENOTYPE.value}_").rename(columns=lambda c: c.replace(f'{Domain.GENOTYPE.value}_', '', 1))

    @property
    def phenotype(self) -> pd.DataFrame:
        """Returns the Phenotype matrix with the prefix removed from names."""
        return self._obj.filter(regex=f"^{Domain.PHENOTYPE.value}_").rename(columns=lambda c: c.replace(f'{Domain.PHENOTYPE.value}_', '', 1))

    @property
    def amr(self) -> pd.DataFrame:
        """Returns the AMR determinant matrix with the prefix removed from names."""
        return self._obj.filter(regex=f"^{Domain.AMR.value}_").rename(columns=lambda c: c.replace(f'{Domain.AMR.value}_', '', 1))

    @property
    def virulence(self) -> pd.DataFrame:
        """Returns the Virulence marker matrix with the prefix removed from names."""
        return self._obj.filter(regex=f"^{Domain.VIRULENCE.value}_").rename(columns=lambda c: c.replace(f'{Domain.VIRULENCE.value}_', '', 1))

    def has_any(self, traits: list[str], domain: Union[str, Domain] = Domain.AMR) -> pd.Series:
        """
        Checks if isolates possess ANY of the specified traits.

        Args:
            traits: List of trait names (without prefix).
            domain: The domain prefix (e.g., Domain.AMR, Domain.VIRULENCE). Defaults to Domain.AMR.

        Returns:
            A boolean Series indicating presence of any specified trait.
        """
        domain_val = domain.value if isinstance(domain, Domain) else domain
        # Safely prepend the prefix and check if the columns actually exist
        trait_cols = [f"{domain_val}_{t}" for t in traits if f"{domain_val}_{t}" in self._obj.columns]

        if not trait_cols:
            # If none of the genes exist in the dataset, no isolate has them
            return pd.Series(False, index=self._obj.index)

        # Pandas matrix math: Check across the columns (axis=1) for any True values
        return self._obj[trait_cols].any(axis=1)

    def has_all(self, traits: list[str], domain: Union[str, Domain] = Domain.VIRULENCE) -> pd.Series:
        """
        Checks if isolates possess ALL of the specified traits.

        Args:
            traits: List of trait names.
            domain: Domain prefix (e.g., Domain.VIRULENCE). Defaults to Domain.VIRULENCE.

        Returns:
            A boolean Series.
        """
        domain_val = domain.value if isinstance(domain, Domain) else domain
        trait_cols = [f"{domain_val}_{t}" for t in traits]

        # If the user asks for a gene that isn't even in the dataset, they can't have 'all'
        missing_cols = set(trait_cols) - set(self._obj.columns)
        if missing_cols:
            return pd.Series(False, index=self._obj.index)

        # Matrix math: Check if ALL trait columns are True
        return self._obj[trait_cols].all(axis=1)

    def has_gene(self, gene_col: str, gene_name: str) -> pd.Series:
        """
        Searches for a specific gene within a comma-separated column.

        Args:
            gene_col: Column name containing gene lists.
            gene_name: The specific gene to find.

        Returns:
            A boolean Series.
        """
        # str.contains with na=False avoids allocating a new Series with fillna
        return self._obj[gene_col].str.contains(gene_name, regex=False, na=False)

    def sort_loci(self, locus_col: str) -> pd.DataFrame:
        """
        Sorts the DataFrame numerically by locus (e.g., K2 before K10).

        Args:
            locus_col: Column containing locus names.

        Returns:
            A sorted copy of the DataFrame.
        """
        df = self._obj.copy()
        # Extract the integer part of the locus (e.g., "K10" -> 10, "O2v1" -> 2)
        sort_key = df[locus_col].str.extract(r'(\d+)', expand=False).astype(float)
        return df.iloc[sort_key.sort_values().index]

amr property

amr: DataFrame

Returns the AMR determinant matrix with the prefix removed from names.

genotype property

genotype: DataFrame

Returns the Core Genotype matrix with the prefix removed from names.

phenotype property

phenotype: DataFrame

Returns the Phenotype matrix with the prefix removed from names.

virulence property

virulence: DataFrame

Returns the Virulence marker matrix with the prefix removed from names.

has_all

has_all(traits: list[str], domain: Union[str, Domain] = Domain.VIRULENCE) -> pd.Series

Checks if isolates possess ALL of the specified traits.

Parameters:

Name Type Description Default
traits list[str]

List of trait names.

required
domain Union[str, Domain]

Domain prefix (e.g., Domain.VIRULENCE). Defaults to Domain.VIRULENCE.

VIRULENCE

Returns:

Type Description
Series

A boolean Series.

Source code in src/seroepi/accessors.py
def has_all(self, traits: list[str], domain: Union[str, Domain] = Domain.VIRULENCE) -> pd.Series:
    """
    Checks if isolates possess ALL of the specified traits.

    Args:
        traits: List of trait names.
        domain: Domain prefix (e.g., Domain.VIRULENCE). Defaults to Domain.VIRULENCE.

    Returns:
        A boolean Series.
    """
    domain_val = domain.value if isinstance(domain, Domain) else domain
    trait_cols = [f"{domain_val}_{t}" for t in traits]

    # If the user asks for a gene that isn't even in the dataset, they can't have 'all'
    missing_cols = set(trait_cols) - set(self._obj.columns)
    if missing_cols:
        return pd.Series(False, index=self._obj.index)

    # Matrix math: Check if ALL trait columns are True
    return self._obj[trait_cols].all(axis=1)

has_any

has_any(traits: list[str], domain: Union[str, Domain] = Domain.AMR) -> pd.Series

Checks if isolates possess ANY of the specified traits.

Parameters:

Name Type Description Default
traits list[str]

List of trait names (without prefix).

required
domain Union[str, Domain]

The domain prefix (e.g., Domain.AMR, Domain.VIRULENCE). Defaults to Domain.AMR.

AMR

Returns:

Type Description
Series

A boolean Series indicating presence of any specified trait.

Source code in src/seroepi/accessors.py
def has_any(self, traits: list[str], domain: Union[str, Domain] = Domain.AMR) -> pd.Series:
    """
    Checks if isolates possess ANY of the specified traits.

    Args:
        traits: List of trait names (without prefix).
        domain: The domain prefix (e.g., Domain.AMR, Domain.VIRULENCE). Defaults to Domain.AMR.

    Returns:
        A boolean Series indicating presence of any specified trait.
    """
    domain_val = domain.value if isinstance(domain, Domain) else domain
    # Safely prepend the prefix and check if the columns actually exist
    trait_cols = [f"{domain_val}_{t}" for t in traits if f"{domain_val}_{t}" in self._obj.columns]

    if not trait_cols:
        # If none of the genes exist in the dataset, no isolate has them
        return pd.Series(False, index=self._obj.index)

    # Pandas matrix math: Check across the columns (axis=1) for any True values
    return self._obj[trait_cols].any(axis=1)

has_gene

has_gene(gene_col: str, gene_name: str) -> pd.Series

Searches for a specific gene within a comma-separated column.

Parameters:

Name Type Description Default
gene_col str

Column name containing gene lists.

required
gene_name str

The specific gene to find.

required

Returns:

Type Description
Series

A boolean Series.

Source code in src/seroepi/accessors.py
def has_gene(self, gene_col: str, gene_name: str) -> pd.Series:
    """
    Searches for a specific gene within a comma-separated column.

    Args:
        gene_col: Column name containing gene lists.
        gene_name: The specific gene to find.

    Returns:
        A boolean Series.
    """
    # str.contains with na=False avoids allocating a new Series with fillna
    return self._obj[gene_col].str.contains(gene_name, regex=False, na=False)

sort_loci

sort_loci(locus_col: str) -> pd.DataFrame

Sorts the DataFrame numerically by locus (e.g., K2 before K10).

Parameters:

Name Type Description Default
locus_col str

Column containing locus names.

required

Returns:

Type Description
DataFrame

A sorted copy of the DataFrame.

Source code in src/seroepi/accessors.py
def sort_loci(self, locus_col: str) -> pd.DataFrame:
    """
    Sorts the DataFrame numerically by locus (e.g., K2 before K10).

    Args:
        locus_col: Column containing locus names.

    Returns:
        A sorted copy of the DataFrame.
    """
    df = self._obj.copy()
    # Extract the integer part of the locus (e.g., "K10" -> 10, "O2v1" -> 2)
    sort_key = df[locus_col].str.extract(r'(\d+)', expand=False).astype(float)
    return df.iloc[sort_key.sort_values().index]

GeoAccessor

Pandas accessor for geographical operations on isolate datasets.

Provides methods for standardizing location names, imputing missing coordinates using a gazetteer, and performing reverse geocoding.

Attributes:

Name Type Description
gazetteer DataFrame

A DataFrame containing centroid coordinates and metadata for countries.

Source code in src/seroepi/accessors.py
@pd.api.extensions.register_dataframe_accessor("geo")
class GeoAccessor:
    """
    Pandas accessor for geographical operations on isolate datasets.

    Provides methods for standardizing location names, imputing missing
    coordinates using a gazetteer, and performing reverse geocoding.

    Attributes:
        gazetteer (pd.DataFrame): A DataFrame containing centroid coordinates and
            metadata for countries.
    """

    # Class-level cache to prevent expensive dataframe recreation on every accessor call
    _gazetteer_df = None

    def __init__(self, pandas_obj: pd.DataFrame):
        self._obj = pandas_obj

    @property
    def gazetteer(self) -> pd.DataFrame:
        """Returns the internal gazetteer used for coordinate imputation."""
        if GeoAccessor._gazetteer_df is None:
            GeoAccessor._gazetteer_df = pd.DataFrame.from_dict(GAZETTEER_DICT, orient='index')
        return GeoAccessor._gazetteer_df

    @property
    def spatial(self) -> pd.DataFrame:
        """Returns a DataFrame of all spatial columns, with the prefix removed."""
        return self._obj.filter(regex=f"^{Domain.SPATIAL.value}_(?!res_)").rename(columns=lambda c: c.replace(f'{Domain.SPATIAL.value}_', '', 1))

    @property
    def spatial_resolution(self) -> pd.DataFrame:
        """Returns a DataFrame of all spatial resolution columns."""
        return self._obj.filter(regex=f"^{Domain.SPATIAL_RES.value}_").rename(columns=lambda c: c.replace(f'{Domain.SPATIAL_RES.value}_', '', 1))

    def standardize_and_impute(self, spatial_col: str = None) -> pd.DataFrame:
        """
        Standardizes spatial names and imputes missing coordinates.

        Uses the internal gazetteer to find centroids for countries when exact
        latitude and longitude are missing.

        Args:
            spatial_col: Optional specific spatial column to impute by. Defaults to the first mapped spatial column.

        Returns:
            A new DataFrame with imputed coordinates and spatial resolution metadata.
        """
        df = self._obj.copy()
        ref_data = self.gazetteer

        spatial_cols = df.filter(regex=f"^{Domain.SPATIAL.value}_(?!res_)").columns.tolist()
        if not spatial_cols:
            return df

        if spatial_col is None:
            spatial_col = spatial_cols[0]
        elif spatial_col not in df.columns and f"{Domain.SPATIAL.value}_{spatial_col}" in df.columns:
            spatial_col = f"{Domain.SPATIAL.value}_{spatial_col}"

        res_col = spatial_col.replace(f"{Domain.SPATIAL.value}_", f"{Domain.SPATIAL_RES.value}_")

        # Initialize tracking
        if res_col not in df.columns:
            df[res_col] = SpatialResolution.UNKNOWN.value

        # If it's a category, we need to ensure all possible enum values are in the categories before assigning
        if isinstance(df[res_col].dtype, pd.CategoricalDtype):
            missing_cats = [c for c in SpatialResolution.choices() if c not in df[res_col].cat.categories]
            if missing_cats:
                df[res_col] = df[res_col].cat.add_categories(missing_cats)

        exact_mask = df['latitude'].notna() & df['longitude'].notna()
        df.loc[exact_mask, res_col] = SpatialResolution.EXACT.value

        # Impute using the instant dictionary-backed dataframe
        clean_spatial = df[spatial_col].str.strip()
        needs_imputation = (~exact_mask) & clean_spatial.notna()

        # OPTIMIZATION: Only map the rows that actually need imputation
        impute_spatial = clean_spatial[needs_imputation]
        df.loc[needs_imputation, 'latitude'] = impute_spatial.map(ref_data['centroid_lat'])
        df.loc[needs_imputation, 'longitude'] = impute_spatial.map(ref_data['centroid_lon'])

        imputed_mask = needs_imputation & df['latitude'].notna()
        df.loc[imputed_mask, res_col] = clean_spatial[imputed_mask].map(ref_data['spatial_resolution'])

        # Remove unused categories if it was a CategoricalDtype
        if isinstance(df[res_col].dtype, pd.CategoricalDtype):
             df[res_col] = df[res_col].cat.remove_unused_categories()

        df['latitude'] = df['latitude'].astype("Float64")
        df['longitude'] = df['longitude'].astype("Float64")

        return df

    def reverse_geocode(self, geojson_path: Union[str, Path] = None, target_spatial_name: str = 'Country') -> pd.DataFrame:
        """
        Performs reverse geocoding to determine spatial locality from coordinates.

        Args:
            geojson_path: Optional path to a GeoJSON file containing boundary polygons.
                Defaults to the built-in world_boundaries.geojson.
            target_spatial_name: The name to append to the spatial domain prefix.

        Returns:
            A new DataFrame with updated 'spatial' information.
        """
        if geojson_path is None:
            geojson_path = Path(__file__).parent / "data" / "world_boundaries.geojson"

        if not Path(geojson_path).exists():
            return self._obj.copy()

        with open(geojson_path, 'r', encoding='utf-8') as f:
            feature_collection = json.load(f)

        features = [f for f in feature_collection.get('features', []) if f.get('geometry')]

        # 1. Fast C-level GeoJSON geometry parsing
        geom_strings = [json.dumps(f['geometry']) for f in features]
        polygons = from_geojson(geom_strings)
        country_names = np.array([f['properties'].get('ADMIN', 'Unknown') for f in features])

        df = self._obj.copy()
        exact_mask = df['latitude'].notna() & df['longitude'].notna()

        if not exact_mask.any():
            return df

        # OPTIMIZATION: Reverse geocode unique coordinates using an STRtree Spatial Index
        unique_coords = df.loc[exact_mask, ['latitude', 'longitude']].drop_duplicates()

        # 2. Vectorized Point creation (astype(float) prevents Pandas Float64 extension errors)
        pts = points(unique_coords['longitude'].astype(float).values, unique_coords['latitude'].astype(float).values)

        # 3. R-Tree spatial index for lightning-fast Point-in-Polygon queries
        tree = STRtree(polygons)
        pt_idx, poly_idx = tree.query(pts, predicate='intersects')

        # 4. Resolve border overlaps by keeping the first matched polygon per point
        unique_pt_idx, unique_indices = np.unique(pt_idx, return_index=True)

        # Map back the names
        country_results = np.full(len(unique_coords), pd.NA, dtype=object)
        country_results[unique_pt_idx] = country_names[poly_idx[unique_indices]]
        unique_coords['country_name'] = country_results

        # Merge back into the main DataFrame
        df = df.merge(unique_coords, on=['latitude', 'longitude'], how='left')

        new_col = f"{Domain.SPATIAL.value}_{target_spatial_name}"
        res_col = f"{Domain.SPATIAL_RES.value}_{target_spatial_name}"

        df['latitude'] = df['latitude'].astype("Float64")
        df['longitude'] = df['longitude'].astype("Float64")

        return df

gazetteer property

gazetteer: DataFrame

Returns the internal gazetteer used for coordinate imputation.

spatial property

spatial: DataFrame

Returns a DataFrame of all spatial columns, with the prefix removed.

spatial_resolution property

spatial_resolution: DataFrame

Returns a DataFrame of all spatial resolution columns.

reverse_geocode

reverse_geocode(geojson_path: Union[str, Path] = None, target_spatial_name: str = 'Country') -> pd.DataFrame

Performs reverse geocoding to determine spatial locality from coordinates.

Parameters:

Name Type Description Default
geojson_path Union[str, Path]

Optional path to a GeoJSON file containing boundary polygons. Defaults to the built-in world_boundaries.geojson.

None
target_spatial_name str

The name to append to the spatial domain prefix.

'Country'

Returns:

Type Description
DataFrame

A new DataFrame with updated 'spatial' information.

Source code in src/seroepi/accessors.py
def reverse_geocode(self, geojson_path: Union[str, Path] = None, target_spatial_name: str = 'Country') -> pd.DataFrame:
    """
    Performs reverse geocoding to determine spatial locality from coordinates.

    Args:
        geojson_path: Optional path to a GeoJSON file containing boundary polygons.
            Defaults to the built-in world_boundaries.geojson.
        target_spatial_name: The name to append to the spatial domain prefix.

    Returns:
        A new DataFrame with updated 'spatial' information.
    """
    if geojson_path is None:
        geojson_path = Path(__file__).parent / "data" / "world_boundaries.geojson"

    if not Path(geojson_path).exists():
        return self._obj.copy()

    with open(geojson_path, 'r', encoding='utf-8') as f:
        feature_collection = json.load(f)

    features = [f for f in feature_collection.get('features', []) if f.get('geometry')]

    # 1. Fast C-level GeoJSON geometry parsing
    geom_strings = [json.dumps(f['geometry']) for f in features]
    polygons = from_geojson(geom_strings)
    country_names = np.array([f['properties'].get('ADMIN', 'Unknown') for f in features])

    df = self._obj.copy()
    exact_mask = df['latitude'].notna() & df['longitude'].notna()

    if not exact_mask.any():
        return df

    # OPTIMIZATION: Reverse geocode unique coordinates using an STRtree Spatial Index
    unique_coords = df.loc[exact_mask, ['latitude', 'longitude']].drop_duplicates()

    # 2. Vectorized Point creation (astype(float) prevents Pandas Float64 extension errors)
    pts = points(unique_coords['longitude'].astype(float).values, unique_coords['latitude'].astype(float).values)

    # 3. R-Tree spatial index for lightning-fast Point-in-Polygon queries
    tree = STRtree(polygons)
    pt_idx, poly_idx = tree.query(pts, predicate='intersects')

    # 4. Resolve border overlaps by keeping the first matched polygon per point
    unique_pt_idx, unique_indices = np.unique(pt_idx, return_index=True)

    # Map back the names
    country_results = np.full(len(unique_coords), pd.NA, dtype=object)
    country_results[unique_pt_idx] = country_names[poly_idx[unique_indices]]
    unique_coords['country_name'] = country_results

    # Merge back into the main DataFrame
    df = df.merge(unique_coords, on=['latitude', 'longitude'], how='left')

    new_col = f"{Domain.SPATIAL.value}_{target_spatial_name}"
    res_col = f"{Domain.SPATIAL_RES.value}_{target_spatial_name}"

    df['latitude'] = df['latitude'].astype("Float64")
    df['longitude'] = df['longitude'].astype("Float64")

    return df

standardize_and_impute

standardize_and_impute(spatial_col: str = None) -> pd.DataFrame

Standardizes spatial names and imputes missing coordinates.

Uses the internal gazetteer to find centroids for countries when exact latitude and longitude are missing.

Parameters:

Name Type Description Default
spatial_col str

Optional specific spatial column to impute by. Defaults to the first mapped spatial column.

None

Returns:

Type Description
DataFrame

A new DataFrame with imputed coordinates and spatial resolution metadata.

Source code in src/seroepi/accessors.py
def standardize_and_impute(self, spatial_col: str = None) -> pd.DataFrame:
    """
    Standardizes spatial names and imputes missing coordinates.

    Uses the internal gazetteer to find centroids for countries when exact
    latitude and longitude are missing.

    Args:
        spatial_col: Optional specific spatial column to impute by. Defaults to the first mapped spatial column.

    Returns:
        A new DataFrame with imputed coordinates and spatial resolution metadata.
    """
    df = self._obj.copy()
    ref_data = self.gazetteer

    spatial_cols = df.filter(regex=f"^{Domain.SPATIAL.value}_(?!res_)").columns.tolist()
    if not spatial_cols:
        return df

    if spatial_col is None:
        spatial_col = spatial_cols[0]
    elif spatial_col not in df.columns and f"{Domain.SPATIAL.value}_{spatial_col}" in df.columns:
        spatial_col = f"{Domain.SPATIAL.value}_{spatial_col}"

    res_col = spatial_col.replace(f"{Domain.SPATIAL.value}_", f"{Domain.SPATIAL_RES.value}_")

    # Initialize tracking
    if res_col not in df.columns:
        df[res_col] = SpatialResolution.UNKNOWN.value

    # If it's a category, we need to ensure all possible enum values are in the categories before assigning
    if isinstance(df[res_col].dtype, pd.CategoricalDtype):
        missing_cats = [c for c in SpatialResolution.choices() if c not in df[res_col].cat.categories]
        if missing_cats:
            df[res_col] = df[res_col].cat.add_categories(missing_cats)

    exact_mask = df['latitude'].notna() & df['longitude'].notna()
    df.loc[exact_mask, res_col] = SpatialResolution.EXACT.value

    # Impute using the instant dictionary-backed dataframe
    clean_spatial = df[spatial_col].str.strip()
    needs_imputation = (~exact_mask) & clean_spatial.notna()

    # OPTIMIZATION: Only map the rows that actually need imputation
    impute_spatial = clean_spatial[needs_imputation]
    df.loc[needs_imputation, 'latitude'] = impute_spatial.map(ref_data['centroid_lat'])
    df.loc[needs_imputation, 'longitude'] = impute_spatial.map(ref_data['centroid_lon'])

    imputed_mask = needs_imputation & df['latitude'].notna()
    df.loc[imputed_mask, res_col] = clean_spatial[imputed_mask].map(ref_data['spatial_resolution'])

    # Remove unused categories if it was a CategoricalDtype
    if isinstance(df[res_col].dtype, pd.CategoricalDtype):
         df[res_col] = df[res_col].cat.remove_unused_categories()

    df['latitude'] = df['latitude'].astype("Float64")
    df['longitude'] = df['longitude'].astype("Float64")

    return df

QCAccessor

Pandas accessor for quality control operations.

Provides methods for filtering assemblies based on metrics like N50 and contig count.

Examples:

>>> import pandas as pd
>>> import seroepi.accessors
>>> df = pd.DataFrame({'qc_N50': [50000, 5000]})
>>> clean_df = df.qc.filter_assemblies(min_n50=10000)
Source code in src/seroepi/accessors.py
@pd.api.extensions.register_dataframe_accessor("qc")
class QCAccessor:
    """
    Pandas accessor for quality control operations.

    Provides methods for filtering assemblies based on metrics like N50 and
    contig count.

    Examples:
        >>> import pandas as pd
        >>> import seroepi.accessors
        >>> df = pd.DataFrame({'qc_N50': [50000, 5000]})
        >>> clean_df = df.qc.filter_assemblies(min_n50=10000)
    """
    def __init__(self, pandas_obj: pd.DataFrame):
        self._obj = pandas_obj

    @property
    def metrics(self) -> pd.DataFrame:
        """Returns the QC metrics matrix with the prefix removed."""
        return self._obj.filter(regex=f"^{Domain.QC.value}_").rename(columns=lambda c: c.replace(f'{Domain.QC.value}_', '', 1))

    def filter_assemblies(self, min_n50: int = 10000, max_contigs: int = 500,
                          require_species: str = None) -> pd.DataFrame:
        """
        Filters genomes based on quality thresholds.

        Args:
            min_n50: Minimum N50 value. Defaults to 10000.
            max_contigs: Maximum number of contigs. Defaults to 500.
            require_species: Optional species name to filter for.

        Returns:
            A filtered copy of the DataFrame.
        """
        df = self._obj

        masks = []

        if f'{Domain.QC.value}_N50' in df.columns:
            # Coerce to numeric in case Kleborate spat out weird strings like '-'
            n50 = pd.to_numeric(df[f'{Domain.QC.value}_N50'], errors='coerce')
            masks.append((n50 >= min_n50) | n50.isna())

        if f'{Domain.QC.value}_contig_count' in df.columns:
            contigs = pd.to_numeric(df[f'{Domain.QC.value}_contig_count'], errors='coerce')
            masks.append((contigs <= max_contigs) | contigs.isna())

        if require_species and f'{Domain.QC.value}_species' in df.columns:
            masks.append(df[f'{Domain.QC.value}_species'].str.contains(require_species, case=False, na=False))

        if not masks:
            return df.copy()

        # Efficiently reduce the conditions using numpy logic
        final_mask = np.logical_and.reduce(masks)
        return df[final_mask].copy()

    def report(self) -> pd.Series:
        """
        Generates a summary report of dataset quality.

        Returns:
            A Series containing summary metrics (e.g., Total Warnings, Median N50).
        """
        metrics = self.metrics
        report = {}
        if 'QC_warnings' in metrics.columns:
            report['Total Warnings'] = (metrics['QC_warnings'] != '-').sum()
        if 'N50' in metrics.columns:
            report['Median N50'] = pd.to_numeric(metrics['N50'], errors='coerce').median()

        return pd.Series(report)

metrics property

metrics: DataFrame

Returns the QC metrics matrix with the prefix removed.

filter_assemblies

filter_assemblies(min_n50: int = 10000, max_contigs: int = 500, require_species: str = None) -> pd.DataFrame

Filters genomes based on quality thresholds.

Parameters:

Name Type Description Default
min_n50 int

Minimum N50 value. Defaults to 10000.

10000
max_contigs int

Maximum number of contigs. Defaults to 500.

500
require_species str

Optional species name to filter for.

None

Returns:

Type Description
DataFrame

A filtered copy of the DataFrame.

Source code in src/seroepi/accessors.py
def filter_assemblies(self, min_n50: int = 10000, max_contigs: int = 500,
                      require_species: str = None) -> pd.DataFrame:
    """
    Filters genomes based on quality thresholds.

    Args:
        min_n50: Minimum N50 value. Defaults to 10000.
        max_contigs: Maximum number of contigs. Defaults to 500.
        require_species: Optional species name to filter for.

    Returns:
        A filtered copy of the DataFrame.
    """
    df = self._obj

    masks = []

    if f'{Domain.QC.value}_N50' in df.columns:
        # Coerce to numeric in case Kleborate spat out weird strings like '-'
        n50 = pd.to_numeric(df[f'{Domain.QC.value}_N50'], errors='coerce')
        masks.append((n50 >= min_n50) | n50.isna())

    if f'{Domain.QC.value}_contig_count' in df.columns:
        contigs = pd.to_numeric(df[f'{Domain.QC.value}_contig_count'], errors='coerce')
        masks.append((contigs <= max_contigs) | contigs.isna())

    if require_species and f'{Domain.QC.value}_species' in df.columns:
        masks.append(df[f'{Domain.QC.value}_species'].str.contains(require_species, case=False, na=False))

    if not masks:
        return df.copy()

    # Efficiently reduce the conditions using numpy logic
    final_mask = np.logical_and.reduce(masks)
    return df[final_mask].copy()

report

report() -> pd.Series

Generates a summary report of dataset quality.

Returns:

Type Description
Series

A Series containing summary metrics (e.g., Total Warnings, Median N50).

Source code in src/seroepi/accessors.py
def report(self) -> pd.Series:
    """
    Generates a summary report of dataset quality.

    Returns:
        A Series containing summary metrics (e.g., Total Warnings, Median N50).
    """
    metrics = self.metrics
    report = {}
    if 'QC_warnings' in metrics.columns:
        report['Total Warnings'] = (metrics['QC_warnings'] != '-').sum()
    if 'N50' in metrics.columns:
        report['Median N50'] = pd.to_numeric(metrics['N50'], errors='coerce').median()

    return pd.Series(report)

seroepi.formulation

Module for abstracting a vaccine _formulation using trait prevalence and stability.

BaseFormulationDesigner

Bases: ModelledMixin, ABC

Abstract base class for _formulation designers.

Designers are responsible for evaluating prevalence estimates and generating a vaccine _formulation with stability metrics.

Attributes:

Name Type Description
valency

The target valency (number of targets) for the vaccine.

n_jobs

The number of CPU cores to use for processing (-1 for all).

formulation_ Optional[Formulation]

The resulting Formulation object after fitting.

Source code in src/seroepi/formulation.py
class BaseFormulationDesigner(ModelledMixin, ABC):
    """
    Abstract base class for _formulation designers.

    Designers are responsible for evaluating prevalence estimates and generating
    a vaccine _formulation with stability metrics.

    Attributes:
        valency: The target valency (number of targets) for the vaccine.
        n_jobs: The number of CPU cores to use for processing (-1 for all).
        formulation_: The resulting Formulation object after fitting.
    """

    def __init__(self, valency: int = 6, n_jobs: int = -1):
        """
        Initializes the designer.

        Args:
            valency: The target valency. Defaults to 6.
            n_jobs: Number of concurrent workers. Defaults to -1 (all available).
        """
        self.valency = valency
        self.n_jobs = n_jobs
        self.formulation_: Optional[Formulation] = None

    def fit(self, *args, progress_callback: Optional[Callable[[int, int], None]] = None, **kwargs) -> 'BaseFormulationDesigner':
        """Calculates the _formulation and stores it in self.formulation_"""
        pass

    def predict(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        Uses the fitted _formulation to predict vaccine coverage on a given DataFrame.
        Returns only the rows that are covered by the designed _formulation.
        """
        self.check_is_fitted()
        trait_name = self.formulation_.trait
        if trait_name not in df.columns:
            raise KeyError(f"The trait column '{trait_name}' was not found in the provided DataFrame.")

        covered_targets = self.formulation_.get_formulation()
        return df[df[trait_name].isin(covered_targets)].copy()

__init__

__init__(valency: int = 6, n_jobs: int = -1)

Initializes the designer.

Parameters:

Name Type Description Default
valency int

The target valency. Defaults to 6.

6
n_jobs int

Number of concurrent workers. Defaults to -1 (all available).

-1
Source code in src/seroepi/formulation.py
def __init__(self, valency: int = 6, n_jobs: int = -1):
    """
    Initializes the designer.

    Args:
        valency: The target valency. Defaults to 6.
        n_jobs: Number of concurrent workers. Defaults to -1 (all available).
    """
    self.valency = valency
    self.n_jobs = n_jobs
    self.formulation_: Optional[Formulation] = None

fit

fit(*args, progress_callback: Optional[Callable[[int, int], None]] = None, **kwargs) -> BaseFormulationDesigner

Calculates the formulation and stores it in self.formulation

Source code in src/seroepi/formulation.py
def fit(self, *args, progress_callback: Optional[Callable[[int, int], None]] = None, **kwargs) -> 'BaseFormulationDesigner':
    """Calculates the _formulation and stores it in self.formulation_"""
    pass

predict

predict(df: DataFrame) -> pd.DataFrame

Uses the fitted _formulation to predict vaccine coverage on a given DataFrame. Returns only the rows that are covered by the designed _formulation.

Source code in src/seroepi/formulation.py
def predict(self, df: pd.DataFrame) -> pd.DataFrame:
    """
    Uses the fitted _formulation to predict vaccine coverage on a given DataFrame.
    Returns only the rows that are covered by the designed _formulation.
    """
    self.check_is_fitted()
    trait_name = self.formulation_.trait
    if trait_name not in df.columns:
        raise KeyError(f"The trait column '{trait_name}' was not found in the provided DataFrame.")

    covered_targets = self.formulation_.get_formulation()
    return df[df[trait_name].isin(covered_targets)].copy()

CVFormulationDesigner

Bases: BaseFormulationDesigner

Rigorous _formulation design using true Leave-One-Out (LOO) cross-validation.

This method retrains the model for each LOO permutation, which is more computationally expensive but necessary for complex models.

Source code in src/seroepi/formulation.py
class CVFormulationDesigner(BaseFormulationDesigner):
    """
    Rigorous _formulation design using true Leave-One-Out (LOO) cross-validation.

    This method retrains the model for each LOO permutation, which is more
    computationally expensive but necessary for complex models.
    """
    def fit(self, estimator: BaseEstimator, agg_df: pd.DataFrame, loo_col: str,
            progress_callback: Optional[Callable[[int, int], None]] = None) -> 'CVFormulationDesigner':
        """
        Evaluates an estimator using LOO cross-validation to design a _formulation.

        Args:
            estimator: The estimator instance to use.
            agg_df: The aggregated data for the estimator.
            loo_col: The column name to use for Leave-One-Out cross-validation.

        Returns:
            The fitted designer instance.
        """
        # 1. Baseline
        baseline_result = _clone_estimator(estimator).calculate(agg_df)
        trait_name = baseline_result.trait
        baseline = _extract_ranks(baseline_result.data, 'baseline_rank')

        # 2. Permutations (Parallel Processing)
        groups = agg_df[loo_col].unique()
        total_groups = len(groups)

        jobs = (delayed(_run_cv_fold)(estimator, agg_df, loo_col, group) for group in groups)

        loo_records = []
        try:
            with Parallel(n_jobs=self.n_jobs, return_as="generator") as parallel:
                for i, result in enumerate(parallel(jobs), 1):
                    loo_records.append(result)
                    if progress_callback:
                        progress_callback(i, total_groups)
        except TypeError:
            # Fallback for joblib < 1.3 where return_as="generator" isn't supported
            with Parallel(n_jobs=self.n_jobs) as parallel:
                loo_records = parallel(jobs)
                if progress_callback:
                    progress_callback(total_groups, total_groups)

        history = pd.concat(loo_records, ignore_index=True)

        # 3. Compile
        self.formulation_ = _compile_stability_metrics(baseline, history, trait_name, self.valency)
        self.is_fitted_ = True
        return self

fit

fit(estimator: BaseEstimator, agg_df: DataFrame, loo_col: str, progress_callback: Optional[Callable[[int, int], None]] = None) -> CVFormulationDesigner

Evaluates an estimator using LOO cross-validation to design a _formulation.

Parameters:

Name Type Description Default
estimator BaseEstimator

The estimator instance to use.

required
agg_df DataFrame

The aggregated data for the estimator.

required
loo_col str

The column name to use for Leave-One-Out cross-validation.

required

Returns:

Type Description
CVFormulationDesigner

The fitted designer instance.

Source code in src/seroepi/formulation.py
def fit(self, estimator: BaseEstimator, agg_df: pd.DataFrame, loo_col: str,
        progress_callback: Optional[Callable[[int, int], None]] = None) -> 'CVFormulationDesigner':
    """
    Evaluates an estimator using LOO cross-validation to design a _formulation.

    Args:
        estimator: The estimator instance to use.
        agg_df: The aggregated data for the estimator.
        loo_col: The column name to use for Leave-One-Out cross-validation.

    Returns:
        The fitted designer instance.
    """
    # 1. Baseline
    baseline_result = _clone_estimator(estimator).calculate(agg_df)
    trait_name = baseline_result.trait
    baseline = _extract_ranks(baseline_result.data, 'baseline_rank')

    # 2. Permutations (Parallel Processing)
    groups = agg_df[loo_col].unique()
    total_groups = len(groups)

    jobs = (delayed(_run_cv_fold)(estimator, agg_df, loo_col, group) for group in groups)

    loo_records = []
    try:
        with Parallel(n_jobs=self.n_jobs, return_as="generator") as parallel:
            for i, result in enumerate(parallel(jobs), 1):
                loo_records.append(result)
                if progress_callback:
                    progress_callback(i, total_groups)
    except TypeError:
        # Fallback for joblib < 1.3 where return_as="generator" isn't supported
        with Parallel(n_jobs=self.n_jobs) as parallel:
            loo_records = parallel(jobs)
            if progress_callback:
                progress_callback(total_groups, total_groups)

    history = pd.concat(loo_records, ignore_index=True)

    # 3. Compile
    self.formulation_ = _compile_stability_metrics(baseline, history, trait_name, self.valency)
    self.is_fitted_ = True
    return self

Formulation dataclass

Represents a proposed vaccine _formulation based on target prevalence and stability.

This class holds the results of a _formulation design process, including rankings, stability metrics from cross-validation, and permutation history.

Attributes:

Name Type Description
trait str

The trait type (e.g., 'K_locus').

max_valency int

The maximum number of targets in the _formulation.

rankings DataFrame

A DataFrame containing the definitive ranking of targets (Trait, Rank, Prevalence, Cumulative Coverage).

stability_metrics DataFrame

A DataFrame containing metrics from LOO stability analysis (e.g., Mean LOO Rank, Rank Variance, Probability in Top N).

permutation_history DataFrame

A DataFrame containing the full history of ranks across all LOO permutations.

Examples:

>>> import pandas as pd
>>> from seroepi.formulation import Formulation
>>> rankings = pd.DataFrame({'trait': ['K1', 'K2'], 'estimate': [0.5, 0.3]})
>>> _formulation = Formulation(
...     trait='K_locus',
...     max_valency=2,
...     rankings=rankings,
...     stability_metrics=pd.DataFrame(),
...     permutation_history=pd.DataFrame()
... )
>>> print(_formulation.get_formulation())
['K1', 'K2']
Source code in src/seroepi/formulation.py
@dataclass(frozen=True, slots=True)
class Formulation:
    """
    Represents a proposed vaccine _formulation based on target prevalence and stability.

    This class holds the results of a _formulation design process, including rankings,
    stability metrics from cross-validation, and permutation history.

    Attributes:
        trait: The trait type (e.g., 'K_locus').
        max_valency: The maximum number of targets in the _formulation.
        rankings: A DataFrame containing the definitive ranking of targets
            (Trait, Rank, Prevalence, Cumulative Coverage).
        stability_metrics: A DataFrame containing metrics from LOO stability analysis
            (e.g., Mean LOO Rank, Rank Variance, Probability in Top N).
        permutation_history: A DataFrame containing the full history of ranks
            across all LOO permutations.

    Examples:
        >>> import pandas as pd
        >>> from seroepi.formulation import Formulation
        >>> rankings = pd.DataFrame({'trait': ['K1', 'K2'], 'estimate': [0.5, 0.3]})
        >>> _formulation = Formulation(
        ...     trait='K_locus',
        ...     max_valency=2,
        ...     rankings=rankings,
        ...     stability_metrics=pd.DataFrame(),
        ...     permutation_history=pd.DataFrame()
        ... )
        >>> print(_formulation.get_formulation())
        ['K1', 'K2']
    """
    trait: str  # e.g., 'K_locus'
    max_valency: int  # e.g., 6 (for a hexavalent vaccine)
    # The definitive ranking matrix (Antigen, Rank, Prevalence, Cumulative Coverage)
    rankings: pd.DataFrame
    # Leave-One-Out (LOO) Stability Data
    # Rows = Antigens, Cols = 'Rank Variance', 'Prob in Top N', etc.
    stability_metrics: pd.DataFrame
    # Full history of ranks across all LOO permutations (used for plotting)
    permutation_history: pd.DataFrame

    @classmethod
    def from_custom(
            cls,
            custom_targets: list[str],
            baseline_result: 'PrevalenceEstimates'
    ) -> 'Formulation':
        """
        Creates a custom Formulation from a user-defined list of targets.
        Calculates the baseline coverage for these specific targets.
        """
        trait_name = baseline_result.trait
        raw_df = baseline_result.data

        # 1. Calculate the true baseline prevalence for everything
        baseline = raw_df.groupby('target')['estimate'].sum().sort_values(ascending=False).reset_index()

        # 2. Filter and reorder the baseline to match the user's custom list exactly
        # We use pd.Categorical to ensure the dataframe retains the exact order the user requested
        baseline['target_cat'] = pd.Categorical(baseline['target'], categories=custom_targets, ordered=True)
        custom_rankings = baseline.dropna(subset=['target_cat']).sort_values('target_cat').drop(
            columns=['target_cat'])

        # The new rank is simply the order they requested them in
        custom_rankings['baseline_rank'] = range(1, len(custom_targets) + 1)

        # 3. Create empty stability metrics (since this is a manual override, not a LOO calculation)
        empty_df = pd.DataFrame()

        return cls(
            trait=trait_name,
            max_valency=len(custom_targets),
            rankings=custom_rankings,
            stability_metrics=empty_df,
            permutation_history=empty_df
        )

    def save(self, filepath: Union[str, Path]) -> None:
        """Serializes the Formulation instance to disk."""
        path = Path(filepath)
        path.parent.mkdir(parents=True, exist_ok=True)
        joblib_dump(self, path)

    @classmethod
    def load(cls: Type['Formulation'], filepath: Union[str, Path]) -> 'Formulation':
        """Loads a serialized Formulation from disk."""
        path = Path(filepath)
        if not path.exists():
            raise FileNotFoundError(f"No formulation found at {path}")
        formulation = joblib_load(path)
        if not isinstance(formulation, cls):
            raise TypeError(f"Type mismatch: Expected {cls.__name__}, got {type(formulation).__name__}.")
        return formulation

    def get_formulation(self) -> list[str]:
        """
        Returns the top N targets for the proposed vaccine.

        Returns:
            A list of target names.
        """
        return self.rankings.head(self.max_valency)['target'].tolist()

    def evaluate_longevity(self, forecast: 'IncidenceEstimates') -> pd.DataFrame:
        """
        Evaluates the formulation against a time-series incidence forecast to
        determine its historical and projected longevity.

        Returns a DataFrame tracking the absolute case burden and the percentage
        of that burden covered by this formulation over time.
        """
        df = forecast.data.copy()

        # 1. Extract the specific antigens in this vaccine
        targets = self.get_formulation()

        # 2. Calculate the total expected cases across ALL circulating strains per time step
        total_cases = df.groupby('date')['estimate'].sum().rename('total_cases')

        # 3. Calculate the expected cases caused ONLY by strains in our vaccine
        covered_df = df[df['target'].isin(targets)]
        covered_cases = covered_df.groupby('date')['estimate'].sum().rename('covered_cases')

        # 4. Merge and calculate the moving coverage percentage
        longevity = pd.merge(total_cases, covered_cases, left_index=True, right_index=True, how='left').fillna(0)

        # Safe division to prevent NaNs if a specific month has 0 total cases projected
        longevity['coverage_pct'] = np.where(
            longevity['total_cases'] > 0,
            (longevity['covered_cases'] / longevity['total_cases']) * 100,
            0.0
        )

        return longevity.reset_index()

evaluate_longevity

evaluate_longevity(forecast: IncidenceEstimates) -> pd.DataFrame

Evaluates the formulation against a time-series incidence forecast to determine its historical and projected longevity.

Returns a DataFrame tracking the absolute case burden and the percentage of that burden covered by this formulation over time.

Source code in src/seroepi/formulation.py
def evaluate_longevity(self, forecast: 'IncidenceEstimates') -> pd.DataFrame:
    """
    Evaluates the formulation against a time-series incidence forecast to
    determine its historical and projected longevity.

    Returns a DataFrame tracking the absolute case burden and the percentage
    of that burden covered by this formulation over time.
    """
    df = forecast.data.copy()

    # 1. Extract the specific antigens in this vaccine
    targets = self.get_formulation()

    # 2. Calculate the total expected cases across ALL circulating strains per time step
    total_cases = df.groupby('date')['estimate'].sum().rename('total_cases')

    # 3. Calculate the expected cases caused ONLY by strains in our vaccine
    covered_df = df[df['target'].isin(targets)]
    covered_cases = covered_df.groupby('date')['estimate'].sum().rename('covered_cases')

    # 4. Merge and calculate the moving coverage percentage
    longevity = pd.merge(total_cases, covered_cases, left_index=True, right_index=True, how='left').fillna(0)

    # Safe division to prevent NaNs if a specific month has 0 total cases projected
    longevity['coverage_pct'] = np.where(
        longevity['total_cases'] > 0,
        (longevity['covered_cases'] / longevity['total_cases']) * 100,
        0.0
    )

    return longevity.reset_index()

from_custom classmethod

from_custom(custom_targets: list[str], baseline_result: PrevalenceEstimates) -> Formulation

Creates a custom Formulation from a user-defined list of targets. Calculates the baseline coverage for these specific targets.

Source code in src/seroepi/formulation.py
@classmethod
def from_custom(
        cls,
        custom_targets: list[str],
        baseline_result: 'PrevalenceEstimates'
) -> 'Formulation':
    """
    Creates a custom Formulation from a user-defined list of targets.
    Calculates the baseline coverage for these specific targets.
    """
    trait_name = baseline_result.trait
    raw_df = baseline_result.data

    # 1. Calculate the true baseline prevalence for everything
    baseline = raw_df.groupby('target')['estimate'].sum().sort_values(ascending=False).reset_index()

    # 2. Filter and reorder the baseline to match the user's custom list exactly
    # We use pd.Categorical to ensure the dataframe retains the exact order the user requested
    baseline['target_cat'] = pd.Categorical(baseline['target'], categories=custom_targets, ordered=True)
    custom_rankings = baseline.dropna(subset=['target_cat']).sort_values('target_cat').drop(
        columns=['target_cat'])

    # The new rank is simply the order they requested them in
    custom_rankings['baseline_rank'] = range(1, len(custom_targets) + 1)

    # 3. Create empty stability metrics (since this is a manual override, not a LOO calculation)
    empty_df = pd.DataFrame()

    return cls(
        trait=trait_name,
        max_valency=len(custom_targets),
        rankings=custom_rankings,
        stability_metrics=empty_df,
        permutation_history=empty_df
    )

get_formulation

get_formulation() -> list[str]

Returns the top N targets for the proposed vaccine.

Returns:

Type Description
list[str]

A list of target names.

Source code in src/seroepi/formulation.py
def get_formulation(self) -> list[str]:
    """
    Returns the top N targets for the proposed vaccine.

    Returns:
        A list of target names.
    """
    return self.rankings.head(self.max_valency)['target'].tolist()

load classmethod

load(filepath: Union[str, Path]) -> Formulation

Loads a serialized Formulation from disk.

Source code in src/seroepi/formulation.py
@classmethod
def load(cls: Type['Formulation'], filepath: Union[str, Path]) -> 'Formulation':
    """Loads a serialized Formulation from disk."""
    path = Path(filepath)
    if not path.exists():
        raise FileNotFoundError(f"No formulation found at {path}")
    formulation = joblib_load(path)
    if not isinstance(formulation, cls):
        raise TypeError(f"Type mismatch: Expected {cls.__name__}, got {type(formulation).__name__}.")
    return formulation

save

save(filepath: Union[str, Path]) -> None

Serializes the Formulation instance to disk.

Source code in src/seroepi/formulation.py
def save(self, filepath: Union[str, Path]) -> None:
    """Serializes the Formulation instance to disk."""
    path = Path(filepath)
    path.parent.mkdir(parents=True, exist_ok=True)
    joblib_dump(self, path)

PostHocFormulationDesigner

Bases: BaseFormulationDesigner

Fast formulation design using post-hoc estimation.

This method computes stability exactly for Frequentist estimates where retraining is not required for Leave-One-Out (LOO) analysis. For complex modelled estimates (e.g., Bayesian, Spatial), it can be used as a fast, linear approximation of stability (ignoring non-linear shrinkage and spatial correlation).

Source code in src/seroepi/formulation.py
class PostHocFormulationDesigner(BaseFormulationDesigner):
    """
    Fast formulation design using post-hoc estimation.

    This method computes stability exactly for Frequentist estimates where retraining
    is not required for Leave-One-Out (LOO) analysis. For complex modelled estimates 
    (e.g., Bayesian, Spatial), it can be used as a fast, linear approximation 
    of stability (ignoring non-linear shrinkage and spatial correlation).
    """

    def fit(self, result: 'PrevalenceEstimates', loo_col: str,
            progress_callback: Optional[Callable[[int, int], None]] = None) -> 'PostHocFormulationDesigner':
        """
        Evaluates prevalence results to design a _formulation.

        Args:
            result: The prevalence estimates to evaluate.
            loo_col: The column name to use for Leave-One-Out cross-validation.

        Returns:
            The fitted designer instance.
        """
        raw_df = result.data
        trait_name = result.trait

        # 1. Baseline
        baseline = _extract_ranks(raw_df, 'baseline_rank')

        # 2. Permutations (Vectorized O(N) Subtraction)
        # Pre-calculate global sums and the individual group sums
        total_estimates = raw_df.groupby('target')['estimate'].sum()
        group_target_estimates = raw_df.groupby([loo_col, 'target'])['estimate'].sum().unstack(fill_value=0)

        unique_groups = raw_df[loo_col].unique()
        total_groups = len(unique_groups)
        loo_records = []
        for i, group in enumerate(unique_groups, 1):
            # Subtract the holdout group's contribution from the global total
            if group in group_target_estimates.index:
                loo_estimates = total_estimates - group_target_estimates.loc[group]
            else:
                loo_estimates = total_estimates.copy()

            loo_ranks = loo_estimates.sort_values(ascending=False).reset_index()
            loo_ranks.columns = ['target', 'estimate']
            loo_ranks['loo_rank'] = loo_ranks.index + 1
            loo_ranks['holdout_group'] = group
            loo_records.append(loo_ranks)

            if progress_callback:
                progress_callback(i, total_groups)

        history = pd.concat(loo_records, ignore_index=True)

        # 3. Compile
        self.formulation_ = _compile_stability_metrics(baseline, history, trait_name, self.valency)
        self.is_fitted_ = True
        return self

fit

fit(result: PrevalenceEstimates, loo_col: str, progress_callback: Optional[Callable[[int, int], None]] = None) -> PostHocFormulationDesigner

Evaluates prevalence results to design a _formulation.

Parameters:

Name Type Description Default
result PrevalenceEstimates

The prevalence estimates to evaluate.

required
loo_col str

The column name to use for Leave-One-Out cross-validation.

required

Returns:

Type Description
PostHocFormulationDesigner

The fitted designer instance.

Source code in src/seroepi/formulation.py
def fit(self, result: 'PrevalenceEstimates', loo_col: str,
        progress_callback: Optional[Callable[[int, int], None]] = None) -> 'PostHocFormulationDesigner':
    """
    Evaluates prevalence results to design a _formulation.

    Args:
        result: The prevalence estimates to evaluate.
        loo_col: The column name to use for Leave-One-Out cross-validation.

    Returns:
        The fitted designer instance.
    """
    raw_df = result.data
    trait_name = result.trait

    # 1. Baseline
    baseline = _extract_ranks(raw_df, 'baseline_rank')

    # 2. Permutations (Vectorized O(N) Subtraction)
    # Pre-calculate global sums and the individual group sums
    total_estimates = raw_df.groupby('target')['estimate'].sum()
    group_target_estimates = raw_df.groupby([loo_col, 'target'])['estimate'].sum().unstack(fill_value=0)

    unique_groups = raw_df[loo_col].unique()
    total_groups = len(unique_groups)
    loo_records = []
    for i, group in enumerate(unique_groups, 1):
        # Subtract the holdout group's contribution from the global total
        if group in group_target_estimates.index:
            loo_estimates = total_estimates - group_target_estimates.loc[group]
        else:
            loo_estimates = total_estimates.copy()

        loo_ranks = loo_estimates.sort_values(ascending=False).reset_index()
        loo_ranks.columns = ['target', 'estimate']
        loo_ranks['loo_rank'] = loo_ranks.index + 1
        loo_ranks['holdout_group'] = group
        loo_records.append(loo_ranks)

        if progress_callback:
            progress_callback(i, total_groups)

    history = pd.concat(loo_records, ignore_index=True)

    # 3. Compile
    self.formulation_ = _compile_stability_metrics(baseline, history, trait_name, self.valency)
    self.is_fitted_ = True
    return self

seroepi.estimators

Module for estimating trait prevalence, diversity and incidence among isolates.

AlphaDiversityEstimates dataclass

Bases: Estimates

Container for Alpha Diversity results.

Attributes:

Name Type Description
metrics list[str]

List of diversity metrics calculated (e.g., ['shannon', 'simpson']).

Source code in src/seroepi/estimators/_base.py
@dataclass(frozen=True, slots=True)
class AlphaDiversityEstimates(Estimates):
    """
    Container for Alpha Diversity results.

    Attributes:
        metrics: List of diversity metrics calculated (e.g., ['shannon', 'simpson']).
    """
    metrics: list[str]

AlphaDiversityEstimator

Bases: BaseEstimator[AlphaDiversityEstimates]

Source code in src/seroepi/estimators/_core.py
class AlphaDiversityEstimator(BaseEstimator[AlphaDiversityEstimates]):
    Metric = Literal['shannon', 'simpson', 'richness']
    _DEFAULT_METRICS = ['shannon', 'simpson', 'richness']
    def __init__(self, target: str = None, metrics: list[Metric] = None):
        self.target = target
        self.metrics = metrics or self._DEFAULT_METRICS
        self._method_label = "alpha_diversity"

    def get_params(self) -> dict:
        """Returns parameters for cloning compatibility."""
        return {'target': self.target, 'metrics': self.metrics}

    def calculate(self, div_df: pd.DataFrame) -> AlphaDiversityEstimates:
        # Extract the metadata attached by the accessor
        meta = div_df.attrs.get("metric_meta", {})
        target_col = meta.get("trait", self.target)
        strata = meta.get("stratified_by", [])

        if not target_col:
            raise ValueError("Target trait must be defined either in init or via accessor metadata.")

        results = []

        # If stratified, group by the strata. Otherwise, treat as one global group.
        groups = div_df.groupby(strata, observed=True) if strata else [('Global', div_df)]

        for name, group in groups:
            # We already have the counts! No need to run value_counts() again.
            counts = group['variant_count'].values

            # Filter out true zeroes (important for richness)
            counts = counts[counts > 0]
            if len(counts) == 0:
                continue

            p = counts / counts.sum()

            row = {k: v for k, v in zip(strata, name)} if strata and isinstance(name, tuple) else {}
            if strata and not isinstance(name, tuple):
                row = {strata[0]: name}

            if 'shannon' in self.metrics:
                row['shannon'] = entropy(p, base=np.e)
            if 'simpson' in self.metrics:
                row['simpson'] = 1.0 - np.sum(p ** 2)
            if 'richness' in self.metrics:
                row['richness'] = len(counts)

            row['n_samples'] = counts.sum()
            results.append(row)

        res_df = pd.DataFrame(results)

        return AlphaDiversityEstimates(
            data=res_df,
            stratified_by=strata,
            adjusted_for=meta.get("adjusted_for", 'unknown'),
            trait=target_col,
            aggregation_type=meta.get("aggregation_type", AggregationType.TRAIT),
            metrics=self.metrics
        )

get_params

get_params() -> dict

Returns parameters for cloning compatibility.

Source code in src/seroepi/estimators/_core.py
def get_params(self) -> dict:
    """Returns parameters for cloning compatibility."""
    return {'target': self.target, 'metrics': self.metrics}

BaseEstimator

Bases: ABC, Generic[T_Result]

The universal contract for all seroepi statistical models.

All prevalence, diversity, and incidence estimators must inherit from this class and implement the calculate method.

Source code in src/seroepi/estimators/_base.py
class BaseEstimator(ABC, Generic[T_Result]):
    """
    The universal contract for all seroepi statistical models.

    All prevalence, diversity, and incidence estimators must inherit from this
    class and implement the `calculate` method.
    """

    def _extract_strata(self, agg_df: pd.DataFrame, exclude_cols: list[str] = None) -> Tuple[list[str], dict]:
        """
        Extracts stratification columns and metadata from an aggregated DataFrame.

        Args:
            agg_df: The aggregated DataFrame.
            exclude_cols: Column names to exclude from stratification.

        Returns:
            A tuple containing the list of strata columns and the metadata dictionary.
        """
        if exclude_cols is None:
            exclude_cols = []

        meta = agg_df.attrs.get("metric_meta", {})
        inferred_strata = [col for col in agg_df.columns if col not in exclude_cols]
        stratified_by = meta.get("stratified_by", inferred_strata)
        self._validate_suitability(agg_df, stratified_by)
        return stratified_by, meta

    @staticmethod
    def _validate_suitability(df: pd.DataFrame, strata: list[str]):
        """
        Validates that the stratification columns are suitable for modeling.

        Checks for common statistical traps such as continuous variables being
        used as strata or over-stratification by unique identifiers.

        Args:
            df: The DataFrame to check.
            strata: The list of stratification columns.

        Raises:
            ValueError: If a stratum is a continuous float.
        """
        for col in strata:
            # 1. The Continuous Float Trap
            if pd.api.types.is_float_dtype(df[col]):
                raise ValueError(
                    f"Strata column '{col}' is a continuous float. "
                    "Prevalence estimators require discrete categorical groups. "
                    f"Please bin this variable (e.g., using pd.cut()) before aggregating."
                )

            # 2. The Raw Datetime Trap (Checks dtype, standard 'date' name, OR the 'temporal_' prefix)
            if pd.api.types.is_datetime64_any_dtype(df[col]) or 'date' in col.lower() or col.startswith(f"{Domain.TEMPORAL.value}_"):
                # We only warn here, because sometimes daily is intentional during a rapid outbreak
                warn(
                    f"You are stratifying on a raw date column '{col}'. "
                    "This will calculate prevalence for every single day. "
                    "Consider bucketing by month/year using .dt.to_period('M') first.",
                    UserWarning
                )

        # 3. The Primary Key / Over-Stratification Trap
        if 'n' in df.columns:
            # If the average group size is close to 1, they shattered the dataset
            avg_group_size = df['n'].mean()
            if len(df) > 1 and avg_group_size < 1.5:
                warn(
                    f"The average group size is {avg_group_size:.2f}. "
                    "Did you accidentally stratify by a unique identifier (like 'sample_id')? "
                    "This will cause your models to overfit or crash.",
                    UserWarning
                )

    @abstractmethod
    def calculate(self, df: pd.DataFrame) -> T_Result:
        """
        Executes the estimator's logic on the provided DataFrame.

        Args:
            df: The input DataFrame (usually aggregated).

        Returns:
            An Estimates object (e.g., PrevalenceEstimates).
        """
        pass

calculate abstractmethod

calculate(df: DataFrame) -> T_Result

Executes the estimator's logic on the provided DataFrame.

Parameters:

Name Type Description Default
df DataFrame

The input DataFrame (usually aggregated).

required

Returns:

Type Description
T_Result

An Estimates object (e.g., PrevalenceEstimates).

Source code in src/seroepi/estimators/_base.py
@abstractmethod
def calculate(self, df: pd.DataFrame) -> T_Result:
    """
    Executes the estimator's logic on the provided DataFrame.

    Args:
        df: The input DataFrame (usually aggregated).

    Returns:
        An Estimates object (e.g., PrevalenceEstimates).
    """
    pass

BayesianIncidenceEstimator

Bases: ModelledMixin, BayesianMixin, BaseEstimator[IncidenceEstimates]

Bayesian Structural Time Series (BSTS) for incidence forecasting.

Uses a Gaussian Random Walk with drift to model latent log-incidence, and a Negative Binomial likelihood to handle overdispersed clinical count data.

Source code in src/seroepi/estimators/_modelled.py
class BayesianIncidenceEstimator(ModelledMixin, BayesianMixin, BaseEstimator[IncidenceEstimates]):
    """
    Bayesian Structural Time Series (BSTS) for incidence forecasting.

    Uses a Gaussian Random Walk with drift to model latent log-incidence, 
    and a Negative Binomial likelihood to handle overdispersed clinical count data.
    """

    def __init__(self,
                 forecast_horizon: int = 12,
                 method: BayesianInferenceMethod = BayesianInferenceMethod.MCMC,
                 num_samples: int = 1500,
                 num_chains: int = 4,
                 num_warmup: int = 1000,
                 svi_steps: int = 3000,
                 seed: int = 42):

        self.forecast_horizon = forecast_horizon
        self._init_bayesian(method, num_samples, num_chains, num_warmup, svi_steps, seed)
        self._method_label = f'bsts_forecast_{self.method.value}'

        # Internal state
        self.strata_ = []
        self.meta_ = {}

    def _model(self, T, n_strata, Y=None, forecast_horizon=0):
        """The NumPyro BSTS Model."""
        total_T = T + forecast_horizon
        # Priors for the latent state parameters
        with plate("strata", n_strata, dim=-1):
            # Base log-rate at t=0
            mu_0 = samp("mu_0", dist.Normal(0, 2))
            # The overarching trend (drift)
            drift = samp("drift", dist.Normal(0, 0.5))
            # How volatile the month-to-month changes are
            sigma_rw = samp("sigma_rw", dist.HalfNormal(0.5))
            # Overdispersion for the Negative Binomial count data
            dispersion = samp("dispersion", dist.HalfNormal(2))

        # Gaussian Random Walk (Innovations)
        with plate("time", T, dim=-2):
            with plate("strata_inner", n_strata, dim=-1):
                innovations_hist = samp("innovations", dist.Normal(0, 1))

        if forecast_horizon > 0:
            with plate("time_future", forecast_horizon, dim=-2):
                with plate("strata_inner_future", n_strata, dim=-1):
                    innovations_future = samp("innovations_future", dist.Normal(0, 1))
            innovations = jnp.concatenate([innovations_hist, innovations_future], axis=-2)
        else:
            innovations = innovations_hist

        # Construct the Latent Log-Rate, use JAX's cumulative sum to build the random walk over time
        rw = jnp.cumsum(innovations * sigma_rw, axis=-2)
        time_steps = jnp.arange(total_T)[:, None]
        # Latent state equation: Baseline + Trend + Random Walk
        log_rate = mu_0 + (time_steps * drift) + rw

        # Likelihood function
        if Y is not None:  # INFERENCE MODE (Fitting historical data)
            historical_rate = log_rate[:T]
            with plate("obs_time", T, dim=-2):
                with plate("obs_strata", n_strata, dim=-1):
                    samp("obs", dist.NegativeBinomial2(mean=jnp.exp(historical_rate), concentration=dispersion), obs=Y)
        else:  # FORECASTING MODE (Projecting the future)
            with plate("obs_time", total_T, dim=-2):
                with plate("obs_strata", n_strata, dim=-1):
                    samp("obs", dist.NegativeBinomial2(mean=jnp.exp(log_rate), concentration=dispersion))

    def fit(self, inc_df: pd.DataFrame) -> 'BayesianIncidenceEstimator':
        """Pivots the count data into a matrix and fits the BSTS model."""
        self.meta_ = inc_df.attrs.get("metric_meta", {})
        self.strata_ = self.meta_.get("stratified_by", [])

        df = inc_df.copy()
        group_cols = self.strata_ + ['trait'] if 'trait' in df.columns else self.strata_

        if not group_cols:
            df['_dummy_group'] = 'Global'
            group_cols = ['_dummy_group']

        # Pivot to get a T x n_strata continuous matrix
        pivot_df = df.pivot_table(
            index='date',
            columns=group_cols,
            values='variant_count',
            fill_value=0
        )

        self.dates_ = pivot_df.index
        self.T_ = len(self.dates_)
        self.n_strata_ = pivot_df.shape[1]
        self.strata_labels_ = pivot_df.columns

        jax_data = {
            "T": self.T_,
            "n_strata": self.n_strata_,
            "Y": jnp.array(pivot_df.values),
            "forecast_horizon": 0  # 0 during inference
        }

        rng_key = random.PRNGKey(self.seed)
        self.samples_ = self._run_inference(jax_data, rng_key)
        self.is_fitted_ = True
        return self

    def predict(self, inc_df: pd.DataFrame) -> IncidenceEstimates:
        """Forecasts future incidence counts using the fitted posterior samples."""
        self.check_is_fitted()

        pred = Predictive(self._model, self.samples_)
        pred_key = random.PRNGKey(self.seed + 1)

        # Run forward for predictions
        pred_samples = pred(
            pred_key,
            T=self.T_,
            n_strata=self.n_strata_,
            Y=None,
            forecast_horizon=self.forecast_horizon
        )

        # obs_draws is shape (num_samples, T + forecast_horizon, n_strata)
        estimate, lower, upper, drifts, p_vals = _summarize_incidence_posterior(
            pred_samples['obs'], 
            self.samples_['drift']
        )

        # Reconstruct continuous dates, adding the future horizon
        freq_str = self.meta_.get("freq", TemporalResolution.MONTH.value)
        try:
            freq = TemporalResolution(freq_str).pandas_offset
            if not freq: freq = 'MS'
        except ValueError:
            freq = 'MS'

        all_dates = pd.date_range(start=self.dates_[0], periods=self.T_ + self.forecast_horizon, freq=freq)

        # Melt the predictions back into long format
        results = []
        group_cols = self.strata_ + ['trait'] if 'trait' in inc_df.columns else self.strata_
        if not group_cols:
            group_cols = ['_dummy_group']

        for i, date in enumerate(all_dates):
            for j, strata_val in enumerate(self.strata_labels_):
                row = {'date': date}
                if group_cols:
                    if isinstance(strata_val, tuple):
                        row.update(zip(group_cols, strata_val))
                    else:
                        row[group_cols[0]] = strata_val

                row['estimate'] = float(estimate[i, j])
                row['lower'] = float(lower[i, j])
                row['upper'] = float(upper[i, j])
                results.append(row)

        res_df = pd.DataFrame(results)

        # Merge back the original counts (this ensures that historical rows keep variant_count, total_sequenced)
        merge_cols = ['date'] + (self.strata_ + ['trait'] if 'trait' in inc_df.columns else self.strata_)
        final_df = res_df.merge(inc_df, on=merge_cols, how='left')
        target = self.meta_.get("trait")
        # Build the model summary with incidence rate ratios (drift exponential)

        model_res = []
        for j, strata_val in enumerate(self.strata_labels_):
            row = {}
            if group_cols:
                if isinstance(strata_val, tuple):
                    row.update(zip(group_cols, strata_val))
                else:
                    row[group_cols[0]] = strata_val
            row['IRR'] = float(jnp.exp(drifts[j]))
            row['prob_increasing'] = float(p_vals[j])
            row['status'] = 'Converged'
            model_res.append(row)

        return IncidenceEstimates(
            data=final_df,
            stratified_by=self.strata_,
            adjusted_for=self.meta_.get("adjusted_for", "unknown"),
            trait=target or "unknown",
            freq=freq_str,
            aggregation_type=self.meta_.get("aggregation_type", "unknown"),
            model_results=pd.DataFrame(model_res)
        )

fit

fit(inc_df: DataFrame) -> BayesianIncidenceEstimator

Pivots the count data into a matrix and fits the BSTS model.

Source code in src/seroepi/estimators/_modelled.py
def fit(self, inc_df: pd.DataFrame) -> 'BayesianIncidenceEstimator':
    """Pivots the count data into a matrix and fits the BSTS model."""
    self.meta_ = inc_df.attrs.get("metric_meta", {})
    self.strata_ = self.meta_.get("stratified_by", [])

    df = inc_df.copy()
    group_cols = self.strata_ + ['trait'] if 'trait' in df.columns else self.strata_

    if not group_cols:
        df['_dummy_group'] = 'Global'
        group_cols = ['_dummy_group']

    # Pivot to get a T x n_strata continuous matrix
    pivot_df = df.pivot_table(
        index='date',
        columns=group_cols,
        values='variant_count',
        fill_value=0
    )

    self.dates_ = pivot_df.index
    self.T_ = len(self.dates_)
    self.n_strata_ = pivot_df.shape[1]
    self.strata_labels_ = pivot_df.columns

    jax_data = {
        "T": self.T_,
        "n_strata": self.n_strata_,
        "Y": jnp.array(pivot_df.values),
        "forecast_horizon": 0  # 0 during inference
    }

    rng_key = random.PRNGKey(self.seed)
    self.samples_ = self._run_inference(jax_data, rng_key)
    self.is_fitted_ = True
    return self

predict

predict(inc_df: DataFrame) -> IncidenceEstimates

Forecasts future incidence counts using the fitted posterior samples.

Source code in src/seroepi/estimators/_modelled.py
def predict(self, inc_df: pd.DataFrame) -> IncidenceEstimates:
    """Forecasts future incidence counts using the fitted posterior samples."""
    self.check_is_fitted()

    pred = Predictive(self._model, self.samples_)
    pred_key = random.PRNGKey(self.seed + 1)

    # Run forward for predictions
    pred_samples = pred(
        pred_key,
        T=self.T_,
        n_strata=self.n_strata_,
        Y=None,
        forecast_horizon=self.forecast_horizon
    )

    # obs_draws is shape (num_samples, T + forecast_horizon, n_strata)
    estimate, lower, upper, drifts, p_vals = _summarize_incidence_posterior(
        pred_samples['obs'], 
        self.samples_['drift']
    )

    # Reconstruct continuous dates, adding the future horizon
    freq_str = self.meta_.get("freq", TemporalResolution.MONTH.value)
    try:
        freq = TemporalResolution(freq_str).pandas_offset
        if not freq: freq = 'MS'
    except ValueError:
        freq = 'MS'

    all_dates = pd.date_range(start=self.dates_[0], periods=self.T_ + self.forecast_horizon, freq=freq)

    # Melt the predictions back into long format
    results = []
    group_cols = self.strata_ + ['trait'] if 'trait' in inc_df.columns else self.strata_
    if not group_cols:
        group_cols = ['_dummy_group']

    for i, date in enumerate(all_dates):
        for j, strata_val in enumerate(self.strata_labels_):
            row = {'date': date}
            if group_cols:
                if isinstance(strata_val, tuple):
                    row.update(zip(group_cols, strata_val))
                else:
                    row[group_cols[0]] = strata_val

            row['estimate'] = float(estimate[i, j])
            row['lower'] = float(lower[i, j])
            row['upper'] = float(upper[i, j])
            results.append(row)

    res_df = pd.DataFrame(results)

    # Merge back the original counts (this ensures that historical rows keep variant_count, total_sequenced)
    merge_cols = ['date'] + (self.strata_ + ['trait'] if 'trait' in inc_df.columns else self.strata_)
    final_df = res_df.merge(inc_df, on=merge_cols, how='left')
    target = self.meta_.get("trait")
    # Build the model summary with incidence rate ratios (drift exponential)

    model_res = []
    for j, strata_val in enumerate(self.strata_labels_):
        row = {}
        if group_cols:
            if isinstance(strata_val, tuple):
                row.update(zip(group_cols, strata_val))
            else:
                row[group_cols[0]] = strata_val
        row['IRR'] = float(jnp.exp(drifts[j]))
        row['prob_increasing'] = float(p_vals[j])
        row['status'] = 'Converged'
        model_res.append(row)

    return IncidenceEstimates(
        data=final_df,
        stratified_by=self.strata_,
        adjusted_for=self.meta_.get("adjusted_for", "unknown"),
        trait=target or "unknown",
        freq=freq_str,
        aggregation_type=self.meta_.get("aggregation_type", "unknown"),
        model_results=pd.DataFrame(model_res)
    )

BayesianMixin

Shared inference logic for NumPyro-based Bayesian estimators.

Source code in src/seroepi/estimators/_modelled.py
class BayesianMixin:
    """
    Shared inference logic for NumPyro-based Bayesian estimators.
    """
    def _init_bayesian(self, method: BayesianInferenceMethod, num_samples: int, num_chains: int,
                       num_warmup: int, svi_steps: int, seed: int):
        self.method = BayesianInferenceMethod(method) if isinstance(method, str) else method
        self.num_samples = num_samples
        self.num_chains = num_chains
        self.num_warmup = num_warmup
        self.svi_steps = svi_steps
        self.seed = seed
        self.samples_ = None
        self.extra_fields_ = None

    def _check_zero_padding(self, df: pd.DataFrame):
        """Ensures the incoming dataframe has been properly padded for Bayesian inference."""
        meta = df.attrs.get('metric_meta', {})

        # If metadata is entirely missing, we assume it's a raw pandas df and warn,
        # but let the math fail naturally if it's jagged.
        if 'is_zero_padded' in meta and not meta['is_zero_padded']:
            raise ValueError(
                f"Mathematical Integrity Error: {self.__class__.__name__} requires a strictly rectangular, "
                "zero-padded matrix to construct the posterior geometry. "
                "Please regenerate the dataset using `.epi.aggregate_...(pad_zeros=True)`."
            )

    def _run_inference(self, jax_data: dict, rng_key: random.PRNGKey):
        """Routes to the correct inference engine based on self.method."""
        if self.method == BayesianInferenceMethod.MCMC:
            return self._mcmc_inference(jax_data, rng_key)
        elif self.method == BayesianInferenceMethod.SVI:
            return self._svi_inference(jax_data, rng_key)
        else:
            raise ValueError(f"Unknown method: {self.method}. Choose from MCMC or SVI.")

    def _mcmc_inference(self, jax_data: dict, rng_key: random.PRNGKey):
        """Runs MCMC inference using NUTS."""
        mcmc = MCMC(NUTS(self._model), num_warmup=self.num_warmup, num_samples=self.num_samples,
                    num_chains=self.num_chains, progress_bar=False)
        mcmc.run(rng_key, **jax_data)
        self.extra_fields_ = mcmc.get_extra_fields()
        return mcmc.get_samples()

    def _svi_inference(self, jax_data: dict, rng_key: random.PRNGKey):
        """Runs Stochastic Variational Inference."""
        opt_key, pred_key = random.split(rng_key)
        guide = autoguide.AutoNormal(self._model)
        optimizer = optim.Adam(step_size=0.01)
        svi = SVI(self._model, guide, optimizer, loss=Trace_ELBO())
        svi_result = svi.run(opt_key, num_steps=self.svi_steps, **jax_data)
        predictive = Predictive(self._model, guide=guide, params=svi_result.params, num_samples=self.num_samples)
        return predictive(pred_key, **jax_data)

    def diagnostics(self) -> pd.DataFrame:
        """Returns MCMC diagnostics (R-hat, ESS) as a formatted DataFrame."""
        if hasattr(self, 'check_is_fitted'):
            self.check_is_fitted()
        if self.method != BayesianInferenceMethod.MCMC:
            raise TypeError("Diagnostics are only available for MCMC inference.")

        summary_dict = diag.summary(self.samples_, prob=0.95, group_by_chain=False)

        # Flatten multi-dimensional parameters (like fixed/random effects) into individual rows
        rows = []
        for param, stats in summary_dict.items():
            param_shape = np.shape(stats['mean'])

            if len(param_shape) == 0:
                row = {'Parameter': param}
                row.update({k: float(v) for k, v in stats.items()})
                rows.append(row)
            else:
                it = np.nditer(np.empty(param_shape), flags=['multi_index'])
                for _ in it:
                    idx = it.multi_index
                    idx_str = ",".join(map(str, idx))
                    row = {'Parameter': f"{param}[{idx_str}]"}
                    row.update({k: float(np.asarray(v)[idx]) for k, v in stats.items()})
                    rows.append(row)

        # Safely extract and append the MCMC sampler's internal extra fields
        if getattr(self, 'extra_fields_', None) is not None:
            for field, values in self.extra_fields_.items():
                val_array = np.asarray(values, dtype=float)
                rows.append({
                    'Parameter': f"sampler_{field}",
                    'mean': float(np.mean(val_array)),
                    'std': float(np.std(val_array)),
                    'sum': float(np.sum(val_array))
                })

        return pd.DataFrame(rows)

diagnostics

diagnostics() -> pd.DataFrame

Returns MCMC diagnostics (R-hat, ESS) as a formatted DataFrame.

Source code in src/seroepi/estimators/_modelled.py
def diagnostics(self) -> pd.DataFrame:
    """Returns MCMC diagnostics (R-hat, ESS) as a formatted DataFrame."""
    if hasattr(self, 'check_is_fitted'):
        self.check_is_fitted()
    if self.method != BayesianInferenceMethod.MCMC:
        raise TypeError("Diagnostics are only available for MCMC inference.")

    summary_dict = diag.summary(self.samples_, prob=0.95, group_by_chain=False)

    # Flatten multi-dimensional parameters (like fixed/random effects) into individual rows
    rows = []
    for param, stats in summary_dict.items():
        param_shape = np.shape(stats['mean'])

        if len(param_shape) == 0:
            row = {'Parameter': param}
            row.update({k: float(v) for k, v in stats.items()})
            rows.append(row)
        else:
            it = np.nditer(np.empty(param_shape), flags=['multi_index'])
            for _ in it:
                idx = it.multi_index
                idx_str = ",".join(map(str, idx))
                row = {'Parameter': f"{param}[{idx_str}]"}
                row.update({k: float(np.asarray(v)[idx]) for k, v in stats.items()})
                rows.append(row)

    # Safely extract and append the MCMC sampler's internal extra fields
    if getattr(self, 'extra_fields_', None) is not None:
        for field, values in self.extra_fields_.items():
            val_array = np.asarray(values, dtype=float)
            rows.append({
                'Parameter': f"sampler_{field}",
                'mean': float(np.mean(val_array)),
                'std': float(np.std(val_array)),
                'sum': float(np.sum(val_array))
            })

    return pd.DataFrame(rows)

BayesianPrevalenceEstimator

Bases: ModelledMixin, BayesianMixin, BaseEstimator[PrevalenceEstimates]

Bayesian hierarchical model for prevalence estimation.

This estimator uses MCMC or SVI to fit a binomial model with random effects for groups and fixed effects for targets. It handles overdispersion and provides credible intervals.

Examples:

>>> from seroepi.estimators import BayesianPrevalenceEstimator
>>> estimator = BayesianPrevalenceEstimator(method='mcmc')
>>> # result = estimator.calculate(agg_df)
Source code in src/seroepi/estimators/_modelled.py
class BayesianPrevalenceEstimator(ModelledMixin, BayesianMixin, BaseEstimator[PrevalenceEstimates]):
    """
    Bayesian hierarchical model for prevalence estimation.

    This estimator uses MCMC or SVI to fit a binomial model with random effects
    for groups and fixed effects for targets. It handles overdispersion and
    provides credible intervals.

    Examples:
        >>> from seroepi.estimators import BayesianPrevalenceEstimator
        >>> estimator = BayesianPrevalenceEstimator(method='mcmc')
        >>> # result = estimator.calculate(agg_df)
    """
    def __init__(self, method: BayesianInferenceMethod = BayesianInferenceMethod.MCMC, num_samples: int = 1500, num_chains: int = 4,
                 num_warmup: int = 1000, svi_steps: int = 3000, target_event: str = 'event', target_n: str = 'n', seed: int = 42):
        """
        Initializes the BayesianPrevalenceEstimator.

        Args:
            method: Inference method ('mcmc' or 'svi'). Defaults to 'mcmc'.
            num_samples: Number of posterior samples to draw. Defaults to 1500.
            num_chains: Number of MCMC chains. Defaults to 4.
            num_warmup: Number of warmup steps for MCMC. Defaults to 1000.
            svi_steps: Number of optimization steps for SVI. Defaults to 3000.
            target_event: Column name for event counts. Defaults to 'event'.
            target_n: Column name for total counts (denominators). Defaults to 'n'.
            seed: Random seed for reproducibility. Defaults to 42.
        """
        self._init_bayesian(method, num_samples, num_chains, num_warmup, svi_steps, seed)
        self._method_label = f'bayesian_{self.method.value}'

        self.target_event = target_event
        self.target_n = target_n

        # Fitted attributes (trailing underscores)
        self.encoders_ = {}
        self.strata_ = []
        self.meta_ = {}

    def _model(self, target_idx, group_idx, n, n_targets, n_groups, event=None):
        """Internal NumPyro model definition."""
        # 1. Global Intercept (Regularized)
        alpha = samp("alpha", dist.Normal(0, 1.5))

        # 2. Fixed effect: Target deviation from baseline
        b_target = samp("b_target", dist.Normal(0, 1).expand([n_targets]))

        # 3. Random effect: Group variation (Non-centered parameterization)
        sd_group = samp("sd_group", dist.HalfNormal(1))
        z_group = samp("z_group", dist.Normal(0, 1).expand([n_groups]))
        r_group = z_group * sd_group

        # 4. Generalized Logit link
        logit_p = alpha + b_target[target_idx] + r_group[group_idx]

        # 5. Likelihood
        samp("obs", dist.Binomial(total_count=n, logits=logit_p), obs=event)

    def fit(self, agg_df: pd.DataFrame) -> 'BayesianPrevalenceEstimator':
        """
        Parses data, fits encoders, and runs inference to get posterior samples.

        Args:
            agg_df: The aggregated DataFrame.

        Returns:
            The fitted estimator instance.
        """
        self._check_zero_padding(agg_df)

        self.strata_, self.meta_ = self._extract_strata(agg_df, exclude_cols=[self.target_event, self.target_n, 'trait'])

        df_fit = agg_df.copy()

        if not self.strata_:
            df_fit['_dummy_group'] = 'Global'
            group_col = '_dummy_group'
        else:
            group_col = self.strata_[0]

        target_col = 'trait'

        # Fit and store the encoders exactly once!
        for col in [group_col, target_col]:
            le = LabelEncoder()
            # We fit_transform here, establishing the strict mapping
            df_fit[f'{col}_idx'] = le.fit_transform(df_fit[col].astype(str))
            self.encoders_[col] = le

        jax_data = {
            "target_idx": jnp.array(df_fit[f'{target_col}_idx'].values),
            "group_idx": jnp.array(df_fit[f'{group_col}_idx'].values),
            "n": jnp.array(df_fit[self.target_n].values),
            "event": jnp.array(df_fit[self.target_event].values),
            "n_targets": len(self.encoders_[target_col].classes_),
            "n_groups": len(self.encoders_[group_col].classes_)
        }

        # Run inference
        rng_key = random.PRNGKey(self.seed)
        self.samples_ = self._run_inference(jax_data, rng_key)

        self.is_fitted_ = True
        return self

    def predict(self, agg_df: pd.DataFrame) -> PrevalenceEstimates:
        """
        Uses the fitted samples and encoders to calculate prevalence bounds.

        Args:
            agg_df: The DataFrame to generate predictions for.

        Returns:
            A PrevalenceEstimates object.
        """
        self.check_is_fitted()

        predict_df = agg_df.copy()

        if not self.strata_:
            predict_df['_dummy_group'] = 'Global'
            group_col = '_dummy_group'
        else:
            group_col = self.strata_[0]

        target_col = 'trait'

        target_idx = jnp.array(self.encoders_[target_col].transform(predict_df[target_col].astype(str)))
        group_idx = jnp.array(self.encoders_[group_col].transform(predict_df[group_col].astype(str)))

        # Matrix math against self.samples_ ...
        estimate, lower, upper = _compute_prevalence_posterior(
            self.samples_["alpha"],
            self.samples_["b_target"],
            self.samples_["z_group"],
            self.samples_["sd_group"],
            target_idx,
            group_idx
        )

        new_cols = {
            'estimate': np.array(estimate),
            'lower': np.array(lower),
            'upper': np.array(upper)
        }

        # 2. Fast horizontal concatenation (ignores the deep copy overhead)
        result_df = pd.concat([agg_df, pd.DataFrame(new_cols, index=agg_df.index)], axis=1)

        return PrevalenceEstimates(
            data=result_df,
            stratified_by=self.strata_,
            adjusted_for=self.meta_.get("adjusted_for", 'unknown'),
            method=self._method_label,
            aggregation_type=self.meta_.get("aggregation_type", "unknown"),
            trait=self.meta_.get("trait", "unknown")
        )

__init__

__init__(method: BayesianInferenceMethod = BayesianInferenceMethod.MCMC, num_samples: int = 1500, num_chains: int = 4, num_warmup: int = 1000, svi_steps: int = 3000, target_event: str = 'event', target_n: str = 'n', seed: int = 42)

Initializes the BayesianPrevalenceEstimator.

Parameters:

Name Type Description Default
method BayesianInferenceMethod

Inference method ('mcmc' or 'svi'). Defaults to 'mcmc'.

MCMC
num_samples int

Number of posterior samples to draw. Defaults to 1500.

1500
num_chains int

Number of MCMC chains. Defaults to 4.

4
num_warmup int

Number of warmup steps for MCMC. Defaults to 1000.

1000
svi_steps int

Number of optimization steps for SVI. Defaults to 3000.

3000
target_event str

Column name for event counts. Defaults to 'event'.

'event'
target_n str

Column name for total counts (denominators). Defaults to 'n'.

'n'
seed int

Random seed for reproducibility. Defaults to 42.

42
Source code in src/seroepi/estimators/_modelled.py
def __init__(self, method: BayesianInferenceMethod = BayesianInferenceMethod.MCMC, num_samples: int = 1500, num_chains: int = 4,
             num_warmup: int = 1000, svi_steps: int = 3000, target_event: str = 'event', target_n: str = 'n', seed: int = 42):
    """
    Initializes the BayesianPrevalenceEstimator.

    Args:
        method: Inference method ('mcmc' or 'svi'). Defaults to 'mcmc'.
        num_samples: Number of posterior samples to draw. Defaults to 1500.
        num_chains: Number of MCMC chains. Defaults to 4.
        num_warmup: Number of warmup steps for MCMC. Defaults to 1000.
        svi_steps: Number of optimization steps for SVI. Defaults to 3000.
        target_event: Column name for event counts. Defaults to 'event'.
        target_n: Column name for total counts (denominators). Defaults to 'n'.
        seed: Random seed for reproducibility. Defaults to 42.
    """
    self._init_bayesian(method, num_samples, num_chains, num_warmup, svi_steps, seed)
    self._method_label = f'bayesian_{self.method.value}'

    self.target_event = target_event
    self.target_n = target_n

    # Fitted attributes (trailing underscores)
    self.encoders_ = {}
    self.strata_ = []
    self.meta_ = {}

fit

fit(agg_df: DataFrame) -> BayesianPrevalenceEstimator

Parses data, fits encoders, and runs inference to get posterior samples.

Parameters:

Name Type Description Default
agg_df DataFrame

The aggregated DataFrame.

required

Returns:

Type Description
BayesianPrevalenceEstimator

The fitted estimator instance.

Source code in src/seroepi/estimators/_modelled.py
def fit(self, agg_df: pd.DataFrame) -> 'BayesianPrevalenceEstimator':
    """
    Parses data, fits encoders, and runs inference to get posterior samples.

    Args:
        agg_df: The aggregated DataFrame.

    Returns:
        The fitted estimator instance.
    """
    self._check_zero_padding(agg_df)

    self.strata_, self.meta_ = self._extract_strata(agg_df, exclude_cols=[self.target_event, self.target_n, 'trait'])

    df_fit = agg_df.copy()

    if not self.strata_:
        df_fit['_dummy_group'] = 'Global'
        group_col = '_dummy_group'
    else:
        group_col = self.strata_[0]

    target_col = 'trait'

    # Fit and store the encoders exactly once!
    for col in [group_col, target_col]:
        le = LabelEncoder()
        # We fit_transform here, establishing the strict mapping
        df_fit[f'{col}_idx'] = le.fit_transform(df_fit[col].astype(str))
        self.encoders_[col] = le

    jax_data = {
        "target_idx": jnp.array(df_fit[f'{target_col}_idx'].values),
        "group_idx": jnp.array(df_fit[f'{group_col}_idx'].values),
        "n": jnp.array(df_fit[self.target_n].values),
        "event": jnp.array(df_fit[self.target_event].values),
        "n_targets": len(self.encoders_[target_col].classes_),
        "n_groups": len(self.encoders_[group_col].classes_)
    }

    # Run inference
    rng_key = random.PRNGKey(self.seed)
    self.samples_ = self._run_inference(jax_data, rng_key)

    self.is_fitted_ = True
    return self

predict

predict(agg_df: DataFrame) -> PrevalenceEstimates

Uses the fitted samples and encoders to calculate prevalence bounds.

Parameters:

Name Type Description Default
agg_df DataFrame

The DataFrame to generate predictions for.

required

Returns:

Type Description
PrevalenceEstimates

A PrevalenceEstimates object.

Source code in src/seroepi/estimators/_modelled.py
def predict(self, agg_df: pd.DataFrame) -> PrevalenceEstimates:
    """
    Uses the fitted samples and encoders to calculate prevalence bounds.

    Args:
        agg_df: The DataFrame to generate predictions for.

    Returns:
        A PrevalenceEstimates object.
    """
    self.check_is_fitted()

    predict_df = agg_df.copy()

    if not self.strata_:
        predict_df['_dummy_group'] = 'Global'
        group_col = '_dummy_group'
    else:
        group_col = self.strata_[0]

    target_col = 'trait'

    target_idx = jnp.array(self.encoders_[target_col].transform(predict_df[target_col].astype(str)))
    group_idx = jnp.array(self.encoders_[group_col].transform(predict_df[group_col].astype(str)))

    # Matrix math against self.samples_ ...
    estimate, lower, upper = _compute_prevalence_posterior(
        self.samples_["alpha"],
        self.samples_["b_target"],
        self.samples_["z_group"],
        self.samples_["sd_group"],
        target_idx,
        group_idx
    )

    new_cols = {
        'estimate': np.array(estimate),
        'lower': np.array(lower),
        'upper': np.array(upper)
    }

    # 2. Fast horizontal concatenation (ignores the deep copy overhead)
    result_df = pd.concat([agg_df, pd.DataFrame(new_cols, index=agg_df.index)], axis=1)

    return PrevalenceEstimates(
        data=result_df,
        stratified_by=self.strata_,
        adjusted_for=self.meta_.get("adjusted_for", 'unknown'),
        method=self._method_label,
        aggregation_type=self.meta_.get("aggregation_type", "unknown"),
        trait=self.meta_.get("trait", "unknown")
    )

BetaDiversityEstimates dataclass

Bases: Estimates

Container for Beta Diversity results (distance matrices).

Attributes:

Name Type Description
metric str

The distance metric used (e.g., 'braycurtis').

Source code in src/seroepi/estimators/_base.py
@dataclass(frozen=True, slots=True)
class BetaDiversityEstimates(Estimates):
    """
    Container for Beta Diversity results (distance matrices).

    Attributes:
        metric: The distance metric used (e.g., 'braycurtis').
    """
    metric: str

BetaDiversityEstimator

Bases: BaseEstimator[BetaDiversityEstimates]

Source code in src/seroepi/estimators/_core.py
class BetaDiversityEstimator(BaseEstimator[BetaDiversityEstimates]):
    def __init__(self, target: str = None, metric: str = 'braycurtis'):
        """
        Calculates between-group dissimilarity.
        Common metrics: 'braycurtis' (abundance-weighted), 'jaccard' (presence/absence).
        """
        self.target = target
        self.metric = metric
        self._method_label = f"beta_diversity_{self.metric}"

    def get_params(self) -> dict:
        """Returns parameters for cloning compatibility."""
        return {'target': self.target, 'metric': self.metric}

    def calculate(self, div_df: pd.DataFrame) -> BetaDiversityEstimates:
        # 1. Extract metadata from the accessor
        meta = div_df.attrs.get("metric_meta", {})
        target_col = meta.get("trait", self.target)
        strata = meta.get("stratified_by", [])

        if not target_col:
            raise ValueError("Target trait must be defined either in init or via accessor metadata.")
        if not strata:
            raise ValueError("Beta diversity requires at least one stratification level to compare groups.")

        # 2. Pivot the data into a Wide Matrix
        # Rows = Strata (e.g., Hospitals), Columns = Variants (e.g., K_loci), Values = Counts
        pivot_df = div_df.pivot_table(
            index=strata,
            columns='trait',
            values='variant_count',
            fill_value=0,  # CRITICAL: Missing variants in a group must be explicitly 0
            aggfunc='sum'
        )

        # 3. Calculate Pairwise Distances
        # pdist calculates the condensed distance vector, squareform turns it into an NxN matrix
        distances = pdist(pivot_df.values, metric=self.metric)
        dist_matrix = squareform(distances)

        # Sanitize any NaNs generated by distance metrics on empty strata
        dist_matrix = np.nan_to_num(dist_matrix, nan=0.0)

        # 4. Format the row/column names for the UI
        # If stratified by multiple columns (e.g., ['Region', 'Year']), we flatten the tuple for display
        strata_names = [
            " | ".join(map(str, idx)) if isinstance(idx, tuple) else str(idx)
            for idx in pivot_df.index
        ]

        # 5. Wrap back into an explicitly labeled Pandas DataFrame
        result_matrix = pd.DataFrame(
            dist_matrix,
            index=strata_names,
            columns=strata_names
        )

        return BetaDiversityEstimates(
            data=result_matrix,
            stratified_by=strata,
            adjusted_for=meta.get("adjusted_for", 'unknown'),
            trait=target_col,
            aggregation_type=meta.get("aggregation_type", AggregationType.TRAIT),
            metric=self.metric
        )

__init__

__init__(target: str = None, metric: str = 'braycurtis')

Calculates between-group dissimilarity. Common metrics: 'braycurtis' (abundance-weighted), 'jaccard' (presence/absence).

Source code in src/seroepi/estimators/_core.py
def __init__(self, target: str = None, metric: str = 'braycurtis'):
    """
    Calculates between-group dissimilarity.
    Common metrics: 'braycurtis' (abundance-weighted), 'jaccard' (presence/absence).
    """
    self.target = target
    self.metric = metric
    self._method_label = f"beta_diversity_{self.metric}"

get_params

get_params() -> dict

Returns parameters for cloning compatibility.

Source code in src/seroepi/estimators/_core.py
def get_params(self) -> dict:
    """Returns parameters for cloning compatibility."""
    return {'target': self.target, 'metric': self.metric}

Estimates dataclass

Base container for statistical estimates.

Attributes:

Name Type Description
data DataFrame

A DataFrame containing the estimates and original strata.

stratified_by list[str]

List of columns used for stratification.

adjusted_for Optional[str]

Column name used for cluster adjustment, if any.

trait str

The trait variable for which estimates were calculated.

Source code in src/seroepi/estimators/_base.py
@dataclass(frozen=True, slots=True)
class Estimates:
    """
    Base container for statistical estimates.

    Attributes:
        data: A DataFrame containing the estimates and original strata.
        stratified_by: List of columns used for stratification.
        adjusted_for: Column name used for cluster adjustment, if any.
        trait: The trait variable for which estimates were calculated.
    """
    data: pd.DataFrame
    stratified_by: list[str]
    adjusted_for: Optional[str]
    trait: str  # e.g., "blaKPC" or "Serotype"
    aggregation_type: AggregationType  # "trait" or "compositional"

GLMIncidenceEstimator

Bases: ModelledMixin, BaseEstimator[IncidenceEstimates]

Negative Binomial GLM for time-series incidence estimation.

Fits a Negative Binomial model to count data over time, optionally adjusting for sequencing volume (relative incidence).

Examples:

>>> from seroepi.estimators import GLMIncidenceEstimator
>>> estimator = GLMIncidenceEstimator(use_relative_incidence=True)
>>> # result = estimator.calculate(inc_df)
Source code in src/seroepi/estimators/_modelled.py
class GLMIncidenceEstimator(ModelledMixin, BaseEstimator[IncidenceEstimates]):
    """
    Negative Binomial GLM for time-series incidence estimation.

    Fits a Negative Binomial model to count data over time, optionally
    adjusting for sequencing volume (relative incidence).

    Examples:
        >>> from seroepi.estimators import GLMIncidenceEstimator
        >>> estimator = GLMIncidenceEstimator(use_relative_incidence=True)
        >>> # result = estimator.calculate(inc_df)
    """
    def __init__(self, use_relative_incidence: bool = True, forecast_horizon: int = 0):
        """
        Initializes the GLMIncidenceEstimator.

        Args:
            use_relative_incidence: If True, models cases adjusting for total
                sequencing volume (offset). If False, models absolute counts.
            forecast_horizon: Number of future time steps to project. Defaults to 0.
        """
        self.use_relative_incidence = use_relative_incidence
        self.forecast_horizon = forecast_horizon
        self._method_label = "neg_binomial_glm"

        # Internal state tracking
        self.fit_results_ = {}
        self.strata_ = []
        self.meta_ = {}

    def fit(self, inc_df: pd.DataFrame) -> 'GLMIncidenceEstimator':
        """Fits the Negative Binomial GLM to each stratum."""
        self.meta_ = inc_df.attrs.get("metric_meta", {})
        target_col = self.meta_.get("trait")
        self.freq_ = self.meta_.get("freq")
        self.strata_ = self.meta_.get("stratified_by", [])

        if not target_col or not self.freq_:
            raise ValueError("Incidence metadata missing. Ensure data came from `epi.aggregate_incidence`.")

        # Sort entirely upstream to avoid O(N log N) operations inside the loop
        inc_df_sorted = inc_df.sort_values('date')
        group_cols = self.strata_ + ['trait'] if 'trait' in inc_df_sorted.columns else self.strata_
        groups = inc_df_sorted.groupby(group_cols, observed=True) if group_cols else [('Global', inc_df_sorted)]

        for name, group in groups:
            # Drop the .sort_values() here. Just copy.
            df_group = group.copy()

            # Filter out periods with ZERO sequencing volume for the GLM fit
            df_model = df_group[df_group['total_sequenced'] > 0].copy()

            if len(df_model) < 3:
                self.fit_results_[name] = None  # Not enough data to model
                continue

            # Create a numeric Time Step for the slope
            min_date = df_model['date'].min()

            if self.freq_ == TemporalResolution.MONTH.value or self.freq_.startswith('M'):
                df_model['time_step'] = (df_model['date'].dt.year - min_date.year) * 12 + (df_model['date'].dt.month - min_date.month)
            elif self.freq_ == TemporalResolution.WEEK.value or self.freq_.startswith('W'):
                df_model['time_step'] = (df_model['date'] - min_date).dt.days // 7
            elif self.freq_ == TemporalResolution.YEAR.value or self.freq_.startswith('Y'):
                df_model['time_step'] = df_model['date'].dt.year - min_date.year
            else:
                df_model['time_step'] = (df_model['date'] - min_date).dt.days

            Y = df_model['variant_count']
            X = sm.add_constant(df_model['time_step'])

            offset = np.log(df_model['total_sequenced']) if self.use_relative_incidence else None

            try:
                # alpha=1.0 is a robust starting guess for overdispersion
                model = sm.GLM(Y, X, family=sm.families.NegativeBinomial(alpha=1.0), offset=offset)
                self.fit_results_[name] = model.fit()
            except Exception as e:
                warn(f"GLM failed to converge for stratum {name}: {e}")
                self.fit_results_[name] = None

        self.is_fitted_ = True
        return self

    def predict(self, inc_df: pd.DataFrame) -> IncidenceEstimates:
        """Extracts Incidence Rate Ratios (IRR) and forecasts trends."""
        self.check_is_fitted()

        freq_str = self.freq_
        try:
            freq = TemporalResolution(freq_str).pandas_offset
            if not freq: freq = 'MS'
        except ValueError:
            freq = 'MS'

        group_cols = self.strata_ + ['trait'] if 'trait' in inc_df.columns else self.strata_
        # Sort to ensure chronological order for time-steps
        groups = inc_df.sort_values('date').groupby(group_cols, observed=True) if group_cols else [('Global', inc_df)]

        results = []
        all_pred_dfs = []

        for name, group in groups:
            fit = self.fit_results_.get(name)

            row = {k: v for k, v in zip(group_cols, name)} if group_cols and isinstance(name, tuple) else {}
            if group_cols and not isinstance(name, tuple):
                row = {group_cols[0]: name}

            if fit is None:
                row.update({
                    'IRR': np.nan, 'IRR_lower': np.nan, 'IRR_upper': np.nan,
                    'p_value': np.nan, 'status': 'Failed/Insufficient Data'
                })

                df_pred = group.copy()
                df_pred['estimate'] = np.nan
                df_pred['lower'] = np.nan
                df_pred['upper'] = np.nan
                all_pred_dfs.append(df_pred)
            else:
                coef = fit.params.get('time_step', 0)
                p_val = fit.pvalues.get('time_step', 1.0)
                ci_lower = fit.conf_int()[0].get('time_step', 0)
                ci_upper = fit.conf_int()[1].get('time_step', 0)

                row.update({
                    'IRR': np.exp(coef),
                    'IRR_lower': np.exp(ci_lower),
                    'IRR_upper': np.exp(ci_upper),
                    'p_value': p_val,
                    'status': 'Converged'
                })

                # --- Generate Predictions & Forecasts ---
                df_pred = group.copy()
                min_date = df_pred['date'].min()

                # Append future horizon rows if required
                if getattr(self, 'forecast_horizon', 0) > 0:
                    max_date = df_pred['date'].max()
                    future_dates = pd.date_range(start=max_date, periods=self.forecast_horizon + 1, freq=freq)[1:]
                    future_df = pd.DataFrame({'date': future_dates})

                    if group_cols:
                        for col in group_cols:
                            future_df[col] = name[group_cols.index(col)] if isinstance(name, tuple) else name

                    # Use the mean historical sequencing volume for relative adjustments in the future
                    future_df['total_sequenced'] = df_pred['total_sequenced'].mean()
                    future_df['variant_count'] = 0

                    df_pred = pd.concat([df_pred, future_df], ignore_index=True)

                # Recalculate the mathematical time_step for ALL rows
                if self.freq_ == TemporalResolution.MONTH.value or self.freq_.startswith('M'):
                    df_pred['time_step'] = (df_pred['date'].dt.year - min_date.year) * 12 + (df_pred['date'].dt.month - min_date.month)
                elif self.freq_ == TemporalResolution.WEEK.value or self.freq_.startswith('W'):
                    df_pred['time_step'] = (df_pred['date'] - min_date).dt.days // 7
                elif self.freq_ == TemporalResolution.YEAR.value or self.freq_.startswith('Y'):
                    df_pred['time_step'] = df_pred['date'].dt.year - min_date.year
                else:
                    df_pred['time_step'] = (df_pred['date'] - min_date).dt.days

                # Generate predictions
                X = sm.add_constant(df_pred['time_step'], has_constant='add')
                offset = np.log(df_pred['total_sequenced'].clip(lower=1e-8)) if self.use_relative_incidence else None

                pred_res = fit.get_prediction(X, offset=offset).summary_frame(alpha=0.05)
                df_pred['estimate'] = pred_res['mean'].values
                df_pred['lower'] = pred_res['mean_ci_lower'].values
                df_pred['upper'] = pred_res['mean_ci_upper'].values

                all_pred_dfs.append(df_pred)

            results.append(row)

        final_df = pd.concat(all_pred_dfs, ignore_index=True) if all_pred_dfs else inc_df.copy()

        return IncidenceEstimates(
            data=final_df,
            stratified_by=self.strata_,
            adjusted_for=self.meta_.get("adjusted_for", 'unknown'),
            trait=self.meta_.get("trait", "unknown"),
            freq=self.freq_,
            aggregation_type=self.meta_.get("aggregation_type", "unknown"),
            model_results=pd.DataFrame(results)
        )

__init__

__init__(use_relative_incidence: bool = True, forecast_horizon: int = 0)

Initializes the GLMIncidenceEstimator.

Parameters:

Name Type Description Default
use_relative_incidence bool

If True, models cases adjusting for total sequencing volume (offset). If False, models absolute counts.

True
forecast_horizon int

Number of future time steps to project. Defaults to 0.

0
Source code in src/seroepi/estimators/_modelled.py
def __init__(self, use_relative_incidence: bool = True, forecast_horizon: int = 0):
    """
    Initializes the GLMIncidenceEstimator.

    Args:
        use_relative_incidence: If True, models cases adjusting for total
            sequencing volume (offset). If False, models absolute counts.
        forecast_horizon: Number of future time steps to project. Defaults to 0.
    """
    self.use_relative_incidence = use_relative_incidence
    self.forecast_horizon = forecast_horizon
    self._method_label = "neg_binomial_glm"

    # Internal state tracking
    self.fit_results_ = {}
    self.strata_ = []
    self.meta_ = {}

fit

fit(inc_df: DataFrame) -> GLMIncidenceEstimator

Fits the Negative Binomial GLM to each stratum.

Source code in src/seroepi/estimators/_modelled.py
def fit(self, inc_df: pd.DataFrame) -> 'GLMIncidenceEstimator':
    """Fits the Negative Binomial GLM to each stratum."""
    self.meta_ = inc_df.attrs.get("metric_meta", {})
    target_col = self.meta_.get("trait")
    self.freq_ = self.meta_.get("freq")
    self.strata_ = self.meta_.get("stratified_by", [])

    if not target_col or not self.freq_:
        raise ValueError("Incidence metadata missing. Ensure data came from `epi.aggregate_incidence`.")

    # Sort entirely upstream to avoid O(N log N) operations inside the loop
    inc_df_sorted = inc_df.sort_values('date')
    group_cols = self.strata_ + ['trait'] if 'trait' in inc_df_sorted.columns else self.strata_
    groups = inc_df_sorted.groupby(group_cols, observed=True) if group_cols else [('Global', inc_df_sorted)]

    for name, group in groups:
        # Drop the .sort_values() here. Just copy.
        df_group = group.copy()

        # Filter out periods with ZERO sequencing volume for the GLM fit
        df_model = df_group[df_group['total_sequenced'] > 0].copy()

        if len(df_model) < 3:
            self.fit_results_[name] = None  # Not enough data to model
            continue

        # Create a numeric Time Step for the slope
        min_date = df_model['date'].min()

        if self.freq_ == TemporalResolution.MONTH.value or self.freq_.startswith('M'):
            df_model['time_step'] = (df_model['date'].dt.year - min_date.year) * 12 + (df_model['date'].dt.month - min_date.month)
        elif self.freq_ == TemporalResolution.WEEK.value or self.freq_.startswith('W'):
            df_model['time_step'] = (df_model['date'] - min_date).dt.days // 7
        elif self.freq_ == TemporalResolution.YEAR.value or self.freq_.startswith('Y'):
            df_model['time_step'] = df_model['date'].dt.year - min_date.year
        else:
            df_model['time_step'] = (df_model['date'] - min_date).dt.days

        Y = df_model['variant_count']
        X = sm.add_constant(df_model['time_step'])

        offset = np.log(df_model['total_sequenced']) if self.use_relative_incidence else None

        try:
            # alpha=1.0 is a robust starting guess for overdispersion
            model = sm.GLM(Y, X, family=sm.families.NegativeBinomial(alpha=1.0), offset=offset)
            self.fit_results_[name] = model.fit()
        except Exception as e:
            warn(f"GLM failed to converge for stratum {name}: {e}")
            self.fit_results_[name] = None

    self.is_fitted_ = True
    return self

predict

predict(inc_df: DataFrame) -> IncidenceEstimates

Extracts Incidence Rate Ratios (IRR) and forecasts trends.

Source code in src/seroepi/estimators/_modelled.py
def predict(self, inc_df: pd.DataFrame) -> IncidenceEstimates:
    """Extracts Incidence Rate Ratios (IRR) and forecasts trends."""
    self.check_is_fitted()

    freq_str = self.freq_
    try:
        freq = TemporalResolution(freq_str).pandas_offset
        if not freq: freq = 'MS'
    except ValueError:
        freq = 'MS'

    group_cols = self.strata_ + ['trait'] if 'trait' in inc_df.columns else self.strata_
    # Sort to ensure chronological order for time-steps
    groups = inc_df.sort_values('date').groupby(group_cols, observed=True) if group_cols else [('Global', inc_df)]

    results = []
    all_pred_dfs = []

    for name, group in groups:
        fit = self.fit_results_.get(name)

        row = {k: v for k, v in zip(group_cols, name)} if group_cols and isinstance(name, tuple) else {}
        if group_cols and not isinstance(name, tuple):
            row = {group_cols[0]: name}

        if fit is None:
            row.update({
                'IRR': np.nan, 'IRR_lower': np.nan, 'IRR_upper': np.nan,
                'p_value': np.nan, 'status': 'Failed/Insufficient Data'
            })

            df_pred = group.copy()
            df_pred['estimate'] = np.nan
            df_pred['lower'] = np.nan
            df_pred['upper'] = np.nan
            all_pred_dfs.append(df_pred)
        else:
            coef = fit.params.get('time_step', 0)
            p_val = fit.pvalues.get('time_step', 1.0)
            ci_lower = fit.conf_int()[0].get('time_step', 0)
            ci_upper = fit.conf_int()[1].get('time_step', 0)

            row.update({
                'IRR': np.exp(coef),
                'IRR_lower': np.exp(ci_lower),
                'IRR_upper': np.exp(ci_upper),
                'p_value': p_val,
                'status': 'Converged'
            })

            # --- Generate Predictions & Forecasts ---
            df_pred = group.copy()
            min_date = df_pred['date'].min()

            # Append future horizon rows if required
            if getattr(self, 'forecast_horizon', 0) > 0:
                max_date = df_pred['date'].max()
                future_dates = pd.date_range(start=max_date, periods=self.forecast_horizon + 1, freq=freq)[1:]
                future_df = pd.DataFrame({'date': future_dates})

                if group_cols:
                    for col in group_cols:
                        future_df[col] = name[group_cols.index(col)] if isinstance(name, tuple) else name

                # Use the mean historical sequencing volume for relative adjustments in the future
                future_df['total_sequenced'] = df_pred['total_sequenced'].mean()
                future_df['variant_count'] = 0

                df_pred = pd.concat([df_pred, future_df], ignore_index=True)

            # Recalculate the mathematical time_step for ALL rows
            if self.freq_ == TemporalResolution.MONTH.value or self.freq_.startswith('M'):
                df_pred['time_step'] = (df_pred['date'].dt.year - min_date.year) * 12 + (df_pred['date'].dt.month - min_date.month)
            elif self.freq_ == TemporalResolution.WEEK.value or self.freq_.startswith('W'):
                df_pred['time_step'] = (df_pred['date'] - min_date).dt.days // 7
            elif self.freq_ == TemporalResolution.YEAR.value or self.freq_.startswith('Y'):
                df_pred['time_step'] = df_pred['date'].dt.year - min_date.year
            else:
                df_pred['time_step'] = (df_pred['date'] - min_date).dt.days

            # Generate predictions
            X = sm.add_constant(df_pred['time_step'], has_constant='add')
            offset = np.log(df_pred['total_sequenced'].clip(lower=1e-8)) if self.use_relative_incidence else None

            pred_res = fit.get_prediction(X, offset=offset).summary_frame(alpha=0.05)
            df_pred['estimate'] = pred_res['mean'].values
            df_pred['lower'] = pred_res['mean_ci_lower'].values
            df_pred['upper'] = pred_res['mean_ci_upper'].values

            all_pred_dfs.append(df_pred)

        results.append(row)

    final_df = pd.concat(all_pred_dfs, ignore_index=True) if all_pred_dfs else inc_df.copy()

    return IncidenceEstimates(
        data=final_df,
        stratified_by=self.strata_,
        adjusted_for=self.meta_.get("adjusted_for", 'unknown'),
        trait=self.meta_.get("trait", "unknown"),
        freq=self.freq_,
        aggregation_type=self.meta_.get("aggregation_type", "unknown"),
        model_results=pd.DataFrame(results)
    )

GLMPrevalenceEstimator

Bases: ModelledMixin, BaseEstimator[PrevalenceEstimates]

Frequentist binomial GLM for prevalence estimation.

Uses statsmodels to fit a Generalized Linear Model with a binomial family and logit link.

Source code in src/seroepi/estimators/_modelled.py
class GLMPrevalenceEstimator(ModelledMixin, BaseEstimator[PrevalenceEstimates]):
    """
    Frequentist binomial GLM for prevalence estimation.

    Uses statsmodels to fit a Generalized Linear Model with a binomial family
    and logit link.
    """
    def __init__(self, target_event: str = 'event', target_n: str = 'n'):
        """
        Initializes the GLMPrevalenceEstimator.

        Args:
            target_event: Column name for event counts. Defaults to 'event'.
            target_n: Column name for total counts. Defaults to 'n'.
        """
        self.target_event = target_event
        self.target_n = target_n
        self._method_label = "binomial_glm"

    def fit(self, agg_df: pd.DataFrame) -> 'GLMPrevalenceEstimator':
        """Fits the binomial GLM."""
        self.strata_, self.meta_ = self._extract_strata(agg_df, exclude_cols=[self.target_event, self.target_n, 'trait'])

        feature_cols = self.strata_ + ['trait'] if 'trait' in agg_df.columns else self.strata_

        # 1. Fit the encoder (We still use sklearn here because statsmodels' categorical handling can be clunky)
        self.encoder_ = OneHotEncoder(drop='first', sparse_output=False, handle_unknown='ignore')
        X_encoded = self.encoder_.fit_transform(agg_df[feature_cols])

        # Add the intercept
        X = sm.add_constant(X_encoded)

        # 2. The Statsmodels Magic: Just pass [Successes, Failures] directly!
        successes = agg_df[self.target_event].values
        failures = agg_df[self.target_n].values - successes
        Y = np.column_stack((successes, failures))

        # 3. Fit the Binomial GLM safely
        # It handles the Fisher Information / Hessian inversion automatically
        glm_model = sm.GLM(Y, X, family=sm.families.Binomial())
        self.fit_results_ = glm_model.fit()

        self.is_fitted_ = True
        return self

    def predict(self, agg_df: pd.DataFrame) -> PrevalenceEstimates:
        """Generates predictions and confidence intervals."""
        self.check_is_fitted()

        # Transform new data
        feature_cols = self.strata_ + ['trait'] if 'trait' in agg_df.columns else self.strata_
        X_encoded = self.encoder_.transform(agg_df[feature_cols])
        X = sm.add_constant(X_encoded, has_constant='add')

        # statsmodels natively handles the delta method and inverse-link transformations
        predictions = self.fit_results_.get_prediction(X).summary_frame(alpha=0.05)

        new_cols = {
            'estimate': predictions['mean'].values,
            'lower': predictions['mean_ci_lower'].values,
            'upper': predictions['mean_ci_upper'].values
        }

        result_df = pd.concat([agg_df.copy(), pd.DataFrame(new_cols, index=agg_df.index)], axis=1)

        return PrevalenceEstimates(
            data=result_df,
            stratified_by=self.strata_,
            adjusted_for=self.meta_.get("adjusted_for", 'unknown'),
            method=self._method_label,
            aggregation_type=self.meta_.get("aggregation_type", "unknown"),
            trait=self.meta_.get("trait", "unknown")
        )

__init__

__init__(target_event: str = 'event', target_n: str = 'n')

Initializes the GLMPrevalenceEstimator.

Parameters:

Name Type Description Default
target_event str

Column name for event counts. Defaults to 'event'.

'event'
target_n str

Column name for total counts. Defaults to 'n'.

'n'
Source code in src/seroepi/estimators/_modelled.py
def __init__(self, target_event: str = 'event', target_n: str = 'n'):
    """
    Initializes the GLMPrevalenceEstimator.

    Args:
        target_event: Column name for event counts. Defaults to 'event'.
        target_n: Column name for total counts. Defaults to 'n'.
    """
    self.target_event = target_event
    self.target_n = target_n
    self._method_label = "binomial_glm"

fit

fit(agg_df: DataFrame) -> GLMPrevalenceEstimator

Fits the binomial GLM.

Source code in src/seroepi/estimators/_modelled.py
def fit(self, agg_df: pd.DataFrame) -> 'GLMPrevalenceEstimator':
    """Fits the binomial GLM."""
    self.strata_, self.meta_ = self._extract_strata(agg_df, exclude_cols=[self.target_event, self.target_n, 'trait'])

    feature_cols = self.strata_ + ['trait'] if 'trait' in agg_df.columns else self.strata_

    # 1. Fit the encoder (We still use sklearn here because statsmodels' categorical handling can be clunky)
    self.encoder_ = OneHotEncoder(drop='first', sparse_output=False, handle_unknown='ignore')
    X_encoded = self.encoder_.fit_transform(agg_df[feature_cols])

    # Add the intercept
    X = sm.add_constant(X_encoded)

    # 2. The Statsmodels Magic: Just pass [Successes, Failures] directly!
    successes = agg_df[self.target_event].values
    failures = agg_df[self.target_n].values - successes
    Y = np.column_stack((successes, failures))

    # 3. Fit the Binomial GLM safely
    # It handles the Fisher Information / Hessian inversion automatically
    glm_model = sm.GLM(Y, X, family=sm.families.Binomial())
    self.fit_results_ = glm_model.fit()

    self.is_fitted_ = True
    return self

predict

predict(agg_df: DataFrame) -> PrevalenceEstimates

Generates predictions and confidence intervals.

Source code in src/seroepi/estimators/_modelled.py
def predict(self, agg_df: pd.DataFrame) -> PrevalenceEstimates:
    """Generates predictions and confidence intervals."""
    self.check_is_fitted()

    # Transform new data
    feature_cols = self.strata_ + ['trait'] if 'trait' in agg_df.columns else self.strata_
    X_encoded = self.encoder_.transform(agg_df[feature_cols])
    X = sm.add_constant(X_encoded, has_constant='add')

    # statsmodels natively handles the delta method and inverse-link transformations
    predictions = self.fit_results_.get_prediction(X).summary_frame(alpha=0.05)

    new_cols = {
        'estimate': predictions['mean'].values,
        'lower': predictions['mean_ci_lower'].values,
        'upper': predictions['mean_ci_upper'].values
    }

    result_df = pd.concat([agg_df.copy(), pd.DataFrame(new_cols, index=agg_df.index)], axis=1)

    return PrevalenceEstimates(
        data=result_df,
        stratified_by=self.strata_,
        adjusted_for=self.meta_.get("adjusted_for", 'unknown'),
        method=self._method_label,
        aggregation_type=self.meta_.get("aggregation_type", "unknown"),
        trait=self.meta_.get("trait", "unknown")
    )

IncidenceEstimates dataclass

Bases: Estimates

Container for time-series incidence results.

Attributes:

Name Type Description
freq str

The time resolution used (e.g., TemporalResolution.MONTH.value).

model_results DataFrame

A DataFrame containing regression outputs (IRR, CIs, P-values).

Source code in src/seroepi/estimators/_base.py
@dataclass(frozen=True, slots=True)
class IncidenceEstimates(Estimates):
    """
    Container for time-series incidence results.

    Attributes:
        freq: The time resolution used (e.g., TemporalResolution.MONTH.value).
        model_results: A DataFrame containing regression outputs (IRR, CIs, P-values).
    """
    freq: str                   # The time resolution (e.g., TemporalResolution.MONTH.value)
    model_results: pd.DataFrame # The regression outputs (IRR, CIs, P-values)

ModelledMixin

Bases: ABC

Contract for estimators with an internal fitted state.

Enforces the scikit-learn fit/predict paradigm and provides universal serialization for fitted models.

Attributes:

Name Type Description
is_fitted_ bool

Boolean indicating if the model has been fitted.

Source code in src/seroepi/estimators/_modelled.py
class ModelledMixin(ABC):
    """
    Contract for estimators with an internal fitted state.

    Enforces the scikit-learn fit/predict paradigm and provides universal
    serialization for fitted models.

    Attributes:
        is_fitted_: Boolean indicating if the model has been fitted.
    """

    # State tracking
    is_fitted_: bool = False

    def check_is_fitted(self):
        """
        Checks if the model is fitted.

        Raises:
            RuntimeError: If the model has not been fitted.
        """
        if not self.is_fitted_:
            raise RuntimeError(
                "This estimator instance is not fitted yet. "
                "Call 'fit' with appropriate arguments before using this estimator."
            )

    @abstractmethod
    def fit(self, df: pd.DataFrame) -> 'ModelledMixin':
        """Calculates the internal state (e.g., MCMC samples) and saves it to self."""
        pass

    @abstractmethod
    def predict(self, df: pd.DataFrame):
        """Uses the fitted internal state to generate predictions on the dataframe."""
        pass

    def calculate(self, df: pd.DataFrame):
        """One-liner to fit and predict on the same data."""
        return self.fit(df).predict(df)

    def save_model(self, filepath: Union[str, Path]) -> None:
        """
        Universally serializes the fitted estimator instance to disk.

        Args:
            filepath: Path where the model should be saved.
        """
        if not self.is_fitted_:
            warn(f"You are saving a {self.__class__.__name__} that hasn't been fitted yet.")

        path = Path(filepath)
        path.parent.mkdir(parents=True, exist_ok=True)
        joblib_dump(self, path)

    @classmethod
    def load_model(cls: Type[T_Modelled], filepath: Union[str, Path]) -> T_Modelled:
        """
        Loads a serialized estimator from disk.

        Args:
            filepath: Path to the serialized model file.

        Returns:
            The loaded estimator instance.

        Raises:
            FileNotFoundError: If the file does not exist.
            TypeError: If the loaded model is not of the expected type.
        """
        path = Path(filepath)
        if not path.exists():
            raise FileNotFoundError(f"No model found at {path}")

        estimator = joblib_load(path)

        # Strict Type Guard
        if not isinstance(estimator, cls):
            raise TypeError(
                f"Type mismatch: Attempted to load into {cls.__name__}, "
                f"but the file contains a {type(estimator).__name__}."
            )

        return estimator

calculate

calculate(df: DataFrame)

One-liner to fit and predict on the same data.

Source code in src/seroepi/estimators/_modelled.py
def calculate(self, df: pd.DataFrame):
    """One-liner to fit and predict on the same data."""
    return self.fit(df).predict(df)

check_is_fitted

check_is_fitted()

Checks if the model is fitted.

Raises:

Type Description
RuntimeError

If the model has not been fitted.

Source code in src/seroepi/estimators/_modelled.py
def check_is_fitted(self):
    """
    Checks if the model is fitted.

    Raises:
        RuntimeError: If the model has not been fitted.
    """
    if not self.is_fitted_:
        raise RuntimeError(
            "This estimator instance is not fitted yet. "
            "Call 'fit' with appropriate arguments before using this estimator."
        )

fit abstractmethod

fit(df: DataFrame) -> ModelledMixin

Calculates the internal state (e.g., MCMC samples) and saves it to self.

Source code in src/seroepi/estimators/_modelled.py
@abstractmethod
def fit(self, df: pd.DataFrame) -> 'ModelledMixin':
    """Calculates the internal state (e.g., MCMC samples) and saves it to self."""
    pass

load_model classmethod

load_model(filepath: Union[str, Path]) -> T_Modelled

Loads a serialized estimator from disk.

Parameters:

Name Type Description Default
filepath Union[str, Path]

Path to the serialized model file.

required

Returns:

Type Description
T_Modelled

The loaded estimator instance.

Raises:

Type Description
FileNotFoundError

If the file does not exist.

TypeError

If the loaded model is not of the expected type.

Source code in src/seroepi/estimators/_modelled.py
@classmethod
def load_model(cls: Type[T_Modelled], filepath: Union[str, Path]) -> T_Modelled:
    """
    Loads a serialized estimator from disk.

    Args:
        filepath: Path to the serialized model file.

    Returns:
        The loaded estimator instance.

    Raises:
        FileNotFoundError: If the file does not exist.
        TypeError: If the loaded model is not of the expected type.
    """
    path = Path(filepath)
    if not path.exists():
        raise FileNotFoundError(f"No model found at {path}")

    estimator = joblib_load(path)

    # Strict Type Guard
    if not isinstance(estimator, cls):
        raise TypeError(
            f"Type mismatch: Attempted to load into {cls.__name__}, "
            f"but the file contains a {type(estimator).__name__}."
        )

    return estimator

predict abstractmethod

predict(df: DataFrame)

Uses the fitted internal state to generate predictions on the dataframe.

Source code in src/seroepi/estimators/_modelled.py
@abstractmethod
def predict(self, df: pd.DataFrame):
    """Uses the fitted internal state to generate predictions on the dataframe."""
    pass

save_model

save_model(filepath: Union[str, Path]) -> None

Universally serializes the fitted estimator instance to disk.

Parameters:

Name Type Description Default
filepath Union[str, Path]

Path where the model should be saved.

required
Source code in src/seroepi/estimators/_modelled.py
def save_model(self, filepath: Union[str, Path]) -> None:
    """
    Universally serializes the fitted estimator instance to disk.

    Args:
        filepath: Path where the model should be saved.
    """
    if not self.is_fitted_:
        warn(f"You are saving a {self.__class__.__name__} that hasn't been fitted yet.")

    path = Path(filepath)
    path.parent.mkdir(parents=True, exist_ok=True)
    joblib_dump(self, path)

PrevalenceEstimates dataclass

Bases: Estimates

Container for prevalence results.

Attributes:

Name Type Description
method str

The statistical method used (e.g., 'bayesian_mcmc').

Source code in src/seroepi/estimators/_base.py
@dataclass(frozen=True, slots=True)
class PrevalenceEstimates(Estimates):
    """
    Container for prevalence results.

    Attributes:
        method: The statistical method used (e.g., 'bayesian_mcmc').
    """
    method: str

SpatialPrevalenceEstimator

Bases: ModelledMixin, BayesianMixin, BaseEstimator[PrevalenceEstimates]

Gaussian Process (GP) based spatial prevalence estimator.

Fits a GP model to spatial binomial data, allowing for continuous mapping of prevalence across a geographic area.

Examples:

>>> from seroepi.estimators import SpatialPrevalenceEstimator
>>> estimator = SpatialPrevalenceEstimator(lat_col='lat', lon_col='lon')
>>> # result = estimator.calculate(agg_df)
Source code in src/seroepi/estimators/_modelled.py
class SpatialPrevalenceEstimator(ModelledMixin, BayesianMixin, BaseEstimator[PrevalenceEstimates]):
    """
    Gaussian Process (GP) based spatial prevalence estimator.

    Fits a GP model to spatial binomial data, allowing for continuous mapping
    of prevalence across a geographic area.

    Examples:
        >>> from seroepi.estimators import SpatialPrevalenceEstimator
        >>> estimator = SpatialPrevalenceEstimator(lat_col='lat', lon_col='lon')
        >>> # result = estimator.calculate(agg_df)
    """
    def __init__(self, lat_col: str = 'lat', lon_col: str = 'lon',
                 method: BayesianInferenceMethod = BayesianInferenceMethod.MCMC, num_samples: int = 1500, num_chains: int = 4,
                 num_warmup: int = 1000, svi_steps: int = 3000,
                 target_event: str = 'event', target_n: str = 'n', seed: int = 42):
        """
        Initializes the SpatialPrevalenceEstimator.

        Args:
            lat_col: Column name for latitude. Defaults to 'lat'.
            lon_col: Column name for longitude. Defaults to 'lon'.
            method: Inference method. Defaults to 'mcmc'.
            num_samples: Number of samples. Defaults to 1500.
            num_chains: Number of chains. Defaults to 4.
            num_warmup: Number of warmup steps. Defaults to 1000.
            svi_steps: Number of SVI steps. Defaults to 3000.
            target_event: Column for events. Defaults to 'event'.
            target_n: Column for totals. Defaults to 'n'.
            seed: Random seed. Defaults to 42.
        """
        self._init_bayesian(method, num_samples, num_chains, num_warmup, svi_steps, seed)
        self.lat_col = lat_col
        self.lon_col = lon_col
        self._method_label = f'spatial_gp_{self.method.value}'

        self.target_event = target_event
        self.target_n = target_n

        # Fitted attributes
        self.X_train_ = None
        self.loc_mean_ = None
        self.loc_scale_ = None
        self.meta_ = {}

    def _model(self, X, n, event=None):
        """Internal NumPyro GP model."""
        # 1. Global Intercept
        alpha = samp("alpha", dist.Normal(0, 1.5))

        # 2. GP Kernel Parameters (Variance/Amplitude and Length-scale)
        var = samp("var", dist.HalfNormal(1.0))
        length = samp("length", dist.InverseGamma(2.0, 1.0))

        # 3. Spatial Covariance Matrix
        K = _rbf_kernel(X, X, var, length)

        # 4. Latent Spatial Field (f)
        f = samp("f", dist.MultivariateNormal(loc=jnp.zeros(X.shape[0]), covariance_matrix=K))

        # 5. Likelihood
        logit_p = alpha + f
        samp("obs", dist.Binomial(total_count=n, logits=logit_p), obs=event)

    def fit(self, agg_df: pd.DataFrame) -> 'SpatialPrevalenceEstimator':
        """Aggregates to unique locations, normalizes, and fits the GP."""
        self._check_zero_padding(agg_df)

        if self.lat_col not in agg_df.columns or self.lon_col not in agg_df.columns:
            raise KeyError(
                f"Spatial estimator requires '{self.lat_col}' and '{self.lon_col}' "
                "in the aggregated data. Please ensure you included them in the 'Stratify By' dropdown."
            )

        # 1. Ensure we only have ONE row per unique lat/lon to prevent singular matrices
        spatial_df = agg_df.groupby([self.lat_col, self.lon_col], as_index=False).agg({
            self.target_event: 'sum',
            self.target_n: 'sum'
        })

        # 2. Extract and Standardize Coordinates
        raw_coords = spatial_df[[self.lat_col, self.lon_col]].astype(float).values
        self.loc_mean_ = np.mean(raw_coords, axis=0)
        self.loc_scale_ = np.std(raw_coords, axis=0) + 1e-8  # Prevent div by zero

        self.X_train_ = (raw_coords - self.loc_mean_) / self.loc_scale_

        jax_data = {
            "X": jnp.array(self.X_train_),
            "n": jnp.array(spatial_df[self.target_n].values),
            "event": jnp.array(spatial_df[self.target_event].values)
        }

        # 3. Run Inference
        rng_key = random.PRNGKey(self.seed)
        self.samples_ = self._run_inference(jax_data, rng_key)

        self.meta_ = agg_df.attrs.get("metric_meta", {})
        self.is_fitted_ = True
        return self

    def predict(self, df: pd.DataFrame) -> PrevalenceEstimates:
        """
        Calculates the conditional predictive posterior for any set of coordinates.

        Args:
            df: DataFrame containing latitude and longitude columns.

        Returns:
            A PrevalenceEstimates object with predicted values at the locations.
        """
        self.check_is_fitted()

        # 1. Standardize new coordinates using the FITTED scaler
        raw_X_test = df[[self.lat_col, self.lon_col]].astype(float).values
        X_test = jnp.array((raw_X_test - self.loc_mean_) / self.loc_scale_)
        X_train = jnp.array(self.X_train_)

        # 2. Use JAX vmap to instantly vectorize this complex math across all 1500 samples!
        estimate, lower, upper = _compute_spatial_posterior(
            self.samples_['var'],
            self.samples_['length'],
            self.samples_['alpha'],
            self.samples_['f'],
            X_train,
            X_test
        )

        result_df = df.copy()
        result_df['estimate'] = np.array(estimate)
        result_df['lower'] = np.array(lower)
        result_df['upper'] = np.array(upper)

        return PrevalenceEstimates(
            data=result_df,
            stratified_by=[self.lat_col, self.lon_col],
            adjusted_for=self.meta_.get("adjusted_for", 'unknown'),
            method=self._method_label,
            aggregation_type=self.meta_.get("aggregation_type", "unknown"),
            trait=self.meta_.get("trait", "unknown")
        )

__init__

__init__(lat_col: str = 'lat', lon_col: str = 'lon', method: BayesianInferenceMethod = BayesianInferenceMethod.MCMC, num_samples: int = 1500, num_chains: int = 4, num_warmup: int = 1000, svi_steps: int = 3000, target_event: str = 'event', target_n: str = 'n', seed: int = 42)

Initializes the SpatialPrevalenceEstimator.

Parameters:

Name Type Description Default
lat_col str

Column name for latitude. Defaults to 'lat'.

'lat'
lon_col str

Column name for longitude. Defaults to 'lon'.

'lon'
method BayesianInferenceMethod

Inference method. Defaults to 'mcmc'.

MCMC
num_samples int

Number of samples. Defaults to 1500.

1500
num_chains int

Number of chains. Defaults to 4.

4
num_warmup int

Number of warmup steps. Defaults to 1000.

1000
svi_steps int

Number of SVI steps. Defaults to 3000.

3000
target_event str

Column for events. Defaults to 'event'.

'event'
target_n str

Column for totals. Defaults to 'n'.

'n'
seed int

Random seed. Defaults to 42.

42
Source code in src/seroepi/estimators/_modelled.py
def __init__(self, lat_col: str = 'lat', lon_col: str = 'lon',
             method: BayesianInferenceMethod = BayesianInferenceMethod.MCMC, num_samples: int = 1500, num_chains: int = 4,
             num_warmup: int = 1000, svi_steps: int = 3000,
             target_event: str = 'event', target_n: str = 'n', seed: int = 42):
    """
    Initializes the SpatialPrevalenceEstimator.

    Args:
        lat_col: Column name for latitude. Defaults to 'lat'.
        lon_col: Column name for longitude. Defaults to 'lon'.
        method: Inference method. Defaults to 'mcmc'.
        num_samples: Number of samples. Defaults to 1500.
        num_chains: Number of chains. Defaults to 4.
        num_warmup: Number of warmup steps. Defaults to 1000.
        svi_steps: Number of SVI steps. Defaults to 3000.
        target_event: Column for events. Defaults to 'event'.
        target_n: Column for totals. Defaults to 'n'.
        seed: Random seed. Defaults to 42.
    """
    self._init_bayesian(method, num_samples, num_chains, num_warmup, svi_steps, seed)
    self.lat_col = lat_col
    self.lon_col = lon_col
    self._method_label = f'spatial_gp_{self.method.value}'

    self.target_event = target_event
    self.target_n = target_n

    # Fitted attributes
    self.X_train_ = None
    self.loc_mean_ = None
    self.loc_scale_ = None
    self.meta_ = {}

fit

fit(agg_df: DataFrame) -> SpatialPrevalenceEstimator

Aggregates to unique locations, normalizes, and fits the GP.

Source code in src/seroepi/estimators/_modelled.py
def fit(self, agg_df: pd.DataFrame) -> 'SpatialPrevalenceEstimator':
    """Aggregates to unique locations, normalizes, and fits the GP."""
    self._check_zero_padding(agg_df)

    if self.lat_col not in agg_df.columns or self.lon_col not in agg_df.columns:
        raise KeyError(
            f"Spatial estimator requires '{self.lat_col}' and '{self.lon_col}' "
            "in the aggregated data. Please ensure you included them in the 'Stratify By' dropdown."
        )

    # 1. Ensure we only have ONE row per unique lat/lon to prevent singular matrices
    spatial_df = agg_df.groupby([self.lat_col, self.lon_col], as_index=False).agg({
        self.target_event: 'sum',
        self.target_n: 'sum'
    })

    # 2. Extract and Standardize Coordinates
    raw_coords = spatial_df[[self.lat_col, self.lon_col]].astype(float).values
    self.loc_mean_ = np.mean(raw_coords, axis=0)
    self.loc_scale_ = np.std(raw_coords, axis=0) + 1e-8  # Prevent div by zero

    self.X_train_ = (raw_coords - self.loc_mean_) / self.loc_scale_

    jax_data = {
        "X": jnp.array(self.X_train_),
        "n": jnp.array(spatial_df[self.target_n].values),
        "event": jnp.array(spatial_df[self.target_event].values)
    }

    # 3. Run Inference
    rng_key = random.PRNGKey(self.seed)
    self.samples_ = self._run_inference(jax_data, rng_key)

    self.meta_ = agg_df.attrs.get("metric_meta", {})
    self.is_fitted_ = True
    return self

predict

predict(df: DataFrame) -> PrevalenceEstimates

Calculates the conditional predictive posterior for any set of coordinates.

Parameters:

Name Type Description Default
df DataFrame

DataFrame containing latitude and longitude columns.

required

Returns:

Type Description
PrevalenceEstimates

A PrevalenceEstimates object with predicted values at the locations.

Source code in src/seroepi/estimators/_modelled.py
def predict(self, df: pd.DataFrame) -> PrevalenceEstimates:
    """
    Calculates the conditional predictive posterior for any set of coordinates.

    Args:
        df: DataFrame containing latitude and longitude columns.

    Returns:
        A PrevalenceEstimates object with predicted values at the locations.
    """
    self.check_is_fitted()

    # 1. Standardize new coordinates using the FITTED scaler
    raw_X_test = df[[self.lat_col, self.lon_col]].astype(float).values
    X_test = jnp.array((raw_X_test - self.loc_mean_) / self.loc_scale_)
    X_train = jnp.array(self.X_train_)

    # 2. Use JAX vmap to instantly vectorize this complex math across all 1500 samples!
    estimate, lower, upper = _compute_spatial_posterior(
        self.samples_['var'],
        self.samples_['length'],
        self.samples_['alpha'],
        self.samples_['f'],
        X_train,
        X_test
    )

    result_df = df.copy()
    result_df['estimate'] = np.array(estimate)
    result_df['lower'] = np.array(lower)
    result_df['upper'] = np.array(upper)

    return PrevalenceEstimates(
        data=result_df,
        stratified_by=[self.lat_col, self.lon_col],
        adjusted_for=self.meta_.get("adjusted_for", 'unknown'),
        method=self._method_label,
        aggregation_type=self.meta_.get("aggregation_type", "unknown"),
        trait=self.meta_.get("trait", "unknown")
    )

UnpooledPrevalenceEstimator

Bases: BaseEstimator[PrevalenceEstimates]

Source code in src/seroepi/estimators/_core.py
class UnpooledPrevalenceEstimator(BaseEstimator[PrevalenceEstimates]):

    Method = Literal['wilson', 'wald', 'agresti_coull', 'clopper_pearson', 'jeffreys']

    def __init__(self, method: Method = 'wilson', alpha: float = 0.05):
        self.method = method.lower()
        self._method_label = f"unpooled_{self.method}"
        self._method_func = _FREQUENTIST_KERNELS.get(self.method, None)
        if self._method_func is None:
            raise ValueError(f"Unknown method: {self.method}. "
                             f"Choose from: {list(_FREQUENTIST_KERNELS.keys())}")
        self.alpha = alpha

    def get_params(self) -> dict:
        """Returns parameters for cloning compatibility during Cross-Validation."""
        return {'method': self.method, 'alpha': self.alpha}

    def calculate(self, agg_df: pd.DataFrame) -> PrevalenceEstimates:
        """Expects the output of df.epi.aggregate_prevalence()"""

        stratified_by, meta = self._extract_strata(agg_df, exclude_cols=['event', 'n', 'trait'])

        # Extract vectors for fast numpy math
        counts = agg_df['event'].values
        denominators = agg_df['n'].values

        # Route to the selected mathematical method
        prop, lower, upper = self._method_func(counts, denominators, self.alpha)

        new_cols = {
            'estimate': np.nan_to_num(prop, nan=0.0),
            'lower': np.nan_to_num(lower, nan=0.0),
            'upper': np.nan_to_num(upper, nan=0.0)
        }

        # 2. Fast horizontal concatenation (ignores the deep copy overhead)
        result_df = pd.concat([agg_df, pd.DataFrame(new_cols, index=agg_df.index)], axis=1)

        return PrevalenceEstimates(
            data=result_df,
            stratified_by=stratified_by,
            adjusted_for=meta.get("adjusted_for", 'unknown'),
            method=self._method_label,
            aggregation_type=meta.get("aggregation_type", AggregationType.TRAIT),
            trait=meta.get("trait", "unknown")
        )

calculate

calculate(agg_df: DataFrame) -> PrevalenceEstimates

Expects the output of df.epi.aggregate_prevalence()

Source code in src/seroepi/estimators/_core.py
def calculate(self, agg_df: pd.DataFrame) -> PrevalenceEstimates:
    """Expects the output of df.epi.aggregate_prevalence()"""

    stratified_by, meta = self._extract_strata(agg_df, exclude_cols=['event', 'n', 'trait'])

    # Extract vectors for fast numpy math
    counts = agg_df['event'].values
    denominators = agg_df['n'].values

    # Route to the selected mathematical method
    prop, lower, upper = self._method_func(counts, denominators, self.alpha)

    new_cols = {
        'estimate': np.nan_to_num(prop, nan=0.0),
        'lower': np.nan_to_num(lower, nan=0.0),
        'upper': np.nan_to_num(upper, nan=0.0)
    }

    # 2. Fast horizontal concatenation (ignores the deep copy overhead)
    result_df = pd.concat([agg_df, pd.DataFrame(new_cols, index=agg_df.index)], axis=1)

    return PrevalenceEstimates(
        data=result_df,
        stratified_by=stratified_by,
        adjusted_for=meta.get("adjusted_for", 'unknown'),
        method=self._method_label,
        aggregation_type=meta.get("aggregation_type", AggregationType.TRAIT),
        trait=meta.get("trait", "unknown")
    )

get_params

get_params() -> dict

Returns parameters for cloning compatibility during Cross-Validation.

Source code in src/seroepi/estimators/_core.py
def get_params(self) -> dict:
    """Returns parameters for cloning compatibility during Cross-Validation."""
    return {'method': self.method, 'alpha': self.alpha}

seroepi.plotting

BasePlotter

Bases: ABC

Stateless base class for all plotting engines in seroepi.

Source code in src/seroepi/plotting.py
class BasePlotter(ABC):
    """
    Stateless base class for all plotting engines in seroepi.
    """
    # The Hero Palette: Electric Cyan and Neon Pink
    # Cyan is scientifically colorblind-safe while looking stunning on dark backgrounds.
    _MAIN_COLOUR = '#0EA5E9'
    # Translucent Cyan for Confidence Interval Ribbons (20% Opacity)
    _CI_COLOUR = 'rgba(14, 165, 233, 0.2)'
    # A secondary highlight color (Optional, but great for distinguishing target groups)
    _ACCENT_COLOUR = '#EC4899'  # Vibrant Neon Pink
    _FONT_COLOUR = '#94A3B8'  # Slate 400 (Highly readable on both light and dark backgrounds)
    _GRID_COLOUR = 'rgba(148, 163, 184, 0.2)'  # Subtle translucent grid lines
    # Global cache to prevent reading the file from disk multiple times
    _WORLD_GEOJSON = None

    # To be overridden by subclasses with the supported result types
    SUPPORTED_TYPES = ()

    @classmethod
    def _get_world_geojson(cls) -> dict:
        """
        Lazily loads and caches the internal world boundaries GeoJSON.

        Returns:
            A dictionary containing the GeoJSON data.
        """
        if cls._WORLD_GEOJSON is None:
            try:
                # Safely navigates the package structure regardless of where it's installed
                geojson_path = files('seroepi.data').joinpath('world_boundaries.geojson')
                with geojson_path.open(mode='r', encoding='utf-8') as f:
                    cls._WORLD_GEOJSON = json_load(f)
            except Exception as e:
                warn(f"Could not load internal world boundaries. Ensure the file exists: {e}")
                cls._WORLD_GEOJSON = {}
        return cls._WORLD_GEOJSON

    @classmethod
    def can_render(cls, result_obj: Any) -> bool:
        """Checks if the incoming result object is supported by this plotter."""
        # Safely extract the inner type if it's passed as a type hint or instance
        return isinstance(result_obj, cls.SUPPORTED_TYPES)

    @classmethod
    def _clean_label(cls, col_name: str) -> str:
        """Strips domain prefixes for clean UI rendering."""
        if not isinstance(col_name, str): return str(col_name)
        for domain in [Domain.GENOTYPE.value, Domain.PHENOTYPE.value, Domain.AMR.value, Domain.VIRULENCE.value]:
            prefix = f"{domain}_"
            if col_name.startswith(prefix):
                return col_name.replace(prefix, "").replace("_", " ")
        return col_name.replace("_", " ")

    @classmethod
    def get_colorscale(cls, transparent: bool = True) -> list:
        """Returns the standard Cyberpunk continuous color scale."""
        base_color = 'rgba(0,0,0,0)' if transparent else 'rgba(15, 23, 42, 0.4)'
        return [
            [0.0, base_color],
            [0.4, '#8B5CF6'],                # Deep Purple
            [0.7, cls._MAIN_COLOUR],         # Electric Cyan
            [1.0, cls._ACCENT_COLOUR]        # Neon Pink
        ]

    @classmethod
    def apply_theme(cls, fig: Figure) -> Figure:
        """Applies a universal transparent theme optimized for both light and dark web app modes."""
        fig.update_layout(
            plot_bgcolor='rgba(0,0,0,0)',
            paper_bgcolor='rgba(0,0,0,0)',
            font=dict(color=cls._FONT_COLOUR),
            hoverlabel=dict(
                bgcolor="black",
                bordercolor=cls._MAIN_COLOUR,
                font_size=14,
                font_color="white"
            )
        )
        fig.update_xaxes(
            gridcolor=cls._GRID_COLOUR,
            zerolinecolor=cls._GRID_COLOUR,
            linecolor=cls._GRID_COLOUR
        )
        fig.update_yaxes(
            gridcolor=cls._GRID_COLOUR,
            zerolinecolor=cls._GRID_COLOUR,
            linecolor=cls._GRID_COLOUR
        )
        return fig

    @classmethod
    @abstractmethod
    def render(cls, result_obj: Any, **kwargs) -> 'Figure':
        """
        Renders the result object into a Plotly figure.

        Args:
            result_obj: The result object to visualize.
            **kwargs: Additional plotting arguments.

        Returns:
            A plotly Figure object.
        """
        pass

apply_theme classmethod

apply_theme(fig: Figure) -> Figure

Applies a universal transparent theme optimized for both light and dark web app modes.

Source code in src/seroepi/plotting.py
@classmethod
def apply_theme(cls, fig: Figure) -> Figure:
    """Applies a universal transparent theme optimized for both light and dark web app modes."""
    fig.update_layout(
        plot_bgcolor='rgba(0,0,0,0)',
        paper_bgcolor='rgba(0,0,0,0)',
        font=dict(color=cls._FONT_COLOUR),
        hoverlabel=dict(
            bgcolor="black",
            bordercolor=cls._MAIN_COLOUR,
            font_size=14,
            font_color="white"
        )
    )
    fig.update_xaxes(
        gridcolor=cls._GRID_COLOUR,
        zerolinecolor=cls._GRID_COLOUR,
        linecolor=cls._GRID_COLOUR
    )
    fig.update_yaxes(
        gridcolor=cls._GRID_COLOUR,
        zerolinecolor=cls._GRID_COLOUR,
        linecolor=cls._GRID_COLOUR
    )
    return fig

can_render classmethod

can_render(result_obj: Any) -> bool

Checks if the incoming result object is supported by this plotter.

Source code in src/seroepi/plotting.py
@classmethod
def can_render(cls, result_obj: Any) -> bool:
    """Checks if the incoming result object is supported by this plotter."""
    # Safely extract the inner type if it's passed as a type hint or instance
    return isinstance(result_obj, cls.SUPPORTED_TYPES)

get_colorscale classmethod

get_colorscale(transparent: bool = True) -> list

Returns the standard Cyberpunk continuous color scale.

Source code in src/seroepi/plotting.py
@classmethod
def get_colorscale(cls, transparent: bool = True) -> list:
    """Returns the standard Cyberpunk continuous color scale."""
    base_color = 'rgba(0,0,0,0)' if transparent else 'rgba(15, 23, 42, 0.4)'
    return [
        [0.0, base_color],
        [0.4, '#8B5CF6'],                # Deep Purple
        [0.7, cls._MAIN_COLOUR],         # Electric Cyan
        [1.0, cls._ACCENT_COLOUR]        # Neon Pink
    ]

render abstractmethod classmethod

render(result_obj: Any, **kwargs) -> Figure

Renders the result object into a Plotly figure.

Parameters:

Name Type Description Default
result_obj Any

The result object to visualize.

required
**kwargs

Additional plotting arguments.

{}

Returns:

Type Description
Figure

A plotly Figure object.

Source code in src/seroepi/plotting.py
@classmethod
@abstractmethod
def render(cls, result_obj: Any, **kwargs) -> 'Figure':
    """
    Renders the result object into a Plotly figure.

    Args:
        result_obj: The result object to visualize.
        **kwargs: Additional plotting arguments.

    Returns:
        A plotly Figure object.
    """
    pass

CumulativeCoveragePlotter

Bases: BasePlotter

Calculates cumulative population coverage. Crucial for designing multivalent vaccines (e.g., K-locus targeting).

Source code in src/seroepi/plotting.py
class CumulativeCoveragePlotter(BasePlotter):
    """
    Calculates cumulative population coverage.
    Crucial for designing multivalent vaccines (e.g., K-locus targeting).
    """
    SUPPORTED_TYPES = (estimators.PrevalenceEstimates, dict)

    @classmethod
    def render(cls, result: Union['estimators.PrevalenceEstimates', dict],  max_valencies: int = None, **kwargs):
        if not cls.can_render(result):
            raise TypeError(f"{cls.__name__} does not support {type(result).__name__}.")

        if isinstance(result, dict):
            res = result.get("res")
            formulation = result.get("formulation")
        else:
            res = result
            formulation = None

        if res.aggregation_type != AggregationType.COMPOSITIONAL:
            raise ValueError("Cumulative coverage strictly requires compositional prevalence estimates.")

        data = res.data.copy()

        # Idiomatic SciPy: Retrieve the exact Z-score for a 95% two-sided interval (~1.96)
        z_score = norm.ppf(0.975)

        # Extract SE from the existing CIs to mathematically preserve the complex 
        # shrinkage/smoothing applied by the upstream Bayesian or Spatial estimators!
        data['se'] = (data['upper'] - data['lower']) / (2 * z_score)
        data['var'] = data['se'] ** 2

        if formulation:
            target_order = formulation.get_formulation()
        else:
            # Sort strictly by raw count to simulate prioritizing the most common variants globally
            target_order = data.groupby('target')['event'].sum().sort_values(ascending=False).index.tolist()
            if max_valencies:
                target_order = target_order[:max_valencies]

        fig = go.Figure()
        group_cols = res.stratified_by

        import plotly.express as px
        colors = px.colors.qualitative.Plotly

        if not group_cols:
            # --- GLOBAL COVERAGE ---
            grouped = data.groupby('target', observed=True)['estimate'].sum().reindex(target_order).fillna(0)
            grouped_var = data.groupby('target', observed=True)['var'].sum().reindex(target_order).fillna(0)

            cum_prop = grouped.cumsum().clip(0, 1)
            cum_se = np.sqrt(grouped_var.cumsum())

            cum_lower = (cum_prop - z_score * cum_se).clip(0, 1)
            cum_upper = (cum_prop + z_score * cum_se).clip(0, 1)

            # Draw the translucent confidence ribbon
            fig.add_trace(go.Scatter(
                x=target_order + target_order[::-1],
                y=cum_upper.tolist() + cum_lower.tolist()[::-1],
                fill='toself',
                fillcolor=cls._MAIN_COLOUR,
                opacity=0.2,
                line=dict(color='rgba(255,255,255,0)'),
                hoverinfo="skip",
                showlegend=False,
                legendgroup="Cumulative Population Coverage"
            ))

            fig.add_trace(go.Scatter(
                x=target_order, y=cum_prop,
                mode='lines+markers',
                name='Cumulative Population Coverage',
                line=dict(color=cls._MAIN_COLOUR, width=3),
                marker=dict(size=8, color=cls._MAIN_COLOUR),
                customdata=np.column_stack((cum_lower.values, cum_upper.values)),
                hovertemplate="<b>%{x}</b><br>Cumulative Coverage: %{y:.1%}<br>95% CI: %{customdata[0]:.1%} - %{customdata[1]:.1%}<extra></extra>",
                legendgroup="Cumulative Population Coverage"
            ))
            strata_label = "Baseline"
        else:
            # --- STRATIFIED COVERAGE ---
            color_col = group_cols[0]
            strata_label = f"Stratified by {cls._clean_label(color_col)}"

            for i, (stratum, group_df) in enumerate(data.groupby(color_col, observed=True)):
                grouped = group_df.groupby('target', observed=True)['estimate'].sum().reindex(target_order).fillna(0)
                grouped_var = group_df.groupby('target', observed=True)['var'].sum().reindex(target_order).fillna(0)

                cum_prop = grouped.cumsum().clip(0, 1)
                cum_se = np.sqrt(grouped_var.cumsum())

                cum_lower = (cum_prop - z_score * cum_se).clip(0, 1)
                cum_upper = (cum_prop + z_score * cum_se).clip(0, 1)

                color = colors[i % len(colors)]

                # Draw the translucent confidence ribbon
                fig.add_trace(go.Scatter(
                    x=target_order + target_order[::-1],
                    y=cum_upper.tolist() + cum_lower.tolist()[::-1],
                    fill='toself',
                    fillcolor=color,
                    opacity=0.2,
                    line=dict(color='rgba(255,255,255,0)'),
                    hoverinfo="skip",
                    showlegend=False,
                    legendgroup=str(stratum)
                ))

                fig.add_trace(go.Scatter(
                    x=target_order, y=cum_prop,
                    mode='lines+markers',
                    name=str(stratum),
                    line=dict(color=color, width=2),
                    marker=dict(size=6, color=color),
                    customdata=np.column_stack((cum_lower.values, cum_upper.values)),
                    hovertemplate=f"<b>%{{x}}</b><br>{cls._clean_label(color_col)}: {stratum}<br>Cumulative Coverage: %{{y:.1%}}<br>95% CI: %{{customdata[0]:.1%}} - %{{customdata[1]:.1%}}<extra></extra>",
                    legendgroup=str(stratum)
                ))

        return cls.apply_theme(fig.update_layout(
            title=f"<b>Cumulative Coverage</b><br><sup>Targeting top {len(target_order)} {cls._clean_label(res.trait)} variants | {strata_label}</sup>",
            xaxis=dict(title="Variant added to formulation", tickangle=45),
            yaxis=dict(title='Cumulative Population Coverage', tickformat='.0%', range=[0, 1.05]),
            hovermode="x unified"
        ))

NetworkPlotter

Bases: BasePlotter

Source code in src/seroepi/plotting.py
class NetworkPlotter(BasePlotter):
    SUPPORTED_TYPES = (DistancesBase,)

    @staticmethod
    def _build_edges(rows: np.ndarray, cols: np.ndarray, pos: np.ndarray) -> tuple[list, list]:
        """Helper to cleanly vectorize the line-break generation for Plotly networks."""
        if len(rows) == 0:
            return [], []
        ex = np.full(len(rows) * 3, None, dtype=object)
        ex[0::3], ex[1::3] = pos[rows, 0], pos[cols, 0]

        ey = np.full(len(rows) * 3, None, dtype=object)
        ey[0::3], ey[1::3] = pos[rows, 1], pos[cols, 1]
        return ex.tolist(), ey.tolist()

    @classmethod
    def render(cls, result: DistancesBase, df: pd.DataFrame = None, pos: np.ndarray = None,
               edge_type: str = 'snp', threshold: int = 20,
               color_col: str = None, trans_network: DistancesBase = None, **kwargs) -> go.Figure:
        if not cls.can_render(result):
            raise TypeError(f"{cls.__name__} does not support {type(result).__name__}.")
        """
        Plots an interactive force-directed network using MDS coordinates.
        """
        dense_dist = result.matrix.toarray().astype(float)

        # Fallback to calculate MDS if it wasn't passed in via a cache
        if pos is None:
            pos = result.layout()

        # Align the dataframe rows to perfectly match the distance matrix index
        if df is not None and 'sample_id' in df.columns:
            df_aligned = df.set_index('sample_id').reindex(result.index)
        else:
            df_aligned = pd.DataFrame(index=result.index)

        edge_x, edge_y = [], []
        title = "<b>Isolate Network</b><br><sup>Nodes positioned by distance layout (MDS)</sup>"

        if edge_type == "snp" and getattr(result, 'metric_type', None) in [DistanceMetricType.ABSOLUTE_DISTANCE, DistanceMetricType.RELATIVE_DISTANCE]:
            # OPTIMIZATION: Use numpy's upper triangle (k=1) to prevent drawing self-loops or duplicate edges
            adj = (dense_dist <= threshold)
            rows, cols = np.where(np.triu(adj, k=1))

            edge_x, edge_y = cls._build_edges(rows, cols, pos)

            title = f"<b>Genomic SNP Network</b><br><sup>Edges connect isolates ≤ {threshold} SNPs apart</sup>"

        elif edge_type == "trans":
            net = trans_network if trans_network is not None else result
            if getattr(net, 'metric_type', None) in [DistanceMetricType.ABSOLUTE_SIMILARITY, DistanceMetricType.RELATIVE_SIMILARITY]:
                from scipy.sparse import triu
                # OPTIMIZATION: Query the sparse matrix directly without blowing it up into a dense memory hog
                upper_adj = triu(net.matrix, k=1).tocoo()
                edge_x, edge_y = cls._build_edges(upper_adj.row, upper_adj.col, pos)

                title = f"<b>Transmission Network</b><br><sup>Edges connect isolates based on spatial/temporal proximity</sup>"

        edge_trace = go.Scatter(x=edge_x, y=edge_y, line=dict(width=0.4, color='#888'), hoverinfo='none', mode='lines')

        if color_col and color_col in df_aligned.columns:
            color_series = df_aligned[color_col]
            # Safely add "Unknown" to the approved categories to maintain memory efficiency
            if isinstance(color_series.dtype, pd.CategoricalDtype) and "Unknown" not in color_series.cat.categories:
                color_series = color_series.cat.add_categories("Unknown")

            color_vals = color_series.fillna("Unknown").astype(str)
            color_map = {c: i for i, c in enumerate(color_vals.unique())}
            colors = [color_map[c] for c in color_vals]
            hover_text = [f"ID: {idx}<br>{color_col}: {c}" for idx, c in zip(result.index, color_vals)]
            marker_dict = dict(showscale=False, color=colors, colorscale='Turbo', size=10, line=dict(width=1, color='white'))
        else:
            hover_text = [f"ID: {idx}" for idx in result.index]
            marker_dict = dict(color=cls._MAIN_COLOUR, size=10, line=dict(width=1, color='white'))

        node_trace = go.Scatter(
            x=pos[:, 0], y=pos[:, 1], mode='markers',
            hovertext=hover_text, hoverinfo="text", marker=marker_dict
        )

        return cls.apply_theme(go.Figure(data=[edge_trace, node_trace]).update_layout(
            title=title,
            showlegend=False, hovermode='closest',
            xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
            yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
            margin=dict(b=20, l=5, r=5, t=60)
        ))

render_plot

render_plot(result_obj: Any, plot_type: PlotType, **kwargs) -> go.Figure

A central router that invokes the correct plotter for the desired plot type.

Source code in src/seroepi/plotting.py
def render_plot(result_obj: Any, plot_type: PlotType, **kwargs) -> go.Figure:
    """
    A central router that invokes the correct plotter for the desired plot type.
    """
    if plotter := _PLOTTER_MAP.get(plot_type, None):
        return plotter.render(result_obj, **kwargs)
    available = list(_PLOTTER_MAP.keys())
    raise ValueError(f"Plot type '{plot_type}' is not registered. Available: {available}")

seroepi.io

Module for genotype file I/O and parsing.

BaseGenotypeParser

Base class for standardizing external datasets.

Subclasses must define column mappings and category definitions for specific input formats (e.g., Kleborate output).

Source code in src/seroepi/io.py
class BaseGenotypeParser:
    """
    Base class for standardizing external datasets.

    Subclasses must define column mappings and category definitions for specific
    input formats (e.g., Kleborate output).
    """

    # Subclasses define how to map raw columns to UnifiedIsolateSchema columns
    column_map: dict[str, str] = {}
    qc_cols: list[str] = []
    vir_cols: list[str] = []
    amr_cols: list[str] = []
    geno_cols: list[str] = []
    pheno_cols: list[str] = []

    @classmethod
    def get_parser(cls, flavour: Union[str, GenotypeFlavour]):
        flavour_val = flavour.value if isinstance(flavour, GenotypeFlavour) else flavour
        if flavour_val == "pathogenwatch-kleborate":
            return PathogenwatchKleborateParser
        return cls

    @staticmethod
    def _clean_mixed_dates(df: pd.DataFrame, date_col: str, res_col: str) -> pd.DataFrame:
        """
        Standardizes mixed-format dates (YYYY, YYYY-MM, YYYY-MM-DD).

        Args:
            df: The DataFrame to clean.
            date_col: The target prefixed temporal column name.
            res_col: The target prefixed resolution column name.

        Returns:
            The DataFrame with a standardized temporal column and resolution parallel.
        """
        if date_col not in df.columns:
            return df
        s = df[date_col].copy()
        s = s.astype(str).str.strip()
        s = s.replace(['nan', '<NA>', 'None', ''], pd.NA)
        lengths = s.str.len()

        df[res_col] = TemporalResolution.UNKNOWN.value
        df.loc[lengths == 4, res_col] = TemporalResolution.YEAR.value
        df.loc[(lengths >= 6) & (lengths <= 7) & s.str.contains(r'-|/', na=False), res_col] = TemporalResolution.MONTH.value
        df.loc[lengths >= 8, res_col] = TemporalResolution.DAY.value

        is_year = df[res_col] == TemporalResolution.YEAR.value
        is_month = df[res_col] == TemporalResolution.MONTH.value

        s.loc[is_year] = s.loc[is_year] + '-07-02'
        s.loc[is_month] = s.loc[is_month] + '-15'

        df[date_col] = pd.to_datetime(s, errors='coerce', format='mixed')
        df.loc[df[date_col].isna(), res_col] = TemporalResolution.UNKNOWN.value
        return df

    @classmethod
    def _ingest_user_metadata(cls, meta_df: pd.DataFrame, id_col: str,
                              date_col: str = None, date_res: str = None, spatial_col: str = None,
                              spatial_res: str = None, lat_col: str = None, lon_col: str = None) -> pd.DataFrame:
        """
        Standardizes user-uploaded metadata.

        Args:
            meta_df: The metadata DataFrame.
            id_col: Column name for sample IDs.
            date_col: Column name for isolation dates.
            date_res: User-specified temporal resolution.
            spatial_col: Column name for the primary geographic level.
            spatial_res: User-specified spatial resolution.
            lat_col: Column name for latitudes.
            lon_col: Column name for longitudes.

        Returns:
            A cleaned metadata DataFrame with prefixed user columns.
        """
        df = meta_df.copy()
        rename_map = {id_col: 'sample_id'}
        if lat_col: rename_map[lat_col] = 'latitude'
        if lon_col: rename_map[lon_col] = 'longitude'

        df = df.rename(columns=rename_map)

        if 'latitude' in df.columns:
            df['latitude'] = df['latitude'].astype("Float64")
        if 'longitude' in df.columns:
            df['longitude'] = df['longitude'].astype("Float64")

        # Safely prefix the temporal data while keeping the user's name intact
        if date_col and date_col in df.columns:
            new_date_col = f"{Domain.TEMPORAL.value}_{date_col}"
            res_col = f"{Domain.TEMPORAL_RES.value}_{date_col}"
            df = df.rename(columns={date_col: new_date_col})
            df = cls._clean_mixed_dates(df, date_col=new_date_col, res_col=res_col)
            if date_res and date_res != TemporalResolution.UNKNOWN.value:
                df[res_col] = date_res

        # Safely prefix the spatial data while keeping the user's name intact
        if spatial_col and spatial_col in df.columns:
            new_spatial_col = f"{Domain.SPATIAL.value}_{spatial_col}"
            res_col = f"{Domain.SPATIAL_RES.value}_{spatial_col}"
            df = df.rename(columns={spatial_col: new_spatial_col})
            if spatial_res and spatial_res != SpatialResolution.UNKNOWN.value:
                df[res_col] = spatial_res
            else:
                df[res_col] = SpatialResolution.UNKNOWN.value

        core_prefixes = (f"{Domain.TEMPORAL.value}_", f"{Domain.TEMPORAL_RES.value}_", 
                         f"{Domain.SPATIAL.value}_", f"{Domain.SPATIAL_RES.value}_")
        new_names = {col: f"meta_{col}" for col in df.columns if col not in ['sample_id', 'latitude', 'longitude'] and not col.startswith(core_prefixes)}
        return df.rename(columns=new_names)

    @staticmethod
    def _optimize_categorical_dtypes(df: pd.DataFrame, threshold: float = 0.5) -> pd.DataFrame:
        """
        Converts string columns to categorical dtypes if cardinality is low.

        Args:
            df: The DataFrame to optimize.
            threshold: The ratio of unique values to total rows below which
                a column is converted to categorical. Defaults to 0.5.

        Returns:
            The optimized DataFrame.
        """
        df_opt = df.copy()
        total_rows = len(df_opt)
        if total_rows == 0:
            return df_opt

        string_cols = df_opt.select_dtypes(include=['object', 'string', 'category']).columns
        for col in string_cols:
            is_cat = isinstance(df_opt[col].dtype, pd.CategoricalDtype)
            if is_cat:
                if not df_opt[col].cat.ordered:
                    df_opt[col] = df_opt[col].cat.as_ordered()
            else:
                num_unique = df_opt[col].nunique(dropna=False)
                if (num_unique / total_rows) < threshold and num_unique < total_rows:
                    df_opt[col] = df_opt[col].astype('category').cat.as_ordered()
        return df_opt

    @staticmethod
    def _optimize_binary_dtypes(df: pd.DataFrame) -> pd.DataFrame:
        """
        Converts numeric columns containing only 0 and 1 to Int8.

        Args:
            df: The DataFrame to optimize.

        Returns:
            The optimized DataFrame.
        """
        df_opt = df.copy()
        for col in df_opt.columns:
            if pd.api.types.is_numeric_dtype(df_opt[col]):
                unique_vals = df_opt[col].dropna().unique()
                if set(unique_vals).issubset({0, 1, 0.0, 1.0}):
                    df_opt[col] = df_opt[col].astype('Int8')
        return df_opt

    @classmethod
    def from_files(
            cls,
            genotype_path: Union[str, Path],
            meta_path: Optional[Union[str, Path]] = None,
            meta_kwargs: dict = None,
            dataset_name: str = "Unknown Dataset"
    ) -> pd.DataFrame:
        """
        Convenience factory to read CSV files and parse them.

        Args:
            genotype_path: Path to the raw genotype CSV.
            meta_path: Optional path to the metadata CSV.
            meta_kwargs: Arguments for metadata ingestion.
            dataset_name: Name to tag the resulting dataset with.
        """
        genotype_df = pd.read_csv(genotype_path, engine="pyarrow")
        meta_df = None

        if meta_path is not None:
            meta_df = pd.read_csv(meta_path, engine="pyarrow")

        return cls.parse(genotype_df, meta_df=meta_df, meta_kwargs=meta_kwargs, dataset_name=dataset_name)

    @classmethod
    def from_records(
            cls,
            records: list[dict],
            meta_df: Optional[pd.DataFrame] = None,
            meta_kwargs: dict = None,
            sep: str = '/',
            dataset_name: str = "Unknown Dataset"
    ) -> pd.DataFrame:
        """
        Convenience factory to read a list of nested dictionaries (e.g., from an API) and parse them.

        Args:
            records: List of nested dictionaries.
            meta_df: Optional metadata DataFrame to merge.
            meta_kwargs: Arguments for metadata ingestion.
            sep: Separator for flattening nested JSON keys. Defaults to '/'.
            dataset_name: Name to tag the resulting dataset with.
        """
        genotype_df = pd.json_normalize(records, sep=sep)
        return cls.parse(genotype_df, meta_df=meta_df, meta_kwargs=meta_kwargs, dataset_name=dataset_name)

    @classmethod
    def parse(
            cls,
            genotype_df: pd.DataFrame,
            meta_df: Optional[pd.DataFrame] = None,
            meta_kwargs: dict = None,
            dataset_name: str = "Unknown Dataset"
    ) -> pd.DataFrame:
        """
        Parses and validates a genotype dataset, optionally merging with metadata.

        Args:
            genotype_df: The raw genotype DataFrame.
            meta_df: Optional metadata DataFrame to merge.
            meta_kwargs: Arguments for metadata ingestion (e.g., column names).
            dataset_name: Name to tag the resulting dataset with.

        Returns:
            A validated DataFrame conforming to UnifiedIsolateSchema.
        """
        attrs = genotype_df.attrs
        df = genotype_df.copy()
        df = df.rename(columns=cls.column_map)
        new_names = {}
        for col in df.columns:
            if col in cls.qc_cols:
                new_names[col] = f"{Domain.QC.value}_{col}"
            elif col in cls.vir_cols:
                new_names[col] = f"{Domain.VIRULENCE.value}_{col}"
            elif col in cls.amr_cols:
                new_names[col] = f"{Domain.AMR.value}_{col}"
            elif col in cls.geno_cols:
                new_names[col] = f"{Domain.GENOTYPE.value}_{col}"
            elif col in cls.pheno_cols:
                new_names[col] = f"{Domain.PHENOTYPE.value}_{col}"

        df = df.rename(columns=new_names)


        if meta_df is not None:
            kwargs = meta_kwargs or {}
            # Use the namespaced schema method
            clean_meta = cls._ingest_user_metadata(meta_df, **kwargs)
            overlap_cols = [col for col in clean_meta.columns
                            if col in df.columns and col != 'sample_id']
            df = pd.merge(df.drop(columns=overlap_cols), clean_meta,
                          on='sample_id', how='left')

        # 1. PANDERA VALIDATION FIRST
        # This ensures schema coercions (like Float64, datetime) are safely applied
        # before we aggressively downcast memory with custom categories.
        valid_df = UnifiedIsolateSchema.validate(df)

        # Guarantee GPS columns exist post-validation so accessors can safely impute
        # (This bypasses the pd.NA evaluation crash during Pandera's ge/le checks)
        if 'latitude' not in valid_df.columns:
            valid_df['latitude'] = pd.NA
        valid_df['latitude'] = valid_df['latitude'].astype("Float64")

        if 'longitude' not in valid_df.columns:
            valid_df['longitude'] = pd.NA
        valid_df['longitude'] = valid_df['longitude'].astype("Float64")

        # AUTOMATIC REVERSE GEOCODING
        # If the dataset has coordinates but lacks spatial regions, infer them!
        spatial_cols = valid_df.filter(regex=f"^{Domain.SPATIAL.value}_(?!res_)").columns.tolist()
        has_coords = valid_df['latitude'].notna().any() and valid_df['longitude'].notna().any()

        if not spatial_cols and has_coords:
            valid_df = valid_df.geo.reverse_geocode()

        # 2. CUSTOM OPTIMIZATIONS
        valid_df = valid_df.geo.standardize_and_impute()
        valid_df = cls._optimize_binary_dtypes(valid_df)
        valid_df = cls._optimize_categorical_dtypes(valid_df, threshold=0.4)

        # Pandera strips attrs on validation/copying, so attach them right at the end
        valid_df.attrs = attrs
        valid_df.attrs["dataset_name"] = dataset_name

        return valid_df

from_files classmethod

from_files(genotype_path: Union[str, Path], meta_path: Optional[Union[str, Path]] = None, meta_kwargs: dict = None, dataset_name: str = 'Unknown Dataset') -> pd.DataFrame

Convenience factory to read CSV files and parse them.

Parameters:

Name Type Description Default
genotype_path Union[str, Path]

Path to the raw genotype CSV.

required
meta_path Optional[Union[str, Path]]

Optional path to the metadata CSV.

None
meta_kwargs dict

Arguments for metadata ingestion.

None
dataset_name str

Name to tag the resulting dataset with.

'Unknown Dataset'
Source code in src/seroepi/io.py
@classmethod
def from_files(
        cls,
        genotype_path: Union[str, Path],
        meta_path: Optional[Union[str, Path]] = None,
        meta_kwargs: dict = None,
        dataset_name: str = "Unknown Dataset"
) -> pd.DataFrame:
    """
    Convenience factory to read CSV files and parse them.

    Args:
        genotype_path: Path to the raw genotype CSV.
        meta_path: Optional path to the metadata CSV.
        meta_kwargs: Arguments for metadata ingestion.
        dataset_name: Name to tag the resulting dataset with.
    """
    genotype_df = pd.read_csv(genotype_path, engine="pyarrow")
    meta_df = None

    if meta_path is not None:
        meta_df = pd.read_csv(meta_path, engine="pyarrow")

    return cls.parse(genotype_df, meta_df=meta_df, meta_kwargs=meta_kwargs, dataset_name=dataset_name)

from_records classmethod

from_records(records: list[dict], meta_df: Optional[DataFrame] = None, meta_kwargs: dict = None, sep: str = '/', dataset_name: str = 'Unknown Dataset') -> pd.DataFrame

Convenience factory to read a list of nested dictionaries (e.g., from an API) and parse them.

Parameters:

Name Type Description Default
records list[dict]

List of nested dictionaries.

required
meta_df Optional[DataFrame]

Optional metadata DataFrame to merge.

None
meta_kwargs dict

Arguments for metadata ingestion.

None
sep str

Separator for flattening nested JSON keys. Defaults to '/'.

'/'
dataset_name str

Name to tag the resulting dataset with.

'Unknown Dataset'
Source code in src/seroepi/io.py
@classmethod
def from_records(
        cls,
        records: list[dict],
        meta_df: Optional[pd.DataFrame] = None,
        meta_kwargs: dict = None,
        sep: str = '/',
        dataset_name: str = "Unknown Dataset"
) -> pd.DataFrame:
    """
    Convenience factory to read a list of nested dictionaries (e.g., from an API) and parse them.

    Args:
        records: List of nested dictionaries.
        meta_df: Optional metadata DataFrame to merge.
        meta_kwargs: Arguments for metadata ingestion.
        sep: Separator for flattening nested JSON keys. Defaults to '/'.
        dataset_name: Name to tag the resulting dataset with.
    """
    genotype_df = pd.json_normalize(records, sep=sep)
    return cls.parse(genotype_df, meta_df=meta_df, meta_kwargs=meta_kwargs, dataset_name=dataset_name)

parse classmethod

parse(genotype_df: DataFrame, meta_df: Optional[DataFrame] = None, meta_kwargs: dict = None, dataset_name: str = 'Unknown Dataset') -> pd.DataFrame

Parses and validates a genotype dataset, optionally merging with metadata.

Parameters:

Name Type Description Default
genotype_df DataFrame

The raw genotype DataFrame.

required
meta_df Optional[DataFrame]

Optional metadata DataFrame to merge.

None
meta_kwargs dict

Arguments for metadata ingestion (e.g., column names).

None
dataset_name str

Name to tag the resulting dataset with.

'Unknown Dataset'

Returns:

Type Description
DataFrame

A validated DataFrame conforming to UnifiedIsolateSchema.

Source code in src/seroepi/io.py
@classmethod
def parse(
        cls,
        genotype_df: pd.DataFrame,
        meta_df: Optional[pd.DataFrame] = None,
        meta_kwargs: dict = None,
        dataset_name: str = "Unknown Dataset"
) -> pd.DataFrame:
    """
    Parses and validates a genotype dataset, optionally merging with metadata.

    Args:
        genotype_df: The raw genotype DataFrame.
        meta_df: Optional metadata DataFrame to merge.
        meta_kwargs: Arguments for metadata ingestion (e.g., column names).
        dataset_name: Name to tag the resulting dataset with.

    Returns:
        A validated DataFrame conforming to UnifiedIsolateSchema.
    """
    attrs = genotype_df.attrs
    df = genotype_df.copy()
    df = df.rename(columns=cls.column_map)
    new_names = {}
    for col in df.columns:
        if col in cls.qc_cols:
            new_names[col] = f"{Domain.QC.value}_{col}"
        elif col in cls.vir_cols:
            new_names[col] = f"{Domain.VIRULENCE.value}_{col}"
        elif col in cls.amr_cols:
            new_names[col] = f"{Domain.AMR.value}_{col}"
        elif col in cls.geno_cols:
            new_names[col] = f"{Domain.GENOTYPE.value}_{col}"
        elif col in cls.pheno_cols:
            new_names[col] = f"{Domain.PHENOTYPE.value}_{col}"

    df = df.rename(columns=new_names)


    if meta_df is not None:
        kwargs = meta_kwargs or {}
        # Use the namespaced schema method
        clean_meta = cls._ingest_user_metadata(meta_df, **kwargs)
        overlap_cols = [col for col in clean_meta.columns
                        if col in df.columns and col != 'sample_id']
        df = pd.merge(df.drop(columns=overlap_cols), clean_meta,
                      on='sample_id', how='left')

    # 1. PANDERA VALIDATION FIRST
    # This ensures schema coercions (like Float64, datetime) are safely applied
    # before we aggressively downcast memory with custom categories.
    valid_df = UnifiedIsolateSchema.validate(df)

    # Guarantee GPS columns exist post-validation so accessors can safely impute
    # (This bypasses the pd.NA evaluation crash during Pandera's ge/le checks)
    if 'latitude' not in valid_df.columns:
        valid_df['latitude'] = pd.NA
    valid_df['latitude'] = valid_df['latitude'].astype("Float64")

    if 'longitude' not in valid_df.columns:
        valid_df['longitude'] = pd.NA
    valid_df['longitude'] = valid_df['longitude'].astype("Float64")

    # AUTOMATIC REVERSE GEOCODING
    # If the dataset has coordinates but lacks spatial regions, infer them!
    spatial_cols = valid_df.filter(regex=f"^{Domain.SPATIAL.value}_(?!res_)").columns.tolist()
    has_coords = valid_df['latitude'].notna().any() and valid_df['longitude'].notna().any()

    if not spatial_cols and has_coords:
        valid_df = valid_df.geo.reverse_geocode()

    # 2. CUSTOM OPTIMIZATIONS
    valid_df = valid_df.geo.standardize_and_impute()
    valid_df = cls._optimize_binary_dtypes(valid_df)
    valid_df = cls._optimize_categorical_dtypes(valid_df, threshold=0.4)

    # Pandera strips attrs on validation/copying, so attach them right at the end
    valid_df.attrs = attrs
    valid_df.attrs["dataset_name"] = dataset_name

    return valid_df

PathogenwatchKleborateParser

Bases: BaseGenotypeParser

Adapter for Kleborate files downloaded from Pathogenwatch.

This parser maps Kleborate's specific column names and categories to the UnifiedIsolateSchema.

Source code in src/seroepi/io.py
class PathogenwatchKleborateParser(BaseGenotypeParser):
    """
    Adapter for Kleborate files downloaded from Pathogenwatch.

    This parser maps Kleborate's specific column names and categories to the
    UnifiedIsolateSchema.
    """
    column_map = {
        'Genome Name': 'sample_id',
        'Latitude': 'latitude',
        'Longitude': 'longitude',
        'Country': f'{Domain.SPATIAL.value}_Country',
        'Region': f'{Domain.SPATIAL.value}_Region',
        'Continent': f'{Domain.SPATIAL.value}_Continent',
        'Collection Date': f'{Domain.TEMPORAL.value}_Collection_Date',
        'Year': f'{Domain.TEMPORAL.value}_Year',
        'Month': f'{Domain.TEMPORAL.value}_Month',
        'Day': f'{Domain.TEMPORAL.value}_Day'
    }
    geno_cols = ['ST', 'K_locus', 'O_locus']
    pheno_cols = ['K_type', 'O_type']
    qc_cols = ['species', 'species_match', 'contig_count', 'N50', 'largest_contig', 'total_size', 'ambiguous_bases',
               'QC_warnings', 'K_locus_confidence', 'O_locus_confidence']
    vir_cols = ['YbST', 'Yersiniabactin', 'CbST', 'Colibactin', 'AbST', 'Aerobactin', 'SmST', 'Salmochelin', 'RmST',
                'RmpADC', 'virulence_score', 'rmpA2']
    amr_cols = ['AGly_acquired', 'Col_acquired', 'Fcyn_acquired', 'Flq_acquired', 'Gly_acquired', 'MLS_acquired',
                'Phe_acquired', 'Rif_acquired', 'Sul_acquired', 'Tet_acquired', 'Tgc_acquired', 'Tmt_acquired',
                'Bla_acquired', 'Bla_inhR_acquired', 'Bla_ESBL_acquired', 'Bla_ESBL_inhR_acquired', 'Bla_Carb_acquired',
                'Bla_chr', 'SHV_mutations', 'Omp_mutations', 'Col_mutations', 'Flq_mutations', 'resistance_score',
                'num_resistance_classes', 'num_resistance_genes', 'Ciprofloxacin_prediction', 'Ciprofloxacin_profile',
                'Ciprofloxacin_MIC_prediction']

UnifiedIsolateSchema

Bases: DataFrameModel

Pandera schema for validating and standardizing isolate datasets.

This schema ensures that all input data, whether from Pathogenwatch or user uploads, conforms to a unified structure for downstream analysis.

Attributes:

Name Type Description
sample_id Series[string]

Unique identifier for each isolate.

latitude Optional[Series[Float64]]

Latitude coordinate (-90 to 90).

longitude Optional[Series[Float64]]

Longitude coordinate (-180 to 180).

qc_metrics Optional[Series[Float64]]

Dynamic columns for quality control (prefixed with 'qc_').

geno_traits Optional[Series[Float64]]

Dynamic columns for genotypes/alleles (prefixed with 'geno_').

pheno_traits Optional[Series[Float64]]

Dynamic columns for phenotypic traits (prefixed with 'pheno_').

amr_traits Optional[Series[Float64]]

Dynamic columns for AMR markers (prefixed with 'amr_').

vir_traits Optional[Series[Float64]]

Dynamic columns for virulence markers (prefixed with 'vir_').

temporal_cols Optional[Series[Float64]]

Dynamic columns for temporal data (prefixed with 'temporal_').

temporal_res_cols Optional[Series[Float64]]

Dynamic columns for temporal resolution (prefixed with 'temporal_res_').

spatial_cols Optional[Series[Float64]]

Dynamic columns for spatial data (prefixed with 'spatial_').

spatial_res_cols Optional[Series[Float64]]

Dynamic columns for spatial resolution (prefixed with 'spatial_res_').

user_metadata Optional[Series[Float64]]

Dynamic columns for user metadata (prefixed with 'meta_').

Examples:

>>> import pandas as pd
>>> df = pd.DataFrame({'sample_id': ['S1'], 'K_locus': ['KL1']})
>>> validated_df = UnifiedIsolateSchema.validate(df)
Source code in src/seroepi/io.py
class UnifiedIsolateSchema(pa.DataFrameModel):
    """
    Pandera schema for validating and standardizing isolate datasets.

    This schema ensures that all input data, whether from Pathogenwatch or user
    uploads, conforms to a unified structure for downstream analysis.

    Attributes:
        sample_id: Unique identifier for each isolate.
        latitude: Latitude coordinate (-90 to 90).
        longitude: Longitude coordinate (-180 to 180).
        qc_metrics: Dynamic columns for quality control (prefixed with 'qc_').
        geno_traits: Dynamic columns for genotypes/alleles (prefixed with 'geno_').
        pheno_traits: Dynamic columns for phenotypic traits (prefixed with 'pheno_').
        amr_traits: Dynamic columns for AMR markers (prefixed with 'amr_').
        vir_traits: Dynamic columns for virulence markers (prefixed with 'vir_').
        temporal_cols: Dynamic columns for temporal data (prefixed with 'temporal_').
        temporal_res_cols: Dynamic columns for temporal resolution (prefixed with 'temporal_res_').
        spatial_cols: Dynamic columns for spatial data (prefixed with 'spatial_').
        spatial_res_cols: Dynamic columns for spatial resolution (prefixed with 'spatial_res_').
        user_metadata: Dynamic columns for user metadata (prefixed with 'meta_').

    Examples:
        >>> import pandas as pd
        >>> df = pd.DataFrame({'sample_id': ['S1'], 'K_locus': ['KL1']})
        >>> validated_df = UnifiedIsolateSchema.validate(df)
    """
    # Core
    sample_id: Series["string"] = pa.Field(unique=True, coerce=True)

    # Raw GPS Fields
    latitude: Optional[Series["Float64"]] = pa.Field(ge=-90, le=90, nullable=True, coerce=True)
    longitude: Optional[Series["Float64"]] = pa.Field(ge=-180, le=180, nullable=True, coerce=True)

    class Config:
        strict = "filter"

    @classmethod
    def to_schema(cls) -> pa.DataFrameSchema:
        schema = super().to_schema()
        # Dynamically attach regex fields using the Object-based API to bypass the DataFrameModel bug
        # where Optional[] is ignored for regex fields.
        return schema.add_columns({
            f"^{Domain.QC.value}_.*$": pa.Column(regex=True, required=False, nullable=True),
            f"^{Domain.GENOTYPE.value}_.*$": pa.Column(regex=True, required=False, nullable=True),
            f"^{Domain.PHENOTYPE.value}_.*$": pa.Column(regex=True, required=False, nullable=True),
            f"^{Domain.AMR.value}_.*$": pa.Column(regex=True, required=False, nullable=True),
            f"^{Domain.VIRULENCE.value}_.*$": pa.Column(regex=True, required=False, nullable=True),
            f"^{Domain.TEMPORAL.value}_(?!res_).*$": pa.Column("datetime64[ns]", regex=True, required=False, nullable=True, coerce=True),
            f"^{Domain.TEMPORAL_RES.value}_.*$": pa.Column("category", regex=True, required=False, nullable=True, coerce=True, checks=pa.Check.isin(TemporalResolution.choices())),
            f"^{Domain.SPATIAL.value}_(?!res_).*$": pa.Column("string", regex=True, required=False, nullable=True, coerce=True),
            f"^{Domain.SPATIAL_RES.value}_.*$": pa.Column("category", regex=True, required=False, nullable=True, coerce=True, checks=pa.Check.isin(SpatialResolution.choices())),
            "^meta_.*$": pa.Column(regex=True, required=False, nullable=True)
        })

seroepi.dist

Module to handle genetic distance measures between isolates.

DistancesBase dataclass

Bases: ABC

Source code in src/seroepi/dist.py
@dataclass(frozen=True, slots=True)
class DistancesBase(ABC):
    matrix: csr_array
    index: pd.Series
    metric_type: DistanceMetricType
    max_value: float = None  # Required if converting between absolute and relative

    def __post_init__(self):
        """Validates the consistency of the distance matrix and labels."""
        # Airtight check 1: Dimensions must match
        if self.matrix.shape[0] != self.matrix.shape[1]:
            raise ValueError("Distance matrix must be square.")
        if len(self.index) != self.matrix.shape[0]:
            raise ValueError("Number of labels must match matrix dimensions.")
        self.index.name = 'sample_id'

    @abstractmethod
    def get_clusters(self, *args, **kwargs) -> pd.Series: ...

    def layout(self, random_state: int = 42, n_init: int = 1, max_iter: int = 100) -> np.ndarray:
        """
        Calculates a 2D layout for the distance matrix using Multi-Dimensional Scaling (MDS).

        Args:
            random_state: Seed for reproducibility. Defaults to 42.
            n_init: Number of initialization runs. Defaults to 1 for speed.
            max_iter: Maximum iterations. Defaults to 100 for speed.

        Returns:
            A numpy array of shape (n_samples, 2) containing the 2D coordinates.
        """
        dense_dist = self.matrix.toarray().astype(float)

        if self.metric_type in [DistanceMetricType.ABSOLUTE_SIMILARITY, DistanceMetricType.RELATIVE_SIMILARITY]:
            # Convert similarity to dissimilarity for MDS
            max_val = self.max_value if self.max_value is not None else 1.0
            dense_dist = max_val - dense_dist
            mask = (dense_dist == max_val) & (~np.eye(dense_dist.shape[0], dtype=bool))
            if mask.any():
                dense_dist[mask] = max_val * 2
        else:
            # If the matrix was sparse, 0s off the diagonal represent missing data. Fill with max distance.
            mask = (dense_dist == 0) & (~np.eye(dense_dist.shape[0], dtype=bool))
            if mask.any():
                dense_dist[mask] = dense_dist.max() * 2

        mds_kwargs = {
            'n_components': 2,
            'dissimilarity': 'precomputed',
            'random_state': random_state,
            'n_init': n_init,
            'max_iter': max_iter
        }

        # Dynamically suppress scikit-learn API FutureWarnings
        sig = inspect.signature(MDS.__init__)
        if 'normalized_stress' in sig.parameters: mds_kwargs['normalized_stress'] = 'auto'
        if 'init' in sig.parameters: mds_kwargs['init'] = 'random'

        mds = MDS(**mds_kwargs)
        return mds.fit_transform(dense_dist)

__post_init__

__post_init__()

Validates the consistency of the distance matrix and labels.

Source code in src/seroepi/dist.py
def __post_init__(self):
    """Validates the consistency of the distance matrix and labels."""
    # Airtight check 1: Dimensions must match
    if self.matrix.shape[0] != self.matrix.shape[1]:
        raise ValueError("Distance matrix must be square.")
    if len(self.index) != self.matrix.shape[0]:
        raise ValueError("Number of labels must match matrix dimensions.")
    self.index.name = 'sample_id'

layout

layout(random_state: int = 42, n_init: int = 1, max_iter: int = 100) -> np.ndarray

Calculates a 2D layout for the distance matrix using Multi-Dimensional Scaling (MDS).

Parameters:

Name Type Description Default
random_state int

Seed for reproducibility. Defaults to 42.

42
n_init int

Number of initialization runs. Defaults to 1 for speed.

1
max_iter int

Maximum iterations. Defaults to 100 for speed.

100

Returns:

Type Description
ndarray

A numpy array of shape (n_samples, 2) containing the 2D coordinates.

Source code in src/seroepi/dist.py
def layout(self, random_state: int = 42, n_init: int = 1, max_iter: int = 100) -> np.ndarray:
    """
    Calculates a 2D layout for the distance matrix using Multi-Dimensional Scaling (MDS).

    Args:
        random_state: Seed for reproducibility. Defaults to 42.
        n_init: Number of initialization runs. Defaults to 1 for speed.
        max_iter: Maximum iterations. Defaults to 100 for speed.

    Returns:
        A numpy array of shape (n_samples, 2) containing the 2D coordinates.
    """
    dense_dist = self.matrix.toarray().astype(float)

    if self.metric_type in [DistanceMetricType.ABSOLUTE_SIMILARITY, DistanceMetricType.RELATIVE_SIMILARITY]:
        # Convert similarity to dissimilarity for MDS
        max_val = self.max_value if self.max_value is not None else 1.0
        dense_dist = max_val - dense_dist
        mask = (dense_dist == max_val) & (~np.eye(dense_dist.shape[0], dtype=bool))
        if mask.any():
            dense_dist[mask] = max_val * 2
    else:
        # If the matrix was sparse, 0s off the diagonal represent missing data. Fill with max distance.
        mask = (dense_dist == 0) & (~np.eye(dense_dist.shape[0], dtype=bool))
        if mask.any():
            dense_dist[mask] = dense_dist.max() * 2

    mds_kwargs = {
        'n_components': 2,
        'dissimilarity': 'precomputed',
        'random_state': random_state,
        'n_init': n_init,
        'max_iter': max_iter
    }

    # Dynamically suppress scikit-learn API FutureWarnings
    sig = inspect.signature(MDS.__init__)
    if 'normalized_stress' in sig.parameters: mds_kwargs['normalized_stress'] = 'auto'
    if 'init' in sig.parameters: mds_kwargs['init'] = 'random'

    mds = MDS(**mds_kwargs)
    return mds.fit_transform(dense_dist)

GenomicDistances dataclass

Bases: DistancesBase

Source code in src/seroepi/dist.py
@dataclass(frozen=True, slots=True)
class GenomicDistances(DistancesBase):

    @classmethod
    def from_file(cls, filepath_or_buffer: Union[str, Path], flavour: Union[str, DistanceFlavour]) -> 'GenomicDistances':
        """
        Factory method to parse a distance matrix or tree from a file based on flavour.
        """
        flavour_val = flavour.value if isinstance(flavour, DistanceFlavour) else flavour
        if flavour_val == DistanceFlavour.PATHOGENWATCH.value:
            return cls.from_pathogenwatch(filepath_or_buffer)
        elif flavour_val == DistanceFlavour.SKA2.value:
            return cls.from_ska2(filepath_or_buffer)
        elif flavour_val == DistanceFlavour.NEWICK.value:
            with open(filepath_or_buffer, 'r') as f:
                newick_string = f.read()
            return cls.from_newick(newick_string)
        else:
            raise ValueError(f"Unknown distance flavour: {flavour_val}")

    @classmethod
    def from_pairwise(cls, query_col: pd.Series, target_col: pd.Series, weight_col: pd.Series,
                      metric_type: DistanceMetricType = DistanceMetricType.ABSOLUTE_DISTANCE) -> 'GenomicDistances':
        """
        Creates a Distances instance from long-format pairwise data.

        Args:
            query_col: Series containing the first isolate IDs.
            target_col: Series containing the second isolate IDs.
            weight_col: Series containing the distances/similarities.
            metric_type: The type of metric provided. Defaults to ABSOLUTE_DISTANCE.

        Returns:
            A new Distances instance.
        """
        # 1. Drop to pure NumPy for speed
        q_vals = query_col.values
        t_vals = target_col.values

        # 2. Use pd.factorize for highly optimized O(N) string hashing 
        # (significantly faster than np.unique's O(N log N) sorting approach)
        codes, uids = pd.factorize(np.concatenate([q_vals, t_vals]))

        # 4. Split the mapped indices back into rows and columns
        half = len(q_vals)
        rows = codes[:half]
        cols = codes[half:]
        # 5. Build the matrix
        n = len(uids)
        M = coo_array((weight_col.values, (rows, cols)), shape=(n, n))
        # 6. Symmetrize and ensure it returns a CSR matrix
        return cls(M.maximum(M.T).tocsr(), pd.Series(uids), metric_type)

    @classmethod
    def from_ska2(cls, filepath_or_buffer) -> 'GenomicDistances':
        """
        Parses a pairwise distance matrix from SKA2 output.

        Args:
            filepath_or_buffer: Path to the SKA2 distance file.

        Returns:
            A new Distances instance.
        """
        df = pd.read_table(filepath_or_buffer, usecols=(0, 1, 2))
        return cls.from_pairwise(df.iloc[:, 0], df.iloc[:, 1], df.iloc[:, 2])

    @classmethod
    def from_pathogenwatch(cls, filepath_or_buffer) -> 'GenomicDistances':
        """
        Parses a square distance matrix from Pathogenwatch.

        Args:
            filepath_or_buffer: Path to the Pathogenwatch CSV file.

        Returns:
            A new Distances instance.
        """
        df = pd.read_csv(filepath_or_buffer, index_col=0)
        M = coo_array(df.values)
        M = M.maximum(M.T)
        return cls(M.tocsr(), pd.Series(df.columns), DistanceMetricType.ABSOLUTE_DISTANCE)

    @classmethod
    def from_newick(cls, newick_string: str) -> 'GenomicDistances':
        """
        Parses a Newick string and calculates patristic distances.

        Requires Biopython (`pip install biopython`).

        Args:
            newick_string: The Newick tree string.

        Returns:
            A new Distances instance.

        Raises:
            ImportError: If biopython is not installed.
        """
        try:
            from Bio import Phylo
        except ImportError:
            raise ImportError("biopython is required to calculate patristics. Install with seroepi[dev]")

        # 1. Parse the tree
        tree = Phylo.read(StringIO(newick_string), "newick")

        # 2. Extract terminals (leaves)
        terminals = tree.get_terminals()
        labels = [leaf.name for leaf in terminals]
        n = len(terminals)

        # 3. Populate a dense NumPy array (symmetric distance calculation)
        matrix = np.zeros((n, n), dtype=np.float64)
        for i in range(n):
            for j in range(i + 1, n):
                dist = tree.distance(terminals[i], terminals[j])
                matrix[i, j] = dist
                matrix[j, i] = dist

                # 4. Convert to CSR array and return as a Distances instance
        return cls(
            matrix=csr_array(matrix),
            index=pd.Series(labels),
            metric_type=DistanceMetricType.ABSOLUTE_DISTANCE
        )

    def get_clusters(self, threshold: int = 20) -> pd.Series:
        """
        Identifies clusters via connected components based on a distance threshold.

        Args:
            threshold: Maximum distance to consider isolates as connected.
                Defaults to 20 (e.g., 20 SNPs).

        Returns:
            A Series of cluster labels indexed by isolate IDs.
        """
        adj = self.matrix.copy()  # Make a copy of the CSR array to avoid mutating the frozen original
        # Convert to a binary adjacency array:
        # If distance <= threshold, make it a 1 (Valid Edge).
        # If distance > threshold, make it a 0 (Severed Edge).
        adj.data = (adj.data <= threshold).astype(np.int8)
        # Safely eliminate the 0s (which are now only the severed edges).
        # Identical clones are safe because their distance of 0 was turned into a 1!
        adj.eliminate_zeros()
        adj.setdiag(1)  # Ensure every sample is connected to itself on the diagonal
        _, labels = sp_connected_components(csgraph=adj, directed=False, return_labels=True)
        return pd.Series(labels, index=self.index, dtype='category', name=f"connected_components_{threshold=}").cat.as_ordered()

    def to_type(self, target_type: DistanceMetricType) -> 'GenomicDistances':
        """
        Converts the distances to a different metric type.

        Args:
            target_type: The desired target MetricType.

        Returns:
            A new Distances instance with the converted matrix.

        Raises:
            ValueError: If conversion requires `max_value` but it is not set.
        """
        if self.metric_type == target_type:
            return self

        # If crossing the Absolute <-> Relative boundary, we need max_value
        needs_max = {DistanceMetricType.ABSOLUTE_DISTANCE, DistanceMetricType.ABSOLUTE_SIMILARITY}
        targets_norm = {DistanceMetricType.RELATIVE_DISTANCE, DistanceMetricType.RELATIVE_SIMILARITY}

        if (self.metric_type in needs_max and target_type in targets_norm) or \
                (self.metric_type in targets_norm and target_type in needs_max):
            if self.max_value is None:
                raise ValueError(f"Cannot convert between Absolute and Relative without a max_value.")

        # Standardize to Relative Distance first (as a base state)
        if self.metric_type == DistanceMetricType.ABSOLUTE_DISTANCE:
            base_mat = self.matrix / self.max_value
        elif self.metric_type == DistanceMetricType.ABSOLUTE_SIMILARITY:
            base_mat = 1.0 - (self.matrix / self.max_value)
        elif self.metric_type == DistanceMetricType.RELATIVE_SIMILARITY:
            base_mat = 1.0 - self.matrix
        else:
            base_mat = self.matrix

        # 2. Convert from base state (Relative Distance) to Target
        if target_type == DistanceMetricType.RELATIVE_DISTANCE:
            new_mat = base_mat
        elif target_type == DistanceMetricType.RELATIVE_SIMILARITY:
            new_mat = 1.0 - base_mat
        elif target_type == DistanceMetricType.ABSOLUTE_DISTANCE:
            new_mat = base_mat * self.max_value
        elif target_type == DistanceMetricType.ABSOLUTE_SIMILARITY:
            new_mat = (1.0 - base_mat) * self.max_value

        # Explicitly cast back to CSR to prevent dense matrix bleed from scalar subtraction
        if not isinstance(new_mat, csr_array):
            new_mat = csr_array(new_mat)

        # Return a new frozen instance
        return replace(self, matrix=new_mat, metric_type=target_type)

from_file classmethod

from_file(filepath_or_buffer: Union[str, Path], flavour: Union[str, DistanceFlavour]) -> GenomicDistances

Factory method to parse a distance matrix or tree from a file based on flavour.

Source code in src/seroepi/dist.py
@classmethod
def from_file(cls, filepath_or_buffer: Union[str, Path], flavour: Union[str, DistanceFlavour]) -> 'GenomicDistances':
    """
    Factory method to parse a distance matrix or tree from a file based on flavour.
    """
    flavour_val = flavour.value if isinstance(flavour, DistanceFlavour) else flavour
    if flavour_val == DistanceFlavour.PATHOGENWATCH.value:
        return cls.from_pathogenwatch(filepath_or_buffer)
    elif flavour_val == DistanceFlavour.SKA2.value:
        return cls.from_ska2(filepath_or_buffer)
    elif flavour_val == DistanceFlavour.NEWICK.value:
        with open(filepath_or_buffer, 'r') as f:
            newick_string = f.read()
        return cls.from_newick(newick_string)
    else:
        raise ValueError(f"Unknown distance flavour: {flavour_val}")

from_newick classmethod

from_newick(newick_string: str) -> GenomicDistances

Parses a Newick string and calculates patristic distances.

Requires Biopython (pip install biopython).

Parameters:

Name Type Description Default
newick_string str

The Newick tree string.

required

Returns:

Type Description
GenomicDistances

A new Distances instance.

Raises:

Type Description
ImportError

If biopython is not installed.

Source code in src/seroepi/dist.py
@classmethod
def from_newick(cls, newick_string: str) -> 'GenomicDistances':
    """
    Parses a Newick string and calculates patristic distances.

    Requires Biopython (`pip install biopython`).

    Args:
        newick_string: The Newick tree string.

    Returns:
        A new Distances instance.

    Raises:
        ImportError: If biopython is not installed.
    """
    try:
        from Bio import Phylo
    except ImportError:
        raise ImportError("biopython is required to calculate patristics. Install with seroepi[dev]")

    # 1. Parse the tree
    tree = Phylo.read(StringIO(newick_string), "newick")

    # 2. Extract terminals (leaves)
    terminals = tree.get_terminals()
    labels = [leaf.name for leaf in terminals]
    n = len(terminals)

    # 3. Populate a dense NumPy array (symmetric distance calculation)
    matrix = np.zeros((n, n), dtype=np.float64)
    for i in range(n):
        for j in range(i + 1, n):
            dist = tree.distance(terminals[i], terminals[j])
            matrix[i, j] = dist
            matrix[j, i] = dist

            # 4. Convert to CSR array and return as a Distances instance
    return cls(
        matrix=csr_array(matrix),
        index=pd.Series(labels),
        metric_type=DistanceMetricType.ABSOLUTE_DISTANCE
    )

from_pairwise classmethod

from_pairwise(query_col: Series, target_col: Series, weight_col: Series, metric_type: DistanceMetricType = DistanceMetricType.ABSOLUTE_DISTANCE) -> GenomicDistances

Creates a Distances instance from long-format pairwise data.

Parameters:

Name Type Description Default
query_col Series

Series containing the first isolate IDs.

required
target_col Series

Series containing the second isolate IDs.

required
weight_col Series

Series containing the distances/similarities.

required
metric_type DistanceMetricType

The type of metric provided. Defaults to ABSOLUTE_DISTANCE.

ABSOLUTE_DISTANCE

Returns:

Type Description
GenomicDistances

A new Distances instance.

Source code in src/seroepi/dist.py
@classmethod
def from_pairwise(cls, query_col: pd.Series, target_col: pd.Series, weight_col: pd.Series,
                  metric_type: DistanceMetricType = DistanceMetricType.ABSOLUTE_DISTANCE) -> 'GenomicDistances':
    """
    Creates a Distances instance from long-format pairwise data.

    Args:
        query_col: Series containing the first isolate IDs.
        target_col: Series containing the second isolate IDs.
        weight_col: Series containing the distances/similarities.
        metric_type: The type of metric provided. Defaults to ABSOLUTE_DISTANCE.

    Returns:
        A new Distances instance.
    """
    # 1. Drop to pure NumPy for speed
    q_vals = query_col.values
    t_vals = target_col.values

    # 2. Use pd.factorize for highly optimized O(N) string hashing 
    # (significantly faster than np.unique's O(N log N) sorting approach)
    codes, uids = pd.factorize(np.concatenate([q_vals, t_vals]))

    # 4. Split the mapped indices back into rows and columns
    half = len(q_vals)
    rows = codes[:half]
    cols = codes[half:]
    # 5. Build the matrix
    n = len(uids)
    M = coo_array((weight_col.values, (rows, cols)), shape=(n, n))
    # 6. Symmetrize and ensure it returns a CSR matrix
    return cls(M.maximum(M.T).tocsr(), pd.Series(uids), metric_type)

from_pathogenwatch classmethod

from_pathogenwatch(filepath_or_buffer) -> GenomicDistances

Parses a square distance matrix from Pathogenwatch.

Parameters:

Name Type Description Default
filepath_or_buffer

Path to the Pathogenwatch CSV file.

required

Returns:

Type Description
GenomicDistances

A new Distances instance.

Source code in src/seroepi/dist.py
@classmethod
def from_pathogenwatch(cls, filepath_or_buffer) -> 'GenomicDistances':
    """
    Parses a square distance matrix from Pathogenwatch.

    Args:
        filepath_or_buffer: Path to the Pathogenwatch CSV file.

    Returns:
        A new Distances instance.
    """
    df = pd.read_csv(filepath_or_buffer, index_col=0)
    M = coo_array(df.values)
    M = M.maximum(M.T)
    return cls(M.tocsr(), pd.Series(df.columns), DistanceMetricType.ABSOLUTE_DISTANCE)

from_ska2 classmethod

from_ska2(filepath_or_buffer) -> GenomicDistances

Parses a pairwise distance matrix from SKA2 output.

Parameters:

Name Type Description Default
filepath_or_buffer

Path to the SKA2 distance file.

required

Returns:

Type Description
GenomicDistances

A new Distances instance.

Source code in src/seroepi/dist.py
@classmethod
def from_ska2(cls, filepath_or_buffer) -> 'GenomicDistances':
    """
    Parses a pairwise distance matrix from SKA2 output.

    Args:
        filepath_or_buffer: Path to the SKA2 distance file.

    Returns:
        A new Distances instance.
    """
    df = pd.read_table(filepath_or_buffer, usecols=(0, 1, 2))
    return cls.from_pairwise(df.iloc[:, 0], df.iloc[:, 1], df.iloc[:, 2])

get_clusters

get_clusters(threshold: int = 20) -> pd.Series

Identifies clusters via connected components based on a distance threshold.

Parameters:

Name Type Description Default
threshold int

Maximum distance to consider isolates as connected. Defaults to 20 (e.g., 20 SNPs).

20

Returns:

Type Description
Series

A Series of cluster labels indexed by isolate IDs.

Source code in src/seroepi/dist.py
def get_clusters(self, threshold: int = 20) -> pd.Series:
    """
    Identifies clusters via connected components based on a distance threshold.

    Args:
        threshold: Maximum distance to consider isolates as connected.
            Defaults to 20 (e.g., 20 SNPs).

    Returns:
        A Series of cluster labels indexed by isolate IDs.
    """
    adj = self.matrix.copy()  # Make a copy of the CSR array to avoid mutating the frozen original
    # Convert to a binary adjacency array:
    # If distance <= threshold, make it a 1 (Valid Edge).
    # If distance > threshold, make it a 0 (Severed Edge).
    adj.data = (adj.data <= threshold).astype(np.int8)
    # Safely eliminate the 0s (which are now only the severed edges).
    # Identical clones are safe because their distance of 0 was turned into a 1!
    adj.eliminate_zeros()
    adj.setdiag(1)  # Ensure every sample is connected to itself on the diagonal
    _, labels = sp_connected_components(csgraph=adj, directed=False, return_labels=True)
    return pd.Series(labels, index=self.index, dtype='category', name=f"connected_components_{threshold=}").cat.as_ordered()

to_type

to_type(target_type: DistanceMetricType) -> GenomicDistances

Converts the distances to a different metric type.

Parameters:

Name Type Description Default
target_type DistanceMetricType

The desired target MetricType.

required

Returns:

Type Description
GenomicDistances

A new Distances instance with the converted matrix.

Raises:

Type Description
ValueError

If conversion requires max_value but it is not set.

Source code in src/seroepi/dist.py
def to_type(self, target_type: DistanceMetricType) -> 'GenomicDistances':
    """
    Converts the distances to a different metric type.

    Args:
        target_type: The desired target MetricType.

    Returns:
        A new Distances instance with the converted matrix.

    Raises:
        ValueError: If conversion requires `max_value` but it is not set.
    """
    if self.metric_type == target_type:
        return self

    # If crossing the Absolute <-> Relative boundary, we need max_value
    needs_max = {DistanceMetricType.ABSOLUTE_DISTANCE, DistanceMetricType.ABSOLUTE_SIMILARITY}
    targets_norm = {DistanceMetricType.RELATIVE_DISTANCE, DistanceMetricType.RELATIVE_SIMILARITY}

    if (self.metric_type in needs_max and target_type in targets_norm) or \
            (self.metric_type in targets_norm and target_type in needs_max):
        if self.max_value is None:
            raise ValueError(f"Cannot convert between Absolute and Relative without a max_value.")

    # Standardize to Relative Distance first (as a base state)
    if self.metric_type == DistanceMetricType.ABSOLUTE_DISTANCE:
        base_mat = self.matrix / self.max_value
    elif self.metric_type == DistanceMetricType.ABSOLUTE_SIMILARITY:
        base_mat = 1.0 - (self.matrix / self.max_value)
    elif self.metric_type == DistanceMetricType.RELATIVE_SIMILARITY:
        base_mat = 1.0 - self.matrix
    else:
        base_mat = self.matrix

    # 2. Convert from base state (Relative Distance) to Target
    if target_type == DistanceMetricType.RELATIVE_DISTANCE:
        new_mat = base_mat
    elif target_type == DistanceMetricType.RELATIVE_SIMILARITY:
        new_mat = 1.0 - base_mat
    elif target_type == DistanceMetricType.ABSOLUTE_DISTANCE:
        new_mat = base_mat * self.max_value
    elif target_type == DistanceMetricType.ABSOLUTE_SIMILARITY:
        new_mat = (1.0 - base_mat) * self.max_value

    # Explicitly cast back to CSR to prevent dense matrix bleed from scalar subtraction
    if not isinstance(new_mat, csr_array):
        new_mat = csr_array(new_mat)

    # Return a new frozen instance
    return replace(self, matrix=new_mat, metric_type=target_type)

TransmissionDistances dataclass

Bases: DistancesBase

Source code in src/seroepi/dist.py
class TransmissionDistances(DistancesBase):

    @classmethod
    def from_spatiotemporal(
            cls,
            sample_ids: pd.Series,
            coords: np.ndarray,
            dates: np.ndarray,
            clones: np.ndarray,
            spatial_threshold_km: float = 10.0,
            temporal_threshold_days: int = 20,
    ) -> 'TransmissionDistances':
        """Builds a sparse transmission adjacency graph from spatiotemporal arrays."""
        n = len(sample_ids)
        global_rows = []
        global_cols = []

        # Using pandas factorize is extremely fast for finding unique groups (O(N))
        unique_clones, clone_codes = pd.factorize(clones)

        for clone_code in range(len(unique_clones)):
            if pd.isna(unique_clones[clone_code]):
                continue

            idx = np.where(clone_codes == clone_code)[0]

            # Filter to items with valid spatiotemporal data
            valid_mask = ~(np.isnan(coords[idx, 0]) | np.isnan(coords[idx, 1]) | np.isnan(dates[idx]))
            valid_idx = idx[valid_mask]

            if len(valid_idx) < 2:
                continue

            group_coords = coords[valid_idx]
            group_dates = dates[valid_idx]

            tree = BallTree(group_coords, metric='haversine')
            radius_radians = spatial_threshold_km / 6371.0
            spatial_neighbors = tree.query_radius(group_coords, r=radius_radians)

            for i, neighbors in enumerate(spatial_neighbors):
                time_diffs = np.abs(group_dates[i] - group_dates[neighbors])
                valid_neighbors = neighbors[time_diffs <= temporal_threshold_days]

                global_rows.extend([valid_idx[i]] * len(valid_neighbors))
                global_cols.extend(valid_idx[valid_neighbors])

        if global_rows:
            adj = csr_array((np.ones(len(global_rows), dtype=np.int8), (global_rows, global_cols)), shape=(n, n))
            adj = adj.maximum(adj.T)  # Ensure undirected symmetry
        else:
            adj = csr_array((n, n), dtype=np.int8)

        return cls(matrix=adj, index=pd.Series(sample_ids.values, name='sample_id'),
                   metric_type=DistanceMetricType.ABSOLUTE_SIMILARITY, max_value=1.0)

    def get_clusters(self) -> pd.Series:
        """Extracts cluster labels directly from the pre-computed adjacency network."""
        _, labels = sp_connected_components(csgraph=self.matrix, directed=False, return_labels=True)
        return pd.Series(labels, index=self.index, dtype='category').cat.as_ordered()

from_spatiotemporal classmethod

from_spatiotemporal(sample_ids: Series, coords: ndarray, dates: ndarray, clones: ndarray, spatial_threshold_km: float = 10.0, temporal_threshold_days: int = 20) -> TransmissionDistances

Builds a sparse transmission adjacency graph from spatiotemporal arrays.

Source code in src/seroepi/dist.py
@classmethod
def from_spatiotemporal(
        cls,
        sample_ids: pd.Series,
        coords: np.ndarray,
        dates: np.ndarray,
        clones: np.ndarray,
        spatial_threshold_km: float = 10.0,
        temporal_threshold_days: int = 20,
) -> 'TransmissionDistances':
    """Builds a sparse transmission adjacency graph from spatiotemporal arrays."""
    n = len(sample_ids)
    global_rows = []
    global_cols = []

    # Using pandas factorize is extremely fast for finding unique groups (O(N))
    unique_clones, clone_codes = pd.factorize(clones)

    for clone_code in range(len(unique_clones)):
        if pd.isna(unique_clones[clone_code]):
            continue

        idx = np.where(clone_codes == clone_code)[0]

        # Filter to items with valid spatiotemporal data
        valid_mask = ~(np.isnan(coords[idx, 0]) | np.isnan(coords[idx, 1]) | np.isnan(dates[idx]))
        valid_idx = idx[valid_mask]

        if len(valid_idx) < 2:
            continue

        group_coords = coords[valid_idx]
        group_dates = dates[valid_idx]

        tree = BallTree(group_coords, metric='haversine')
        radius_radians = spatial_threshold_km / 6371.0
        spatial_neighbors = tree.query_radius(group_coords, r=radius_radians)

        for i, neighbors in enumerate(spatial_neighbors):
            time_diffs = np.abs(group_dates[i] - group_dates[neighbors])
            valid_neighbors = neighbors[time_diffs <= temporal_threshold_days]

            global_rows.extend([valid_idx[i]] * len(valid_neighbors))
            global_cols.extend(valid_idx[valid_neighbors])

    if global_rows:
        adj = csr_array((np.ones(len(global_rows), dtype=np.int8), (global_rows, global_cols)), shape=(n, n))
        adj = adj.maximum(adj.T)  # Ensure undirected symmetry
    else:
        adj = csr_array((n, n), dtype=np.int8)

    return cls(matrix=adj, index=pd.Series(sample_ids.values, name='sample_id'),
               metric_type=DistanceMetricType.ABSOLUTE_SIMILARITY, max_value=1.0)

get_clusters

get_clusters() -> pd.Series

Extracts cluster labels directly from the pre-computed adjacency network.

Source code in src/seroepi/dist.py
def get_clusters(self) -> pd.Series:
    """Extracts cluster labels directly from the pre-computed adjacency network."""
    _, labels = sp_connected_components(csgraph=self.matrix, directed=False, return_labels=True)
    return pd.Series(labels, index=self.index, dtype='category').cat.as_ordered()

seroepi.constants

Enums for non-user-facing API constants - mostly to help with the app

DistanceMetricType

Bases: StrEnum

Enumeration of supported metric types for pairwise comparisons.

These metric types define how to interpret the numerical values in a distance/similarity matrix.

Attributes:

Name Type Description
ABSOLUTE_DISTANCE

An absolute distance measure (e.g., 5 SNPs).

RELATIVE_DISTANCE

A relative distance measure typically between 0.0 and 1.0 (e.g., 0.05 Hamming).

ABSOLUTE_SIMILARITY

An absolute similarity measure (e.g., 95 shared nucleotides).

RELATIVE_SIMILARITY

A relative similarity measure typically between 0.0 and 1.0 (e.g., 0.95 Jaccard).

Source code in src/seroepi/constants.py
class DistanceMetricType(StrEnum):
    """
    Enumeration of supported metric types for pairwise comparisons.

    These metric types define how to interpret the numerical values
    in a distance/similarity matrix.

    Attributes:
        ABSOLUTE_DISTANCE: An absolute distance measure (e.g., 5 SNPs).
        RELATIVE_DISTANCE: A relative distance measure typically between 0.0 and 1.0 (e.g., 0.05 Hamming).
        ABSOLUTE_SIMILARITY: An absolute similarity measure (e.g., 95 shared nucleotides).
        RELATIVE_SIMILARITY: A relative similarity measure typically between 0.0 and 1.0 (e.g., 0.95 Jaccard).
    """
    ABSOLUTE_DISTANCE = auto()  # e.g., 5 SNPs
    RELATIVE_DISTANCE = auto()  # e.g., 0.05 Hamming
    ABSOLUTE_SIMILARITY = auto()  # e.g., 95 shared nucleotides
    RELATIVE_SIMILARITY = auto()  # e.g., 0.95 Jaccard

    @classmethod
    def _missing_(cls, value):
        if isinstance(value, str):
            if v := cls.__members__.get(value.upper().replace(" ", "_").replace("-", "_")):
                return v
        return cls.ABSOLUTE_DISTANCE

TemporalResolution

Bases: _UiEnum

Source code in src/seroepi/constants.py
class TemporalResolution(_UiEnum):
    YEAR = auto()
    MONTH = auto()
    WEEK = auto()
    DAY = auto()
    UNKNOWN = auto()

    @classmethod
    def _missing_(cls, value):
        return cls.UNKNOWN

    @property
    def pandas_offset(self) -> str:
        """Returns the modern Pandas 2.2+ start-offset alias."""
        return {
            self.YEAR: 'YS',
            self.MONTH: 'MS',
            self.WEEK: 'W-MON',
            self.DAY: 'D',
            self.UNKNOWN: ''
        }[self]

    @property
    def pandas_period(self) -> str:
        """Returns the Pandas period alias."""
        return {
            self.YEAR: 'Y',
            self.MONTH: 'M',
            self.WEEK: 'W',
            self.DAY: 'D',
            self.UNKNOWN: ''
        }[self]

pandas_offset property

pandas_offset: str

Returns the modern Pandas 2.2+ start-offset alias.

pandas_period property

pandas_period: str

Returns the Pandas period alias.

seroepi.client

Module to interact with the Pathogenwatch Next API.

PathogenwatchClient

Client for the Pathogenwatch Next API.

Handles automatic retries, rate limiting, and pagination. It uses a session with a retry strategy to handle common transient errors and rate limits.

Attributes:

Name Type Description
session Session

The underlying requests session with retry strategy.

Examples:

>>> from seroepi.client import PathogenwatchClient
>>> with PathogenwatchClient(api_key="your_api_key") as client:
...     collections = list(client.get_collections(limit=5))
...     for collection in collections:
...         print(collection.name)
Source code in src/seroepi/client.py
class PathogenwatchClient:
    """
    Client for the Pathogenwatch Next API.

    Handles automatic retries, rate limiting, and pagination. It uses a session
    with a retry strategy to handle common transient errors and rate limits.

    Attributes:
        session (requests.Session): The underlying requests session with retry strategy.

    Examples:
        >>> from seroepi.client import PathogenwatchClient
        >>> with PathogenwatchClient(api_key="your_api_key") as client:
        ...     collections = list(client.get_collections(limit=5))
        ...     for collection in collections:
        ...         print(collection.name)
    """
    _BASE = "https://next.pathogen.watch/api/"
    _COLLECTIONS_ENDPOINT = "collections/list"
    _FOLDERS_ENDPOINT = "folders/list"

    def __init__(self, api_key: str):
        """
        Initializes the PathogenwatchClient with an API key.

        Args:
            api_key: The API key for Pathogenwatch Next.
        """
        self.session = requests.Session()

        # Set authentication
        self.session.headers.update({
            "X-API-Key": api_key,
            "Content-Type": "application/json",
            "User-Agent": "seroepi-client/1.0"
        })

        # This replaces the entire threading/lock/backoff mechanism of the old template.
        # It automatically pauses and retries on rate limits (429) or server errors (50X).
        retries = Retry(
            total=5,
            backoff_factor=1,  # 1s, 2s, 4s, 8s, 16s
            status_forcelist=[429, 500, 502, 503, 504],
            allowed_methods=["HEAD", "GET", "OPTIONS", "POST"]
        )
        adapter = HTTPAdapter(max_retries=retries, pool_connections=20, pool_maxsize=20)
        self.session.mount("https://", adapter)

    def __enter__(self):
        """Context manager entry."""
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Context manager exit, closes the session."""
        self.session.close()

    def prefetch(self, items: Iterable['PathogenwatchContainerMixin'], max_workers: int = 10) -> None:
        """
        Concurrently populates the details and genomes cache for multiple collections or folders.

        Uses thread pooling to fetch in parallel while urllib3 safely handles 429 rate limit backoffs.

        Args:
            items: An iterable of PathogenwatchCollection or PathogenwatchFolder objects.
            max_workers: Maximum number of concurrent threads. Defaults to 10.
        """
        with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
            # Submitting item.get_genomes naturally triggers item.get_details 
            # resolving and caching both sequentially per-item, but concurrently across items.
            futures = [executor.submit(item.get_genomes, self) for item in items]
            for future in concurrent.futures.as_completed(futures):
                future.result()  # Raise any exceptions encountered during fetching

    def request(self, method: str, endpoint: str, **kwargs) -> requests.Response:
        """
        Sends an HTTP request to the Pathogenwatch API.

        Args:
            method: HTTP method (e.g., 'GET', 'POST').
            endpoint: API endpoint path.
            **kwargs: Additional arguments passed to requests.request.

        Returns:
            The HTTP response.

        Raises:
            requests.HTTPError: If the request returned an error status code.
        """
        url = f"{self._BASE}/{endpoint.lstrip('/')}"
        with self.session.request(method, url, **kwargs) as response:
            response.raise_for_status() # Automatically raises an error for 4xx/5xx responses
            return response

    def get(self, endpoint: str, **kwargs) -> requests.Response:
        """
        Sends a GET request to the Pathogenwatch API.

        Args:
            endpoint: API endpoint path.
            **kwargs: Additional arguments passed to requests.get.

        Returns:
            The HTTP response.

        Raises:
            requests.HTTPError: If the request returned an error status code.
        """
        url = f"{self._BASE}/{endpoint.lstrip('/')}"
        with self.session.get(url, **kwargs) as response:
            response.raise_for_status() # Automatically raises an error for 4xx/5xx responses
            return response

    def get_collections(self, exclude: str = None, limit: int = None, binned: bool = None
    ) -> Generator['PathogenwatchCollection', None, None]:
        """
        Retrieves collections from Pathogenwatch.

        Args:
            exclude: Collections to exclude.
            limit: Maximum number of collections to retrieve.
            binned: Whether to include binned collections.

        Yields:
            PathogenwatchCollection: The next collection retrieved.
        """
        params = {k: v for k, v in [("exclude", exclude), ("limit", limit),
                                    ("binned", str(binned).lower() if binned is not None else None)] if
                  v is not None}

        valid_keys = {f.name for f in fields(PathogenwatchCollection) if f.init}
        for collection_dict in self.get(self._COLLECTIONS_ENDPOINT, params=params).json():
            yield PathogenwatchCollection(**{k: v for k, v in collection_dict.items() if k in valid_keys})

    def get_folders(self, exclude: str = None, limit: int = None, binned: bool = None
                    ) -> Generator['PathogenwatchFolder', None, None]:
        """
        Retrieves folders from Pathogenwatch.

        Args:
            exclude: Folders to exclude.
            limit: Maximum number of folders to retrieve.
            binned: Whether to include binned folders.

        Yields:
            PathogenwatchFolder: The next folder retrieved.
        """
        params = {k: v for k, v in [("exclude", exclude), ("limit", limit),
                                    ("binned", str(binned).lower() if binned is not None else None)] if
                  v is not None}

        valid_keys = {f.name for f in fields(PathogenwatchFolder) if f.init}
        for folder_dict in self.get(self._FOLDERS_ENDPOINT, params=params).json():
            yield PathogenwatchFolder(**{k: v for k, v in folder_dict.items() if k in valid_keys})

__enter__

__enter__()

Context manager entry.

Source code in src/seroepi/client.py
def __enter__(self):
    """Context manager entry."""
    return self

__exit__

__exit__(exc_type, exc_val, exc_tb)

Context manager exit, closes the session.

Source code in src/seroepi/client.py
def __exit__(self, exc_type, exc_val, exc_tb):
    """Context manager exit, closes the session."""
    self.session.close()

__init__

__init__(api_key: str)

Initializes the PathogenwatchClient with an API key.

Parameters:

Name Type Description Default
api_key str

The API key for Pathogenwatch Next.

required
Source code in src/seroepi/client.py
def __init__(self, api_key: str):
    """
    Initializes the PathogenwatchClient with an API key.

    Args:
        api_key: The API key for Pathogenwatch Next.
    """
    self.session = requests.Session()

    # Set authentication
    self.session.headers.update({
        "X-API-Key": api_key,
        "Content-Type": "application/json",
        "User-Agent": "seroepi-client/1.0"
    })

    # This replaces the entire threading/lock/backoff mechanism of the old template.
    # It automatically pauses and retries on rate limits (429) or server errors (50X).
    retries = Retry(
        total=5,
        backoff_factor=1,  # 1s, 2s, 4s, 8s, 16s
        status_forcelist=[429, 500, 502, 503, 504],
        allowed_methods=["HEAD", "GET", "OPTIONS", "POST"]
    )
    adapter = HTTPAdapter(max_retries=retries, pool_connections=20, pool_maxsize=20)
    self.session.mount("https://", adapter)

get

get(endpoint: str, **kwargs) -> requests.Response

Sends a GET request to the Pathogenwatch API.

Parameters:

Name Type Description Default
endpoint str

API endpoint path.

required
**kwargs

Additional arguments passed to requests.get.

{}

Returns:

Type Description
Response

The HTTP response.

Raises:

Type Description
HTTPError

If the request returned an error status code.

Source code in src/seroepi/client.py
def get(self, endpoint: str, **kwargs) -> requests.Response:
    """
    Sends a GET request to the Pathogenwatch API.

    Args:
        endpoint: API endpoint path.
        **kwargs: Additional arguments passed to requests.get.

    Returns:
        The HTTP response.

    Raises:
        requests.HTTPError: If the request returned an error status code.
    """
    url = f"{self._BASE}/{endpoint.lstrip('/')}"
    with self.session.get(url, **kwargs) as response:
        response.raise_for_status() # Automatically raises an error for 4xx/5xx responses
        return response

get_collections

get_collections(exclude: str = None, limit: int = None, binned: bool = None) -> Generator[PathogenwatchCollection, None, None]

Retrieves collections from Pathogenwatch.

Parameters:

Name Type Description Default
exclude str

Collections to exclude.

None
limit int

Maximum number of collections to retrieve.

None
binned bool

Whether to include binned collections.

None

Yields:

Name Type Description
PathogenwatchCollection PathogenwatchCollection

The next collection retrieved.

Source code in src/seroepi/client.py
def get_collections(self, exclude: str = None, limit: int = None, binned: bool = None
) -> Generator['PathogenwatchCollection', None, None]:
    """
    Retrieves collections from Pathogenwatch.

    Args:
        exclude: Collections to exclude.
        limit: Maximum number of collections to retrieve.
        binned: Whether to include binned collections.

    Yields:
        PathogenwatchCollection: The next collection retrieved.
    """
    params = {k: v for k, v in [("exclude", exclude), ("limit", limit),
                                ("binned", str(binned).lower() if binned is not None else None)] if
              v is not None}

    valid_keys = {f.name for f in fields(PathogenwatchCollection) if f.init}
    for collection_dict in self.get(self._COLLECTIONS_ENDPOINT, params=params).json():
        yield PathogenwatchCollection(**{k: v for k, v in collection_dict.items() if k in valid_keys})

get_folders

get_folders(exclude: str = None, limit: int = None, binned: bool = None) -> Generator[PathogenwatchFolder, None, None]

Retrieves folders from Pathogenwatch.

Parameters:

Name Type Description Default
exclude str

Folders to exclude.

None
limit int

Maximum number of folders to retrieve.

None
binned bool

Whether to include binned folders.

None

Yields:

Name Type Description
PathogenwatchFolder PathogenwatchFolder

The next folder retrieved.

Source code in src/seroepi/client.py
def get_folders(self, exclude: str = None, limit: int = None, binned: bool = None
                ) -> Generator['PathogenwatchFolder', None, None]:
    """
    Retrieves folders from Pathogenwatch.

    Args:
        exclude: Folders to exclude.
        limit: Maximum number of folders to retrieve.
        binned: Whether to include binned folders.

    Yields:
        PathogenwatchFolder: The next folder retrieved.
    """
    params = {k: v for k, v in [("exclude", exclude), ("limit", limit),
                                ("binned", str(binned).lower() if binned is not None else None)] if
              v is not None}

    valid_keys = {f.name for f in fields(PathogenwatchFolder) if f.init}
    for folder_dict in self.get(self._FOLDERS_ENDPOINT, params=params).json():
        yield PathogenwatchFolder(**{k: v for k, v in folder_dict.items() if k in valid_keys})

prefetch

prefetch(items: Iterable[PathogenwatchContainerMixin], max_workers: int = 10) -> None

Concurrently populates the details and genomes cache for multiple collections or folders.

Uses thread pooling to fetch in parallel while urllib3 safely handles 429 rate limit backoffs.

Parameters:

Name Type Description Default
items Iterable[PathogenwatchContainerMixin]

An iterable of PathogenwatchCollection or PathogenwatchFolder objects.

required
max_workers int

Maximum number of concurrent threads. Defaults to 10.

10
Source code in src/seroepi/client.py
def prefetch(self, items: Iterable['PathogenwatchContainerMixin'], max_workers: int = 10) -> None:
    """
    Concurrently populates the details and genomes cache for multiple collections or folders.

    Uses thread pooling to fetch in parallel while urllib3 safely handles 429 rate limit backoffs.

    Args:
        items: An iterable of PathogenwatchCollection or PathogenwatchFolder objects.
        max_workers: Maximum number of concurrent threads. Defaults to 10.
    """
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Submitting item.get_genomes naturally triggers item.get_details 
        # resolving and caching both sequentially per-item, but concurrently across items.
        futures = [executor.submit(item.get_genomes, self) for item in items]
        for future in concurrent.futures.as_completed(futures):
            future.result()  # Raise any exceptions encountered during fetching

request

request(method: str, endpoint: str, **kwargs) -> requests.Response

Sends an HTTP request to the Pathogenwatch API.

Parameters:

Name Type Description Default
method str

HTTP method (e.g., 'GET', 'POST').

required
endpoint str

API endpoint path.

required
**kwargs

Additional arguments passed to requests.request.

{}

Returns:

Type Description
Response

The HTTP response.

Raises:

Type Description
HTTPError

If the request returned an error status code.

Source code in src/seroepi/client.py
def request(self, method: str, endpoint: str, **kwargs) -> requests.Response:
    """
    Sends an HTTP request to the Pathogenwatch API.

    Args:
        method: HTTP method (e.g., 'GET', 'POST').
        endpoint: API endpoint path.
        **kwargs: Additional arguments passed to requests.request.

    Returns:
        The HTTP response.

    Raises:
        requests.HTTPError: If the request returned an error status code.
    """
    url = f"{self._BASE}/{endpoint.lstrip('/')}"
    with self.session.request(method, url, **kwargs) as response:
        response.raise_for_status() # Automatically raises an error for 4xx/5xx responses
        return response

PathogenwatchCollection dataclass

Bases: PathogenwatchContainerMixin

A lazy-loaded proxy object representing a single Pathogenwatch collection.

Attributes:

Name Type Description
binned bool

Whether the collection is binned.

createdAt str

Creation timestamp.

description str

Collection description.

name str

Collection name.

organismId str

ID of the organism.

owner str

Owner of the collection.

uuid str

Unique identifier for the collection.

size int

Number of genomes in the collection.

Source code in src/seroepi/client.py
@dataclass(frozen=True, slots=True)
class PathogenwatchCollection(PathogenwatchContainerMixin):
    """
    A lazy-loaded proxy object representing a single Pathogenwatch collection.

    Attributes:
        binned: Whether the collection is binned.
        createdAt: Creation timestamp.
        description: Collection description.
        name: Collection name.
        organismId: ID of the organism.
        owner: Owner of the collection.
        uuid: Unique identifier for the collection.
        size: Number of genomes in the collection.
    """
    _ENTITY_TYPE: ClassVar[str] = "collections"
    _DETAILS_QUERY_PARAM: ClassVar[str] = "uuid"
    _GENOMES_ID_PARAM: ClassVar[str] = "collectionId"
    _GENOMES_CURSOR_PARAM: ClassVar[str] = "cursor"
    _ATTR_PREFIX: ClassVar[str] = "pw_collection"

    binned: bool
    createdAt: str
    description: str
    name: str
    organismId: str
    owner: str
    uuid: str
    size: int
    _details_cache: Optional[dict] = field(default=None, init=False, repr=False, compare=False)

PathogenwatchContainerMixin

Mixin providing shared fetching logic for Pathogenwatch Collection and Folder dataclasses.

Classes using this mixin must define _ENTITY_TYPE, _DETAILS_QUERY_PARAM, _GENOMES_ID_PARAM, _GENOMES_CURSOR_PARAM, and _ATTR_PREFIX.

Source code in src/seroepi/client.py
class PathogenwatchContainerMixin:
    """
    Mixin providing shared fetching logic for Pathogenwatch Collection and Folder dataclasses.

    Classes using this mixin must define _ENTITY_TYPE, _DETAILS_QUERY_PARAM,
    _GENOMES_ID_PARAM, _GENOMES_CURSOR_PARAM, and _ATTR_PREFIX.
    """
    _ENTITY_TYPE: ClassVar[str]
    _DETAILS_QUERY_PARAM: ClassVar[str]
    _GENOMES_ID_PARAM: ClassVar[str]
    _GENOMES_CURSOR_PARAM: ClassVar[str]
    _ATTR_PREFIX: ClassVar[str]

    def get_details(self, client: PathogenwatchClient) -> dict:
        """
        Fetches detailed information for the container.

        Args:
            client: An instance of PathogenwatchClient.

        Returns:
            A dictionary containing the details.
        """
        if self._details_cache is None:
            details = client.get(f"{self._ENTITY_TYPE}/details", params={self._DETAILS_QUERY_PARAM: self.uuid}).json()
            object.__setattr__(self, '_details_cache', details)
        return self._details_cache

    def get_genomes(self, client: PathogenwatchClient, limit: int = 1000) -> list[dict]:
        """
        Fetches all genomes associated with this container.

        Args:
            client: An instance of PathogenwatchClient.
            limit: Number of genomes to fetch per request for pagination. Defaults to 1000.

        Returns:
            A list of dictionaries, each representing a genome.

        Raises:
            ValueError: If the internal ID for the container cannot be resolved.
        """

        internal_id = self.get_details(client).get('id')
        if not internal_id:
            raise ValueError(f"Could not resolve internal ID for {self._ENTITY_TYPE[:-1]} {self.uuid}")

        all_genomes = []
        cursor = None

        while True:
            params = {self._GENOMES_ID_PARAM: internal_id, "limit": limit}
            if cursor:
                params[self._GENOMES_CURSOR_PARAM] = cursor

            data = client.get(f"{self._ENTITY_TYPE}/genomes", params=params).json()
            all_genomes.extend(data.get("genomes", []))

            cursor = data.get("meta", {}).get("endCursor")
            if not cursor or data.get("meta", {}).get("empty"):
                break

        return all_genomes

get_details

get_details(client: PathogenwatchClient) -> dict

Fetches detailed information for the container.

Parameters:

Name Type Description Default
client PathogenwatchClient

An instance of PathogenwatchClient.

required

Returns:

Type Description
dict

A dictionary containing the details.

Source code in src/seroepi/client.py
def get_details(self, client: PathogenwatchClient) -> dict:
    """
    Fetches detailed information for the container.

    Args:
        client: An instance of PathogenwatchClient.

    Returns:
        A dictionary containing the details.
    """
    if self._details_cache is None:
        details = client.get(f"{self._ENTITY_TYPE}/details", params={self._DETAILS_QUERY_PARAM: self.uuid}).json()
        object.__setattr__(self, '_details_cache', details)
    return self._details_cache

get_genomes

get_genomes(client: PathogenwatchClient, limit: int = 1000) -> list[dict]

Fetches all genomes associated with this container.

Parameters:

Name Type Description Default
client PathogenwatchClient

An instance of PathogenwatchClient.

required
limit int

Number of genomes to fetch per request for pagination. Defaults to 1000.

1000

Returns:

Type Description
list[dict]

A list of dictionaries, each representing a genome.

Raises:

Type Description
ValueError

If the internal ID for the container cannot be resolved.

Source code in src/seroepi/client.py
def get_genomes(self, client: PathogenwatchClient, limit: int = 1000) -> list[dict]:
    """
    Fetches all genomes associated with this container.

    Args:
        client: An instance of PathogenwatchClient.
        limit: Number of genomes to fetch per request for pagination. Defaults to 1000.

    Returns:
        A list of dictionaries, each representing a genome.

    Raises:
        ValueError: If the internal ID for the container cannot be resolved.
    """

    internal_id = self.get_details(client).get('id')
    if not internal_id:
        raise ValueError(f"Could not resolve internal ID for {self._ENTITY_TYPE[:-1]} {self.uuid}")

    all_genomes = []
    cursor = None

    while True:
        params = {self._GENOMES_ID_PARAM: internal_id, "limit": limit}
        if cursor:
            params[self._GENOMES_CURSOR_PARAM] = cursor

        data = client.get(f"{self._ENTITY_TYPE}/genomes", params=params).json()
        all_genomes.extend(data.get("genomes", []))

        cursor = data.get("meta", {}).get("endCursor")
        if not cursor or data.get("meta", {}).get("empty"):
            break

    return all_genomes

PathogenwatchFolder dataclass

Bases: PathogenwatchContainerMixin

A lazy-loaded proxy object representing a single Pathogenwatch folder.

Attributes:

Name Type Description
createdAt str

Creation timestamp.

id str

Internal folder ID.

uuid str

Unique identifier for the folder.

access str

Access level.

name str

Folder name.

binned bool

Whether the folder is binned.

Source code in src/seroepi/client.py
@dataclass(frozen=True, slots=True)
class PathogenwatchFolder(PathogenwatchContainerMixin):
    """
    A lazy-loaded proxy object representing a single Pathogenwatch folder.

    Attributes:
        createdAt: Creation timestamp.
        id: Internal folder ID.
        uuid: Unique identifier for the folder.
        access: Access level.
        name: Folder name.
        binned: Whether the folder is binned.
    """
    _ENTITY_TYPE: ClassVar[str] = "folders"
    _DETAILS_QUERY_PARAM: ClassVar[str] = "id"
    _GENOMES_ID_PARAM: ClassVar[str] = "folderId"
    _GENOMES_CURSOR_PARAM: ClassVar[str] = "after"
    _ATTR_PREFIX: ClassVar[str] = "pw_folder"

    createdAt: str
    id: str
    uuid: str
    access: str
    name: str = ""
    binned: bool = False
    _details_cache: Optional[dict] = field(default=None, init=False, repr=False, compare=False)

seroepi.app