Coverage for moptipy / evaluation / plot_end_statistics_over_parameter.py: 76%

208 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-11-24 08:49 +0000

1"""Plot the end results over a parameter.""" 

2from math import isfinite 

3from typing import Any, Callable, Final, Iterable, cast 

4 

5from matplotlib.artist import Artist # type: ignore 

6from matplotlib.axes import Axes # type: ignore 

7from matplotlib.figure import Figure # type: ignore 

8from pycommons.io.csv import SCOPE_SEPARATOR 

9from pycommons.math.sample_statistics import KEY_MEAN_GEOM 

10from pycommons.types import type_error 

11 

12import moptipy.utils.plot_defaults as pd 

13import moptipy.utils.plot_utils as pu 

14from moptipy.evaluation.axis_ranger import AxisRanger 

15from moptipy.evaluation.base import F_NAME_SCALED 

16from moptipy.evaluation.end_statistics import EndStatistics 

17from moptipy.evaluation.end_statistics import getter as end_stat_getter 

18from moptipy.evaluation.styler import Styler 

19from moptipy.utils.lang import Lang 

20 

21 

22def __make_y_label(y_dim: str) -> str: 

23 """ 

24 Make the y label. 

25 

26 :param y_dim: the y dimension 

27 :returns: the y label 

28 """ 

29 dotidx: Final[int] = y_dim.find(SCOPE_SEPARATOR) 

30 if dotidx > 0: 

31 y_dimension: Final[str] = y_dim[:dotidx] 

32 y_stat: Final[str] = y_dim[dotidx + 1:] 

33 return Lang.translate_func(y_stat)(y_dimension) 

34 return Lang.translate(y_dim) 

35 

36 

37def __make_y_axis(y_dim: str) -> AxisRanger: 

38 """ 

39 Make the y axis. 

40 

41 :param y_dim: the y dimension 

42 :returns: the y axis 

43 """ 

44 dotidx: Final[int] = y_dim.find(SCOPE_SEPARATOR) 

45 if dotidx > 0: 

46 y_dim = y_dim[:dotidx] 

47 return AxisRanger.for_axis(y_dim) 

48 

49 

50def plot_end_statistics_over_param( 

51 data: Iterable[EndStatistics], 

52 figure: Axes | Figure, 

53 x_getter: Callable[[EndStatistics], int | float], 

54 y_dim: str = f"{F_NAME_SCALED}{SCOPE_SEPARATOR}{KEY_MEAN_GEOM}", 

55 algorithm_getter: Callable[[EndStatistics], str | None] = 

56 lambda es: es.algorithm, 

57 instance_getter: Callable[[EndStatistics], str | None] = 

58 lambda es: es.instance, 

59 x_axis: AxisRanger | Callable[[], AxisRanger] = AxisRanger, 

60 y_axis: AxisRanger | Callable[[str], AxisRanger] = 

61 __make_y_axis, 

62 legend: bool = True, 

63 legend_pos: str = "upper right", 

64 distinct_colors_func: Callable[[int], Any] = pd.distinct_colors, 

65 distinct_line_dashes_func: Callable[[int], Any] = 

66 pd.distinct_line_dashes, 

67 importance_to_line_width_func: Callable[[int], float] = 

68 pd.importance_to_line_width, 

69 importance_to_font_size_func: Callable[[int], float] = 

70 pd.importance_to_font_size, 

71 x_grid: bool = True, 

72 y_grid: bool = True, 

73 x_label: str | None = None, 

74 x_label_inside: bool = True, 

75 x_label_location: float = 0.5, 

76 y_label: str | Callable[[str], str] | None = __make_y_label, 

77 y_label_inside: bool = True, 

78 y_label_location: float = 1.0, 

79 instance_priority: float = 0.666, 

80 algorithm_priority: float = 0.333, 

81 stat_priority: float = 0.0, 

82 instance_sort_key: Callable[[str], Any] = lambda x: x, 

83 algorithm_sort_key: Callable[[str], Any] = lambda x: x, 

84 instance_namer: Callable[[str], str] = lambda x: x, 

85 algorithm_namer: Callable[[str], str] = lambda x: x, 

86 stat_sort_key: Callable[[str], str] = lambda x: x, 

87 color_algorithms_as_fallback_group: bool = True) -> Axes: 

