s01: Plotting#

This is a quick introduction to Matplotlib.

From Claus O. Wilke: “Fundamentals of Data Visualization”:

One thing I have learned over the years is that automation is your friend. I think figures should be autogenerated as part of the data analysis pipeline (which should also be automated), and they should come out of the pipeline ready to be sent to the printer, no manual post-processing needed.

Objectives of this session:#

  • Be able to create simple plots with Matplotlib and tweak them

  • Know about object-oriented vs pyplot interfaces of Matplotlib

  • Be able to adapt gallery examples

  • Know how to look for help

  • Know that other tools exist

Repeatability/reproducibility#

  • No manual post-processing. This will bite you when you need to regenerate 50 figures one day before a deadline or regenerate a set of figures after changes in your analysis.

  • Within Python, many libraries exist:

    • Matplotlib: probably the most standard and most widely used

    • Seaborn: high-level interface to Matplotlib, statistical functions built in

    • Altair: declarative visualization (R users will be more at home), statistics built in

    • Plotly: interactive graphs

    • Bokeh: also here good for interactivity

    • ggplot: R users will be more at home

Why are we starting with Matplotlib?#

Matplotlib is perhaps the most “standard” Python plotting library. Many libraries build on top of Matplotlib. Even if you choose to use another library (see above list), chances are high that you need to adapt a Matplotlib plot of somebody else.

x/y-Plots#

One of the most important functions in Matplotlib is plot. A simple line graph can be obtained as follows:

import matplotlib.pyplot as plt
import numpy as np

y = np.arange(10, 100, 10)
print(y)
plt.plot(y)
[10 20 30 40 50 60 70 80 90]
[<matplotlib.lines.Line2D at 0x23c30a18490>]
../_images/01-Plotting_5_2.png

You can also plot this data as a point graphic and only have to pass appropriate arguments here. In the example below, an "o" is passed as the format argument, which changes the representation of the data accordingly. In this format string you can change the symbols as well as the colors. In the documentation there is an overview of all possible values.

plt.plot(y, "o")
[<matplotlib.lines.Line2D at 0x23c30a64b50>]
../_images/01-Plotting_7_1.png

Histograms#

With the function plt.hist you can create histograms.

mu, sigma = 100, 15
x = mu + sigma * np.random.randn(10000)
plt.hist(x)
(array([  33.,  187.,  628., 1655., 2637., 2480., 1564.,  635.,  157.,
          24.]),
 array([ 48.10828174,  58.55164457,  68.99500741,  79.43837025,
         89.88173308, 100.32509592, 110.76845876, 121.2118216 ,
        131.65518443, 142.09854727, 152.54191011]),
 <BarContainer object of 10 artists>)
../_images/01-Plotting_9_1.png
plt.hist(x, bins=50, density=True)
plt.title("Histogramm")
Text(0.5, 1.0, 'Histogramm')
../_images/01-Plotting_10_1.png

Getting more comfortable with Matplotlib#

In the previous plots the x-axis was created automatically - but you can also specify it explicitly and thus create a scatterplot. Create a plot using matplotlib.pyplot.subplots, matplotlib.axes.Axes.scatter, and some other methods on the matplotlib.axes.Axes object:

# this line tells Jupyter to display matplotlib figures in the notebook
%matplotlib inline

import matplotlib.pyplot as plt

# this is dataset 1 from
# https://en.wikipedia.org/wiki/Anscombe%27s_quartet
data_x = [10.0, 8.0, 13.0, 9.0, 11.0, 14.0, 6.0, 4.0, 12.0, 7.0, 5.0]
data_y = [8.04, 6.95, 7.58, 8.81, 8.33, 9.96, 7.24, 4.26, 10.84, 4.82, 5.68]

fig, ax = plt.subplots() # this creates the figure objects we are working with

ax.scatter(x=data_x, y=data_y, c="#E69F00") # #NNNNNN is the HEX color representation, you can create color themes here: https://color.adobe.com/de/create/color-wheel

ax.set_xlabel("we should label the x axis")
ax.set_ylabel("we should label the y axis")
ax.set_title("some title")
Text(0.5, 1.0, 'some title')
../_images/01-Plotting_12_1.png

Matplotlib has two different interfaces#

The more traditional option uses the pyplot interface (plt.<matplotlib.pyplot> carries the global settings):

import matplotlib.pyplot as plt

# this is dataset 1 from
# https://en.wikipedia.org/wiki/Anscombe%27s_quartet
data_x = [10.0, 8.0, 13.0, 9.0, 11.0, 14.0, 6.0, 4.0, 12.0, 7.0, 5.0]
data_y = [8.04, 6.95, 7.58, 8.81, 8.33, 9.96, 7.24, 4.26, 10.84, 4.82, 5.68]

plt.scatter(x=data_x, y=data_y, c="#E69F00")

plt.xlabel("we should label the x axis")
plt.ylabel("we should label the y axis")
plt.title("some title")
Text(0.5, 1.0, 'some title')
../_images/01-Plotting_15_1.png

