"""
==============================================================================
NicePlots: A collection of stylesheets and helper functions for matplotlib
==============================================================================
"""
# ==============================================================================
# Standard Python modules
# ==============================================================================
import warnings
import os
import copy
from collections import OrderedDict
# ==============================================================================
# External Python modules
# ==============================================================================
from matplotlib import patheffects
from matplotlib.collections import LineCollection
import matplotlib.colors as mcolor
import matplotlib.pyplot as plt
import numpy as np
from scipy.interpolate import make_interp_spline, Akima1DInterpolator
# ==============================================================================
# Extension modules
# ==============================================================================
from .parula import parula_map
[docs]
def get_style(styleName="doumont-light"):
"""
Get the stylesheet to pass to matplotlib's style setting functions. This function
works both with niceplots styles and matplotlib's built-in styles. Usage examples::
import matplotlib.pyplot as plt
import niceplots
plt.style.use(niceplots.get_style())
plt.plot([0, 1], [0, 1])
# Or you can use it within a context manager
with plt.style.context(niceplots.get_style()):
plt.plot([0, 1], [0, 1])
# Also try different styles
plt.style.use(niceplots.get_style("james-dark")) # niceplots james dark style
plt.style.use(niceplots.get_style("default")) # matplotlib default style
Parameters
----------
styleName : str, optional
Name of desired style. By default uses doumont-light style. Avaiable styles are:
- doumont-light: the niceplots style you know and love
- doumont-dark: the dark version of the niceplots style you know and love
- james-dark: a really cool alternative to classic niceplots
- james-light: a version of james with a light background, naturally
Returns
-------
str
The style string to be passed to one of matplotlib's style setting functions.
"""
# If the style is a niceplots style, return the file path
if styleName in get_available_styles():
curDir = os.path.dirname(os.path.abspath(__file__))
return os.path.join(curDir, "styles", styleName + ".mplstyle")
# Otherwise assume it's a matplotlib style and just return the style name
return styleName
[docs]
def get_colors(styleName=None):
"""
Get a dictionary with the colors for the current style. This function
only works when niceplots styles are used (not built-in matplotlib ones).
Parameters
----------
styleName : str, optional
Name of desired style. By default gets the colors for the current style. Avaiable styles are:
- doumont-light: the niceplots style you know and love
- doumont-dark: the dark version of the niceplots style you know and love
- james-dark: a really cool alternative to classic niceplots
- james-light: a version of james with a light background, naturally
Returns
-------
dict
Dictionary of the colors for the requested style. The keys
are human-readable names and the keys are the hex codes. It
also adds color names from the rcParams that are generally useful:
- "Axis": axis spine color
- "Background" axis background color
- "Text": default text color
- "Label": axis label color
"""
def get_colors_from_current_style():
# Get the color codes and their names from the (hopefully) "special" parameter
color_codes = get_colors_list()
color_names = plt.rcParams["keymap.help"]
# Ensure that the amount of color names matches the amount of colors
if len(color_codes) != len(color_names):
raise ValueError(
"The colors are not properly named in the stylesheet, please open an issue on GitHub with the details!"
)
colors = OrderedDict(zip(color_names, color_codes))
colors["Axis"] = plt.rcParams["axes.edgecolor"]
colors["Background"] = plt.rcParams["axes.facecolor"]
colors["Text"] = plt.rcParams["text.color"]
colors["Label"] = plt.rcParams["axes.labelcolor"]
return colors
if styleName:
with plt.style.context(get_style(styleName)):
return get_colors_from_current_style()
else:
return get_colors_from_current_style()
[docs]
def get_colors_list(styleName=None):
"""
Get a list with the colors for the current style. This function
works with all matplotlib styles.
Parameters
----------
styleName : str, optional
Name of desired style. By default gets the colors for the current style. Avaiable styles are:
- doumont-light: the niceplots style you know and love
- doumont-dark: the dark version of the niceplots style you know and love
- james-dark: a really cool alternative to classic niceplots
- james-light: a version of james with a light background, naturally
Returns
-------
list
List of the colors for the requested style.
"""
if styleName:
with plt.style.context(get_style(styleName)):
return plt.rcParams["axes.prop_cycle"].by_key()["color"]
else:
return plt.rcParams["axes.prop_cycle"].by_key()["color"]
[docs]
def get_available_styles():
"""
Get a list of the names of styles available.
Returns
-------
list
The names of the available styles.
"""
curDir = os.path.dirname(os.path.abspath(__file__))
styleFilenames = os.listdir(os.path.join(curDir, "styles"))
styles = []
for s in styleFilenames:
name, ext = os.path.splitext(s)
if ext == ".mplstyle":
styles.append(name)
styles.sort() # alphabetize
return styles
[docs]
def handle_close(evt):
"""Handler function that saves the figure as a pdf if the window is closed."""
plt.tight_layout()
plt.savefig("figure.pdf")
[docs]
def adjust_spines(ax=None, spines=["left", "bottom"], outward=True):
"""Function to shift the axes/spines so they have that offset
Doumont look."""
if ax is None:
ax = plt.gca()
# Loop over the spines in the axes and shift them
for loc, spine in ax.spines.items():
if loc in spines:
ax.spines[loc].set_visible(True)
if outward:
spine.set_position(("outward", 12)) # outward by 18 points
else:
ax.spines[loc].set_visible(False) # don't draw spine
# turn off ticks where there is no spine
if "left" in spines:
ax.yaxis.set_ticks_position("left")
elif "right" in spines:
ax.yaxis.set_ticks_position("right")
else:
# no yaxis ticks
ax.yaxis.set_visible(False)
if "bottom" in spines:
ax.xaxis.set_ticks_position("bottom")
elif "top" in spines:
ax.xaxis.set_ticks_position("top")
else:
# no xaxis ticks
# ax.xaxis.set_ticks([])
ax.xaxis.set_visible(False)
[docs]
def draggable_legend(axis=None, color_on=True, **kwargs):
"""Function to create draggable labels on a plot."""
if axis is None:
axis = plt.gca()
# Get relevant parameters
legend = []
nlines = len(axis.lines)
# Set the coordinates of the starting location of the draggable labels
n = np.ceil(np.sqrt(nlines))
lins = np.linspace(0.1, 0.9, int(n))
xs, ys = np.meshgrid(lins, lins)
xs = xs.reshape(-1)
ys = ys.reshape(-1)
coords = np.zeros(2)
# Loop over each line in the plot and create a label
for idx, line in enumerate(axis.lines):
# Set the starting coordinates of the label
coords[0] = xs[idx]
coords[1] = ys[idx]
label = line.get_label()
# Get the color of each line to set the label color as the same
if color_on:
color = line.get_color()
else:
color = "k"
# Set each annotation and make them draggable
legend.append(axis.annotate(label, xy=coords, color=color, xycoords="axes fraction", **kwargs))
legend[idx].draggable()
[docs]
def label_line_ends(ax, lines=None, labels=None, colors=None, x_offset_pts=6, y_offset_pts=0, **kwargs):
"""Place a label just to the right of each line in the axes
Note: Because the labels are placed outside of the axes, this function works best for plots where all lines end as
close to the right edge of the axes as possible. Additionally you need to either use constrained_layout=True
(as NicePlots styles do), or call plt.tight_layout() after calling this function.
Parameters
----------
ax : Matplotlib axes
axes to label the lines of
lines : iterable of matplotlib line objects, optional
Lines to label, by default all lines in the axes
labels : list of strings, optional
Labels for each line, by default uses each line's label attribute
colors : single or list of colors, optional
Color(s) to use for each line, can be a single color for all lines or a list containing an entry for each line,
by default uses each line's color
x_offset_pts : int, float, optional
Horizontal offset of label from the right end of the line, in points, by default 6
y_offset_pts : int, float, optional
Vertical offset of label from the right end of the line, in points, by default 0
kwargs
Any valid keywords for matplotlib's annotate function, except ``xy``, ``xytext``, ``color``, ``textcoords``,
``va``
Returns
-------
list of matplotlib annotation objects
The annotations created
"""
# By default label all lines in the plot
if lines is None:
lines = ax.get_lines()
if labels is None:
# If we're not given labels, use the lines' label attributes, but ignore any lines that are still using the
# matplotlib default label which starts with an underscore
lines = [line for line in lines if not line.get_label().startswith("_")]
labels = [line.get_label() for line in lines]
numLines = len(lines)
if len(labels) != numLines:
raise ValueError(f"Number of labels ({len(labels)}) doesn't match number of lines ({numLines})")
if colors is None:
colors = [line.get_color() for line in lines]
if not isinstance(colors, list):
colors = [colors] * len(lines)
if len(colors) != numLines:
raise ValueError(f"Number of colors ({len(colors)}) doesn't match number of lines ({numLines})")
annotations = []
for line, label, color in zip(lines, labels, colors):
# Get the x, y coordinates of the right-most point on the line
maxXIndex = np.argmax(line.get_xdata())
x = line.get_xdata()[maxXIndex]
y = line.get_ydata()[maxXIndex]
annote = ax.annotate(
label,
xy=(x, y),
xytext=(x_offset_pts, y_offset_pts),
color=color,
textcoords="offset points",
va="center",
**kwargs,
)
annotations.append(annote)
return annotations
[docs]
def horiz_bar(labels, times, header, nd=1, size=[5, 0.5], color=None):
"""Creates a horizontal bar chart to compare positive numbers.
Parameters
----------
labels : list of str
contains the ordered labels for each data set
times : list of float
contains the numerical data for each entry
header : list of two str
contains the left and right header for the labels and
numeric data, respectively
nd : float
the number of digits to show after the decimal point for the data
size : list of two float
the size of the final figure (iffy results)
color : str
hexcode for the color of the scatter points used
Returns
-------
fig: matplotlib Figure
Figure created
axes: array of matplotlib Axes
The subplot axes, one for each bar
"""
# Use the first color if none is specified
if color is None:
color = get_colors_list()[0]
style_colors = get_colors()
line_color = style_colors["Axis"]
# Obtain parameters to size the chart correctly
num = len(times)
width = size[0]
height = size[1] * num
t_max = max(times)
# Create the corresponding number of subplots for each individual timing
fig, axes = plt.subplots(num, 1, figsize=[width, height])
# Loop over each time and get the max number of digits
t_max_digits = 0
for t in times:
tm = len(str(int(t)))
if tm > t_max_digits:
t_max_digits = tm
# Actual loop that draws each bar
for j, (l, t, ax) in enumerate(zip(labels, times, axes)):
# Draw the gray line and singular yellow dot
ax.axhline(y=1, c=line_color, lw=3, zorder=0, alpha=0.5)
ax.scatter([t], [1], c=color, lw=0, s=100, zorder=1, clip_on=False)
# Set chart properties
ax.set_ylim(0.99, 1.01)
ax.set_xlim(0, t_max * 1.05)
ax.tick_params(
axis="both", # changes apply to the x-axis
which="both", # both major and minor ticks are affected
left=False, # ticks along the bottom edge are off
labelleft=False,
labelright=False,
labelbottom=False,
right=False, # ticks along the top edge are off
bottom=j == num,
top=False,
)
ax.spines["top"].set_visible(False)
ax.spines["left"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["bottom"].set_visible(False)
ax.set_ylabel(l, rotation="horizontal", ha="right", va="center")
string = "{number:.{digits}f}".format(number=t, digits=nd)
ax.annotate(
string,
xy=(1, 1),
xytext=(6, 0),
xycoords=ax.get_yaxis_transform(),
textcoords="offset points",
va="center",
)
# Create the top bar line
if j == 0:
ax.text(0, 1.02, header[0], ha="right", fontweight="bold", fontsize="large")
ax.text(t_max, 1.02, header[1], ha="left", fontweight="bold", fontsize="large")
return fig, axes
def stacked_plots(
xlabel,
xdata,
data_dict_list,
figsize=(12, 10),
outward=True,
filename="stacks.png",
xticks=None,
cushion=0.1,
colors=None,
lines_only=False,
line_scaler=1.0,
xlim=None,
dpi=200,
):
# If it's a dictionary, make it into a list so we can generically loop over it
if isinstance(data_dict_list, dict):
data_dict_list = [data_dict_list]
if colors is None:
colors = get_colors_list()
data_dict = data_dict_list[0]
n = len(data_dict)
f, axarr = plt.subplots(n, figsize=figsize)
for i, (ylabel, ydata) in enumerate(data_dict.items()):
if type(ydata) == dict:
if "limits" in ydata.keys():
axarr[i].set_ylim(ydata["limits"])
elif "ticks" in ydata.keys():
axarr[i].set_yticks(ydata["ticks"])
low_tick = ydata["ticks"][0]
high_tick = ydata["ticks"][-1]
height = high_tick - low_tick
limits = [low_tick - cushion * height, high_tick + cushion * height]
axarr[i].set_ylim(limits)
axarr[i].set_ylabel(ylabel, rotation="horizontal", horizontalalignment="right")
# Doesn't correctly work when we give a dict version
if xlim is not None:
if type(ydata) == dict:
ydata = ydata["data"]
ydata = np.array(ydata, dtype="float")
no_nan_y = ydata[np.isfinite(ydata)]
ylim = [np.mean(no_nan_y), np.mean(no_nan_y)]
axarr[i].scatter(list(xlim), ylim, alpha=0.0)
for j, data_dict in enumerate(data_dict_list):
for i, (_, ydata) in enumerate(data_dict.items()):
if type(ydata) == dict:
ydata = ydata["data"]
axarr[i].plot(xdata, ydata, clip_on=False, lw=6 * line_scaler, color=colors[j])
if not lines_only:
axarr[i].scatter(
xdata,
ydata,
clip_on=False,
edgecolors=axarr[i].get_facecolor(),
s=100 * line_scaler**2,
lw=1.5 * line_scaler,
zorder=100,
color=colors[j],
)
for i, ax in enumerate(axarr):
adjust_spines(ax, outward=outward)
if i < len(axarr) - 1:
ax.xaxis.set_ticks([])
else:
ax.xaxis.set_ticks_position("bottom")
if xticks is not None:
ax.xaxis.set_ticks(xticks)
f.align_labels()
axarr[-1].set_xlabel(xlabel)
# plt.tight_layout()
if "png" in filename:
plt.savefig(filename, bbox_inches="tight", dpi=dpi)
else:
plt.savefig(filename, bbox_inches="tight")
return f, axarr
[docs]
def plot_opt_prob(
obj,
xRange,
yRange,
ineqCon=None,
eqCon=None,
nPoints=51,
optPoint=None,
conStyle="shaded",
ax=None,
colors=None,
cmap=None,
levels=None,
labelAxes=True,
):
"""Generate a contour plot of a 2D constrained optimisation problem
Parameters
----------
obj : function
Objective function, should accept inputs in the form f = obj(x, y) where x and y are 2D arrays
xRange : list or array
Upper and lower limits of the plot in x
yRange : list or array
Upper and lower limits of the plot in y
ineqCon : function or list of functions, optional
Inequality constraint functions, should accept inputs in the form g = g(x, y) where x and y are 2D arrays.
Constraints are assumed to be of the form g <= 0
eqCon : functions or list of functions, optional
Equality constraint functions, should accept inputs in the form h = h(x, y) where x and y are 2D arrays.
Constraints are assumed to be of the form h == 0
nPoints : int, optional
Number of points in each direction to evaluate the objective and constraint functions at
optPoint : list or array, optional
Optimal Point, if you want to plot a point there, by default None
conStyle : str, optional
Controls how inequality constraints are represented, "shaded" will shade the infeasible regions while "hashed"
will place hashed lines on the infeasible side of the feasible boundary, by default "shaded", note the "hashed"
option only works for matplotlib >= 3.4
ax : matplotlib axes object, optional
axes to plot, by default None, in which case a new figure will be created and returned by the function
colors : list, optional
List of colors to use for the constraint lines, by default uses the current matplotlib color cycle
cmap : colormap, optional
Colormap to use for the objective contours, by default will use nicePlots' parula map
levels : list, array, int, optional
Number or values of contour lines to plot for the objective function
labelAxes : bool, optional
Whether to label the x and y axes, by default True, in which case the axes will be labelled, "$X_1$" and "$X_2$"
Returns
-------
fig : matplotlib figure object
Figure containing the plot. Returned only if no input ax object is specified
ax : matplotlib axes object, but only if no ax object is specified
Axis with the colored line. Returned only if no input ax object is specified
"""
# --- Create a new figure if the user did not supply an ax object ---
returnFig = False
if ax is None:
fig, ax = plt.subplots()
returnFig = True
# --- If user provided only single inequality or equality constraint, convert it to an iterable ---
cons = {}
for inp, key in zip([eqCon, ineqCon], ["eqCon", "ineqCon"]):
if inp is not None:
if not hasattr(inp, "__iter__"):
cons[key] = [inp]
else:
cons[key] = inp
else:
cons[key] = []
# --- Check that conStyle contains a supported value to avoid random conStyle arguments ---
if conStyle.lower() not in ["shaded", "hashed"]:
raise ValueError(f"conStyle: {conStyle} is not supported")
# --- Check if user has a recent enough version of matplotlib to use hashed boundaries ---
if conStyle.lower() == "hashed":
try:
patheffects.withTickedStroke
except AttributeError:
warnings.warn(
"matplotlib >= 3.4 is required for hashed inequality constrain boundaries, switching to shaded inequality constraint style",
stacklevel=2,
)
conStyle = "shaded"
# --- Define some default values if the user didn't provide them ---
if cmap is None:
cmap = parula_map
if colors is None:
colors = get_colors_list()
nColor = len(colors)
# --- Create grid of points for evaluating functions ---
X, Y = np.meshgrid(
np.linspace(xRange[0], xRange[1], nPoints),
np.linspace(yRange[0], yRange[1], nPoints),
)
# --- Evaluate objective and constraint functions ---
Fobj = obj(X, Y)
g = []
for ineq in cons["ineqCon"]:
g.append(ineq(X, Y))
h = []
for eq in cons["eqCon"]:
h.append(eq(X, Y))
# --- Plot objective contours ---
adjust_spines(ax, outward=True)
ax.contour(
X,
Y,
Fobj,
levels=levels,
cmap=cmap,
)
# --- Plot constraint boundaries ---
colorIndex = 0
for conValue in g:
contour = ax.contour(X, Y, conValue, levels=[0.0], colors=colors[colorIndex % nColor])
if conStyle.lower() == "hashed":
plt.setp(
contour.collections,
path_effects=[patheffects.withTickedStroke(angle=60, length=2)],
)
elif conStyle.lower() == "shaded":
ax.contourf(
X,
Y,
conValue,
levels=[0.0, np.inf],
colors=colors[colorIndex % nColor],
alpha=0.4,
)
colorIndex += 1
for conValue in h:
ax.contour(X, Y, conValue, levels=[0.0], colors=colors[colorIndex % nColor])
# --- Plot optimal point if provided ---
if optPoint is not None:
ax.plot(
optPoint[0],
optPoint[1],
"o",
color="black",
markeredgecolor=ax.get_facecolor(),
markersize=10,
clip_on=False,
)
# --- Label axes if required ---
if labelAxes:
ax.set_xlabel("$x_1$")
ax.set_ylabel("$x_2$", rotation="horizontal", ha="right")
if returnFig:
return fig, ax
else:
return
[docs]
def plot_colored_line(
x,
y,
c,
cmap=None,
fig=None,
ax=None,
addColorBar=False,
cRange=None,
cBarLabel=None,
norm=None,
**kwargs,
):
"""Plot an XY line whose color is determined by some other variable C
Parameters
----------
x : iterable of length n
x data
y : iterable of length n
y data
c : iterable of length n
Data for linecolor
cmap : str or matplotlib colormap, optional
Colormap to use for the objective contours, by default will use nicePlots' parula map
fig : matplotlib figure object, optional
figure to plot on, by default None, in which case a new figure will be created and returned by the function
ax : matplotlib axes object, optional
axes to plot on, by default None, in which case a new figure will be created and returned by the function
addColorBar : bool, optional
Whether to add a colorbar to the axes, by default False
cRange : iterable of length 2, optional
Upper and lower limit for the colormap, by default None, in which case the min and max values of c are used.
cBarLabel : str, optional
Label for the colormap, by default None
norm : matplotlib.colors.Normalize, optional
Specify colormap mapping; both this and cRange cannot be specified, it must be one or the other (or neither)
Returns
-------
fig : matplotlib figure object
Figure containing the plot. Returned only if no input ax object is specified
ax : matplotlib axes object
Axis with the colored line. Returned only if no input ax object is specified
"""
returnFig = False
if ax is None or fig is None:
fig, ax = plt.subplots()
returnFig = True
if cmap is None:
cmap = parula_map
# --- Convert inputs to flattened arrays ---
data = {}
for d, name in zip([x, y, c], ["x", "y", "c"]):
if not isinstance(d, np.ndarray):
data[name] = np.array(d)
else:
data[name] = d
data[name] = data[name].flatten()
# --- Create points and segments ---
points = np.array([data["x"], data["y"]]).T.reshape(-1, 1, 2)
segments = np.concatenate([points[:-1], points[1:]], axis=1)
if cRange is not None and norm is not None:
raise ValueError("cRange and norm cannot both be specified")
if cRange is not None:
norm = plt.Normalize(cRange[0], cRange[1])
lc = LineCollection(segments, cmap=cmap, norm=norm, **kwargs)
# Set the values used for colormapping
lc.set_array(data["c"])
line = ax.add_collection(lc)
if addColorBar:
cBar = fig.colorbar(line, ax=ax)
if cBarLabel is not None:
cBar.set_label(cBarLabel)
ax.autoscale()
if returnFig:
return fig, ax
else:
return
[docs]
def plot_nested_pie(
data,
colors=None,
alphas=None,
ax=None,
innerKwargs=None,
outerKwargs=None,
):
"""Create a two-level pie chart where the inner pie chart is a sum of related categories from the outer one.
The labels are by default set to the keys in the data dictionary.
Parameters
----------
data : nested dict
Data to plot. Formatted as::
{
"Category 1": {
"Subcategory 1": 0.5,
"Subcategory 2": 1.5,
},
"Category 2": {
"Subcategory 1": 2.5,
},
...
}
colors : str or list of str with hex colors, optional
Colors to use for the inner wedges. Can either specify a qualitative matplotlib colormap (it will assume
this is the case if a string is specified), or a list of colors specified with hex codes (e.g., "#F4A103"),
by default will use nice colors (niceplots default). Loops through the colors if more categories than
colors are specified.
alphas : iterable of floats at least as long as the max number of subcategories for a given category
Transparencies to use to vary the color in the outer categories
ax : matplotlib axes object, optional
axes to plot on, by default None, in which case a new figure will be created and returned by the function
innerKwargs : dict
Dictionary of keyword arguments to pass to matplotlib.pyplot.pie for the inner pie chart. "color" and "radius"
are important ones for the nested pie chart and I recommend not touching those unless you know what you're
doing. labels, rotatelabels, wedgeprops, and textprops are also all set by default in this function, but
can be overridden using this parameter
outerKwargs : dict
Dictionary of keyword arguments to pass to matplotlib.pyplot.pie for the outer pie chart. "color" and "radius"
are important ones for the nested pie chart and I recommend not touching those unless you know what you're
doing. labels, rotatelabels, wedgeprops, and textprops are also all set by default in this function, but
can be overridden using this parameter
Returns
-------
pieObjects : dict of matplotlib.patches.Wedge and matplotlib Text objects
Wedges and text objects for the pie plot, formatted similarly to the input data dict::
{
"Category 1": {
"wedge": Category 1 wedge
"text": Category 1 text
"Subcategory 1": {"wedge": Subcategory 1 wedge, "text": Subcategory 1 wedge},
"Subcategory 2": {"wedge": Subcategory 2 wedge, "text": Subcategory 2 wedge},
},
"Category 2": {
"wedge": Category 2 wedge
"text": Category 2 text
"Subcategory 1": {"wedge": Subcategory 1 wedge, "text": Subcategory 1 wedge},
},
...
}
fig : matplotlib figure object
Figure containing the plot. Returned only if no input ax object is specified
ax : matplotlib axes object
Axis with the colored line. Returned only if no input ax object is specified
"""
# If colors is not specified, turn the style's colors into a list of hex colors
if colors is None:
colors = [c for c in get_colors().values()]
# If colors is given as a qualitative matplotlib colormap, turn it into a list of hex colors
elif isinstance(colors, str):
colors = [mcolor.rgb2hex(plt.colormaps[colors](i)) for i in range(len(data))]
numColors = len(colors)
# Go through the colors and only take the color information (not transparency)
for i in range(len(colors)):
if colors[i][0] != "#":
raise ValueError("Colors specified as a string must start with a #")
colors[i] = colors[i][0:7]
# Sum categories and collapse subcategories
innerVals = []
innerLabels = []
outerVals = []
outerLabels = []
total = 0.0
maxSubcat = 0.0
for cat, val in data.items(): # top level categories
innerLabels.append(cat)
innerVals.append(0.0)
# Max number of subcategories for a given category
maxSubcat = max(maxSubcat, len(val))
for subcat, subcatVal in val.items():
innerVals[-1] += subcatVal
total += subcatVal
outerVals.append(subcatVal)
outerLabels.append(subcat)
# Define alphas if not specified
if alphas is None:
alphas = np.linspace(0.75, 0.95, maxSubcat)[-1::-1]
innerColors = [colors[i % numColors] for i in range(len(data))]
outerColors = []
iCat = 0
for catVals in data.values():
numSubcats = len(catVals)
for iSubcat in range(numSubcats):
outerColors.append(colors[iCat % numColors] + float.hex(alphas[iSubcat])[4:6])
iCat += 1
# Nested plot fitting params
size = 0.3
buffer = 0.01
# Create figure if it's not passed in
returnFig = False
if ax is None:
fig, ax = plt.subplots()
returnFig = True
# Set keyword arguments
outerKwargDefaults = {
"radius": 1.0,
"colors": outerColors,
"wedgeprops": dict(width=size, edgecolor=None),
"textprops": dict(rotation_mode="anchor", va="center", ha="center", color="w"),
"labels": outerLabels,
"rotatelabels": False,
"labeldistance": 0.85,
}
innerKwargDefaults = {
"radius": 1.0 - size - buffer,
"colors": innerColors,
"wedgeprops": dict(width=size, edgecolor=None),
"textprops": dict(rotation_mode="anchor", va="center", ha="center", color="w"),
"labels": innerLabels,
"rotatelabels": False,
"labeldistance": 0.75,
}
outerKwargs = {} if outerKwargs is None else outerKwargs
innerKwargs = {} if innerKwargs is None else innerKwargs
# Update kwargs
for outerKey, outerKwargVal in outerKwargDefaults.items():
if outerKey not in outerKwargs:
outerKwargs[outerKey] = outerKwargVal
for innerKey, innerKwargVal in innerKwargDefaults.items():
if innerKey not in innerKwargs:
innerKwargs[innerKey] = innerKwargVal
# Create the pie charts
outerWedges, outerText = ax.pie(outerVals, **outerKwargs)
innerWedges, innerText = ax.pie(innerVals, **innerKwargs)
ax.set(aspect="equal")
# Compile the wedge and text objects into the output dictionary
pieObjects = {}
iSubcat = 0
for i, cat in enumerate(data.keys()):
pieObjects[cat] = {
"wedge": innerWedges[i],
"text": innerText[i],
}
for subcat in data[cat].keys():
pieObjects[cat][subcat] = {
"wedge": outerWedges[iSubcat],
"text": outerText[iSubcat],
}
iSubcat += 1
if returnFig:
return pieObjects, fig, ax
else:
return pieObjects
[docs]
def plot_spline(x, y, ax=None, spline_type="non-overshoot", num_interp_pts=100, spline_options={}, **plot_kwargs):
"""
Fits a spline to the data points and plots the spline.
Parameters
----------
x : array-like
X-coordinates of points to interpolate and plot
y : array-like
X-coordinates of points to interpolate and plot
ax : matplotlib Axis object, optional
Axes on which to plot, by default creates and returns new figure and axis
spline_type : str, optional
Type of spline from the following list of options (by default non-overshoot):
- "non-overshoot": Cubic spline that does not overshoot data (SciPy's Akima1DInterpolator)
- "b-spline": B-spline (SciPy's make_interp_spline)
num_interp_pts : int, optional
Number of points at which to evaluate the spline (linearly spaced between
min and max x values), by default 100
spline_options : dict, optional
Options to pass to the spline, by default none if spline_type is non-overshoot or sets the
spline order to the minimum of 3 and one less than the length of x if b-spline. The available
options can be found here:
- "non-overshoot": https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.Akima1DInterpolator.html
- "b-spline": https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.make_interp_spline.html
plot_kwargs
Keyword arguments to pass to matplotlib's plot function
Returns
-------
fig, ax
If ax is not provided, generates and returns new matplotlib figure and axis
"""
return_fig = False
if ax is None:
fig, ax = plt.subplots()
return_fig = True
# Make the spline
if spline_type == "b-spline":
spline_option_defaults = {"k": min(len(x) - 1, 3)}
else:
spline_option_defaults = {}
options = spline_option_defaults | spline_options
if spline_type == "non-overshoot":
spline = Akima1DInterpolator(x, y, **options)
elif spline_type == "b-spline":
spline = make_interp_spline(x, y, **options)
else:
raise ValueError(f"Unknown spline_type {spline_type}")
# Interpolate and plot
x_interp = np.linspace(np.min(x), np.max(x), num_interp_pts)
ax.plot(x_interp, spline(x_interp), **plot_kwargs)
if return_fig:
return fig, ax
[docs]
def save_figs(fig, name, formats, format_kwargs=None, **kwargs):
"""Save a figure in multiple formats
Parameters
----------
fig : Matplotlib figure
The figure to save
name : str
Output path for the files, e.g "path/to/file/file_name", no file extension required
formats : str, list[str]
file formats to save the figure in, e.g. "png", "pdf", "svg"
format_kwargs : dict, optional
A dictionary of dictionaries, where the keys are the file formats and the values are any keyword arguments that
should only be applied to that format. These kwargs will be added to ones passed to all formats, by default None
kwargs :
Any keyword arguments to pass to `plt.savefig()` for all formats
"""
# --- Strip any extension from the name ---
fileName = os.path.splitext(name)[0]
# --- Convert the file format to a list if only one given ---
if isinstance(formats, str):
formats = [formats]
# --- Save the figures ---
for ext in formats:
if ext[0] == ".":
ext = ext[1:]
# Add any format-specific kwargs
ext_kwargs = copy.deepcopy(kwargs)
if format_kwargs is not None and ext in format_kwargs:
ext_kwargs.update(format_kwargs[ext])
fig.savefig(fileName + "." + ext, **ext_kwargs)
[docs]
def All():
"""Runs commonly called functions provided in this module."""
adjust_spines()
draggable_legend()
plt.gcf().canvas.mpl_connect("close_event", handle_close)