Coverage for moptipy / evaluation / plot_end_results.py: 84%

219 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-11-24 08:49 +0000

1"""Violin plots for end results.""" 

2from typing import Any, Callable, Final, Iterable, cast 

3 

4import matplotlib.collections as mc # type: ignore 

5from matplotlib.axes import Axes # type: ignore 

6from matplotlib.figure import Figure # type: ignore 

7from matplotlib.lines import Line2D # type: ignore 

8from pycommons.io.console import logger 

9from pycommons.types import type_error 

10 

11import moptipy.utils.plot_defaults as pd 

12import moptipy.utils.plot_utils as pu 

13from moptipy.evaluation.axis_ranger import AxisRanger 

14from moptipy.evaluation.base import F_NAME_SCALED 

15from moptipy.evaluation.end_results import EndResult 

16from moptipy.evaluation.end_results import getter as end_result_getter 

17from moptipy.utils.lang import Lang 

18 

19 

20def plot_end_results( 

21 end_results: Iterable[EndResult], 

22 figure: Axes | Figure, 

23 dimension: str = F_NAME_SCALED, 

24 y_axis: AxisRanger | Callable[[str], AxisRanger] = 

25 AxisRanger.for_axis, 

26 distinct_colors_func: Callable[[int], Any] = pd.distinct_colors, 

27 importance_to_line_width_func: Callable[[int], float] = 

28 pd.importance_to_line_width, 

29 importance_to_font_size_func: Callable[[int], float] = 

30 pd.importance_to_font_size, 

31 y_grid: bool = True, 

32 x_grid: bool = True, 

33 x_label: str | Callable[[str], str] | None = Lang.translate, 

34 x_label_inside: bool = True, 

35 x_label_location: float = 1.0, 

36 y_label: str | Callable[[str], str] | None = Lang.translate, 

37 y_label_inside: bool = True, 

38 y_label_location: float = 0.5, 

39 legend_pos: str = "best", 

40 instance_sort_key: Callable[[str], Any] = lambda x: x, 

41 algorithm_sort_key: Callable[[str], Any] = lambda x: x, 

42 instance_namer: Callable[[str], str] = lambda x: x, 

43 algorithm_namer: Callable[[str], str] = lambda x: x) -> Axes: 

44 """ 

45 Plot a set of end result boxes/violins functions into one chart. 

46 

47 In this plot, we combine two visualizations of data distributions: box 

48 plots in the foreground and violin plots in the background. 

49 

50 The box plots show you the median, the 25% and 75% quantiles, the 95% 

51 confidence interval around the median (as notches), the 5% and 95% 

52 quantiles (as whiskers), the arithmetic mean (as triangle), and the 

53 outliers on both ends of the spectrum. This allows you also to compare 

54 data from different distributions rather comfortably, as you can, e.g., 

55 see whether the confidence intervals overlap. 

56 

57 The violin plots in the background are something like smoothed-out, 

58 vertical, and mirror-symmetric histograms. They give you a better 

59 impression about shape and modality of the distribution of the results. 

60 

61 :param end_results: the iterable of end results 

62 :param figure: the figure to plot in 

63 :param dimension: the dimension to display 

64 :param y_axis: the y_axis ranger 

65 :param distinct_colors_func: the function returning the palette 

66 :param importance_to_line_width_func: the function converting importance 

67 values to line widths 

68 :param importance_to_font_size_func: the function converting importance 

69 values to font sizes 

70 :param y_grid: should we have a grid along the y-axis? 

71 :param x_grid: should we have a grid along the x-axis? 

72 :param x_label: a callable returning the label for the x-axis, a label 

73 string, or `None` if no label should be put 

74 :param x_label_inside: put the x-axis label inside the plot (so that 

75 it does not consume additional vertical space) 

76 :param x_label_location: the location of the x-label 

77 :param y_label: a callable returning the label for the y-axis, a label 

78 string, or `None` if no label should be put 

79 :param y_label_inside: put the y-axis label inside the plot (so that it 

80 does not consume additional horizontal space) 

81 :param y_label_location: the location of the y-label 

82 :param legend_pos: the legend position 

83 :param instance_sort_key: the sort key function for instances 

84 :param algorithm_sort_key: the sort key function for algorithms 

85 :param instance_namer: the name function for instances receives an 

86 instance ID and returns an instance name; default=identity function 

87 :param algorithm_namer: the name function for algorithms receives an 

88 algorithm ID and returns an instance name; default=identity function 

89 :returns: the axes object to allow you to add further plot elements 

90 """ 