88 """ 

89 Plot a series of end result statistics over a parameter. 

90 

91 :param data: the iterable of EndStatistics 

92 :param figure: the figure to plot in 

93 :param x_getter: the function computing the x-value for each statistics 

94 object 

95 :param y_dim: the dimension to be plotted along the y-axis 

96 :param algorithm_getter: the algorithm getter 

97 :param instance_getter: the instance getter 

98 :param x_axis: the x_axis ranger 

99 :param y_axis: the y_axis ranger 

100 :param legend: should we plot the legend? 

101 :param legend_pos: the legend position 

102 :param distinct_colors_func: the function returning the palette 

103 :param distinct_line_dashes_func: the function returning the line styles 

104 :param importance_to_line_width_func: the function converting importance 

105 values to line widths 

106 :param importance_to_font_size_func: the function converting importance 

107 values to font sizes 

108 :param x_grid: should we have a grid along the x-axis? 

109 :param y_grid: should we have a grid along the y-axis? 

110 :param x_label: the label for the x-axi or `None` if no label should be put 

111 :param x_label_inside: put the x-axis label inside the plot (so that 

112 it does not consume additional vertical space) 

113 :param x_label_location: the location of the x-axis label 

114 :param y_label: a callable returning the label for the y-axis, a label 

115 string, or `None` if no label should be put 

116 :param y_label_inside: put the y-axis label inside the plot (so that 

117 it does not consume additional horizontal space) 

118 :param y_label_location: the location of the y-axis label 

119 :param instance_priority: the style priority for instances 

120 :param algorithm_priority: the style priority for algorithms 

121 :param stat_priority: the style priority for statistics 

122 :param instance_sort_key: the sort key function for instances 

123 :param algorithm_sort_key: the sort key function for algorithms 

124 :param instance_namer: the name function for instances receives an 

125 instance ID and returns an instance name; default=identity function 

126 :param algorithm_namer: the name function for algorithms receives an 

127 algorithm ID and returns an instance name; default=identity function 

128 :param stat_sort_key: the sort key function for statistics 

129 :param color_algorithms_as_fallback_group: if only a single group of data 

130 was found, use algorithms as group and put them in the legend 

131 :returns: the axes object to allow you to add further plot elements 

132 """ 

133 # Before doing anything, let's do some type checking on the parameters. 

134 # I want to ensure that this function is called correctly before we begin 

135 # to actually process the data. It is better to fail early than to deliver 

136 # some incorrect results. 

137 if not isinstance(data, Iterable): 

138 raise type_error(data, "data", Iterable) 

139 if not isinstance(figure, Axes | Figure): 

140 raise type_error(figure, "figure", (Axes, Figure)) 

141 if not callable(x_getter): 

142 raise type_error(x_getter, "x_getter", call=True) 

143 if not isinstance(y_dim, str): 

144 raise type_error(y_dim, "y_dim", str) 

145 if len(y_dim) <= 0: 

146 raise ValueError(f"invalid y-dimension {y_dim!r}") 

147 if not callable(instance_getter): 

148 raise type_error(instance_getter, "instance_getter", call=True) 

149 if not callable(algorithm_getter): 

150 raise type_error(algorithm_getter, "algorithm_getter", call=True) 

151 if not isinstance(legend, bool): 

152 raise type_error(legend, "legend", bool) 

153 if not isinstance(legend_pos, str): 

154 raise type_error(legend_pos, "legend_pos", str) 

155 if not callable(distinct_colors_func): 

156 raise type_error( 

157 distinct_colors_func, "distinct_colors_func", call=True) 

158 if not callable(distinct_colors_func): 

159 raise type_error( 

160 distinct_colors_func, "distinct_colors_func", call=True) 

161 if not callable(distinct_line_dashes_func): 

162 raise type_error( 

163 distinct_line_dashes_func, "distinct_line_dashes_func", call=True) 

164 if not callable(importance_to_font_size_func): 

165 raise type_error(importance_to_font_size_func, 

166 "importance_to_font_size_func", call=True) 

167 if not isinstance(x_grid, bool): 

168 raise type_error(x_grid, "x_grid", bool) 

169 if not isinstance(y_grid, bool): 

170 raise type_error(y_grid, "y_grid", bool) 

171 if not ((x_label is None) or isinstance(x_label, str)): 

172 raise type_error(x_label, "x_label", (str, None)) 

173 if not isinstance(x_label_inside, bool): 

174 raise type_error(x_label_inside, "x_label_inside", bool) 

175 if not isinstance(x_label_location, float): 

176 raise type_error(x_label_location, "x_label_location", float) 

177 if not ((y_label is None) or callable(y_label) 

178 or isinstance(y_label, str)): 

179 raise type_error(y_label, "y_label", (str, None), call=True) 

180 if not isinstance(y_label_inside, bool): 

181 raise type_error(y_label_inside, "y_label_inside", bool) 

