Coverage for moptipy / evaluation / plot_end_results.py: 84%
219 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-11-24 08:49 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-11-24 08:49 +0000
1"""Violin plots for end results."""
2from typing import Any, Callable, Final, Iterable, cast
4import matplotlib.collections as mc # type: ignore
5from matplotlib.axes import Axes # type: ignore
6from matplotlib.figure import Figure # type: ignore
7from matplotlib.lines import Line2D # type: ignore
8from pycommons.io.console import logger
9from pycommons.types import type_error
11import moptipy.utils.plot_defaults as pd
12import moptipy.utils.plot_utils as pu
13from moptipy.evaluation.axis_ranger import AxisRanger
14from moptipy.evaluation.base import F_NAME_SCALED
15from moptipy.evaluation.end_results import EndResult
16from moptipy.evaluation.end_results import getter as end_result_getter
17from moptipy.utils.lang import Lang
20def plot_end_results(
21 end_results: Iterable[EndResult],
22 figure: Axes | Figure,
23 dimension: str = F_NAME_SCALED,
24 y_axis: AxisRanger | Callable[[str], AxisRanger] =
25 AxisRanger.for_axis,
26 distinct_colors_func: Callable[[int], Any] = pd.distinct_colors,
27 importance_to_line_width_func: Callable[[int], float] =
28 pd.importance_to_line_width,
29 importance_to_font_size_func: Callable[[int], float] =
30 pd.importance_to_font_size,
31 y_grid: bool = True,
32 x_grid: bool = True,
33 x_label: str | Callable[[str], str] | None = Lang.translate,
34 x_label_inside: bool = True,
35 x_label_location: float = 1.0,
36 y_label: str | Callable[[str], str] | None = Lang.translate,
37 y_label_inside: bool = True,
38 y_label_location: float = 0.5,
39 legend_pos: str = "best",
40 instance_sort_key: Callable[[str], Any] = lambda x: x,
41 algorithm_sort_key: Callable[[str], Any] = lambda x: x,
42 instance_namer: Callable[[str], str] = lambda x: x,
43 algorithm_namer: Callable[[str], str] = lambda x: x) -> Axes:
44 """
45 Plot a set of end result boxes/violins functions into one chart.
47 In this plot, we combine two visualizations of data distributions: box
48 plots in the foreground and violin plots in the background.
50 The box plots show you the median, the 25% and 75% quantiles, the 95%
51 confidence interval around the median (as notches), the 5% and 95%
52 quantiles (as whiskers), the arithmetic mean (as triangle), and the
53 outliers on both ends of the spectrum. This allows you also to compare
54 data from different distributions rather comfortably, as you can, e.g.,
55 see whether the confidence intervals overlap.
57 The violin plots in the background are something like smoothed-out,
58 vertical, and mirror-symmetric histograms. They give you a better
59 impression about shape and modality of the distribution of the results.
61 :param end_results: the iterable of end results
62 :param figure: the figure to plot in
63 :param dimension: the dimension to display
64 :param y_axis: the y_axis ranger
65 :param distinct_colors_func: the function returning the palette
66 :param importance_to_line_width_func: the function converting importance
67 values to line widths
68 :param importance_to_font_size_func: the function converting importance
69 values to font sizes
70 :param y_grid: should we have a grid along the y-axis?
71 :param x_grid: should we have a grid along the x-axis?
72 :param x_label: a callable returning the label for the x-axis, a label
73 string, or `None` if no label should be put
74 :param x_label_inside: put the x-axis label inside the plot (so that
75 it does not consume additional vertical space)
76 :param x_label_location: the location of the x-label
77 :param y_label: a callable returning the label for the y-axis, a label
78 string, or `None` if no label should be put
79 :param y_label_inside: put the y-axis label inside the plot (so that it
80 does not consume additional horizontal space)
81 :param y_label_location: the location of the y-label
82 :param legend_pos: the legend position
83 :param instance_sort_key: the sort key function for instances
84 :param algorithm_sort_key: the sort key function for algorithms
85 :param instance_namer: the name function for instances receives an
86 instance ID and returns an instance name; default=identity function
87 :param algorithm_namer: the name function for algorithms receives an
88 algorithm ID and returns an instance name; default=identity function
89 :returns: the axes object to allow you to add further plot elements
90 """
91 # Before doing anything, let's do some type checking on the parameters.
92 # I want to ensure that this function is called correctly before we begin
93 # to actually process the data. It is better to fail early than to deliver
94 # some incorrect results.
95 if not isinstance(end_results, Iterable):
96 raise type_error(end_results, "end_results", Iterable)
97 if not isinstance(figure, Axes | Figure):
98 raise type_error(figure, "figure", (Axes, Figure))
99 if not isinstance(dimension, str):
100 raise type_error(dimension, "dimension", str)
101 if not callable(distinct_colors_func):
102 raise type_error(
103 distinct_colors_func, "distinct_colors_func", call=True)
104 if not callable(importance_to_line_width_func):
105 raise type_error(importance_to_line_width_func,
106 "importance_to_line_width_func", call=True)
107 if not callable(importance_to_font_size_func):
108 raise type_error(importance_to_font_size_func,
109 "importance_to_font_size_func", call=True)
110 if not isinstance(y_grid, bool):
111 raise type_error(y_grid, "y_grid", bool)
112 if not isinstance(x_grid, bool):
113 raise type_error(x_grid, "x_grid", bool)
114 if not ((x_label is None) or callable(x_label)
115 or isinstance(x_label, str)):
116 raise type_error(x_label, "x_label", (str, None), call=True)
117 if not isinstance(x_label_inside, bool):
118 raise type_error(x_label_inside, "x_label_inside", bool)
119 if not isinstance(x_label_location, float):
120 raise type_error(x_label_location, "x_label_location", float)
121 if not ((y_label is None) or callable(y_label)
122 or isinstance(y_label, str)):
123 raise type_error(y_label, "y_label", (str, None), call=True)
124 if not isinstance(y_label_inside, bool):
125 raise type_error(y_label_inside, "y_label_inside", bool)
126 if not isinstance(y_label_location, float):
127 raise type_error(y_label_location, "y_label_location", float)
128 if not isinstance(legend_pos, str):
129 raise type_error(legend_pos, "legend_pos", str)
130 if not callable(instance_sort_key):
131 raise type_error(instance_sort_key, "instance_sort_key", call=True)
132 if not callable(algorithm_sort_key):
133 raise type_error(algorithm_sort_key, "algorithm_sort_key", call=True)
134 if not callable(instance_namer):
135 raise type_error(instance_namer, "instance_namer", call=True)
136 if not callable(algorithm_namer):
137 raise type_error(algorithm_namer, "algorithm_namer", call=True)
139 getter: Final[Callable[[EndResult], int | float]] \
140 = end_result_getter(dimension)
141 logger(f"now plotting end violins for dimension {dimension}.")
143 if callable(y_axis):
144 y_axis = y_axis(dimension)
145 if not isinstance(y_axis, AxisRanger):
146 raise type_error(y_axis, f"y_axis for {dimension}", AxisRanger)
148 # instance -> algorithm -> values
149 data: dict[str, dict[str, list[int | float]]] = {}
150 algo_set: set[str] = set()
152 # We now collect instances, the algorithms, and the measured values.
153 for res in end_results:
154 if not isinstance(res, EndResult):
155 raise type_error(res, "violin plot element", EndResult)
157 algo_set.add(res.algorithm)
159 per_inst_data: dict[str, list[int | float]]
160 if res.instance not in data:
161 data[res.instance] = per_inst_data = {}
162 else:
163 per_inst_data = data[res.instance]
164 inst_algo_data: list[int | float]
165 if res.algorithm not in per_inst_data:
166 per_inst_data[res.algorithm] = inst_algo_data = []
167 else:
168 inst_algo_data = per_inst_data[res.algorithm]
170 value: int | float = getter(res)
171 if not isinstance(value, int | float):
172 raise type_error(value, "value", (int, float))
173 inst_algo_data.append(value)
174 y_axis.register_value(value)
176 # We now know the number of instances and algorithms and have the data in
177 # the hierarchical structure instance->algorithms->values.
178 n_instances: Final[int] = len(data)
179 n_algorithms: Final[int] = len(algo_set)
180 if (n_instances <= 0) or (n_algorithms <= 0):
181 raise ValueError("Data cannot be empty but found "
182 f"{n_instances} and {n_algorithms}.")
183 algorithms: Final[tuple[str, ...]] = \
184 tuple(sorted(algo_set, key=algorithm_sort_key))
185 logger(f"- {n_algorithms} algorithms ({algorithms}) "
186 f"and {n_instances} instances ({data.keys()}).")
188 # compile the data
189 inst_algos: list[tuple[str, list[str]]] = []
190 plot_data: list[list[int | float]] = []
191 plot_algos: list[str] = []
192 instances: Final[list[str]] = sorted(data.keys(), key=instance_sort_key)
193 for inst in instances:
194 per_inst_data = data[inst]
195 algo_names: list[str] = sorted(per_inst_data.keys(),
196 key=algorithm_sort_key)
197 plot_algos.extend(algo_names)
198 inst_algos.append((inst, algo_names))
199 for algo in algo_names:
200 inst_algo_data = per_inst_data[algo]
201 inst_algo_data.sort()
202 plot_data.append(inst_algo_data)
204 # compute the violin positions
205 n_bars: Final[int] = len(plot_data)
206 if n_bars < max(n_instances, n_algorithms):
207 raise ValueError(f"Huh? {n_bars}, {n_instances}, {n_algorithms}")
208 bar_positions: Final[tuple[int, ...]] = \
209 tuple(range(1, len(plot_data) + 1))
211 # Now we got all instances and all algorithms and know the axis ranges.
212 font_size_0: Final[float] = importance_to_font_size_func(0)
214 # set up the graphics area
215 axes: Final[Axes] = pu.get_axes(figure)
216 axes.tick_params(axis="y", labelsize=font_size_0)
217 axes.tick_params(axis="x", labelsize=font_size_0)
219 z_order: int = 0
221 # draw the grid
222 grid_lwd: int | float | None = None
223 if y_grid:
224 grid_lwd = importance_to_line_width_func(-1)
225 z_order += 1
226 axes.grid(axis="y", color=pd.GRID_COLOR, linewidth=grid_lwd,
227 zorder=z_order)
229 x_axis: Final[AxisRanger] = AxisRanger(
230 chosen_min=0.5, chosen_max=bar_positions[-1] + 0.5)
232 # manually add x grid lines between instances
233 if x_grid and (n_instances > 1) and (n_algorithms > 1):
234 if not grid_lwd:
235 grid_lwd = importance_to_line_width_func(-1)
236 counter: int = 0
237 for key in inst_algos:
238 if counter > 0:
239 z_order += 1
240 axes.axvline(x=counter + 0.5,
241 color=pd.GRID_COLOR,
242 linewidth=grid_lwd,
243 zorder=z_order)
244 counter += len(key[1])
246 y_axis.apply(axes, "y")
247 x_axis.apply(axes, "x")
249 violin_width: Final[float] = 3 / 4
250 z_order += 1
251 violins: Final[dict[str, Any]] = axes.violinplot(
252 dataset=plot_data, positions=bar_positions, orientation="vertical",
253 widths=violin_width, showmeans=False, showextrema=False,
254 showmedians=False)
256 # fix the algorithm colors
257 unique_colors: Final[tuple[Any]] = distinct_colors_func(n_algorithms)
258 algo_colors: Final[dict[str, tuple[float, float, float]]] = {}
259 for i, algo in enumerate(algorithms):
260 algo_colors[algo] = unique_colors[i]
262 bodies: Final[list[mc.PolyCollection]] = violins["bodies"]
263 use_colors: Final[list[tuple[float, float, float]]] = []
264 counter = 0
265 for key in inst_algos:
266 for algo in key[1]:
267 z_order += 1
268 bd = bodies[counter]
269 color = algo_colors[algo]
270 use_colors.append(color)
271 bd.set_edgecolor("none")
272 bd.set_facecolor(color)
273 bd.set_alpha(0.6666666666)
274 counter += 1
275 bd.set_zorder(z_order)
277 z_order += 1
278 boxes_bg: Final[dict[str, Any]] = axes.boxplot(
279 x=plot_data, positions=bar_positions, widths=violin_width,
280 showmeans=True, patch_artist=False, notch=True,
281 orientation="vertical", whis=(5.0, 95.0), manage_ticks=False,
282 zorder=z_order)
283 z_order += 1
284 boxes_fg: Final[dict[str, Any]] = axes.boxplot(
285 x=plot_data, positions=bar_positions, widths=violin_width,
286 showmeans=True, patch_artist=False, notch=True,
287 orientation="vertical", whis=(5.0, 95.0), manage_ticks=False,
288 zorder=z_order)
290 for tkey in ("cmeans", "cmins", "cmaxes", "cbars", "cmedians",
291 "cquantiles"):
292 if tkey in violins:
293 violins[tkey].set_color("none")
295 lwd_fg = importance_to_line_width_func(0)
296 lwd_bg = importance_to_line_width_func(1)
298 for bid, boxes in enumerate([boxes_bg, boxes_fg]):
299 for tkey in ("boxes", "medians", "whiskers", "caps", "fliers",
300 "means"):
301 if tkey not in boxes:
302 continue
303 polys: list[Line2D] = boxes[tkey]
304 for line in polys:
305 xdata = cast("list", line.get_xdata(True))
306 if len(xdata) <= 0:
307 line.remove()
308 continue
309 index = int(max(xdata)) - 1
310 thecolor: str | tuple[float, float, float] = \
311 "white" if bid == 0 else use_colors[index]
312 width = lwd_bg if bid == 0 else lwd_fg
313 line.set_solid_joinstyle("round")
314 line.set_solid_capstyle("round")
315 line.set_color(thecolor)
316 line.set_linewidth(width)
317 line.set_markeredgecolor(thecolor)
318 line.set_markerfacecolor("none")
319 line.set_markeredgewidth(width)
320 z_order += 1
321 line.set_zorder(z_order)
323 # compute the labels for the x-axis
324 labels_str: list[str] = []
325 labels_x: list[float] = []
326 needs_legend: bool = False
328 counter = 0
329 if n_instances > 1:
330 # use only the instances as labels
331 for key in inst_algos:
332 current = counter
333 counter += len(key[1])
334 labels_str.append(instance_namer(key[0]))
335 labels_x.append(0.5 * (bar_positions[current]
336 + bar_positions[counter - 1]))
337 needs_legend = (n_algorithms > 1)
338 elif n_algorithms > 1:
339 # only use algorithms as key
340 for key in inst_algos:
341 for algo in key[1]:
342 labels_str.append(algorithm_namer(algo))
343 labels_x.append(bar_positions[counter])
344 counter += 1
346 if labels_str:
347 axes.set_xticks(ticks=labels_x, labels=labels_str, minor=False)
348 else:
349 axes.set_xticks([])
351 # compute the x-label
352 if (x_label is not None) and (not isinstance(x_label, str)):
353 if not callable(x_label):
354 raise type_error(x_label, "x_label", str, True)
355 if (n_algorithms == 1) and (n_instances > 1):
356 x_label = algorithm_namer(algorithms[0])
357 elif (n_algorithms > 1) and (n_instances == 1):
358 x_label = instance_namer(instances[0])
359 else:
360 x_label = x_label("algorithm_on_instance")
362 z_order += 1
363 pu.label_axes(axes=axes,
364 x_label=cast("str | None", x_label),
365 x_label_inside=x_label_inside,
366 x_label_location=x_label_location,
367 y_label=y_label(dimension) if callable(y_label) else y_label,
368 y_label_inside=y_label_inside,
369 y_label_location=y_label_location,
370 font_size=font_size_0,
371 z_order=z_order)
373 if needs_legend:
374 handles: Final[list[Line2D]] = []
376 for algo in algorithms:
377 linestyle = pd.create_line_style()
378 linestyle["label"] = algorithm_namer(algo)
379 legcol = algo_colors[algo]
380 linestyle["color"] = legcol
381 linestyle["markeredgecolor"] = legcol
382 linestyle["xdata"] = []
383 linestyle["ydata"] = []
384 linestyle["linewidth"] = 6
385 handles.append(Line2D(**linestyle)) # type: ignore
386 z_order += 1
388 axes.legend(handles=handles, loc=legend_pos,
389 labelcolor=pu.get_label_colors(handles),
390 fontsize=font_size_0).set_zorder(z_order)
392 logger(f"done plotting {n_bars} end result boxes.")
393 return axes