LocationGroupedEconomicDataset

Bases: EconomicDataset

This creates batches based on wholesaler location groups.

Source code in wt_ml/optimizer/sim_optimizer/model_compressor.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
class LocationGroupedEconomicDataset(EconomicDataset):
    """This creates batches based on wholesaler location groups."""

    def __init__(
        self,
        data: dict[str, np.ndarray],
        encodings: dict[str, Any],
        location_type: Literal[LOCATION_TYPES],
        batch_size: int | None = 32,
        seed: int | np.random.Generator = 37,
    ):
        self.location_type = location_type
        super().__init__(data, encodings, batch_size, seed)
        self.wlsr_encoding = self.encodings["wholesaler"]
        self.wlsr_location_lookup: dict[str, str] = self.encodings[f"wholesaler_{self.location_type}_lookup"]
        self.location_encoding: dict[str, int] = {
            k: i for i, k in enumerate(sorted(set(self.wlsr_location_lookup.values())))
        }
        self.encodings["location"] = self.location_encoding
        self.location_wlsr_idx_group: dict[str, list[str]] = {
            loc: [self.wlsr_encoding[wlsr] for wlsr, loc2 in self.wlsr_location_lookup.items() if loc2 == loc]
            for loc in self.location_encoding.keys()
        }

    @property
    def num_batches(self) -> int:
        return len(self.location_encoding.keys())

    def __iter__(self) -> Iterator[EconomicModelInput]:
        # yields one batch group at a time of model inputs but makes sure all the wholesalers in a group is selected
        wholesaler_index = self._data["wholesaler_index"]
        self.rng.shuffle(wholesaler_index)
        for i, location in enumerate(self.location_encoding.keys()):
            wlsrs_idx = self.location_wlsr_idx_group[location]
            batch_indices = np.isin(wholesaler_index, wlsrs_idx).nonzero()
            data_kwargs = {k: self._data[k][batch_indices] for k, meta in DATASETS.items() if meta.axes[0] == Batch}
            data_kwargs |= {k: self._data[k] for k, meta in DATASETS.items() if meta.axes[0] != Batch}
            # yield one batch at a time, building
            yield EconomicModelInput(**data_kwargs)