dev-resources.site
for different kinds of informations.
Plotting and Data Visualization with Matplotlib
Working with raw data in the form of a CSV (comma-separated value) does not visually tell a story. However, if done right with a visualization library like Matplotlib, your users tend to appreciate you because they can connect the dots easily with visuals.
This article is an introduction to using Matplotlib for plotting and data visualizations.
GitHub Repo
Check the complete source code in this repo.
What is Matplotlib
Matplotlib is a Python plotting library that allows you to turn data into pretty visualizations, also known as plots or figures.
The following reasons are why Matplotlib is necessary for data scientists:
- It is built on NumPy arrays (and Python)
- Integrates directly with Pandas
- Can create basic or advanced plots
Importing Matplotlib
To start with Matplotlib, import it into your Jupyter Notebook like this:
%matplotlib inline
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
-
%matplotlib inline
: this magic command with the percentage sign in front of matplotlib helps make sure all matplotlib plots and graphs appear within the notebook
The simplest way to create a plot is with:
plt.plot();
Let's add some data to the plot:
x = [1, 2, 3, 4]
y = [11, 22, 33, 44]
plt.plot(x, y);
This code shows the plot values on the plot figure's x
and y
axes.
The recommended way of plotting a graph is using this method, which should give the same result as the previous method:
fig, ax = plt.subplots()
ax.plot(x, y);
Note: Changing the x
and y
should return an entirely different graph.
Anatomy of Matplotlib Plot
The representation of a typical workflow of a Matplotlib figure includes:
- A plot axes title
- Legend
- y-axis label
- x-axis label
Let's see an example of the workflow.
# 0. import matplotlib and get it ready for plotting in Jupyter
%matplotlib inline
import matplotlib.pyplot as plt
# 1. Prepare data
x = [1, 2, 3, 4]
y = [11, 22, 33, 44]
# 2. Setup plot
fig, ax = plt.subplots(figsize=(10, 10))
# 3. Plot data
ax.plot(x, y)
# 4. Customize plot
ax.set(title = "Simple plot",
xlabel = "x-axis",
ylabel = "y-axis")
# 5. Save & show (you have to save the whole figure)
fig.savefig("images/sample-plot.png")
The code above shows that you can set a title with the ax.set()
method and save the plot as a .png
file in the images folder.
Creating Figures with NumPy arrays
In this section, you will create different plots like scatter and bar, but there are others like histograms, lines, and subplots.
Copy-paste this code in your notebook:
import numpy as np
x = np.linspace(0, 10, 100)
x[:10]
linspace
: returns evenly spaced numbers over a specified interval. Also, the index of x
displays only the first ten results.
Plot the data and create a line plot:
fig, ax = plt.subplots()
ax.plot(x, x**2);
You should see something like this:
For a scatter plot, use the same data from above:
fig, ax = plt.subplots()
ax.scatter(x, np.exp(x));
Note: Instead of using .plot()
on ax
axes, switch to using .scatter()
.
Working with dictionaries and making a plot:
nut_butter_prices = {"Almond butter": 10, "Peanut butter": 9, "Cashew butter": 5}
fig, ax = plt.subplots()
ax.bar(nut_butter_prices.keys(), nut_butter_prices.values())
ax.set(title = "Teri's Nut Butter Store",
ylabel = "Price ($)"
);
Horizontal Bar
Another way of creating a plot is plotting a horizontal bar with .barh
.
fig, ax = plt.subplots()
ax.barh(list(nut_butter_prices.keys()), list(nut_butter_prices.values()));
Subplots and Histograms
You can turn a single figure into subplots of four equal parts with this code:
# Subplots option 1
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(
nrows = 2,
ncols = 2,
figsize = (10, 5)
)
# Plot to each different axes
ax1.plot(x, x / 2);
ax2.scatter(np.random.random(10), np.random.random(10));
ax3.bar(nut_butter_prices.keys(), nut_butter_prices.values());
ax4.hist(np.random.randn(1000));
# subplots option 2
fig, ax = plt.subplots(nrows = 2,
ncols = 2,
figsize = (10, 5))
# Plot to each different index
ax[0, 0].plot(x, x/2);
ax[0, 1].scatter(np.random.random(10), np.random.random(10));
ax[1, 0].bar(nut_butter_prices.keys(), nut_butter_prices.values());
ax[1, 1].hist(np.random.randn(1000));
Plotting from Pandas DataFrame
This section will show you how to use the Pandas DataFrame to visualize data using a .csv file.
Before using an imported to read and use it, first import the pandas library:
import pandas as pd
Make a DataFrame with this command:
car_sales = pd.read_csv("car_sales.csv")
car_sales
Reading the car sales data is saved in the root directory of the main Python notebook. But if you save it in a folder, you must reference it in the .read_csv()
method.
To remove the $ sign and turn it into an integer data type, run this command, which is in regex:
car_sales['Price']=car_sales['Price'].str.replace('$','',regex=False).str.replace(',','',regex=False).astype(float).astype(int)
car_sales
Add a Sale Date Column:
car_sales["Sale Date"] = pd.date_range("1/1/2023", periods=len(car_sales))
car_sales
Add Total Sales Column:
car_sales["Total Sales"] = car_sales["Price"].cumsum()
car_sales
Plot the Total Sales:
car_sales.plot(x = "Sale Date", y = "Total Sales")
Repeat the same process to plot with any column axis just like this:
car_sales.plot(x="Odometer (KM)", y = "Price", kind = "scatter");
In Summary
Matplotlib creates beautiful visualization depending on what you want to achieve, as it is rich with various options to spice up your data and make it visually appealing.
Resources
Featured ones: