from typing import Optional
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
[docs]def plot(loss: np.ndarray,
labels: list,
cum: bool = True,
title: Optional[str] = None,
file_path: Optional[str] = None,
x_label: Optional[str] = 'Iteration',
y_label: Optional[str] = 'Cumulative Loss',
loc: str = 'upper left',
scale='linear'):
"""Visualize the results of multiple learners.:
Args:
loss (numpy.ndarray): Losses of multiple learners.
labels (list): labels of learners.
cum (bool): Show the cumulative loss or instantaneous loss.
title (str, optional): Title of the figure.
file_path (str, optional): File path to save the results.
x_lable (str, optional): Label of :math:`x` axis.
y_lable (str, optional): Label of :math:`y` axis.
loc (str, optional): Location of the legend.
scale (str, optional): Scale of the :math:`y` axis, 'linear' or 'log'.
"""
plt.figure()
matplotlib.rcParams['font.family'] = "sans-serif"
matplotlib.rcParams['font.sans-serif'] = "Arial"
assert loss.ndim == 3 or loss.ndim == 2
assert loss.shape[0] == len(labels)
if loss.ndim == 3:
xaxis = np.arange(0, loss.shape[2])
if cum is True:
loss = np.cumsum(loss, axis=2)
loss_mean, loss_std = np.mean(loss, axis=1), np.std(loss, axis=1)
else:
xaxis = np.arange(0, loss.shape[1])
if cum is True:
loss = np.cumsum(loss, axis=1)
loss_mean, loss_std = loss, np.zeros_like(loss)
plt.grid(linestyle=':', linewidth=0.5)
if scale == 'log':
plt.yscale('log')
plt.title(title)
plt.xlabel(x_label)
plt.ylabel(y_label)
for i in range(len(loss_mean)):
plt.plot(xaxis, loss_mean[i], label=labels[i])
plt.fill_between(
xaxis,
loss_mean[i] - loss_std[i],
loss_mean[i] + loss_std[i],
alpha=0.15)
plt.legend(loc=loc)
if file_path is not None:
plt.savefig(file_path)
plt.show()