Interactive plots with Jupyter and ipywidgets

Author

Joe Marsh Rossney

Published

August 29, 2025

Summary

This tutorial demonstrates how to create interactive plots in Jupyter notebooks. This is especially useful for workshops.

import logging

from rich.logging import RichHandler

logging.basicConfig(
    level="INFO",
    format="%(message)s",
    datefmt="[%X]",
    handlers=[RichHandler(rich_tracebacks=True)],
)

logger = logging.getLogger()
from typing import Callable

import numpy as np
import matplotlib.pyplot as plt
import sympy

from ipywidgets import interact
import ipywidgets as widgets

# For easy reading
π = np.pi
%matplotlib widget
# Define the sympy symbol
x = sympy.symbols("x")

# A subset of SymPy functions that are (mostly) differentiable
ALLOWED_FUNCS = {
    "x": x,
    "sin": sympy.sin,
    "cos": sympy.cos,
    "tan": sympy.tan,
    "asin": sympy.asin,
    "acos": sympy.acos,
    "atan": sympy.atan,
    "sinh": sympy.sinh,
    "cosh": sympy.cosh,
    "tanh": sympy.tanh,
    "exp": sympy.exp,
    "log": sympy.log,
    "ln": sympy.log,
    "sqrt": sympy.sqrt,
    "pi": sympy.pi,
    "π": sympy.pi,
}


def parse(expression_str: str) -> tuple[sympy.Expr, sympy.Expr]:
    """Parse a string into a sympy expression f(x) and its derivative f'(x)."""
    try:
        f = sympy.sympify(expression_str.strip(), locals=ALLOWED_FUNCS)
    except Exception as e:
        raise ValueError(f"Could not parse expression: {e}")

    # Ensure it's a function of x only
    free_syms = f.free_symbols
    if not free_syms == {x}:
        raise ValueError("Expression must be a function of x only.")

    # Reject any undefined function applications
    undefined = f.atoms(sympy.core.function.AppliedUndef)
    if undefined:
        names = ", ".join(sorted(str(u.func) for u in undefined))
        allowed = ", ".join(
            sorted(k for k in ALLOWED_FUNCS if callable(ALLOWED_FUNCS[k]))
        )
        raise ValueError(f"Unknown function(s): {names}. Allowed: {allowed}")

    dfdx = sympy.diff(f, x)

    return f, dfdx


def lambdify(expression: sympy.Expr) -> Callable[[np.ndarray], [np.ndarray]]:
    return sympy.lambdify(x, expression, modules=["numpy"])


def safe_eval(func, xs):
    """
    Evaluate a numpy-lambdified function with error handling.
    Returns an array with non-finite values masked to np.nan.
    """
    with np.errstate(all="ignore"):
        return func(xs)

    """
        y = np.array(func(xs), dtype=float)
    # Ensure shape compatibility
    if y.shape != xs.shape:
        try:
            y = np.broadcast_to(y, xs.shape)
        except Exception:
            y = np.full_like(xs, np.nan, dtype=float)
    # mask non-finite
    y[~np.isfinite(y)] = np.nan
    return y
    """
f_expr, dfdx_expr = parse("sin(x)")
f, dfdx = map(lambdify, [f_expr, dfdx_expr])

fig, ax = plt.subplots()
ax.set_xlabel("x")
ax.set_ylabel("y")

# NOTE: we cannot use `x` since this is reserved by sympy!
X = np.linspace(-2 * π, 2 * π, 100)
ax.set_xlim(-2 * π, 2 * π)

f, dfdx = map(lambdify, parse("sin(x)"))

ax.plot(X, f(X), label=f"f(x)={f_expr}")
ax.plot(X, dfdx(X), label=f"f'(x)={dfdx_expr}")
ax.legend()
def plot_with_slider():
    fig, ax = plt.subplots(figsize=(6.5, 4), constrained_layout=True)

    def update(expression_str: str, domain: tuple[float, float], x0: float, N: int):
        # Parse expression
        try:
            f_expr, dfdx_expr = parse(expression_str)
        except Exception as e:
            logging.error(e)
            return

        # Check valid domain
        xmin, xmax = domain
        if xmin >= xmax:
            logging.error("`xmin` must be strictly less than `xmax`")
            return

        logging.info(f"Valid expression: {f_expr}")

        # Clear the canvas
        # NOTE: this is a lazy way to do it - it would be better to remove elements individually
        # The downside of being lazy is we have to re-draw everything, including the title
        plt.cla()

        ax.set_title("Functions and their tangents")
        ax.set_xlabel("$x$")
        ax.set_ylabel("$y$")
        ax.set_xlim(xmin, xmax)

        # Lambdify
        f, dfdx = map(lambdify, [f_expr, dfdx_expr])

        # Compute and plot function
        X = np.linspace(xmin, xmax, N)
        y = safe_eval(f, X)
        ax.plot(X, y, label=f"$f(x) = {sympy.latex(f_expr)}$")

        # Compute and plot tangent
        # y0 = float(robust_eval(f_np, np.array([x0]))[0])
        # m = float(robust_eval(df_np, np.array([x0]))[0])
        y0 = float(safe_eval(f, x0))
        m = float(safe_eval(dfdx, x0))
        dx = 0.05 * (xmax - xmin)  # ±5% on either side of x0
        x1, x2 = x0 - dx, x0 + dx
        y1 = y0 + m * (x1 - x0)
        y2 = y0 + m * (x2 - x0)
        ax.plot(
            [x1, x2],
            [y1, y2],
            color="red",
            linewidth=2.5,
            label=rf"$x_0={x0:+.3f}$" + "\n" + rf"$f'(x_0)={m:+.3f}$",
        )

        # The point
        ax.plot([x0], [y0], "o", color="red")

        ax.legend(loc="upper right")

    return update
_ = interact(
    plot_with_slider(),
    expression_str=widgets.Text(
        value="sin(x)", description="f(x):", layout=widgets.Layout(width="400px")
    ),
    domain=widgets.FloatRangeSlider(
        value=(-2 * π, 2 * π),
        min=-4 * π,
        max=4 * π,
        step=0.01,
        description="Domain",
        layout=widgets.Layout(width="400px"),
    ),
    x0=widgets.FloatSlider(
        value=0.0,
        min=-2 * π,
        max=2 * π,
        step=0.01,
        description="x0",
        layout=widgets.Layout(width="400px"),
    ),
    N=widgets.fixed(800),
)

Reuse