Bases: EconomicDataset
This creates batches based on wholesaler location groups.
Source code in wt_ml/optimizer/sim_optimizer/model_compressor.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80 | class LocationGroupedEconomicDataset(EconomicDataset):
"""This creates batches based on wholesaler location groups."""
def __init__(
self,
data: dict[str, np.ndarray],
encodings: dict[str, Any],
location_type: Literal[LOCATION_TYPES],
batch_size: int | None = 32,
seed: int | np.random.Generator = 37,
):
self.location_type = location_type
super().__init__(data, encodings, batch_size, seed)
self.wlsr_encoding = self.encodings["wholesaler"]
self.wlsr_location_lookup: dict[str, str] = self.encodings[f"wholesaler_{self.location_type}_lookup"]
self.location_encoding: dict[str, int] = {
k: i for i, k in enumerate(sorted(set(self.wlsr_location_lookup.values())))
}
self.encodings["location"] = self.location_encoding
self.location_wlsr_idx_group: dict[str, list[str]] = {
loc: [self.wlsr_encoding[wlsr] for wlsr, loc2 in self.wlsr_location_lookup.items() if loc2 == loc]
for loc in self.location_encoding.keys()
}
@property
def num_batches(self) -> int:
return len(self.location_encoding.keys())
def __iter__(self) -> Iterator[EconomicModelInput]:
# yields one batch group at a time of model inputs but makes sure all the wholesalers in a group is selected
wholesaler_index = self._data["wholesaler_index"]
self.rng.shuffle(wholesaler_index)
for i, location in enumerate(self.location_encoding.keys()):
wlsrs_idx = self.location_wlsr_idx_group[location]
batch_indices = np.isin(wholesaler_index, wlsrs_idx).nonzero()
data_kwargs = {k: self._data[k][batch_indices] for k, meta in DATASETS.items() if meta.axes[0] == Batch}
data_kwargs |= {k: self._data[k] for k, meta in DATASETS.items() if meta.axes[0] != Batch}
# yield one batch at a time, building
yield EconomicModelInput(**data_kwargs)
|