MatrixFactorization

Bases: Module

Source code in wt_ml/layers/matrix_factorization.py
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
class MatrixFactorization(Module):
    def __init__(
        self,
        signal_name: str,
        encodings: dict[str, dict[str, int] | str],
        hierarchy: list[str],
        start_date: str | None = None,
        end_date: str | None = None,
        exclude_brands: list[str] | None = None,
        hyperparameters: Hyperparams | None = None,
        direction_map: Mapping[str, Literal[-1, 0, 1]] | None = None,
        name: str | None = None,
        recovery_allowed: bool = False,
        num_weeks_to_learn: int | None = None,
        hier_temporal=False,
    ):
        """A layer that computes a signal based on a matrix factorization like approach on the
        different indices. It assumes it has time as an axis and will forward fill during
        inference if doing inference on dates in the range it applies but that were not trained
        on.

        Args:
            signal_name (str): The name to use for the impact.
            encodings (dict[str, Any]): The encodings used to lookup the meanings of the indices.
            hierarchy (list[str]): The hierarchy levels/indices to do matrix factorization over.
                Note that granularity is handled differently as an additive effect.
            start_date (str | None, optional): The first date to have an impact for. If not
                supplied uses the first date in encodings.
            end_date (str | None, optional): The last date to have an impact for. If not
                supplied uses the last date in encodings.
            exclude_brands (list[str] | None, optional): A list of brands it should not have
                impact on. If not supplied it assumes it should apply to all brands.
            hyperparameters (Hyperparams | None, optional): The hyperparameters it should read
                from if not supplied will use defaults.
            direction_map (Mapping[str, Literal[-1, 0, 1]] | None, optional): Any levels
                specified here will be required to be positive or negative if the value they
                have is 1 or -1 respectively. Specifying 0 has no effect. If not supplied will
                treat it as though everything mapped to 0. Does not apply to granularity due to
                it being additive.
            name (str | None, optional): Name to use for the layer. If not supplied uses the
                class name.
            recovery_allowed (bool, optional): Whether to learn a decay on the impact over time.
                If not supplied will use False.
            num_weeks_to_learn (int | None, optional): Restrict the number of weeks it is
                allowed to learn. Every week in the period with impacts after this will use the
                last value (similar to fill forward). If not supplied will learn every week in
                the inclusive interval of `start_date` to `end_date` (after appropriate
                adjustment if it wasn't supplied).
            hier_temporal (bool, optional): Whether to allow the date axis to depend on
                granularity. This is implemented as matrix factorization into a lower
                dimensional embedding.
        """
        super().__init__(hyperparameters=hyperparameters, name=name)
        self.signal_name = signal_name
        self.encodings = encodings
        self.direction_map = direction_map or {k: 0 for k in hierarchy}
        self.hierarchy = [k for k in (hierarchy or ["date"]) if k in self.encodings.keys()]
        self.hier_temporal = hier_temporal
        self.start_idx: int = get_date_idx(dates=self.encodings["date"], date=start_date, default=0)
        self.end_idx: int = get_date_idx(
            dates=self.encodings["date"], date=end_date, default=len(self.encodings["date"])
        )
        default_num_weeks_to_learn = self.end_idx - self.start_idx + 1
        if num_weeks_to_learn is None:
            num_weeks_to_learn = default_num_weeks_to_learn
        elif default_num_weeks_to_learn < num_weeks_to_learn:
            logger.warning(
                f"Cannot learn {num_weeks_to_learn} weeks in a period of length {default_num_weeks_to_learn}."
            )
            num_weeks_to_learn = default_num_weeks_to_learn
        elif num_weeks_to_learn <= 0:
            raise ValueError("You must allow learning at least 1 week.")
        self.num_weeks_to_learn = num_weeks_to_learn
        if exclude_brands is None:
            exclude_brands = []
        self.exclude_brands = [self.encodings["brand"][br] for br in exclude_brands if br in self.encodings["brand"]]
        self.recovery_allowed = recovery_allowed

    def set_regularizer(self):
        self.enable_date_regularizer = self.hyperparameters.get_bool(
            "enable_date_regularizer",
            default=False,
            help="Flag to specify whether to regularise the time axis.",
        )
        if self.enable_date_regularizer:
            self.date_diff_reg_type = self.hyperparameters.get_choice(
                "date_diff_reg_type",
                default="l1",
                choices=("l1", "l2"),
                help="The kind of regularizer to be applied to time axis.",
            )
            self.date_diff_reg_weight = self.hyperparameters.get_float(
                "date_diff_reg_weight",
                default=0.5,
                min=0.00,
                max=1.0,
                help="The regularizer weight to be applied to reg loss.",
            )
            self.date_accel_reg_weight = self.hyperparameters.get_float(
                "date_accel_reg_weight",
                default=0.5,
                min=0.00,
                max=1.0,
                help="The regularizer weight to be applied to reg loss.",
            )
            self.num_free = self.hyperparameters.get_int(
                "num_free",
                default=10,
                min=0,
                max=30,
                help="length of period of free changing before l1 reg kicks in for date axis.",
            )
        self.enable_factor_l2 = self.hyperparameters.get_bool(
            "enable_factor_l2",
            default=False,
            help="Flag to specify whether to use l2 regularization on the factors.",
        )
        if self.enable_factor_l2:
            self.batch_mult_l2_reg_weight = self.hyperparameters.get_float(
                "batch_mult_l2_reg_weight",
                default=0.5,
                min=0.00,
                max=1.0,
                help="The regularizer weight to be applied to reg loss on the matrix factorization factors.",
            )
            self.batch_additive_l2_reg_weight = self.hyperparameters.get_float(
                "batch_additive_l2_reg_weight",
                default=0.5,
                min=0.00,
                max=1.0,
                help="The regularizer weight to be applied to reg loss on the additive adjustments.",
            )
            self.batch_recovery_l2_reg_weight = self.hyperparameters.get_float(
                "batch_recovery_l2_reg_weight",
                default=0.5,
                min=0.00,
                max=1.0,
                help="The regularizer weight to be applied to reg loss on the recovery rates.",
            )
            self.date_l2_reg_weight = self.hyperparameters.get_float(
                "date_l2_reg_weight",
                default=0.5,
                min=0.00,
                max=1.0,
                help="The regularizer weight to be applied to reg loss on the date matrix factorization factor.",
            )

    def build(self, input_shapes: InputShapes):  # noqa: U100
        self.learning_rate = self.hyperparameters.get_float(
            "learning_rate", default=100.0, min=1, max=10000, help="The parameter to control learning rates of MF vars."
        )
        self.scale_factor = self.hyperparameters.get_float(
            "scale_factor", default=0.03, min=0.0, max=100.0, help="The parameter to control scaling of MF vars."
        )
        self.date_learn_faster = self.hyperparameters.get_float(
            "date_learn_faster",
            default=10.0,
            min=0.001,
            max=100000,
            help="The parameter to control learning rates of date axis.",
        )
        self.n_emb = self.hyperparameters.get_choice(
            "emb_count",
            default=32,
            choices=(2, 4, 8, 16, 32, 64),
            help="The number of different time series to let it learn combinations of.",
        )
        self.batch_mf_vars = {
            name: self.create_var(name=f"{name}_mult", shape=shape, annotated_shape=(name,))
            for name, shape in zip(
                self.hierarchy,
                [[len(self.encodings[k])] for k in self.hierarchy],
            )
            if name not in ("date", "granularity")
        }
        if self.hier_temporal:
            self.temporal_date_raw = self.create_var(
                name="temporal_date_raw",
                shape=[self.num_weeks_to_learn, self.n_emb],
                annotated_shape=("trimmed_date", "n_emb"),
                initializer=0.0,
            )
            self.temporal_granularity_raw = self.create_var(
                name="temporal_granularity_raw",
                shape=[len(self.encodings["granularity"]), self.n_emb],
                annotated_shape=("granularity", "n_emb"),
                initializer=lambda shape, dtype: self.hyperparameters.rng.normal(size=shape, loc=0.0, scale=1.0).astype(
                    dtype.as_numpy_dtype
                ),
            )
        else:
            self.date_mf_var = self.create_var(
                name="date_mult_raw", shape=[self.num_weeks_to_learn], annotated_shape=("date",)
            )
        if "granularity" in self.hierarchy:
            self.granularity_mf_var = self.create_var(
                name="additive_granularity",
                shape=[len(self.encodings["granularity"])],
                annotated_shape=("granularity",),
            )
        if self.recovery_allowed:
            self.recovery_rates = {
                "wholesaler": self.create_var(
                    name="wholesaler_recovery_rate",
                    shape=[len(self.encodings["wholesaler"])],
                    annotated_shape=("wholesaler",),
                )
            }
            self.recovery_min = self.hyperparameters.get_float(
                "recovery_rate_min",
                default=0.0,
                min=0.0,
                max=1.0,
                help="The minimum percentage that it's allowed to learn to recover over the period.",
            )
            self.recovery_max = self.hyperparameters.get_float(
                "recovery_rate_max",
                default=1.0,
                min=0.0,
                max=1.0,
                help="The maximum percentage that it's allowed to learn to recover over the period.",
            )
        self.softbound_weight = self.hyperparameters.get_float(
            "softbound_weight",
            default=0.01,
            min=0.0,
            max=100.0,
            help="The weight to push the softplus from the extremes.",
        )
        self.set_regularizer()

    def __call__(
        self,
        batch: MatrixFactorizationInput,
        training: bool = False,
        debug: bool = False,  # noqa: U100
        skip_metrics: bool = False,
    ) -> MatrixFactorizationIntermediaries:
        lr_factor = tf.constant(self.learning_rate, dtype=tf.float32)
        scale_factor = tf.constant(self.scale_factor, dtype=tf.float32)
        start_idx = tf.cast(self.start_idx, tf.int32)
        end_idx = tf.cast(self.end_idx, tf.int32)
        mask = (
            create_mask(batch.hierarchy["date"], start_idx, end_idx, tf.shape(batch.hierarchy["brand"])[0]) * batch.mask
        )
        if len(self.exclude_brands) > 0:
            brand_idx_mask = tf.tensor_scatter_nd_update(
                tf.ones((len(self.encodings["brand"]),), dtype=tf.float32),
                tf.constant(self.exclude_brands, dtype=tf.int32)[:, None],
                tf.zeros((len(self.exclude_brands),), dtype=tf.float32),
            )
            mask = mask * tf.gather(brand_idx_mask, batch.hierarchy["brand"])[:, None]
        if self.hier_temporal:
            temporal_granularity_raw = tf.gather(
                self.temporal_granularity_raw, tf.cast(batch.hierarchy["granularity"], tf.int32)
            )
            temporal_wibble_emb = tf.math.softmax(temporal_granularity_raw, axis=1)
            time_mults_raw = tf.einsum("be,de,->bd", temporal_wibble_emb, self.temporal_date_raw, lr_factor)
        else:
            time_mults_raw = self.date_mf_var * tf.constant(self.date_learn_faster, dtype=tf.float32) * lr_factor
        if not training:
            time_mults_raw = roll_forward_validation(time_mults_raw)
        if self.direction_map["date"] != 0:
            time_mults_raw = (
                transform_softbounded(
                    time_mults_raw,
                    add_loss=self.add_loss,
                    name="time_mults_raw",
                    max_val=28.0,
                    min_val=-4.0,
                    fcn=softplus,
                    mult=self.softbound_weight,
                )
                * self.direction_map["date"]
            )
        date_mask = tf.math.reduce_any(mask == 1, axis=0)
        temporal_mult = tf.gather(
            time_mults_raw,
            tf.clip_by_value(tf.cast(batch.hierarchy["date"], tf.int32) - start_idx, 0, self.num_weeks_to_learn - 1),
            axis=1 if self.hier_temporal else 0,
        ) * tf.cast(date_mask, dtype=tf.float32)
        batch_mults = self.get_batch_mults(lr_factor, skip_metrics)
        non_temporal_mult = prod_n([tf.gather(mult_raw, batch.hierarchy[k]) for k, mult_raw in batch_mults.items()])
        batch_additive_adjustments = {}
        if "granularity" in self.hierarchy:
            batch_additive_adjustments["granularity"] = (
                tf.gather(self.granularity_mf_var, batch.hierarchy["granularity"])
                * batch.learning_scales
                * scale_factor
            )
        additive_adjustment = tf.add_n(list(batch_additive_adjustments.values()))
        mult = (non_temporal_mult + additive_adjustment)[:, None] * temporal_mult
        mult_no_recover = mult
        if self.recovery_allowed:
            float_start = tf.constant(self.start_idx, dtype=tf.float32)
            float_end = tf.constant(
                min(self.end_idx, self.encodings["date"][self.encodings["max_real_date"]]), dtype=tf.float32
            )
            float_dates = tf.cast(batch.hierarchy["date"], tf.float32)
            time = (float_dates - float_start) / (float_end - float_start)
            batch_recovery_rates = {k: val * lr_factor for k, val in self.recovery_rates.items()}
            recovery_rate_raw = prod_n([tf.gather(val, batch.hierarchy[k]) for k, val in batch_recovery_rates.items()])
            recovery_rate_indexed = (
                transform_softbounded(
                    recovery_rate_raw,
                    add_loss=self.add_loss,
                    name="recovery_amount",
                    max_val=4.0,
                    min_val=-4.0,
                    fcn=tf.nn.sigmoid,
                    mult=self.softbound_weight,
                )
                * (self.recovery_max - self.recovery_min)
                + self.recovery_min
            )
            recovery_impact_over_time = 1.0 - recovery_rate_indexed[:, None] * time
            if not training:
                recovery_impact_over_time = tf.math.maximum(0.0, recovery_impact_over_time)
            mult = mult * recovery_impact_over_time
        else:
            batch_recovery_rates = {}
        impact = softplus(np.log(np.e - 1) + mask * mult)
        impact_by_signal = tf.expand_dims(impact, -1)
        if not skip_metrics:
            used_dates = tf.gather(batch.hierarchy["date"], tf.where(date_mask))
            shifted_start_idx = tf.maximum(start_idx, tf.math.reduce_min(used_dates)) - start_idx
            shifted_end_idx = (
                tf.minimum(
                    start_idx + tf.constant(self.num_weeks_to_learn - 1, dtype=tf.int32), tf.math.reduce_max(used_dates)
                )
                - start_idx
            )
            self._regularize_mf(
                time_mults_raw if self.hier_temporal else time_mults_raw[None],
                shifted_start_idx,
                shifted_end_idx,
                batch_mults,
                batch_additive_adjustments,
                batch_recovery_rates,
            )
        return MatrixFactorizationIntermediaries(
            impact_by_signal=impact_by_signal,
            impact=impact,
            signal_names=tf.convert_to_tensor([self.signal_name]),
            recovery_rate=recovery_rate_indexed if self.recovery_allowed else None,
            recovery_over_time=recovery_impact_over_time if self.recovery_allowed else None,
            temporal_mult=temporal_mult if debug else None,
            non_temporal_mult=non_temporal_mult if debug else None,
            date_var=time_mults_raw if debug else None,
            granularity_effect=additive_adjustment if debug else None,
            mult_no_recovery=mult_no_recover * mask if debug else None,
        )

    def get_batch_mults(self, lr_factor, skip_metrics):
        batch_mults = {}
        for k, var in self.batch_mf_vars.items():
            mult_raw = var * lr_factor
            if k.lower() == "brand" and self.signal_name == "bud_light_event":
                if not skip_metrics and self.enable_factor_l2:
                    self.add_loss(
                        "brand_l2_raw_reg",
                        tf.reduce_sum(tf.math.square(mult_raw) * tf.cast(mult_raw > 0.0, dtype=tf.float32)) / 10.0,
                        "aux",
                        self.batch_mult_l2_reg_weight,
                    )
                mult_raw = mult_raw - 0.5
            if self.direction_map[k] != 0:
                mult_raw = (
                    transform_softbounded(
                        mult_raw,
                        add_loss=self.add_loss,
                        name=f"{k}_mults_raw",
                        max_val=28.0,
                        min_val=-4.0,
                        fcn=softplus,
                        mult=self.softbound_weight,
                    )
                    * self.direction_map[k]
                )
            batch_mults[k] = mult_raw
        return batch_mults

    def _regularize_mf(
        self, time_mults_raw, start_idx, end_idx, batch_mults, batch_additive_adjustments, batch_recovery_rates
    ):
        if self.enable_date_regularizer:
            self._add_date_diff_reg_loss(time_mults_raw, start_idx, end_idx)
        if self.enable_factor_l2:
            self._add_mf_vars_reg_loss(
                batch_mults,
                batch_additive_adjustments,
                batch_recovery_rates,
                time_mults_raw[:, start_idx : end_idx + 1],
            )

    def _add_date_diff_reg_loss(self, time_mults_raw, start_idx, end_idx):
        loss_amount = tf.constant(0, dtype=tf.float32)
        accel_loss_amount = tf.constant(0, dtype=tf.float32)
        loss_name = f"date_velocity_{self.date_diff_reg_type}"
        if start_idx + self.num_free < end_idx:
            velocity = (
                time_mults_raw[:, start_idx + self.num_free + 1 : end_idx + 1]
                - time_mults_raw[:, start_idx + self.num_free : end_idx]
            )
            if self.date_diff_reg_type == "l1":
                loss_amount = tf.reduce_sum(
                    tf.abs(velocity),
                    name="date_diff_l1_reg_loss",
                )
            else:
                loss_amount = tf.reduce_sum(tf.math.square(velocity))
            if start_idx + self.num_free + 1 < end_idx:
                accel_loss_amount = tf.reduce_sum(tf.math.squared_difference(velocity[1:], velocity[:-1]))
        self.add_loss(loss_name, loss_amount, "aux", self.date_diff_reg_weight)
        self.add_loss("date_accel_l2", accel_loss_amount, "aux", self.date_accel_reg_weight)

    def _add_mf_vars_reg_loss(self, batch_mults, batch_additive_adjustments, batch_recovery_rates, time_mults_raw):
        for cat, val_lookup, weight in (
            ("mult", batch_mults, self.batch_mult_l2_reg_weight),
            ("additive", batch_additive_adjustments, self.batch_additive_l2_reg_weight),
            ("recovery", batch_recovery_rates, self.batch_recovery_l2_reg_weight),
            ("mult", {"date": time_mults_raw}, self.date_l2_reg_weight),
        ):
            for name, var in val_lookup.items():
                loss_amount = tf.reduce_sum(tf.square(var), name=f"{name}_{cat}_l2")
                if len(var.shape) > 1:
                    loss_amount = loss_amount / tf.cast(tf.shape(var)[0], dtype=loss_amount.dtype)
                self.add_loss(f"{name}_{cat}_l2", loss_amount, "aux", weight)

__init__(signal_name, encodings, hierarchy, start_date=None, end_date=None, exclude_brands=None, hyperparameters=None, direction_map=None, name=None, recovery_allowed=False, num_weeks_to_learn=None, hier_temporal=False)

A layer that computes a signal based on a matrix factorization like approach on the different indices. It assumes it has time as an axis and will forward fill during inference if doing inference on dates in the range it applies but that were not trained on.

Parameters:

Name Type Description Default
signal_name str

The name to use for the impact.

required
encodings dict[str, Any]

The encodings used to lookup the meanings of the indices.

required
hierarchy list[str]

The hierarchy levels/indices to do matrix factorization over. Note that granularity is handled differently as an additive effect.

required
start_date str | None

The first date to have an impact for. If not supplied uses the first date in encodings.

None
end_date str | None

The last date to have an impact for. If not supplied uses the last date in encodings.

None
exclude_brands list[str] | None

A list of brands it should not have impact on. If not supplied it assumes it should apply to all brands.

None
hyperparameters Hyperparams | None

The hyperparameters it should read from if not supplied will use defaults.

None
direction_map Mapping[str, Literal[-1, 0, 1]] | None

Any levels specified here will be required to be positive or negative if the value they have is 1 or -1 respectively. Specifying 0 has no effect. If not supplied will treat it as though everything mapped to 0. Does not apply to granularity due to it being additive.

