54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481 | class MatrixFactorization(Module):
def __init__(
self,
signal_name: str,
encodings: dict[str, dict[str, int] | str],
hierarchy: list[str],
start_date: str | None = None,
end_date: str | None = None,
exclude_brands: list[str] | None = None,
hyperparameters: Hyperparams | None = None,
direction_map: Mapping[str, Literal[-1, 0, 1]] | None = None,
name: str | None = None,
recovery_allowed: bool = False,
num_weeks_to_learn: int | None = None,
hier_temporal=False,
):
"""A layer that computes a signal based on a matrix factorization like approach on the
different indices. It assumes it has time as an axis and will forward fill during
inference if doing inference on dates in the range it applies but that were not trained
on.
Args:
signal_name (str): The name to use for the impact.
encodings (dict[str, Any]): The encodings used to lookup the meanings of the indices.
hierarchy (list[str]): The hierarchy levels/indices to do matrix factorization over.
Note that granularity is handled differently as an additive effect.
start_date (str | None, optional): The first date to have an impact for. If not
supplied uses the first date in encodings.
end_date (str | None, optional): The last date to have an impact for. If not
supplied uses the last date in encodings.
exclude_brands (list[str] | None, optional): A list of brands it should not have
impact on. If not supplied it assumes it should apply to all brands.
hyperparameters (Hyperparams | None, optional): The hyperparameters it should read
from if not supplied will use defaults.
direction_map (Mapping[str, Literal[-1, 0, 1]] | None, optional): Any levels
specified here will be required to be positive or negative if the value they
have is 1 or -1 respectively. Specifying 0 has no effect. If not supplied will
treat it as though everything mapped to 0. Does not apply to granularity due to
it being additive.
name (str | None, optional): Name to use for the layer. If not supplied uses the
class name.
recovery_allowed (bool, optional): Whether to learn a decay on the impact over time.
If not supplied will use False.
num_weeks_to_learn (int | None, optional): Restrict the number of weeks it is
allowed to learn. Every week in the period with impacts after this will use the
last value (similar to fill forward). If not supplied will learn every week in
the inclusive interval of `start_date` to `end_date` (after appropriate
adjustment if it wasn't supplied).
hier_temporal (bool, optional): Whether to allow the date axis to depend on
granularity. This is implemented as matrix factorization into a lower
dimensional embedding.
"""
super().__init__(hyperparameters=hyperparameters, name=name)
self.signal_name = signal_name
self.encodings = encodings
self.direction_map = direction_map or {k: 0 for k in hierarchy}
self.hierarchy = [k for k in (hierarchy or ["date"]) if k in self.encodings.keys()]
self.hier_temporal = hier_temporal
self.start_idx: int = get_date_idx(dates=self.encodings["date"], date=start_date, default=0)
self.end_idx: int = get_date_idx(
dates=self.encodings["date"], date=end_date, default=len(self.encodings["date"])
)
default_num_weeks_to_learn = self.end_idx - self.start_idx + 1
if num_weeks_to_learn is None:
num_weeks_to_learn = default_num_weeks_to_learn
elif default_num_weeks_to_learn < num_weeks_to_learn:
logger.warning(
f"Cannot learn {num_weeks_to_learn} weeks in a period of length {default_num_weeks_to_learn}."
)
num_weeks_to_learn = default_num_weeks_to_learn
elif num_weeks_to_learn <= 0:
raise ValueError("You must allow learning at least 1 week.")
self.num_weeks_to_learn = num_weeks_to_learn
if exclude_brands is None:
exclude_brands = []
self.exclude_brands = [self.encodings["brand"][br] for br in exclude_brands if br in self.encodings["brand"]]
self.recovery_allowed = recovery_allowed
def set_regularizer(self):
self.enable_date_regularizer = self.hyperparameters.get_bool(
"enable_date_regularizer",
default=False,
help="Flag to specify whether to regularise the time axis.",
)
if self.enable_date_regularizer:
self.date_diff_reg_type = self.hyperparameters.get_choice(
"date_diff_reg_type",
default="l1",
choices=("l1", "l2"),
help="The kind of regularizer to be applied to time axis.",
)
self.date_diff_reg_weight = self.hyperparameters.get_float(
"date_diff_reg_weight",
default=0.5,
min=0.00,
max=1.0,
help="The regularizer weight to be applied to reg loss.",
)
self.date_accel_reg_weight = self.hyperparameters.get_float(
"date_accel_reg_weight",
default=0.5,
min=0.00,
max=1.0,
help="The regularizer weight to be applied to reg loss.",
)
self.num_free = self.hyperparameters.get_int(
"num_free",
default=10,
min=0,
max=30,
help="length of period of free changing before l1 reg kicks in for date axis.",
)
self.enable_factor_l2 = self.hyperparameters.get_bool(
"enable_factor_l2",
default=False,
help="Flag to specify whether to use l2 regularization on the factors.",
)
if self.enable_factor_l2:
self.batch_mult_l2_reg_weight = self.hyperparameters.get_float(
"batch_mult_l2_reg_weight",
default=0.5,
min=0.00,
max=1.0,
help="The regularizer weight to be applied to reg loss on the matrix factorization factors.",
)
self.batch_additive_l2_reg_weight = self.hyperparameters.get_float(
"batch_additive_l2_reg_weight",
default=0.5,
min=0.00,
max=1.0,
help="The regularizer weight to be applied to reg loss on the additive adjustments.",
)
self.batch_recovery_l2_reg_weight = self.hyperparameters.get_float(
"batch_recovery_l2_reg_weight",
default=0.5,
min=0.00,
max=1.0,
help="The regularizer weight to be applied to reg loss on the recovery rates.",
)
self.date_l2_reg_weight = self.hyperparameters.get_float(
"date_l2_reg_weight",
default=0.5,
min=0.00,
max=1.0,
help="The regularizer weight to be applied to reg loss on the date matrix factorization factor.",
)
def build(self, input_shapes: InputShapes): # noqa: U100
self.learning_rate = self.hyperparameters.get_float(
"learning_rate", default=100.0, min=1, max=10000, help="The parameter to control learning rates of MF vars."
)
self.scale_factor = self.hyperparameters.get_float(
"scale_factor", default=0.03, min=0.0, max=100.0, help="The parameter to control scaling of MF vars."
)
self.date_learn_faster = self.hyperparameters.get_float(
"date_learn_faster",
default=10.0,
min=0.001,
max=100000,
help="The parameter to control learning rates of date axis.",
)
self.n_emb = self.hyperparameters.get_choice(
"emb_count",
default=32,
choices=(2, 4, 8, 16, 32, 64),
help="The number of different time series to let it learn combinations of.",
)
self.batch_mf_vars = {
name: self.create_var(name=f"{name}_mult", shape=shape, annotated_shape=(name,))
for name, shape in zip(
self.hierarchy,
[[len(self.encodings[k])] for k in self.hierarchy],
)
if name not in ("date", "granularity")
}
if self.hier_temporal:
self.temporal_date_raw = self.create_var(
name="temporal_date_raw",
shape=[self.num_weeks_to_learn, self.n_emb],
annotated_shape=("trimmed_date", "n_emb"),
initializer=0.0,
)
self.temporal_granularity_raw = self.create_var(
name="temporal_granularity_raw",
shape=[len(self.encodings["granularity"]), self.n_emb],
annotated_shape=("granularity", "n_emb"),
initializer=lambda shape, dtype: self.hyperparameters.rng.normal(size=shape, loc=0.0, scale=1.0).astype(
dtype.as_numpy_dtype
),
)
else:
self.date_mf_var = self.create_var(
name="date_mult_raw", shape=[self.num_weeks_to_learn], annotated_shape=("date",)
)
if "granularity" in self.hierarchy:
self.granularity_mf_var = self.create_var(
name="additive_granularity",
shape=[len(self.encodings["granularity"])],
annotated_shape=("granularity",),
)
if self.recovery_allowed:
self.recovery_rates = {
"wholesaler": self.create_var(
name="wholesaler_recovery_rate",
shape=[len(self.encodings["wholesaler"])],
annotated_shape=("wholesaler",),
)
}
self.recovery_min = self.hyperparameters.get_float(
"recovery_rate_min",
default=0.0,
min=0.0,
max=1.0,
help="The minimum percentage that it's allowed to learn to recover over the period.",
)
self.recovery_max = self.hyperparameters.get_float(
"recovery_rate_max",
default=1.0,
min=0.0,
max=1.0,
help="The maximum percentage that it's allowed to learn to recover over the period.",
)
self.softbound_weight = self.hyperparameters.get_float(
"softbound_weight",
default=0.01,
min=0.0,
max=100.0,
help="The weight to push the softplus from the extremes.",
)
self.set_regularizer()
def __call__(
self,
batch: MatrixFactorizationInput,
training: bool = False,
debug: bool = False, # noqa: U100
skip_metrics: bool = False,
) -> MatrixFactorizationIntermediaries:
lr_factor = tf.constant(self.learning_rate, dtype=tf.float32)
scale_factor = tf.constant(self.scale_factor, dtype=tf.float32)
start_idx = tf.cast(self.start_idx, tf.int32)
end_idx = tf.cast(self.end_idx, tf.int32)
mask = (
create_mask(batch.hierarchy["date"], start_idx, end_idx, tf.shape(batch.hierarchy["brand"])[0]) * batch.mask
)
if len(self.exclude_brands) > 0:
brand_idx_mask = tf.tensor_scatter_nd_update(
tf.ones((len(self.encodings["brand"]),), dtype=tf.float32),
tf.constant(self.exclude_brands, dtype=tf.int32)[:, None],
tf.zeros((len(self.exclude_brands),), dtype=tf.float32),
)
mask = mask * tf.gather(brand_idx_mask, batch.hierarchy["brand"])[:, None]
if self.hier_temporal:
temporal_granularity_raw = tf.gather(
self.temporal_granularity_raw, tf.cast(batch.hierarchy["granularity"], tf.int32)
)
temporal_wibble_emb = tf.math.softmax(temporal_granularity_raw, axis=1)
time_mults_raw = tf.einsum("be,de,->bd", temporal_wibble_emb, self.temporal_date_raw, lr_factor)
else:
time_mults_raw = self.date_mf_var * tf.constant(self.date_learn_faster, dtype=tf.float32) * lr_factor
if not training:
time_mults_raw = roll_forward_validation(time_mults_raw)
if self.direction_map["date"] != 0:
time_mults_raw = (
transform_softbounded(
time_mults_raw,
add_loss=self.add_loss,
name="time_mults_raw",
max_val=28.0,
min_val=-4.0,
fcn=softplus,
mult=self.softbound_weight,
)
* self.direction_map["date"]
)
date_mask = tf.math.reduce_any(mask == 1, axis=0)
temporal_mult = tf.gather(
time_mults_raw,
tf.clip_by_value(tf.cast(batch.hierarchy["date"], tf.int32) - start_idx, 0, self.num_weeks_to_learn - 1),
axis=1 if self.hier_temporal else 0,
) * tf.cast(date_mask, dtype=tf.float32)
batch_mults = self.get_batch_mults(lr_factor, skip_metrics)
non_temporal_mult = prod_n([tf.gather(mult_raw, batch.hierarchy[k]) for k, mult_raw in batch_mults.items()])
batch_additive_adjustments = {}
if "granularity" in self.hierarchy:
batch_additive_adjustments["granularity"] = (
tf.gather(self.granularity_mf_var, batch.hierarchy["granularity"])
* batch.learning_scales
* scale_factor
)
additive_adjustment = tf.add_n(list(batch_additive_adjustments.values()))
mult = (non_temporal_mult + additive_adjustment)[:, None] * temporal_mult
mult_no_recover = mult
if self.recovery_allowed:
float_start = tf.constant(self.start_idx, dtype=tf.float32)
float_end = tf.constant(
min(self.end_idx, self.encodings["date"][self.encodings["max_real_date"]]), dtype=tf.float32
)
float_dates = tf.cast(batch.hierarchy["date"], tf.float32)
time = (float_dates - float_start) / (float_end - float_start)
batch_recovery_rates = {k: val * lr_factor for k, val in self.recovery_rates.items()}
recovery_rate_raw = prod_n([tf.gather(val, batch.hierarchy[k]) for k, val in batch_recovery_rates.items()])
recovery_rate_indexed = (
transform_softbounded(
recovery_rate_raw,
add_loss=self.add_loss,
name="recovery_amount",
max_val=4.0,
min_val=-4.0,
fcn=tf.nn.sigmoid,
mult=self.softbound_weight,
)
* (self.recovery_max - self.recovery_min)
+ self.recovery_min
)
recovery_impact_over_time = 1.0 - recovery_rate_indexed[:, None] * time
if not training:
recovery_impact_over_time = tf.math.maximum(0.0, recovery_impact_over_time)
mult = mult * recovery_impact_over_time
else:
batch_recovery_rates = {}
impact = softplus(np.log(np.e - 1) + mask * mult)
impact_by_signal = tf.expand_dims(impact, -1)
if not skip_metrics:
used_dates = tf.gather(batch.hierarchy["date"], tf.where(date_mask))
shifted_start_idx = tf.maximum(start_idx, tf.math.reduce_min(used_dates)) - start_idx
shifted_end_idx = (
tf.minimum(
start_idx + tf.constant(self.num_weeks_to_learn - 1, dtype=tf.int32), tf.math.reduce_max(used_dates)
)
- start_idx
)
self._regularize_mf(
time_mults_raw if self.hier_temporal else time_mults_raw[None],
shifted_start_idx,
shifted_end_idx,
batch_mults,
batch_additive_adjustments,
batch_recovery_rates,
)
return MatrixFactorizationIntermediaries(
impact_by_signal=impact_by_signal,
impact=impact,
signal_names=tf.convert_to_tensor([self.signal_name]),
recovery_rate=recovery_rate_indexed if self.recovery_allowed else None,
recovery_over_time=recovery_impact_over_time if self.recovery_allowed else None,
temporal_mult=temporal_mult if debug else None,
non_temporal_mult=non_temporal_mult if debug else None,
date_var=time_mults_raw if debug else None,
granularity_effect=additive_adjustment if debug else None,
mult_no_recovery=mult_no_recover * mask if debug else None,
)
def get_batch_mults(self, lr_factor, skip_metrics):
batch_mults = {}
for k, var in self.batch_mf_vars.items():
mult_raw = var * lr_factor
if k.lower() == "brand" and self.signal_name == "bud_light_event":
if not skip_metrics and self.enable_factor_l2:
self.add_loss(
"brand_l2_raw_reg",
tf.reduce_sum(tf.math.square(mult_raw) * tf.cast(mult_raw > 0.0, dtype=tf.float32)) / 10.0,
"aux",
self.batch_mult_l2_reg_weight,
)
mult_raw = mult_raw - 0.5
if self.direction_map[k] != 0:
mult_raw = (
transform_softbounded(
mult_raw,
add_loss=self.add_loss,
name=f"{k}_mults_raw",
max_val=28.0,
min_val=-4.0,
fcn=softplus,
mult=self.softbound_weight,
)
* self.direction_map[k]
)
batch_mults[k] = mult_raw
return batch_mults
def _regularize_mf(
self, time_mults_raw, start_idx, end_idx, batch_mults, batch_additive_adjustments, batch_recovery_rates
):
if self.enable_date_regularizer:
self._add_date_diff_reg_loss(time_mults_raw, start_idx, end_idx)
if self.enable_factor_l2:
self._add_mf_vars_reg_loss(
batch_mults,
batch_additive_adjustments,
batch_recovery_rates,
time_mults_raw[:, start_idx : end_idx + 1],
)
def _add_date_diff_reg_loss(self, time_mults_raw, start_idx, end_idx):
loss_amount = tf.constant(0, dtype=tf.float32)
accel_loss_amount = tf.constant(0, dtype=tf.float32)
loss_name = f"date_velocity_{self.date_diff_reg_type}"
if start_idx + self.num_free < end_idx:
velocity = (
time_mults_raw[:, start_idx + self.num_free + 1 : end_idx + 1]
- time_mults_raw[:, start_idx + self.num_free : end_idx]
)
if self.date_diff_reg_type == "l1":
loss_amount = tf.reduce_sum(
tf.abs(velocity),
name="date_diff_l1_reg_loss",
)
else:
loss_amount = tf.reduce_sum(tf.math.square(velocity))
if start_idx + self.num_free + 1 < end_idx:
accel_loss_amount = tf.reduce_sum(tf.math.squared_difference(velocity[1:], velocity[:-1]))
self.add_loss(loss_name, loss_amount, "aux", self.date_diff_reg_weight)
self.add_loss("date_accel_l2", accel_loss_amount, "aux", self.date_accel_reg_weight)
def _add_mf_vars_reg_loss(self, batch_mults, batch_additive_adjustments, batch_recovery_rates, time_mults_raw):
for cat, val_lookup, weight in (
("mult", batch_mults, self.batch_mult_l2_reg_weight),
("additive", batch_additive_adjustments, self.batch_additive_l2_reg_weight),
("recovery", batch_recovery_rates, self.batch_recovery_l2_reg_weight),
("mult", {"date": time_mults_raw}, self.date_l2_reg_weight),
):
for name, var in val_lookup.items():
loss_amount = tf.reduce_sum(tf.square(var), name=f"{name}_{cat}_l2")
if len(var.shape) > 1:
loss_amount = loss_amount / tf.cast(tf.shape(var)[0], dtype=loss_amount.dtype)
self.add_loss(f"{name}_{cat}_l2", loss_amount, "aux", weight)
|