import logging
from rich.logging import RichHandler
logging.basicConfig(
level="INFO",
format="%(message)s",
datefmt="[%X]",
handlers=[RichHandler(rich_tracebacks=True)],
)
logger = logging.getLogger()Interactive plots with Jupyter and ipywidgets
Summary
This tutorial demonstrates how to create interactive plots in Jupyter notebooks. This is especially useful for workshops.
from typing import Callable
import numpy as np
import matplotlib.pyplot as plt
import sympy
from ipywidgets import interact
import ipywidgets as widgets
# For easy reading
π = np.pi%matplotlib widget# Define the sympy symbol
x = sympy.symbols("x")
# A subset of SymPy functions that are (mostly) differentiable
ALLOWED_FUNCS = {
"x": x,
"sin": sympy.sin,
"cos": sympy.cos,
"tan": sympy.tan,
"asin": sympy.asin,
"acos": sympy.acos,
"atan": sympy.atan,
"sinh": sympy.sinh,
"cosh": sympy.cosh,
"tanh": sympy.tanh,
"exp": sympy.exp,
"log": sympy.log,
"ln": sympy.log,
"sqrt": sympy.sqrt,
"pi": sympy.pi,
"π": sympy.pi,
}
def parse(expression_str: str) -> tuple[sympy.Expr, sympy.Expr]:
"""Parse a string into a sympy expression f(x) and its derivative f'(x)."""
try:
f = sympy.sympify(expression_str.strip(), locals=ALLOWED_FUNCS)
except Exception as e:
raise ValueError(f"Could not parse expression: {e}")
# Ensure it's a function of x only
free_syms = f.free_symbols
if not free_syms == {x}:
raise ValueError("Expression must be a function of x only.")
# Reject any undefined function applications
undefined = f.atoms(sympy.core.function.AppliedUndef)
if undefined:
names = ", ".join(sorted(str(u.func) for u in undefined))
allowed = ", ".join(
sorted(k for k in ALLOWED_FUNCS if callable(ALLOWED_FUNCS[k]))
)
raise ValueError(f"Unknown function(s): {names}. Allowed: {allowed}")
dfdx = sympy.diff(f, x)
return f, dfdx
def lambdify(expression: sympy.Expr) -> Callable[[np.ndarray], [np.ndarray]]:
return sympy.lambdify(x, expression, modules=["numpy"])
def safe_eval(func, xs):
"""
Evaluate a numpy-lambdified function with error handling.
Returns an array with non-finite values masked to np.nan.
"""
with np.errstate(all="ignore"):
return func(xs)
"""
y = np.array(func(xs), dtype=float)
# Ensure shape compatibility
if y.shape != xs.shape:
try:
y = np.broadcast_to(y, xs.shape)
except Exception:
y = np.full_like(xs, np.nan, dtype=float)
# mask non-finite
y[~np.isfinite(y)] = np.nan
return y
"""f_expr, dfdx_expr = parse("sin(x)")
f, dfdx = map(lambdify, [f_expr, dfdx_expr])
fig, ax = plt.subplots()
ax.set_xlabel("x")
ax.set_ylabel("y")
# NOTE: we cannot use `x` since this is reserved by sympy!
X = np.linspace(-2 * π, 2 * π, 100)
ax.set_xlim(-2 * π, 2 * π)
f, dfdx = map(lambdify, parse("sin(x)"))
ax.plot(X, f(X), label=f"f(x)={f_expr}")
ax.plot(X, dfdx(X), label=f"f'(x)={dfdx_expr}")
ax.legend()def plot_with_slider():
fig, ax = plt.subplots(figsize=(6.5, 4), constrained_layout=True)
def update(expression_str: str, domain: tuple[float, float], x0: float, N: int):
# Parse expression
try:
f_expr, dfdx_expr = parse(expression_str)
except Exception as e:
logging.error(e)
return
# Check valid domain
xmin, xmax = domain
if xmin >= xmax:
logging.error("`xmin` must be strictly less than `xmax`")
return
logging.info(f"Valid expression: {f_expr}")
# Clear the canvas
# NOTE: this is a lazy way to do it - it would be better to remove elements individually
# The downside of being lazy is we have to re-draw everything, including the title
plt.cla()
ax.set_title("Functions and their tangents")
ax.set_xlabel("$x$")
ax.set_ylabel("$y$")
ax.set_xlim(xmin, xmax)
# Lambdify
f, dfdx = map(lambdify, [f_expr, dfdx_expr])
# Compute and plot function
X = np.linspace(xmin, xmax, N)
y = safe_eval(f, X)
ax.plot(X, y, label=f"$f(x) = {sympy.latex(f_expr)}$")
# Compute and plot tangent
# y0 = float(robust_eval(f_np, np.array([x0]))[0])
# m = float(robust_eval(df_np, np.array([x0]))[0])
y0 = float(safe_eval(f, x0))
m = float(safe_eval(dfdx, x0))
dx = 0.05 * (xmax - xmin) # ±5% on either side of x0
x1, x2 = x0 - dx, x0 + dx
y1 = y0 + m * (x1 - x0)
y2 = y0 + m * (x2 - x0)
ax.plot(
[x1, x2],
[y1, y2],
color="red",
linewidth=2.5,
label=rf"$x_0={x0:+.3f}$" + "\n" + rf"$f'(x_0)={m:+.3f}$",
)
# The point
ax.plot([x0], [y0], "o", color="red")
ax.legend(loc="upper right")
return update_ = interact(
plot_with_slider(),
expression_str=widgets.Text(
value="sin(x)", description="f(x):", layout=widgets.Layout(width="400px")
),
domain=widgets.FloatRangeSlider(
value=(-2 * π, 2 * π),
min=-4 * π,
max=4 * π,
step=0.01,
description="Domain",
layout=widgets.Layout(width="400px"),
),
x0=widgets.FloatSlider(
value=0.0,
min=-2 * π,
max=2 * π,
step=0.01,
description="x0",
layout=widgets.Layout(width="400px"),
),
N=widgets.fixed(800),
)Reuse
Copyright
2025, UK Centre for Ecology & Hydrology Research Software Engineering Group