When searching for help on the internet, you will find both approaches, they can also be mixed. Although the pyplot interface looks more compact, recommend to learn and use is the object oriented interface.

Subplots#

Note that here one can plot directly with the axes method ax.plot and not as before with plt.plot. If there are several plots in a figure, plt.subplots returns a list of axes. In these you can now create the desired plots:

fig, axes = plt.subplots(1, 3, figsize=(10, 3))  # 1 row, 3 columns
axes[0].plot([1, 2, 3], "o")  # left
axes[1].hist(np.random.randn(100))  # middle
axes[2].plot(np.sin(np.arange(0, 2 * np.pi, 1/360)))  # right
[<matplotlib.lines.Line2D at 0x23c31e15160>]
../_images/01-Plotting_18_1.png

Styling and customizing plots#

  • Do not customize “manually” using a graphical program (not easily repeatable/reproducible).

  • Matplotlib and also all the other libraries allow to customize almost every aspect of a plot.

  • It is useful to study Matplotlib parts of a figure so that we know what to search for to customize things.

  • Matplotlib cheatsheets: https://github.com/matplotlib/cheatsheets

  • You can also select among pre-defined themes/style sheets with matplotlib.style.use, for instance:

plt.style.use('ggplot')
import numpy as np 

# We observe a bunch of squirrels with [height (cm), weight (grams)] pairs
data = np.array([[10., 600.], [16., 1200], [6., 800], [12., 700.], [17., 1400.],
                [8., 500.], [20., 1500.], [21., 1300.], [11., 800.], [18., 1100.]])
# Visualize our data!
fig, ax = plt.subplots()
ax.plot(data[:, 0], data[:, 1], '.', ms=12)
ax.set(xlabel='Height (cm)', ylabel='Weight (g)',
    xlim=[0, 25], ylim=[300, 1700])
[Text(0.5, 0, 'Height (cm)'),
 Text(0, 0.5, 'Weight (g)'),
 (0.0, 25.0),
 (300.0, 1700.0)]
../_images/01-Plotting_22_1.png
plt.style.use('seaborn')

# Visualize our data!
fig, ax = plt.subplots()
ax.plot(data[:, 0], data[:, 1], '.', ms=12)
ax.set(xlabel='Height (cm)', ylabel='Weight (g)',
    xlim=[0, 25], ylim=[300, 1700])
[Text(0.5, 0, 'Height (cm)'),
 Text(0, 0.5, 'Weight (g)'),
 (0.0, 25.0),
 (300.0, 1700.0)]
../_images/01-Plotting_23_1.png

You can control different from your axis object by either addresing all at once using ax.set( key=value ) or individually by ax.set_key(value).

fig, ax = plt.subplots()
ax.plot(data[:, 0], data[:, 1], '.', ms=12)
ax.set(xlabel='Height (cm)', ylabel='Weight (g)', title='Squirrel data')
[Text(0.5, 0, 'Height (cm)'),
 Text(0, 0.5, 'Weight (g)'),
 Text(0.5, 1.0, 'Squirrel data')]
../_images/01-Plotting_25_1.png
fig, ax = plt.subplots()
ax.plot(data[:, 0], data[:, 1], '.', ms=12)
ax.set_xlabel("Height (cm)")
ax.set_ylabel("Weight (g)")
ax.set_title ("Squirrel data")
Text(0.5, 1.0, 'Squirrel data')
../_images/01-Plotting_26_1.png
For all the different presets of styles in Matplotlib see: the gallery.

Exercises#

First load data from the scikit package as a pandas dataframe. You can adress each column in a dataframe via df["columnname"].

from sklearn.datasets import load_iris
import pandas as pd

iris = load_iris()
df = pd.DataFrame(iris['data'], columns=iris['feature_names'])
df['species'] = pd.Categorical.from_codes(iris.target, iris.target_names)

print(df.columns)
Index(['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)',
       'petal width (cm)', 'species'],
      dtype='object')
# get one feature
length_sep = df["sepal length (cm)"]
Task 1.6: Make a scatterplot of sepal length against petal length and color each dot according to its species (numeric keys can be found in iris.target) (2 points). Add labels to the x and y axis (1 point).

Task 1.7: This is a great exercise which is very close to real life. (3 points)

Your task is to select one visualization library (some need to be installed first - indoubt choose Matplotlib or Seaborn since they are part of Anaconda installation):

(i) Matplotlib: probably the most standard and most widely used

(ii) Seaborn: probably the most standard and most widely used

(iii) ggplot: probably the most standard and most widely used

  • Browse the various example galleries (links above).

  • Select one example that simply interests you.

  • First try to reproduce this example in the Jupyter notebook.

  • Then try to print out the data that is used in this example just before the call of the plotting function to learn about its structure. Is it a pandas dataframe? Is it a NumPy array? Is it a dictionary? A list or a list of lists?