Skip to content

Matplotlib

Introduction

Matplotlib is a plotting library for Python. It is a very powerful library and is used by many other Python libraries, such as Pandas, Seaborn, and Plotly.

Matplotlib is a relatively low-level library, meaning that it provides lots of flexibility to create custom plots, while at the same time it requires more code to do so. Other libraries, such as Seaborn, are built on top of Matplotlib and provide a higher-level interface.

Note

This article is just an introduction to matplotlib, and is based on the Matplotlib documentation. For more details, please refer to the documentation.

Plotting with subplots

Typically, we import Matplotlib as follows and then create a subplots object:

import matplotlib.pyplot as plt

fig, ax = plt.subplots()

subplots is a function that returns a Figure object and an Axes object. You can imagine the Figure object as a canvas that contains all the plots that we want to create, while the Axes object is the actual plot (the blank sheet you can plot and hold your data), and is synonymous with the term subplot (because each axes can have several subplots).

Note

To specify the number of subplots that we want to create, we can pass the nrows and ncols arguments to the subplots function. For example, to create a figure with 2 rows and 2 columns, we can do:

fig, ax = plt.subplots(2, 2)

After we have called subplots, we can use the plot method of the Axes object to plot some data with a line connecting the different data points. The input to the plot method is typically two arrays, one for the x-axis and one for the y-axis. For example, if we have time-series data, we can plot the time on the x-axis and the values on the y-axis:

ax.plot(climate_change['date'], climate_change['co2'])
ax.set(
    title='Amount of CO2 (ppm) in each year', 
    xlabel='Year',
    ylabel='Amount of CO2 (ppm)'
)

Note

In the last example, the input to the plot method were two columns of a Pandas DataFrame. However, the plot method can also take two arrays as input. Basically, matplotlib .plot(...) requires the first two arguments to be of array-like type (e.g. lists, NumPy arrays, etc.), and the rest of the arguments are optional.

If we are creating several subplots at the same time, we can use the ax argument to specify which subplot we want to plot on:

fig, axs = plt.subplots(2, 2)
axs[0, 0].plot(x, y)
axs[0, 0].set_title('Axis [0, 0]')
axs[0, 1].plot(x, y, 'tab:orange')
axs[0, 1].set_title('Axis [0, 1]')
axs[1, 0].plot(x, -y, 'tab:green')
axs[1, 0].set_title('Axis [1, 0]')
axs[1, 1].plot(x, -y, 'tab:red')
axs[1, 1].set_title('Axis [1, 1]')

Image title

4 subplots in a single figure.

Among other parameters, .subplots() have two parameters to specify the grid size. nrows and ncols are used to point out the number of rows and columns we need respectively.

.plot() parameters

The .plot() method has a lot of parameters that can be used to customize the plot. For example, we can change the color, width and the line style of the plot:

import matplotlib.pyplot as plt
import numpy as np

# Fixing random state for reproducibility
np.random.seed(19680801)

N = 50
x = np.random.rand(N)
y = np.random.rand(N)
colors = np.random.rand(N)
area = (30 * np.random.rand(N))**2  # 0 to 15 point radii

fig, axs = plt.subplots(1, 1)
ax.scatter(x, y, s=area, c=colors, alpha=0.5)

The color can be specified in different ways. For example, we can use the name of the color, as in the previous example, or we can use the hexadecimal code of the color:

ax.plot(x, y, color='#eeefff')

Other types of plots

Scatter plots

To create a histogram of the distribution of the data for a single 1D array, we can use the hist method:

fig, axs = plt.subplots(1, 1)

ax.scatter(x, y, color='green'')

Image title

A scatter plot example.

Histograms

To create a histogram of the distribution of the data for a single 1D array, we can use the hist method:

fig, axs = plt.subplots(1, 2)

n_bins = 20

axs[0].hist(dist1, bins=n_bins)
axs[1].hist(dist2, bins=n_bins)

Image title

A double histogram example.

Bar plots

To create a bar plot, we can use the bar method:

fig, ax = plt.subplots(1, 1)

# Defines X-axis labels and Y-axis values
fruits = ['apple', 'blueberry', 'cherry', 'orange']
counts = [40, 100, 30, 55]
bar_labels = ['red', 'blue', '_red', 'orange']
bar_colors = ['tab:red', 'tab:blue', 'tab:red', 'tab:orange']

ax.bar(fruits, counts, label=bar_labels, color=bar_colors)

Image title

A bar plot example.

More types of plots

You can find a list of all the different types of plots that can be created with Matplotlib in the Matplotlib gallery.

Show figure

Once the plot is ready, we can show the figure with the show method:

plt.show()
This will open a new window with the figure.

Save figure

Finally, we can also save the figure as a file (e.g., a png or an svg file) with the savefig method:

fig.savefig('co2_levels.png')

Minimal working example

import matplotlib.pyplot as plt
import numpy as np

# Create data for plotting. Any pair of 1D data can be used here,
# for example two columns of a Pandas DataFrame.
t = np.arange(0.0, 2.0, 0.01)
s = 1 + np.sin(2 * np.pi * t)

# Create figure and axes
fig, ax = plt.subplots()

# Plot data
ax.plot(t, s)

# Customize plot
ax.set(xlabel='time (s)', ylabel='voltage (mV)',
       title='About as simple as it gets, folks')

# add a grid and legend
ax.grid()
ax.legend()

# save as file and show window with plot
fig.savefig("test.png")
plt.show()

Other options to customize the plot

Legends

There are many other options that can be used to customize the plot. For example, we can add a legend to the plot with the legend method:

ax.legend()

The legend method will use the label argument of the plot method to create the legend. We can also specify the location of the legend with the loc argument:

ax.legend(loc='upper center')

xlim and ylim

Another useful method is set_xlim and set_ylim, which sets the limits of the x-axis and y-axis. They are used to only show a part of the plot, for example:

ax.set_xlim([1980, 1990])
ax.set_ylim([0, 2.5])

Customizing the font

We can customize the fonts of the plot defining a fontdict dictionary and passing it to the set_XXX methods:

fontdict={'fontsize': 18, 'fontweight': 'bold', 'color': 'blue', 'family': 'serif'}

ax.set_title('Amount of CO2 (ppm) in each year', **fontdict)
ax.set_xlabel('Year', **fontdict)
ax.set_ylabel('Amount of CO2 (ppm)', **fontdict)

Another way to customize the font is to use the rcParams dictionary. This will change the default font for all the plots. For example, to change the font size and family, we can do:

plt.rcParams.update({'font.size': 18, 'font.family': 'serif'})

Customizing ticks

We can customize the ticks of the plot (i.e., what numbers are shown in the x-axis and y-axis) with the set_xticks and set_yticks methods:

ax.set_xticks([1980, 1990, 2000, 2010, 2020])
ax.set_yticks([0, 1, 2, 3, 4, 5])

Grid

We can also add a grid to the plot with the grid method:

ax.grid()

Image title

Sample figure with a grid.