In [None]:
%matplotlib widget
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets

In [None]:
g1 = np.random.multivariate_normal(mean=[-3.3, 1.], cov=[[1,-.2],[-.2,1]], size=(100)) # blue
g2 = np.random.multivariate_normal(mean=[0.3, -3.], cov=[[2.5,-.5],[-.5,0.5]], size=(100)) # orange
g3 = np.random.multivariate_normal(mean=[3, -1.], cov=[[0.5,1],[1,3]], size=(100)) # green

In [None]:
# --- set up figure ---
fig, ax = plt.subplots(figsize=(6,6))
ax.set_xlim(-6,6)
ax.set_ylim(-6,6)
ax.set_aspect('equal', adjustable='box')
ax.grid(True, alpha=0.1)
ax.set_axisbelow(True)

colors = ["red", "blue", "green"]

# initial means
means = [np.array([0.,0.]), np.array([2.,2.]), np.array([-2.,2.])]

# initial covariances
covs = [
    np.array([[1.,0.],[0.,1.]]),
    np.array([[1.,0.],[0.,1.]]),
    np.array([[1.,0.],[0.,1.]])
]

# plot ellipses and scatter markers

gaussians = [ax.scatter(*g.T, color='grey') for g in [g1,g2,g3]]
ellipse_lines = [ax.plot([], [], color=c, lw=2)[0] for c in colors]
centers = [ax.scatter(m[0], m[1], color=c, marker="x", s=80, picker=True) for m,c in zip(means,colors)]

def ellipse_points(mu, cov):
    vals, vecs = np.linalg.eigh(cov)
    if np.any(vals <= 0):
        return np.array([]), np.array([])
    t = np.linspace(0, 2*np.pi, 200)
    circle = np.array([np.cos(t), np.sin(t)])
    ellipse = vecs @ np.diag(np.sqrt(vals)) @ circle
    return mu[0] + ellipse[0], mu[1] + ellipse[1]

def redraw():
    for mu, cov, line, center in zip(means, covs, ellipse_lines, centers):
        x, y = ellipse_points(mu, cov)
        line.set_data(x, y)
        center.set_offsets([mu[0], mu[1]])
    fig.canvas.draw_idle()

# --- dragging logic ---
dragging_idx = None

def on_pick(event):
    global dragging_idx
    for i, c in enumerate(centers):
        if event.artist == c:
            dragging_idx = i

def on_motion(event):
    global dragging_idx
    if dragging_idx is None: return
    if event.inaxes != ax: return
    means[dragging_idx] = np.array([event.xdata, event.ydata])
    redraw()

def on_release(event):
    global dragging_idx
    dragging_idx = None

fig.canvas.mpl_connect("pick_event", on_pick)
fig.canvas.mpl_connect("motion_notify_event", on_motion)
fig.canvas.mpl_connect("button_release_event", on_release)

# --- covariance sliders ---
controls = {
    "red11": widgets.FloatSlider(description="red11", min=0.1, max=5, step=0.1, value=1),
    "red12": widgets.FloatSlider(description="red12", min=-2, max=2, step=0.1, value=0),
    "red22": widgets.FloatSlider(description="red22", min=0.1, max=5, step=0.1, value=1),

    "blue11": widgets.FloatSlider(description="blue11", min=0.1, max=5, step=0.1, value=1),
    "blue12": widgets.FloatSlider(description="blue12", min=-2, max=2, step=0.1, value=0),
    "blue22": widgets.FloatSlider(description="blue22", min=0.1, max=5, step=0.1, value=1),

    "green11": widgets.FloatSlider(description="green11", min=0.1, max=5, step=0.1, value=1),
    "green12": widgets.FloatSlider(description="green12", min=-2, max=2, step=0.1, value=0),
    "green22": widgets.FloatSlider(description="green22", min=0.1, max=5, step=0.1, value=1),
}

def update(red11,red12,red22, blue11,blue12,blue22, green11,green12,green22):
    covs[0] = np.array([[red11,red12],[red12,red22]])
    covs[1] = np.array([[blue11,blue12],[blue12,blue22]])
    covs[2] = np.array([[green11,green12],[green12,green22]])
    redraw()

out = widgets.interactive_output(update, controls)

ui = widgets.VBox([
    widgets.HBox([controls["red11"], controls["red12"], controls["red22"]]),
    widgets.HBox([controls["blue11"], controls["blue12"], controls["blue22"]]),
    widgets.HBox([controls["green11"], controls["green12"], controls["green22"]]),
])

display(ui, out)
redraw()
