create_animation_with_fixed_axes(figures, slider_labels, animation_plot_title=None, x_range=None, y_range=None)

Create an animated figure from a list of go.Figure objects with smooth transitions.

Parameters:

Name Type Description Default
figures list of go.Figure

List of Plotly Figure objects to animate.

required
slider_labels list of int

Custom labels for the slider.

required
animation_plot_title str)|None

Title for the animation plot. Defaults to None.

None
x_range tuple or list

Fixed x-axis limits as (min, max) or a list of unique categories.

None
y_range tuple

Fixed y-axis limits as (min, max).

None

Returns:

Type Description
Figure

go.Figure: A single animated figure with smooth transitions between frames.

Source code in wt_ml/output/animation.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
def create_animation_with_fixed_axes(
    figures: list[go.Figure],
    slider_labels: list[int],
    animation_plot_title: str | None = None,
    x_range: tuple[int] | list[int] | None = None,
    y_range: tuple[int] | list[int] | None = None,
) -> go.Figure:
    """
    Create an animated figure from a list of go.Figure objects with smooth transitions.

    Parameters:
        figures (list of go.Figure): List of Plotly Figure objects to animate.
        slider_labels (list of int): Custom labels for the slider.
        animation_plot_title (str)|None: Title for the animation plot. Defaults to None.
        x_range (tuple or list): Fixed x-axis limits as (min, max) or a list of unique categories.
        y_range (tuple): Fixed y-axis limits as (min, max).

    Returns:
        go.Figure: A single animated figure with smooth transitions between frames.
    """

    # Initialize the base figure for the animation
    fig = go.Figure()

    # Store all frames for the animation
    frames = []
    for i, figure in enumerate(figures):
        frame_data = []

        # Create a frame for each figure
        for trace_idx, trace in enumerate(figure.data):
            trace_copy = trace.to_plotly_json()

            # Keep consistent trace names across frames for smooth transitions
            base_name = trace.name if trace.name else f"trace_{trace_idx}"
            trace_copy["name"] = base_name  # Remove frame-specific naming

            # Remove the 'type' field and incompatible properties for Bar traces
            trace_copy.pop("type", None)

            if trace.type == "bar":
                trace_copy["marker"].pop("symbol", None)
                frame_data.append(go.Bar(**trace_copy))
            else:
                # Add easing function for smoother line transitions
                trace_copy["line"] = trace_copy.get("line", {})
                trace_copy["line"]["shape"] = "spline"  # Use spline interpolation
                frame_data.append(go.Scatter(**trace_copy))

            # Initialize the figure with the first frame's traces
            if i == 0:
                if trace.type == "bar":
                    fig.add_trace(go.Bar(**trace_copy))
                else:
                    fig.add_trace(go.Scatter(**trace_copy))

        # Create a frame with both bar and line traces
        frames.append(
            go.Frame(
                data=frame_data,
                name=str(i),
                traces=list(range(len(figure.data))),  # Explicitly map traces for consistent ordering
            )
        )

    # Assign frames to the figure
    fig.frames = frames

    # Set fixed axis ranges if provided
    if isinstance(x_range, list):
        fig.update_xaxes(categoryorder="array", categoryarray=x_range)
    elif x_range is not None:
        fig.update_xaxes(range=x_range)

    if y_range is not None:
        fig.update_yaxes(range=y_range)

    # Define slider steps with smoother transitions
    steps = []
    for i in range(len(figures)):
        step = dict(
            method="animate",
            args=[
                [str(i)],
                dict(
                    frame=dict(duration=750, redraw=True),
                    transition=dict(duration=500, easing="cubic-in-out"),  # Add easing function
                    mode="immediate",
                ),
            ],
            label=slider_labels[i],
        )
        steps.append(step)

    # Define play/pause buttons with smoother animations
    play_button = dict(
        label="Play",
        method="animate",
        args=[
            None,
            dict(
                frame=dict(duration=750, redraw=True),
                transition=dict(duration=500, easing="cubic-in-out"),
                fromcurrent=True,
                mode="immediate",
            ),
        ],
    )

    pause_button = dict(
        label="Pause", method="animate", args=[[None], dict(frame=dict(duration=0, redraw=False), mode="immediate")]
    )

    # Update layout with sliders and buttons
    fig.update_layout(
        sliders=[
            dict(
                active=0,
                currentvalue={"prefix": "Frame: "},
                pad={"t": 30},
                len=0.8,
                x=0.1,
                y=0,
                steps=steps,
            )
        ],
        title=animation_plot_title,
        updatemenus=[
            dict(
                type="buttons",
                showactive=False,
                buttons=[play_button, pause_button],
                direction="left",
                x=0.1,
                y=0,
                pad={"r": 10, "t": 30},
                xanchor="right",
                yanchor="top",
            )
        ],
    )

    return fig

detect_axis_ranges(figures)

Detect the global x-axis and y-axis ranges from a list of go.Figure objects.

Parameters:

Name Type Description Default
figures list of go.Figure

List of Plotly Figure objects to detect axis ranges.

required

Returns:

Name Type Description
tuple tuple[list[Any] | tuple[Any, Any], tuple[Any, Any]]

(x_range, y_range) where: - x_range is (min_x, max_x) if x-values are numeric, or a sorted list of unique categories if x-values are categorical. - y_range is (min_y, max_y) for the y-axis (always numeric).

Source code in wt_ml/output/animation.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def detect_axis_ranges(
    figures: list[go.Figure],
) -> tuple[list[Any] | tuple[Any, Any], tuple[Any, Any]]:
    """
    Detect the global x-axis and y-axis ranges from a list of go.Figure objects.

    Parameters:
        figures (list of go.Figure): List of Plotly Figure objects to detect axis ranges.

    Returns:
        tuple: (x_range, y_range) where:
            - x_range is (min_x, max_x) if x-values are numeric, or a sorted list of unique categories if x-values are
            categorical.
            - y_range is (min_y, max_y) for the y-axis (always numeric).
    """
    y_min = np.inf
    y_max = -np.inf
    x_min = np.inf
    x_max = -np.inf
    x_categories = set()
    is_x_categorical = False

    for figure in figures:
        for trace in figure.data:
            y_data = trace["y"]  # type: ignore
            if y_data is not None:
                y_min = min(y_min, min(y_data))
                y_max = max(y_max, max(y_data))

            x_data = trace["x"]  # type: ignore
            if x_data is not None:
                if all(isinstance(x, (int, float)) for x in x_data):
                    x_min = min(x_min, min(x_data))
                    x_max = max(x_max, max(x_data))
                else:
                    is_x_categorical = True
                    x_categories.update(x_data)

    if is_x_categorical:
        x_range: list[Any] = sorted(list(x_categories))
    else:
        x_range: tuple[float] = (x_min, x_max)

    y_range: tuple[float] = (y_min, y_max)

    return x_range, y_range