91 # Before doing anything, let's do some type checking on the parameters. 

92 # I want to ensure that this function is called correctly before we begin 

93 # to actually process the data. It is better to fail early than to deliver 

94 # some incorrect results. 

95 if not isinstance(end_results, Iterable): 

96 raise type_error(end_results, "end_results", Iterable) 

97 if not isinstance(figure, Axes | Figure): 

98 raise type_error(figure, "figure", (Axes, Figure)) 

99 if not isinstance(dimension, str): 

100 raise type_error(dimension, "dimension", str) 

101 if not callable(distinct_colors_func): 

102 raise type_error( 

103 distinct_colors_func, "distinct_colors_func", call=True) 

104 if not callable(importance_to_line_width_func): 

105 raise type_error(importance_to_line_width_func, 

106 "importance_to_line_width_func", call=True) 

107 if not callable(importance_to_font_size_func): 

108 raise type_error(importance_to_font_size_func, 

109 "importance_to_font_size_func", call=True) 

110 if not isinstance(y_grid, bool): 

111 raise type_error(y_grid, "y_grid", bool) 

112 if not isinstance(x_grid, bool): 

113 raise type_error(x_grid, "x_grid", bool) 

114 if not ((x_label is None) or callable(x_label) 

115 or isinstance(x_label, str)): 

116 raise type_error(x_label, "x_label", (str, None), call=True) 

117 if not isinstance(x_label_inside, bool): 

118 raise type_error(x_label_inside, "x_label_inside", bool) 

119 if not isinstance(x_label_location, float): 

120 raise type_error(x_label_location, "x_label_location", float) 

121 if not ((y_label is None) or callable(y_label) 

122 or isinstance(y_label, str)): 

123 raise type_error(y_label, "y_label", (str, None), call=True) 

124 if not isinstance(y_label_inside, bool): 

125 raise type_error(y_label_inside, "y_label_inside", bool) 

126 if not isinstance(y_label_location, float): 

127 raise type_error(y_label_location, "y_label_location", float) 

128 if not isinstance(legend_pos, str): 

129 raise type_error(legend_pos, "legend_pos", str) 

130 if not callable(instance_sort_key): 

131 raise type_error(instance_sort_key, "instance_sort_key", call=True) 

132 if not callable(algorithm_sort_key): 

133 raise type_error(algorithm_sort_key, "algorithm_sort_key", call=True) 

134 if not callable(instance_namer): 

135 raise type_error(instance_namer, "instance_namer", call=True) 

136 if not callable(algorithm_namer): 

137 raise type_error(algorithm_namer, "algorithm_namer", call=True) 

138 

139 getter: Final[Callable[[EndResult], int | float]] \ 

140 = end_result_getter(dimension) 

141 logger(f"now plotting end violins for dimension {dimension}.") 

142 

143 if callable(y_axis): 

144 y_axis = y_axis(dimension) 

145 if not isinstance(y_axis, AxisRanger): 

146 raise type_error(y_axis, f"y_axis for {dimension}", AxisRanger) 

147 

148 # instance -> algorithm -> values 

149 data: dict[str, dict[str, list[int | float]]] = {} 

150 algo_set: set[str] = set() 

151 

152 # We now collect instances, the algorithms, and the measured values. 

153 for res in end_results: 

154 if not isinstance(res, EndResult): 

155 raise type_error(res, "violin plot element", EndResult) 

156 

157 algo_set.add(res.algorithm) 

158 

159 per_inst_data: dict[str, list[int | float]] 

160 if res.instance not in data: 

161 data[res.instance] = per_inst_data = {} 

162 else: 

163 per_inst_data = data[res.instance] 

164 inst_algo_data: list[int | float] 

165 if res.algorithm not in per_inst_data: 

166 per_inst_data[res.algorithm] = inst_algo_data = [] 

167 else: 

168 inst_algo_data = per_inst_data[res.algorithm] 

169 

170 value: int | float = getter(res) 

171 if not isinstance(value, int | float): 

172 raise type_error(value, "value", (int, float)) 

173 inst_algo_data.append(value) 

174 y_axis.register_value(value) 

175 

176 # We now know the number of instances and algorithms and have the data in 

177 # the hierarchical structure instance->algorithms->values. 

178 n_instances: Final[int] = len(data) 

179 n_algorithms: Final[int] = len(algo_set) 

180 if (n_instances <= 0) or (n_algorithms <= 0): 

181 raise ValueError("Data cannot be empty but found " 

182 f"{n_instances} and {n_algorithms}.") 

183 algorithms: Final[tuple[str, ...]] = \ 

