How to make a Scatter Plot using matplotlib in Python
What is matplotlib?
matplotlib is a python based object oriented plotting library used for data visualization. It produces publication quality plots in various formats.
Graphically plotting of data is a very important step in statistical data analysis. Scatter plots help in visualizing the correlation of variables. The relationship of the variables can be best displayed with scatter plots.
This article gives you a head start in plotting a scatter plot using matplotlib and python.
Following is a scatter plot generated using matplotlib:
Source : National Geographic
Python Code
#importing the required libraries
import matplotlib
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
import matplotlib.mlab as mlab
# Read data from a CSV file. Click here to download.
r = mlab.csv2rec(‘HealthExpenditure.csv’)
# Create a figure with size 6 x 6 inches.
fig = Figure(figsize=(6,6))
# Create a canvas and add the figure to it.
canvas = FigureCanvas(fig)
# Create a subplot.
ax = fig.add_subplot(111)
# Set the title.
ax.set_title(“Health Expenditure Across The World”,fontsize=14)
# Set the X Axis label.
ax.set_xlabel(“Expenditure per person (US Dollars)”,fontsize=12)
# Set the Y Axis label.
ax.set_ylabel(“Average Life Expectancy at Birth (Years)”,fontsize=12)
# Display Grid.
ax.grid(True,linestyle=’-',color=’0.75′)
# Generate the Scatter Plot.
ax.scatter(r.expenditure,r.life_expectancy,s=20,color=’tomato’);
# Save the generated Scatter Plot to a PNG file.
canvas.print_figure(‘healthvsexpense.png’,dpi=500)
Step by Step
r = mlab.csv2rec(‘HealthExpenditure.csv’)
mlab consists of functions that are written to be compatible with MATLAB commands having the same names. Here, we use the function csv2rec() to get data from a csv file in a record array format.
fig = Figure(figsize=(6,6))
Figure is a container that is used to define the perimeter of the plots. A Figure could contain multiple Axes. The figsize attribute defines the dimensions of the figure. The format of figsize is figsize=(w,h) where w is the width and h is the height of the figure. The unit of measurement for the dimensions is inches.
canvas = FigureCanvas(fig)
FigureCanvas is a container to hold the Figure instance. The primary purpose of a FigureCanvas is to render the figure.
ax = fig.add_subplot(111)
As mentioned earlier a figure could have multiple plots called subplots. To add a subplot to a figure, the add_subplot() is used. The parameter 111 specifies 1 row, 1 column of subplot #1.
ax.set_title(“Health Expenditure Across The World”,fontsize=14)
ax.set_xlabel(“Expenditure per person (US Dollars)”,fontsize=12)
ax.set_ylabel(“Average Life Expectancy at Birth (Years)”,fontsize=12)
The above three statements are used set the text for the title, x axis label and y axis label respectively. The text can be formatted with attributes available in the matplotlib.text class.
ax.grid(True,linestyle=’-',color=’0.75′)
The grid() function is used to set the horizontal and vertical gridlines on the plot. Alternatively you could set only the horizontal or vertical gridlines using ax.yaxis.grid() or ax.xaxis.grid().
ax.scatter(r.expenditure,r.life_expectancy,s=20,color=’tomato’);
The function scatter() is the main function that plots the Scatter Plot. In the example the x axis isĀ r.expenditure, yaxis is r.life_expectancy, s is the size of the point and color is the color of the point. For a complete list of paramters for scatter()refer to scatter().
canvas.print_figure(‘healthvsexpense.png’,dpi=500)
The print_figure function of the FigureCanvas class is used to generate an image file of the plot. The above statement will generate a .PNG file with a resolution of 500 dots per inch.
My Development Setup
Ubuntu 9.04
Python 2.6.2
matplotlib 0.98.5.2-1ubuntu3
References
The official matplotlib site
Ubuntu 9.04
Python 2.6.2
matplotlib 0.98.5.2-1ubuntu3


