Last modified: Dec 13, 2024 By Alexander Williams

Python Matplotlib Legend: Enhance Plot Readability

When creating data visualizations with Matplotlib, adding legends is crucial for making your plots more informative and easier to understand. The plt.legend() function helps identify different data series in your plots.

Basic Legend Usage

To add a basic legend to your plot, you can use the label parameter when plotting and then call plt.legend(). Here's a simple example:


import matplotlib.pyplot as plt
import numpy as np

# Create sample data
x = np.linspace(0, 10, 100)
y1 = np.sin(x)
y2 = np.cos(x)

# Create plots with labels
plt.plot(x, y1, label='Sine')
plt.plot(x, y2, label='Cosine')

# Add legend
plt.legend()
plt.show()

Customizing Legend Location

You can control where the legend appears on your plot using the loc parameter. Matplotlib offers various predefined locations:


# Common legend locations
plt.plot(x, y1, label='Line 1')
plt.plot(x, y2, label='Line 2')
plt.legend(loc='best')  # Automatically choose best location
# Other options: 'upper right', 'lower left', 'center', etc.
plt.show()

Styling Your Legend

The legend can be customized with various parameters to match your visualization needs. Here's an example showing different styling options:


plt.plot(x, y1, label='Dataset 1')
plt.plot(x, y2, label='Dataset 2')

# Customize legend appearance
plt.legend(
    frameon=True,           # Add a frame
    shadow=True,            # Add shadow
    facecolor='white',      # Background color
    edgecolor='black',      # Border color
    fontsize=12,            # Font size
    title='Data Series'     # Legend title
)
plt.show()

Multiple Column Legends

For plots with many data series, you might want to organize your legend in multiple columns. This can be achieved using the ncol parameter:


# Create multiple plots
y3 = np.tan(x)
y4 = x**2

plt.plot(x, y1, label='Sine')
plt.plot(x, y2, label='Cosine')
plt.plot(x, y3, label='Tangent')
plt.plot(x, y4, label='Square')

# Create a two-column legend
plt.legend(ncol=2)
plt.show()

Legend Outside the Plot

For better visibility, you might want to place the legend outside the main plot area. You can do this by adjusting the bbox_to_anchor parameter:


plt.plot(x, y1, label='Data 1')
plt.plot(x, y2, label='Data 2')

# Place legend outside plot
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()  # Adjust layout to prevent legend cutoff
plt.show()

Adding Custom Legend Elements

Sometimes you might need to create custom legend entries that don't correspond to plotted data. Here's how to do that:


from matplotlib.lines import Line2D

# Create custom legend elements
custom_lines = [Line2D([0], [0], color='red', lw=2),
               Line2D([0], [0], color='blue', lw=2)]

plt.plot(x, y1, 'r-')
plt.plot(x, y2, 'b-')

# Add custom legend
plt.legend(custom_lines, ['Custom 1', 'Custom 2'])
plt.show()

For more detailed axis labeling, you might want to check out Python Matplotlib xlabel() and Python Matplotlib ylabel().

Conclusion

The plt.legend() function is an essential tool for creating clear and informative visualizations in Matplotlib. By mastering its various parameters and options, you can create professional-looking plots.

For more complex visualizations, consider exploring Python Matplotlib Scatter Plot Tutorial to combine different plot types with legends.