184 tuple(sorted(algo_set, key=algorithm_sort_key)) 

185 logger(f"- {n_algorithms} algorithms ({algorithms}) " 

186 f"and {n_instances} instances ({data.keys()}).") 

187 

188 # compile the data 

189 inst_algos: list[tuple[str, list[str]]] = [] 

190 plot_data: list[list[int | float]] = [] 

191 plot_algos: list[str] = [] 

192 instances: Final[list[str]] = sorted(data.keys(), key=instance_sort_key) 

193 for inst in instances: 

194 per_inst_data = data[inst] 

195 algo_names: list[str] = sorted(per_inst_data.keys(), 

196 key=algorithm_sort_key) 

197 plot_algos.extend(algo_names) 

198 inst_algos.append((inst, algo_names)) 

199 for algo in algo_names: 

200 inst_algo_data = per_inst_data[algo] 

201 inst_algo_data.sort() 

202 plot_data.append(inst_algo_data) 

203 

204 # compute the violin positions 

205 n_bars: Final[int] = len(plot_data) 

206 if n_bars < max(n_instances, n_algorithms): 

207 raise ValueError(f"Huh? {n_bars}, {n_instances}, {n_algorithms}") 

208 bar_positions: Final[tuple[int, ...]] = \ 

209 tuple(range(1, len(plot_data) + 1)) 

210 

211 # Now we got all instances and all algorithms and know the axis ranges. 

212 font_size_0: Final[float] = importance_to_font_size_func(0) 

213 

214 # set up the graphics area 

215 axes: Final[Axes] = pu.get_axes(figure) 

216 axes.tick_params(axis="y", labelsize=font_size_0) 

217 axes.tick_params(axis="x", labelsize=font_size_0) 

218 

219 z_order: int = 0 

220 

221 # draw the grid 

222 grid_lwd: int | float | None = None 

223 if y_grid: 

224 grid_lwd = importance_to_line_width_func(-1) 

225 z_order += 1 

226 axes.grid(axis="y", color=pd.GRID_COLOR, linewidth=grid_lwd, 

227 zorder=z_order) 

228 

229 x_axis: Final[AxisRanger] = AxisRanger( 

230 chosen_min=0.5, chosen_max=bar_positions[-1] + 0.5) 

231 

232 # manually add x grid lines between instances 

233 if x_grid and (n_instances > 1) and (n_algorithms > 1): 

234 if not grid_lwd: 

235 grid_lwd = importance_to_line_width_func(-1) 

236 counter: int = 0 

237 for key in inst_algos: 

238 if counter > 0: 

239 z_order += 1 

240 axes.axvline(x=counter + 0.5, 

241 color=pd.GRID_COLOR, 

242 linewidth=grid_lwd, 

243 zorder=z_order) 

244 counter += len(key[1]) 

245 

246 y_axis.apply(axes, "y") 

247 x_axis.apply(axes, "x") 

248 

249 violin_width: Final[float] = 3 / 4 

250 z_order += 1 

251 violins: Final[dict[str, Any]] = axes.violinplot( 

252 dataset=plot_data, positions=bar_positions, orientation="vertical", 

253 widths=violin_width, showmeans=False, showextrema=False, 

254 showmedians=False) 

255 

256 # fix the algorithm colors 

257 unique_colors: Final[tuple[Any]] = distinct_colors_func(n_algorithms) 

258 algo_colors: Final[dict[str, tuple[float, float, float]]] = {} 

259 for i, algo in enumerate(algorithms): 

260 algo_colors[algo] = unique_colors[i] 

261 

262 bodies: Final[list[mc.PolyCollection]] = violins["bodies"] 

263 use_colors: Final[list[tuple[float, float, float]]] = [] 

264 counter = 0 

265 for key in inst_algos: 

266 for algo in key[1]: 

267 z_order += 1 

268 bd = bodies[counter] 

269 color = algo_colors[algo] 

270 use_colors.append(color) 

271 bd.set_edgecolor("none") 

272 bd.set_facecolor(color) 

273 bd.set_alpha(0.6666666666) 

274 counter += 1 

275 bd.set_zorder(z_order) 

276 

277 z_order += 1 

278 boxes_bg: Final[dict[str, Any]] = axes.boxplot( 

279 x=plot_data, positions=bar_positions, widths=violin_width, 

280 showmeans=True, patch_artist=False, notch=True, 

281 orientation="vertical", whis=(5.0, 95.0), manage_ticks=False, 

282 zorder=z_order) 

283 z_order += 1 

