create_lead_lag_signal(time_signal, n_leads, n_lags)

Generate multiple lead/lag signals for the given time_signal for each n_leads and n_lags period. Args: time_signal (tf.Tensor): Input time signal tensor with shape (batch_size, time_steps, num_signals). n_leads (int): Number of lead periods to create. If n_leads <=0, no lead_signals will be created. n_lags (int): Number of lag periods to create. If n_lags <=0, no lag_signals will be created. Returns: tf.Tensor: Output signal tensor with shape (batch_size, time_steps, num_signals, (n_lags + 1 + n_leads)). The signal contains the lag signals followed by original time signal, and the lead signals. Example: batch_size = 2 time_steps = 10 num_signals = 3 n_leads = 2 n_lags = 1 time_signal = tf.constant(np.random.rand(batch_size, time_steps, num_signals), dtype=tf.float32) output_signal = create_lead_lag_signal(time_signal, n_leads, n_lags) assert output_signal.shape == (batch_size, time_steps, num_signals, (n_lags + 1 + n_leads)))

Source code in wt_ml/layers/lead_lag_wrapper.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def create_lead_lag_signal(time_signal: tf.Tensor, n_leads: int, n_lags: int) -> tf.Tensor:
    """Generate multiple lead/lag signals for the given `time_signal` for each `n_leads` and `n_lags` period.
    Args:
        time_signal (tf.Tensor): Input time signal tensor with shape (batch_size, time_steps, num_signals).
        n_leads (int): Number of lead periods to create. If n_leads <=0, no lead_signals will be created.
        n_lags (int): Number of lag periods to create. If n_lags <=0, no lag_signals will be created.
    Returns:
        tf.Tensor: Output signal tensor with shape (batch_size, time_steps, num_signals, (n_lags + 1 + n_leads)).
            The signal contains the lag signals followed by original time signal, and the lead signals.
    Example:
        batch_size = 2
        time_steps = 10
        num_signals = 3
        n_leads = 2
        n_lags = 1
        time_signal = tf.constant(np.random.rand(batch_size, time_steps, num_signals), dtype=tf.float32)
        output_signal = create_lead_lag_signal(time_signal, n_leads, n_lags)
        assert output_signal.shape == (batch_size, time_steps, num_signals, (n_lags + 1 + n_leads)))
    """
    if n_leads <= 0 and n_lags <= 0:
        return tf.expand_dims(time_signal, -1)
    # batch_size, time_steps, num_signals
    num_time = tf.shape(time_signal)[1]

    # Pad the time signal with zeros for max lead and lag
    # NOTE: we add extra padding for n_leads to deal with slicing (in [:-1], -1 will not be included).
    # if your signal is [1,2,3,4]
    # padded signal [0,0,0,0,1,2,3,4,0,0,0], for n_leads=3, n_lags=4
    padded_signal = tf.pad(time_signal, [[0, 0], [n_lags, n_leads + 1], [0, 0]], name="padded_signal")

    # Create lead signals for each lead period
    # padded signal [0,0,0,1,2,3,4,0,0,0]
    # lead_perod ∈ [1,n_leads]. Your window is time_steps.
    # Shift window to the right, starting from signal '1' position, by lead_period, along padded signal.
    # Eg. lead_period = 2; lead_signal = [3,4,0,0]
    lead_signals = [
        padded_signal[:, n_lags + lead_period : n_lags + lead_period + num_time]
        for lead_period in range(1, n_leads + 1)
    ]
    # Create lag signals for each lag period
    # padded signal [0,0,0,1,2,3,4,0,0,0]
    # lag_period ∈ [1,n_lags]. Your window is time_steps.
    # Shift window to the left, starting from signal '1' position, by lag_period, along padded signal.
    # Eg. lag_period = 2; lag_signal = [0,0,1,2]
    lag_signals = [
        padded_signal[:, (n_lags - lag_period) : (n_lags - lag_period + num_time)]
        for lag_period in range(n_lags, 0, -1)
    ]
    combined_signals = lag_signals + [time_signal] + lead_signals

    # Stack the lag, original, and lead signals along the last axis
    output_signal = tf.stack(combined_signals, axis=-1, name="combined_lead_lag_output")
    # Now we have signals with
    # [shift_right(original_signal, by=n_lags,...,1), original_signal, shift_left(original_signal, 1,...,n_leads)]
    # or signal_-nth_lag,..signal_-1,signal,signal_+1,..,singal+nth_lead

    return output_signal