In this post I cover how you can make scatter plots using the most popular data visualization libraries in Python. These are Pandas .plot method, Matplotlib, Seaborn, plotly-express and Plotnine.

A common issue with scatter plots is overplotting, so I’m going to demonstrate how you can handle it using 2d histograms. These are variations of the scatter plot that work well when you have a large dataset.

I’ll also cover how you can do small multiples with the libraries that support it.

Getting the data

For this example we will be working with the diamonds dataset. We will try to “explain” the price of diamonds using the “carat”, this numeric variable represents the weight of the diamonds.

import pandas as pd
import numpy as np

import plotly.express as px
import seaborn as sns
import matplotlib.pyplot as plt
from plotnine import *

# Read the diamonds dataset
df = pd.read_csv("https://raw.githubusercontent.com/martinbel/datasets/master/diamonds.csv")
df.head()
Diamonds Dataset

1. Pandas Scatter Plot

The default plot of pandas looks like this. Not bad for a quick plot but there is a good amount of overplotting.

df.plot(x='carat', 
        y='price', 
        kind='scatter', 
        title="Relation between Carat and Price of Diamonds");
Default Scatter Plot

We have two options to improve this plot.

  • 1. Using the alpha argument: This parameter controls the opacity of each point.
  • 2. Using the s argument: This parameter controls the size of each point. We can make the plots smaller and that should help a bit.

This plot is a bit better but still doesn’t seem to be the best way to visualize this data. With the pandas .plot method it’s difficult to map a color to the points, this is why I generally will use other libraries for more complex plots.

df.plot(x='carat', 
        y='price', 
        kind='scatter', 
        alpha=0.1, s=2,
        title="Relation between Carat and Price of Diamonds");
Pandas Scatter Plot – Improved

2. Matplotlib Scatter Plot

The pandas plots are actually using matplotlib. This library is the default backend of pandas for data visualization. However, we can use matplotlib directly to create the same scatter plot. This is how to do it.

plt.scatter(x=df.carat, y=df.price, alpha=0.5, s=2)
plt.title("Relation between Carat and Price of Diamonds");

Another option with matplotlib is doing whats called a “Hexagonal heatmap of 2d bin counts”. To use a shorter expression I’ll call this a “2d hexbin histogram”.

Here is how to do it with matplotlib. There is a lot going on in this code, but we are basically looping over the cut values and calling ax.hexbin to make this “2d hexbin histogram” plot.

# Define values of grouping variable
group_values = list(df.cut.unique())

# set number of columns in the plot
ncols = 3

# calculate number of rows in the plot
nrows = len(group_values) // ncols + (len(group_values) % ncols > 0)

# Define the plot 
plt.figure(figsize = (12, 8))
plt.subplots_adjust(hspace=0.25)
plt.suptitle("Relation between Carat & Price", fontsize=16, y=0.95)

for n, group_val in enumerate(group_values):
    # add a new subplot at each iteration using nrows and cols
    ax = plt.subplot(nrows, ncols, n + 1)
    
    # Filter the dataframe data for each state
    df_temp = df.query("cut == @group_val")
    
    # Hexbin chart
    ax.hexbin(x = df_temp.carat, 
              y = df_temp.price, 
              gridsize = 40, 
              bins = "log",
              cmap = "Blues")
    
    # chart formatting
    ax.set_title(group_val)
    ax.set_xlabel("")
2d Hexbin Histogram – Matplotlib

This plot is probably the best solution to the overplotting issue. Even though we need to write a decent amount of code to make it work, I think it’s worth it.

3. Seaborn Scatter Plot

The Seaborn library has a good trade-off between simplicity of use and visually appealing plots.

This is how you can do a simple scatter plot with Seaborn. In this example, I’ll color the points using the cut variable. This is quite simple to do using the hue parameter in Seaborn whereas with Pandas or matplotlib it’s more difficult.

sns.scatterplot(df, x='carat', y='price', 
                alpha=0.2, s=6, # handle overplotting
                hue='cut')
Seaborn Colored Scatter Plot

