1 Oct 2016

Additive Blending in Matplotlib


A short tutorial on additive blending using matplotlib and numpy.

A problem I have encountered several times is how to nicely plot a large amount of data points in matplotlib and be able to nicely show the density of the data points. The normal method is simple: just plot the data points with their opacity set down, and then you can see the data points overlap.

The issue I have with this is that at some point after your stack a few data points, you reach alpha=1 and then you may as well not both to show any more data points. Here is an example of this, when I simply try and plot twenty thousand realisations of a simple multivariate normal.

import numpy as np
import matplotlib.pyplot as plt

# Sample our data
n = 20000
np.random.seed(0)
x = np.random.normal(size=n)
y = np.random.normal(size=n)

# Create a figure with the axis filling it, with a black background
fig = plt.figure(figsize=(8,6))
ax = fig.add_subplot(111, axisbg="#000000")
fig.subplots_adjust(0, 0, 1, 1)

# Plot our data points
ax.scatter(x, y, alpha=0.3, c="b", lw=0)

# Save it out, without whitespace and labels getting in the way
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
fig.savefig("bad.png", pad_inches=0, bbox_inches="tight")

This leaves a lot to be desired. So what I wanted to do was to implement additive blending, a technique I use a lot in Photoshop and After Effects. However, search as I might it seems that this not something that simply comes with matplotlib. My initial query to StackOverflow went unanswered, but I did find a similar question here, and my implementation closely follows this.

So what have to do the additive blending ourselves if we want it, which is a pain. Straightforward, but a pain nonetheless. Example code is shown below, and hopefully the commenting on it explains what I am doing sufficiently. The gist of it is that you have to render layers yourself, stack them manually and then display the final image buffer using imshow.

import numpy as np
import matplotlib.pyplot as plt

# Sample our data
n = 20000
np.random.seed(0)
x = np.random.normal(size=n)
y = np.random.normal(size=n)

# Create a figure with the axis filling it, with a black background
fig = plt.figure(figsize=(8,6))
ax = fig.add_subplot(111, axisbg="#000000")
fig.subplots_adjust(0, 0, 1, 1)

# Draw the empty axis, which we use as a base.
fig.canvas.draw()
w, h = fig.canvas.get_width_height()
buffer = np.frombuffer(fig.canvas.buffer_rgba(), np.uint8)
first = buffer.astype(np.int16).reshape(h, w, -1) #int16 so we dont overflow
first[first[:, :, -1] == 0] = 0 # Set transparent pixels to 0

layers = 20 # Number of layers to add. Higher numbers are better but slower
for i in range(layers):
    ax.clear()
    ax.patch.set_facecolor("#000000")
    ax.scatter(x[i::layers], y[i::layers], alpha=0.3, c="#1E3B9C", lw=0)
    fig.canvas.draw()
    img = np.frombuffer(fig.canvas.buffer_rgba(), np.uint8).astype(np.int16).reshape(h, w, -1)
    img[img[:, :, -1] == 0] = 0
    first += img # Add these particles to the main layer
    
first = np.clip(first, 0, 255) # Clip buffer back to int8 range

ax.clear()
plt.axis("off")
ax.imshow(first.astype(np.uint8), aspect='auto')
fig.savefig("good.png", pad_inches=0)

Same data points as the first plot, but now not only are the individual data points clearer overall, even in the dense sections, but we also have a better gauge of the distribution due to the brighter center.