182 if not isinstance(y_label_location, float): 

183 raise type_error(y_label_location, "y_label_location", float) 

184 if not isinstance(instance_priority, float): 

185 raise type_error(instance_priority, "instance_priority", float) 

186 if not isfinite(instance_priority): 

187 raise ValueError(f"instance_priority cannot be {instance_priority}.") 

188 if not isinstance(algorithm_priority, float): 

189 raise type_error(algorithm_priority, "algorithm_priority", float) 

190 if not isfinite(algorithm_priority): 

191 raise ValueError(f"algorithm_priority cannot be {algorithm_priority}.") 

192 if not isinstance(stat_priority, float): 

193 raise type_error(stat_priority, "stat_priority", float) 

194 if not isfinite(stat_priority): 

195 raise ValueError(f"stat_priority cannot be {stat_priority}.") 

196 if not callable(instance_sort_key): 

197 raise type_error(instance_sort_key, "instance_sort_key", call=True) 

198 if not callable(algorithm_sort_key): 

199 raise type_error(algorithm_sort_key, "algorithm_sort_key", call=True) 

200 if not callable(stat_sort_key): 

201 raise type_error(stat_sort_key, "stat_sort_key", call=True) 

202 if not callable(instance_namer): 

203 raise type_error(instance_namer, "instance_namer", call=True) 

204 if not callable(algorithm_namer): 

205 raise type_error(algorithm_namer, "algorithm_namer", call=True) 

206 if not isinstance(color_algorithms_as_fallback_group, bool): 

207 raise type_error(color_algorithms_as_fallback_group, 

208 "color_algorithms_as_fallback_group", bool) 

209 

210 # the getter for the dimension value 

211 y_getter: Final[Callable[[EndStatistics], int | float]] \ 

212 = cast("Callable[[EndStatistics], int | float]", 

213 end_stat_getter(y_dim)) 

214 if not callable(y_getter): 

215 raise type_error(y_getter, "y-getter", call=True) 

216 

217 # set up the axis rangers 

218 if callable(x_axis): 

219 x_axis = x_axis() 

220 if not isinstance(x_axis, AxisRanger): 

221 raise type_error(x_axis, "x_axis", AxisRanger) 

222 

223 if callable(y_axis): 

224 y_axis = y_axis(y_dim) 

225 if not isinstance(y_axis, AxisRanger): 

226 raise type_error(y_axis, "y_axis", AxisRanger) 

227 

228 # First, we try to find groups of data to plot together in the same 

229 # color/style. We distinguish progress objects from statistical runs. 

230 instances: Final[Styler] = Styler( 

231 none_name=Lang.translate("all_insts"), 

232 priority=instance_priority, 

233 namer=instance_namer, 

234 name_sort_function=instance_sort_key) 

235 algorithms: Final[Styler] = Styler( 

236 none_name=Lang.translate("all_algos"), 

237 namer=algorithm_namer, 

238 priority=algorithm_priority, name_sort_function=algorithm_sort_key) 

239 

240 # we now extract the data: x -> algo -> inst -> y 

241 dataset: Final[dict[str | None, dict[ 

242 str | None, dict[int | float, int | float]]]] = {} 

243 for endstat in data: 

244 if not isinstance(endstat, EndStatistics): 

245 raise type_error(endstat, "element in data", EndStatistics) 

246 x_value = x_getter(endstat) 

247 if not isinstance(x_value, int | float): 

248 raise type_error(x_value, "x-value", (int, float)) 

249 l_algo = algorithm_getter(endstat) 

250 if not ((l_algo is None) or isinstance(l_algo, str)): 

251 raise type_error(l_algo, "algorithm name", None, call=True) 

252 l_inst = instance_getter(endstat) 

253 if not ((l_inst is None) or isinstance(l_inst, str)): 

254 raise type_error(l_algo, "instance name", None, call=True) 

255 y_value = y_getter(endstat) 

256 if not isinstance(y_value, int | float): 

257 raise type_error(y_value, "y-value", (int, float)) 

258 if l_algo in dataset: 

259 l1_dataset = dataset[l_algo] 

260 else: 

261 dataset[l_algo] = l1_dataset = {} 

262 if l_inst in l1_dataset: 

263 l2_dataset = l1_dataset[l_inst] 

264 else: 

265 l1_dataset[l_inst] = l2_dataset = {} 

266 if x_value in l2_dataset: 

267 raise ValueError( 

268 f"combination x={x_value}, algo={l_algo!r}, inst={l_inst!r} " 

269 f"already known as value {l2_dataset[x_value]}, cannot assign " 

270 f"value {y_value}.") 