None
name str | None

Name to use for the layer. If not supplied uses the class name.

None
recovery_allowed bool

Whether to learn a decay on the impact over time. If not supplied will use False.

False
num_weeks_to_learn int | None

Restrict the number of weeks it is allowed to learn. Every week in the period with impacts after this will use the last value (similar to fill forward). If not supplied will learn every week in the inclusive interval of start_date to end_date (after appropriate adjustment if it wasn't supplied).

None
hier_temporal bool

Whether to allow the date axis to depend on granularity. This is implemented as matrix factorization into a lower dimensional embedding.

False
Source code in wt_ml/layers/matrix_factorization.py
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def __init__(
    self,
    signal_name: str,
    encodings: dict[str, dict[str, int] | str],
    hierarchy: list[str],
    start_date: str | None = None,
    end_date: str | None = None,
    exclude_brands: list[str] | None = None,
    hyperparameters: Hyperparams | None = None,
    direction_map: Mapping[str, Literal[-1, 0, 1]] | None = None,
    name: str | None = None,
    recovery_allowed: bool = False,
    num_weeks_to_learn: int | None = None,
    hier_temporal=False,
):
    """A layer that computes a signal based on a matrix factorization like approach on the
    different indices. It assumes it has time as an axis and will forward fill during
    inference if doing inference on dates in the range it applies but that were not trained
    on.

    Args:
        signal_name (str): The name to use for the impact.
        encodings (dict[str, Any]): The encodings used to lookup the meanings of the indices.
        hierarchy (list[str]): The hierarchy levels/indices to do matrix factorization over.
            Note that granularity is handled differently as an additive effect.
        start_date (str | None, optional): The first date to have an impact for. If not
            supplied uses the first date in encodings.
        end_date (str | None, optional): The last date to have an impact for. If not
            supplied uses the last date in encodings.
        exclude_brands (list[str] | None, optional): A list of brands it should not have
            impact on. If not supplied it assumes it should apply to all brands.
        hyperparameters (Hyperparams | None, optional): The hyperparameters it should read
            from if not supplied will use defaults.
        direction_map (Mapping[str, Literal[-1, 0, 1]] | None, optional): Any levels
            specified here will be required to be positive or negative if the value they
            have is 1 or -1 respectively. Specifying 0 has no effect. If not supplied will
            treat it as though everything mapped to 0. Does not apply to granularity due to
            it being additive.
        name (str | None, optional): Name to use for the layer. If not supplied uses the
            class name.
        recovery_allowed (bool, optional): Whether to learn a decay on the impact over time.
            If not supplied will use False.
        num_weeks_to_learn (int | None, optional): Restrict the number of weeks it is
            allowed to learn. Every week in the period with impacts after this will use the
            last value (similar to fill forward). If not supplied will learn every week in
            the inclusive interval of `start_date` to `end_date` (after appropriate
            adjustment if it wasn't supplied).
        hier_temporal (bool, optional): Whether to allow the date axis to depend on
            granularity. This is implemented as matrix factorization into a lower
            dimensional embedding.
    """
    super().__init__(hyperparameters=hyperparameters, name=name)
    self.signal_name = signal_name
    self.encodings = encodings
    self.direction_map = direction_map or {k: 0 for k in hierarchy}
    self.hierarchy = [k for k in (hierarchy or ["date"]) if k in self.encodings.keys()]
    self.hier_temporal = hier_temporal
    self.start_idx: int = get_date_idx(dates=self.encodings["date"], date=start_date, default=0)
    self.end_idx: int = get_date_idx(
        dates=self.encodings["date"], date=end_date, default=len(self.encodings["date"])
    )
    default_num_weeks_to_learn = self.end_idx - self.start_idx + 1
    if num_weeks_to_learn is None:
        num_weeks_to_learn = default_num_weeks_to_learn
    elif default_num_weeks_to_learn < num_weeks_to_learn:
        logger.warning(
            f"Cannot learn {num_weeks_to_learn} weeks in a period of length {default_num_weeks_to_learn}."
        )
        num_weeks_to_learn = default_num_weeks_to_learn
    elif num_weeks_to_learn <= 0:
        raise ValueError("You must allow learning at least 1 week.")
    self.num_weeks_to_learn = num_weeks_to_learn
    if exclude_brands is None:
        exclude_brands = []
    self.exclude_brands = [self.encodings["brand"][br] for br in exclude_brands if br in self.encodings["brand"]]
    self.recovery_allowed = recovery_allowed