Another option that’s supported in Seaborn to visualize this data is using small multiples or facets. You can do this with the sns.FacetGrid class in Seaborn.

The idea is we “condition” the linear relation between carat and price by a categorical variable, in this case, the cut of the diamonds.

g = sns.FacetGrid(df, 
                  col='cut', 
                  col_wrap=3,
                  sharex=True, sharey=False 
                 )

g.map(sns.scatterplot, 
      'carat', 
      'price', 
      alpha=0.2,  # Reduce Overplotting
      s=6         # Control the size of the marker
     )

g.fig.subplots_adjust(top=0.9);
g.fig.suptitle("Relation between Carat & Price by Diamond Cut");
Scatter Plot – Small Multiples / Facets

Finally, there is one more option in Seaborn that helps solve the overplotting issue. The idea is using the sns.histplot function instead of sns.scatterplot in the FacetGrid call. Literally this is all we need to change to try a different visualization.

g = sns.FacetGrid(df, 
                  col='cut', 
                  col_wrap=3,
                  hue='cut', 
                  sharex=False, sharey=False)

g.map(sns.histplot, 
      'carat', 
      'price', 
      bins=40
     )

g.fig.subplots_adjust(top=0.9);
g.fig.suptitle("Relation between Carat & Price by Diamond Cut");

This plot is a great solution to visualize this data. It’s similar to the hexbin plot we did with matplotlib but we are using squares.

4. Plotly Express Scatter Plot

If we were developing an interactive visualization, then I think Plotly is the best library available. However, there isn’t a simple way to do a 2d Hexbin Histogram in Python.

Plotly is one of the simplest to use libraries to make scatter plots. It’s even easier than Seaborn. If you want to do a small multiples plot, you can just map a colum to the facet_col argument.

fig = px.scatter(
    df, x='carat', y='price',
    color='cut',     
    opacity=0.5,     # Similar to alpha
    facet_col='cut', 
    facet_col_wrap=3
)

# better hover labels
fig.update_traces(marker={'size': 3})
fig.update_traces(hovertemplate=None)
fig.update_layout(hovermode="x")
Scatter Plot – Plotly Express

I think this is a great plot that is quit simple to make. However, we don’t have the functionality available in matplotlib or seaborn for this particular use case.

5. Plotnine Scatter Plot

Plotnine is a port of the popular R data visualization library, ggplot2.

This is how you can do a simple histogram with a color mapping to a categorical variable.

I would only recommend plotnine to people that are coming from R and don’t want to learn the rest of the libraries. If you come from R, this syntax is practically the same as you wrote in R. This is a very non-pythonic library!

The main idea is you map the data to x, y, color and then decide how you want to visualize it. In this case we are using geom_point but there are multiple geoms available.

(ggplot(df, 
        aes(x='carat', y='price', color='cut')) +
 geom_point(alpha=0.5, size=0.25) +
 coord_cartesian(ylim=(0, 20000))
)

It’s possible to do a small multiples plot also. Here is how to do it.

We are adding a call to facet_wrap with a formula inside. This formula is common in R and means “explained by cut”.

(ggplot(df, 
        aes(x='carat', y='price', color='cut')) +
 facet_wrap('~ cut', scales='y_free') +
 geom_point(alpha=0.2, size=0.25) +
 geom_smooth(method='lowess', se=False) +
 coord_cartesian(ylim=(0, 20000)) +
 theme(figure_size=(10, 6))
)
Scatter Plot – Plotnine – Python ggplot2 port

The terminology of facet grid and facet wrap come from the ggplot2 book. This book explains the logic on how ggplot was developed and the main ideas behind it.

Conclusion

That covers how to make a scatter plots using the most popular data visualization libraries in Python.

I think they all have their positive and negative points and I think it’s good to be familiar with all of them. The idea is to be able to use the best tool for the job, and for data visualization in Python it means being familiar with multiple libraries.

Did I mention I have a YouTube channel where I cover data science topics? In this video I explain in more detail what I covered in this post.


Leave a Reply

Your email address will not be published. Required fields are marked *