271 l2_dataset[x_value] = y_value 

272 x_axis.register_value(x_value) 

273 y_axis.register_value(y_value) 

274 algorithms.add(l_algo) 

275 instances.add(l_inst) 

276 del data, y_getter, x_getter, x_value, y_value 

277 

278 if len(dataset) <= 0: 

279 raise ValueError("no data found?") 

280 

281 def __set_importance(st: Styler) -> None: 

282 none = 1 

283 not_none = 0 

284 none_lw = importance_to_line_width_func(none) 

285 not_none_lw = importance_to_line_width_func(not_none) 

286 st.set_line_width(lambda p: [none_lw if i <= 0 else not_none_lw 

287 for i in range(p)]) 

288 

289 # determine the style groups 

290 groups: list[Styler] = [] 

291 instances.finalize() 

292 algorithms.finalize() 

293 

294 if instances.count > 1: 

295 groups.append(instances) 

296 if algorithms.count > 1: 

297 groups.append(algorithms) 

298 

299 if len(groups) > 0: 

300 groups.sort() 

301 groups[0].set_line_color(distinct_colors_func) 

302 

303 if len(groups) > 1: 

304 groups[1].set_line_dash(distinct_line_dashes_func) 

305 elif color_algorithms_as_fallback_group: 

306 algorithms.set_line_color(distinct_colors_func) 

307 groups.append(algorithms) 

308 

309 # If we only have <= 2 groups, we can mark None and not-None values with 

310 # different importance. 

311 if instances.has_none and (instances.count > 1): 

312 __set_importance(instances) 

313 elif algorithms.has_none and (algorithms.count > 1): 

314 __set_importance(algorithms) 

315 

316 # we will collect all lines to plot in plot_list 

317 plot_list: list[dict] = [] 

318 for algo in algorithms.keys: 

319 l1_dataset = dataset[algo] 

320 for inst in instances.keys: 

321 if inst not in l1_dataset: 

322 raise ValueError(f"instance {inst!r} not in dataset" 

323 f" for algorithm {algo!r}.") 

324 l2_dataset = l1_dataset[inst] 

325 style = pd.create_line_style() 

326 style["x"] = x_vals = sorted(l2_dataset.keys()) 

327 style["y"] = [l2_dataset[x] for x in x_vals] 

328 for g in groups: 

329 g.add_line_style(inst if g is instances else algo, style) 

330 plot_list.append(style) 

331 del dataset, l1_dataset, l2_dataset 

332 

333 # now we have all data, let's move to the actual plotting 

334 font_size_0: Final[float] = importance_to_font_size_func(0) 

335 

336 # set up the graphics area 

337 axes: Final[Axes] = pu.get_axes(figure) 

338 axes.tick_params(axis="x", labelsize=font_size_0) 

339 axes.tick_params(axis="y", labelsize=font_size_0) 

340 

341 # draw the grid 

342 if x_grid or y_grid: 

343 grid_lwd = importance_to_line_width_func(-1) 

344 if x_grid: 

345 axes.grid(axis="x", color=pd.GRID_COLOR, linewidth=grid_lwd) 

346 if y_grid: 

347 axes.grid(axis="y", color=pd.GRID_COLOR, linewidth=grid_lwd) 

348 

349 # plot the lines 

350 for line in plot_list: 

351 axes.step(where="post", **line) 

352 del plot_list 

353 

354 # make sure that we can see the maximum of the parameters 

355 x_axis.pad_detected_range(pad_max=True) 

356 

357 x_axis.apply(axes, "x") 

358 y_axis.apply(axes, "y") 

359 

360 if legend: 

361 handles: list[Artist] = [] 

362 

363 for g in groups: 

364 g.add_to_legend(handles.append) 

365 g.has_style = False 

366 

367 if instances.has_style: 

368 instances.add_to_legend(handles.append) 

369 if algorithms.has_style: 

370 algorithms.add_to_legend(handles.append) 

371 

372 if len(handles) > 0: 

373 axes.legend(loc=legend_pos, 

374 handles=handles, 

375 labelcolor=[art.color if hasattr(art, "color") 

376 else pd.COLOR_BLACK for art in handles], 

377 fontsize=font_size_0) 

378 

379 pu.label_axes(axes=axes, 

380 x_label=x_label, 

381 x_label_inside=x_label_inside, 

382 x_label_location=x_label_location, 

383 y_label=y_label(y_dim) if callable(y_label) else y_label, 

384 y_label_inside=y_label_inside, 

385 y_label_location=y_label_location, 

386 font_size=font_size_0) 

387 return axes