import logging
from rich.logging import RichHandler
logging.basicConfig(="INFO",
levelformat="%(message)s",
="[%X]",
datefmt=[RichHandler(rich_tracebacks=True)],
handlers
)
= logging.getLogger() logger
Interactive plots with Jupyter and ipywidgets
Summary
This tutorial demonstrates how to create interactive plots in Jupyter notebooks. This is especially useful for workshops.
from typing import Callable
import numpy as np
import matplotlib.pyplot as plt
import sympy
from ipywidgets import interact
import ipywidgets as widgets
# For easy reading
= np.pi π
%matplotlib widget
# Define the sympy symbol
= sympy.symbols("x")
x
# A subset of SymPy functions that are (mostly) differentiable
= {
ALLOWED_FUNCS "x": x,
"sin": sympy.sin,
"cos": sympy.cos,
"tan": sympy.tan,
"asin": sympy.asin,
"acos": sympy.acos,
"atan": sympy.atan,
"sinh": sympy.sinh,
"cosh": sympy.cosh,
"tanh": sympy.tanh,
"exp": sympy.exp,
"log": sympy.log,
"ln": sympy.log,
"sqrt": sympy.sqrt,
"pi": sympy.pi,
"π": sympy.pi,
}
def parse(expression_str: str) -> tuple[sympy.Expr, sympy.Expr]:
"""Parse a string into a sympy expression f(x) and its derivative f'(x)."""
try:
= sympy.sympify(expression_str.strip(), locals=ALLOWED_FUNCS)
f except Exception as e:
raise ValueError(f"Could not parse expression: {e}")
# Ensure it's a function of x only
= f.free_symbols
free_syms if not free_syms == {x}:
raise ValueError("Expression must be a function of x only.")
# Reject any undefined function applications
= f.atoms(sympy.core.function.AppliedUndef)
undefined if undefined:
= ", ".join(sorted(str(u.func) for u in undefined))
names = ", ".join(
allowed sorted(k for k in ALLOWED_FUNCS if callable(ALLOWED_FUNCS[k]))
)raise ValueError(f"Unknown function(s): {names}. Allowed: {allowed}")
= sympy.diff(f, x)
dfdx
return f, dfdx
def lambdify(expression: sympy.Expr) -> Callable[[np.ndarray], [np.ndarray]]:
return sympy.lambdify(x, expression, modules=["numpy"])
def safe_eval(func, xs):
"""
Evaluate a numpy-lambdified function with error handling.
Returns an array with non-finite values masked to np.nan.
"""
with np.errstate(all="ignore"):
return func(xs)
"""
y = np.array(func(xs), dtype=float)
# Ensure shape compatibility
if y.shape != xs.shape:
try:
y = np.broadcast_to(y, xs.shape)
except Exception:
y = np.full_like(xs, np.nan, dtype=float)
# mask non-finite
y[~np.isfinite(y)] = np.nan
return y
"""
= parse("sin(x)")
f_expr, dfdx_expr = map(lambdify, [f_expr, dfdx_expr])
f, dfdx
= plt.subplots()
fig, ax "x")
ax.set_xlabel("y")
ax.set_ylabel(
# NOTE: we cannot use `x` since this is reserved by sympy!
= np.linspace(-2 * π, 2 * π, 100)
X -2 * π, 2 * π)
ax.set_xlim(
= map(lambdify, parse("sin(x)"))
f, dfdx
=f"f(x)={f_expr}")
ax.plot(X, f(X), label=f"f'(x)={dfdx_expr}")
ax.plot(X, dfdx(X), label ax.legend()
def plot_with_slider():
= plt.subplots(figsize=(6.5, 4), constrained_layout=True)
fig, ax
def update(expression_str: str, domain: tuple[float, float], x0: float, N: int):
# Parse expression
try:
= parse(expression_str)
f_expr, dfdx_expr except Exception as e:
logging.error(e)return
# Check valid domain
= domain
xmin, xmax if xmin >= xmax:
"`xmin` must be strictly less than `xmax`")
logging.error(return
f"Valid expression: {f_expr}")
logging.info(
# Clear the canvas
# NOTE: this is a lazy way to do it - it would be better to remove elements individually
# The downside of being lazy is we have to re-draw everything, including the title
plt.cla()
"Functions and their tangents")
ax.set_title("$x$")
ax.set_xlabel("$y$")
ax.set_ylabel(
ax.set_xlim(xmin, xmax)
# Lambdify
= map(lambdify, [f_expr, dfdx_expr])
f, dfdx
# Compute and plot function
= np.linspace(xmin, xmax, N)
X = safe_eval(f, X)
y =f"$f(x) = {sympy.latex(f_expr)}$")
ax.plot(X, y, label
# Compute and plot tangent
# y0 = float(robust_eval(f_np, np.array([x0]))[0])
# m = float(robust_eval(df_np, np.array([x0]))[0])
= float(safe_eval(f, x0))
y0 = float(safe_eval(dfdx, x0))
m = 0.05 * (xmax - xmin) # ±5% on either side of x0
dx = x0 - dx, x0 + dx
x1, x2 = y0 + m * (x1 - x0)
y1 = y0 + m * (x2 - x0)
y2
ax.plot(
[x1, x2],
[y1, y2],="red",
color=2.5,
linewidth=rf"$x_0={x0:+.3f}$" + "\n" + rf"$f'(x_0)={m:+.3f}$",
label
)
# The point
"o", color="red")
ax.plot([x0], [y0],
="upper right")
ax.legend(loc
return update
= interact(
_
plot_with_slider(),=widgets.Text(
expression_str="sin(x)", description="f(x):", layout=widgets.Layout(width="400px")
value
),=widgets.FloatRangeSlider(
domain=(-2 * π, 2 * π),
valuemin=-4 * π,
max=4 * π,
=0.01,
step="Domain",
description=widgets.Layout(width="400px"),
layout
),=widgets.FloatSlider(
x0=0.0,
valuemin=-2 * π,
max=2 * π,
=0.01,
step="x0",
description=widgets.Layout(width="400px"),
layout
),=widgets.fixed(800),
N )
Reuse
Copyright
2025, UK Centre for Ecology & Hydrology Research Software Engineering Group