284 boxes_fg: Final[dict[str, Any]] = axes.boxplot( 

285 x=plot_data, positions=bar_positions, widths=violin_width, 

286 showmeans=True, patch_artist=False, notch=True, 

287 orientation="vertical", whis=(5.0, 95.0), manage_ticks=False, 

288 zorder=z_order) 

289 

290 for tkey in ("cmeans", "cmins", "cmaxes", "cbars", "cmedians", 

291 "cquantiles"): 

292 if tkey in violins: 

293 violins[tkey].set_color("none") 

294 

295 lwd_fg = importance_to_line_width_func(0) 

296 lwd_bg = importance_to_line_width_func(1) 

297 

298 for bid, boxes in enumerate([boxes_bg, boxes_fg]): 

299 for tkey in ("boxes", "medians", "whiskers", "caps", "fliers", 

300 "means"): 

301 if tkey not in boxes: 

302 continue 

303 polys: list[Line2D] = boxes[tkey] 

304 for line in polys: 

305 xdata = cast("list", line.get_xdata(True)) 

306 if len(xdata) <= 0: 

307 line.remove() 

308 continue 

309 index = int(max(xdata)) - 1 

310 thecolor: str | tuple[float, float, float] = \ 

311 "white" if bid == 0 else use_colors[index] 

312 width = lwd_bg if bid == 0 else lwd_fg 

313 line.set_solid_joinstyle("round") 

314 line.set_solid_capstyle("round") 

315 line.set_color(thecolor) 

316 line.set_linewidth(width) 

317 line.set_markeredgecolor(thecolor) 

318 line.set_markerfacecolor("none") 

319 line.set_markeredgewidth(width) 

320 z_order += 1 

321 line.set_zorder(z_order) 

322 

323 # compute the labels for the x-axis 

324 labels_str: list[str] = [] 

325 labels_x: list[float] = [] 

326 needs_legend: bool = False 

327 

328 counter = 0 

329 if n_instances > 1: 

330 # use only the instances as labels 

331 for key in inst_algos: 

332 current = counter 

333 counter += len(key[1]) 

334 labels_str.append(instance_namer(key[0])) 

335 labels_x.append(0.5 * (bar_positions[current] 

336 + bar_positions[counter - 1])) 

337 needs_legend = (n_algorithms > 1) 

338 elif n_algorithms > 1: 

339 # only use algorithms as key 

340 for key in inst_algos: 

341 for algo in key[1]: 

342 labels_str.append(algorithm_namer(algo)) 

343 labels_x.append(bar_positions[counter]) 

344 counter += 1 

345 

346 if labels_str: 

347 axes.set_xticks(ticks=labels_x, labels=labels_str, minor=False) 

348 else: 

349 axes.set_xticks([]) 

350 

351 # compute the x-label 

352 if (x_label is not None) and (not isinstance(x_label, str)): 

353 if not callable(x_label): 

354 raise type_error(x_label, "x_label", str, True) 

355 if (n_algorithms == 1) and (n_instances > 1): 

356 x_label = algorithm_namer(algorithms[0]) 

357 elif (n_algorithms > 1) and (n_instances == 1): 

358 x_label = instance_namer(instances[0]) 

359 else: 

360 x_label = x_label("algorithm_on_instance") 

361 

362 z_order += 1 

363 pu.label_axes(axes=axes, 

364 x_label=cast("str | None", x_label), 

365 x_label_inside=x_label_inside, 

366 x_label_location=x_label_location, 

367 y_label=y_label(dimension) if callable(y_label) else y_label, 

368 y_label_inside=y_label_inside, 

369 y_label_location=y_label_location, 

370 font_size=font_size_0, 

371 z_order=z_order) 

372 

373 if needs_legend: 

374 handles: Final[list[Line2D]] = [] 

375 

376 for algo in algorithms: 

377 linestyle = pd.create_line_style() 

378 linestyle["label"] = algorithm_namer(algo) 

379 legcol = algo_colors[algo] 

380 linestyle["color"] = legcol 

381 linestyle["markeredgecolor"] = legcol 

382 linestyle["xdata"] = [] 

383 linestyle["ydata"] = [] 

384 linestyle["linewidth"] = 6 

385 handles.append(Line2D(**linestyle)) # type: ignore 

386 z_order += 1 

387 

388 axes.legend(handles=handles, loc=legend_pos, 

389 labelcolor=pu.get_label_colors(handles), 

390 fontsize=font_size_0).set_zorder(z_order) 

391 

392 logger(f"done plotting {n_bars} end result boxes.") 

393 return axes