Coverage for moptipy / evaluation / plot_progress.py: 74%

237 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-11-24 08:49 +0000

1"""Plot a set of `Progress` or `StatRun` objects into one figure.""" 

2from math import isfinite 

3from typing import Any, Callable, Final, Iterable 

4 

5from matplotlib.artist import Artist # type: ignore 

6from matplotlib.axes import Axes # type: ignore 

7from matplotlib.figure import Figure # type: ignore 

8from pycommons.types import type_error 

9 

10import moptipy.utils.plot_defaults as pd 

11import moptipy.utils.plot_utils as pu 

12from moptipy.evaluation.axis_ranger import AxisRanger 

13from moptipy.evaluation.base import get_algorithm, get_instance, sort_key 

14from moptipy.evaluation.progress import Progress 

15from moptipy.evaluation.stat_run import StatRun, get_statistic 

16from moptipy.evaluation.styler import Styler 

17from moptipy.utils.lang import Lang 

18 

19 

20def plot_progress( 

21 progresses: Iterable[Progress | StatRun], 

22 figure: Axes | Figure, 

23 x_axis: AxisRanger | Callable[[str], AxisRanger] = 

24 AxisRanger.for_axis, 

25 y_axis: AxisRanger | Callable[[str], AxisRanger] = 

26 AxisRanger.for_axis, 

27 legend: bool = True, 

28 distinct_colors_func: Callable[[int], Any] = pd.distinct_colors, 

29 distinct_line_dashes_func: Callable[[int], Any] = 

30 pd.distinct_line_dashes, 

31 importance_to_line_width_func: Callable[[int], float] = 

32 pd.importance_to_line_width, 

33 importance_to_alpha_func: Callable[[int], float] = 

34 pd.importance_to_alpha, 

35 importance_to_font_size_func: Callable[[int], float] = 

36 pd.importance_to_font_size, 

37 x_grid: bool = True, 

38 y_grid: bool = True, 

39 x_label: str | Callable[[str], str] | None = Lang.translate, 

40 x_label_inside: bool = True, 

41 x_label_location: float = 0.5, 

42 y_label: str | Callable[[str], str] | None = Lang.translate, 

43 y_label_inside: bool = True, 

44 y_label_location: float = 1.0, 

45 instance_priority: float = 0.666, 

46 algorithm_priority: float = 0.333, 

47 stat_priority: float = 0.0, 

48 instance_sort_key: Callable[[str], Any] = lambda x: x, 

49 algorithm_sort_key: Callable[[str], Any] = lambda x: x, 

50 stat_sort_key: Callable[[str], Any] = lambda x: x, 

51 color_algorithms_as_fallback_group: bool = True, 

52 instance_namer: Callable[[str], str] = lambda x: x, 

53 algorithm_namer: Callable[[str], str] = lambda x: x) -> Axes: 

54 """ 

55 Plot a set of progress or statistical run lines into one chart. 

56 

57 :param progresses: the iterable of progresses and statistical runs 

58 :param figure: the figure to plot in 

59 :param x_axis: the x_axis ranger 

60 :param y_axis: the y_axis ranger 

61 :param legend: should we plot the legend? 

62 :param distinct_colors_func: the function returning the palette 

63 :param distinct_line_dashes_func: the function returning the line styles 

64 :param importance_to_line_width_func: the function converting importance 

65 values to line widths 

66 :param importance_to_alpha_func: the function converting importance 

67 values to alphas 

68 :param importance_to_font_size_func: the function converting importance 

69 values to font sizes 

70 :param x_grid: should we have a grid along the x-axis? 

71 :param y_grid: should we have a grid along the y-axis? 

72 :param x_label: a callable returning the label for the x-axis, a label 

73 string, or `None` if no label should be put 

74 :param x_label_inside: put the x-axis label inside the plot (so that 

75 it does not consume additional vertical space) 

76 :param x_label_location: the location of the x-axis label 

77 :param y_label: a callable returning the label for the y-axis, a label 

78 string, or `None` if no label should be put 

79 :param y_label_inside: put the y-axis label inside the plot (so that 

80 it does not consume additional horizontal space) 

81 :param y_label_location: the location of the y-axis label 

82 :param instance_priority: the style priority for instances 

83 :param algorithm_priority: the style priority for algorithms 

84 :param stat_priority: the style priority for statistics 

85 :param instance_sort_key: the sort key function for instances 

86 :param algorithm_sort_key: the sort key function for algorithms 

87 :param stat_sort_key: the sort key function for statistics 

88 :param color_algorithms_as_fallback_group: if only a single group of data 

89 was found, use algorithms as group and put them in the legend 

90 :param instance_namer: the name function for instances receives an 

91 instance ID and returns an instance name; default=identity function 

92 :param algorithm_namer: the name function for algorithms receives an 

93 algorithm ID and returns an algorithm name; default=identity function 

94 :returns: the axes object to allow you to add further plot elements 

95 """ 

96 # Before doing anything, let's do some type checking on the parameters. 

97 # I want to ensure that this function is called correctly before we begin 

98 # to actually process the data. It is better to fail early than to deliver 

99 # some incorrect results. 

100 if not isinstance(progresses, Iterable): 

101 raise type_error(progresses, "progresses", Iterable) 

102 if not isinstance(figure, Axes | Figure): 

103 raise type_error(figure, "figure", (Axes, Figure)) 

104 if not isinstance(legend, bool): 

105 raise type_error(legend, "legend", bool) 

106 if not callable(distinct_colors_func): 

107 raise type_error( 

108 distinct_colors_func, "distinct_colors_func", call=True) 

109 if not callable(distinct_colors_func): 

110 raise type_error( 

111 distinct_colors_func, "distinct_colors_func", call=True) 

112 if not callable(distinct_line_dashes_func): 

113 raise type_error( 

114 distinct_line_dashes_func, "distinct_line_dashes_func", call=True) 

115 if not callable(importance_to_line_width_func): 

116 raise type_error(importance_to_line_width_func, 

117 "importance_to_line_width_func", call=True) 

118 if not callable(importance_to_alpha_func): 

119 raise type_error( 

120 importance_to_alpha_func, "importance_to_alpha_func", call=True) 

121 if not callable(importance_to_font_size_func): 

122 raise type_error(importance_to_font_size_func, 

123 "importance_to_font_size_func", call=True) 

124 if not isinstance(x_grid, bool): 

125 raise type_error(x_grid, "x_grid", bool) 

126 if not isinstance(y_grid, bool): 

127 raise type_error(y_grid, "y_grid", bool) 

128 if not ((x_label is None) or callable(x_label) 

129 or isinstance(x_label, str)): 

130 raise type_error(x_label, "x_label", (str, None), call=True) 

131 if not isinstance(x_label_inside, bool): 

132 raise type_error(x_label_inside, "x_label_inside", bool) 

133 if not isinstance(x_label_location, float): 

134 raise type_error(x_label_location, "x_label_location", float) 

135 if not ((y_label is None) or callable(y_label) 

136 or isinstance(y_label, str)): 

137 raise type_error(y_label, "y_label", (str, None), call=True) 

138 if not isinstance(y_label_inside, bool): 

139 raise type_error(y_label_inside, "y_label_inside", bool) 

140 if not isinstance(y_label_location, float): 

141 raise type_error(y_label_location, "y_label_location", float) 

142 if not isinstance(instance_priority, float): 

143 raise type_error(instance_priority, "instance_priority", float) 

144 if not isfinite(instance_priority): 

145 raise ValueError(f"instance_priority cannot be {instance_priority}.") 

146 if not isinstance(algorithm_priority, float): 

147 raise type_error(algorithm_priority, "algorithm_priority", float) 

148 if not isfinite(algorithm_priority): 

149 raise ValueError(f"algorithm_priority cannot be {algorithm_priority}.") 

150 if not isinstance(stat_priority, float): 

151 raise type_error(stat_priority, "stat_priority", float) 

152 if not isfinite(stat_priority): 

153 raise ValueError(f"stat_priority cannot be {stat_priority}.") 

154 if not callable(instance_sort_key): 

155 raise type_error(instance_sort_key, "instance_sort_key", call=True) 

156 if not callable(algorithm_sort_key): 

157 raise type_error(algorithm_sort_key, "algorithm_sort_key", call=True) 

158 if not callable(stat_sort_key): 

159 raise type_error(stat_sort_key, "stat_sort_key", call=True) 

160 if not callable(instance_namer): 

161 raise type_error(instance_namer, "instance_namer", call=True) 

162 if not callable(algorithm_namer): 

163 raise type_error(algorithm_namer, "algorithm_namer", call=True) 

164 if not isinstance(color_algorithms_as_fallback_group, bool): 

165 raise type_error(color_algorithms_as_fallback_group, 

166 "color_algorithms_as_fallback_group", bool) 

167 

168 # First, we try to find groups of data to plot together in the same 

169 # color/style. We distinguish progress objects from statistical runs. 

170 instances: Final[Styler] = Styler(key_func=get_instance, 

171 namer=instance_namer, 

172 none_name=Lang.translate("all_insts"), 

173 priority=instance_priority, 

174 name_sort_function=instance_sort_key) 

175 algorithms: Final[Styler] = Styler(key_func=get_algorithm, 

176 namer=algorithm_namer, 

177 none_name=Lang.translate("all_algos"), 

178 priority=algorithm_priority, 

179 name_sort_function=algorithm_sort_key) 

180 statistics: Final[Styler] = Styler(key_func=get_statistic, 

181 none_name=Lang.translate("single_run"), 

182 priority=stat_priority, 

183 name_sort_function=stat_sort_key) 

184 x_dim: str | None = None 

185 y_dim: str | None = None 

186 progress_list: list[Progress] = [] 

187 statrun_list: list[StatRun] = [] 

188 

189 # First pass: find out the statistics, instances, algorithms, and types 

190 for prg in progresses: 

191 instances.add(prg) 

192 algorithms.add(prg) 

193 statistics.add(prg) 

194 if isinstance(prg, Progress): 

195 progress_list.append(prg) 

196 elif isinstance(prg, StatRun): 

197 statrun_list.append(prg) 

198 else: 

199 raise type_error(prg, "progress plot element", 

200 (Progress, StatRun)) 

201 

202 # Validate that we have consistent time and objective units. 

203 if x_dim is None: 

204 x_dim = prg.time_unit 

205 elif x_dim != prg.time_unit: 

206 raise ValueError( 

207 f"Time units {x_dim} and {prg.time_unit} do not fit!") 

208 

209 if y_dim is None: 

210 y_dim = prg.f_name 

211 elif y_dim != prg.f_name: 

212 raise ValueError( 

213 f"F-units {y_dim} and {prg.f_name} do not fit!") 

214 del progresses 

215 

216 if (len(progress_list) + len(statrun_list)) <= 0: 

217 raise ValueError("Empty input data?") 

218 

219 if (x_dim is None) or (y_dim is None): 

220 raise ValueError("Illegal state?") 

221 

222 instances.finalize() 

223 algorithms.finalize() 

224 statistics.finalize() 

225 

226 # pick the right sorting order 

227 sf: Callable[[StatRun | Progress], Any] = sort_key 

228 if (instances.count > 1) and (algorithms.count == 1) \ 

229 and (statistics.count == 1): 

230 def __x1(r: StatRun | Progress, ssf=instance_sort_key) -> Any: 

231 return ssf(r.instance) 

232 sf = __x1 

233 elif (instances.count == 1) and (algorithms.count > 1) \ 

234 and (statistics.count == 1): 

235 def __x2(r: StatRun | Progress, ssf=algorithm_sort_key) -> Any: 

236 return ssf(r.algorithm) 

237 sf = __x2 

238 elif (instances.count == 1) and (algorithms.count == 1) \ 

239 and (statistics.count > 1): 

240 def __x3(r: StatRun | Progress, ssf=stat_sort_key) -> Any: 

241 return ssf(r.instance) 

242 sf = __x3 

243 elif (instances.count > 1) and (algorithms.count > 1): 

244 def __x4(r: StatRun | Progress, sas=algorithm_sort_key, 

245 ias=instance_sort_key, 

246 ag=algorithm_priority > instance_priority) \ 

247 -> tuple[Any, Any]: 

248 k1 = ias(r.instance) 

249 k2 = sas(r.algorithm) 

250 return (k2, k1) if ag else (k1, k2) 

251 sf = __x4 

252 

253 statrun_list.sort(key=sf) 

254 progress_list.sort() 

255 

256 def __set_importance(st: Styler) -> None: 

257 if st is statistics: 

258 none = -1 

259 not_none = 1 

260 else: 

261 none = 1 

262 not_none = 0 

263 none_lw = importance_to_line_width_func(none) 

264 not_none_lw = importance_to_line_width_func(not_none) 

265 st.set_line_width(lambda x: [none_lw if i <= 0 else not_none_lw 

266 for i in range(x)]) 

267 none_a = importance_to_alpha_func(none) 

268 not_none_a = importance_to_alpha_func(not_none) 

269 st.set_line_alpha(lambda x: [none_a if i <= 0 else not_none_a 

270 for i in range(x)]) 

271 

272 # determine the style groups 

273 groups: list[Styler] = [] 

274 

275 no_importance = True 

276 if instances.count > 1: 

277 groups.append(instances) 

278 if algorithms.count > 1: 

279 groups.append(algorithms) 

280 add_stat_to_groups = False 

281 if statistics.count > 1: 

282 if statistics.has_none and (statistics.count == 2): 

283 __set_importance(statistics) 

284 no_importance = False 

285 add_stat_to_groups = True 

286 else: 

287 groups.append(statistics) 

288 

289 if len(groups) > 0: 

290 groups.sort() 

291 groups[0].set_line_color(distinct_colors_func) 

292 

293 if len(groups) > 1: 

294 groups[1].set_line_dash(distinct_line_dashes_func) 

295 

296 if (len(groups) > 2) and no_importance: 

297 g = groups[2] 

298 if g.count > 2: 

299 raise ValueError( 

300 f"Cannot have {g.count} importance values.") 

301 __set_importance(g) 

302 no_importance = False 

303 elif color_algorithms_as_fallback_group: 

304 algorithms.set_line_color(distinct_colors_func) 

305 groups.append(algorithms) 

306 

307 if add_stat_to_groups: 

308 groups.append(statistics) 

309 

310 # If we only have <= 2 groups, we can mark None and not-None values with 

311 # different importance. 

312 if no_importance and statistics.has_none and (statistics.count > 1): 

313 __set_importance(statistics) 

314 no_importance = False 

315 if no_importance and instances.has_none and (instances.count > 1): 

316 __set_importance(instances) 

317 no_importance = False 

318 if no_importance and algorithms.has_none and (algorithms.count > 1): 

319 __set_importance(algorithms) 

320 

321 # we will collect all lines to plot in plot_list 

322 plot_list: list[dict] = [] 

323 

324 # first we collect all progress object 

325 for prgs in progress_list: 

326 style = pd.create_line_style() 

327 for g in groups: 

328 g.add_line_style(prgs, style) 

329 style["x"] = prgs.time 

330 style["y"] = prgs.f 

331 plot_list.append(style) 

332 del progress_list 

333 

334 # now collect the plot data for the statistics 

335 for sn in statistics.keys: 

336 if sn is None: 

337 continue 

338 for sr in statrun_list: 

339 if statistics.key_func(sr) != sn: 

340 continue 

341 

342 style = pd.create_line_style() 

343 for g in groups: 

344 g.add_line_style(sr, style) 

345 style["x"] = sr.stat[:, 0] 

346 style["y"] = sr.stat[:, 1] 

347 plot_list.append(style) 

348 del statrun_list 

349 

350 font_size_0: Final[float] = importance_to_font_size_func(0) 

351 

352 # set up the graphics area 

353 axes: Final[Axes] = pu.get_axes(figure) 

354 axes.tick_params(axis="x", labelsize=font_size_0) 

355 axes.tick_params(axis="y", labelsize=font_size_0) 

356 

357 # draw the grid 

358 if x_grid or y_grid: 

359 grid_lwd = importance_to_line_width_func(-1) 

360 if x_grid: 

361 axes.grid(axis="x", color=pd.GRID_COLOR, linewidth=grid_lwd) 

362 if y_grid: 

363 axes.grid(axis="y", color=pd.GRID_COLOR, linewidth=grid_lwd) 

364 

365 # set up the axis rangers 

366 if callable(x_axis): 

367 x_axis = x_axis(x_dim) 

368 if not isinstance(x_axis, AxisRanger): 

369 raise type_error(x_axis, "x_axis", AxisRanger) 

370 

371 if callable(y_axis): 

372 y_axis = y_axis(y_dim) 

373 if not isinstance(y_axis, AxisRanger): 

374 raise type_error(y_axis, "y_axis", AxisRanger) 

375 

376 # plot the lines 

377 for line in plot_list: 

378 axes.step(where="post", **line) 

379 x_axis.register_array(line["x"]) 

380 y_axis.register_array(line["y"]) 

381 del plot_list 

382 

383 x_axis.apply(axes, "x") 

384 y_axis.apply(axes, "y") 

385 

386 if legend: 

387 handles: list[Artist] = [] 

388 

389 for g in groups: 

390 g.add_to_legend(handles.append) 

391 g.has_style = False 

392 

393 if instances.has_style: 

394 instances.add_to_legend(handles.append) 

395 if algorithms.has_style: 

396 algorithms.add_to_legend(handles.append) 

397 if statistics.has_style: 

398 statistics.add_to_legend(handles.append) 

399 

400 if len(handles) > 0: 

401 axes.legend(loc="upper right", 

402 handles=handles, 

403 labelcolor=[art.color if hasattr(art, "color") 

404 else pd.COLOR_BLACK for art in handles], 

405 fontsize=font_size_0) 

406 

407 pu.label_axes(axes=axes, 

408 x_label=x_label(x_dim) if callable(x_label) else x_label, 

409 x_label_inside=x_label_inside, 

410 x_label_location=x_label_location, 

411 y_label=y_label(y_dim) if callable(y_label) else y_label, 

412 y_label_inside=y_label_inside, 

413 y_label_location=y_label_location, 

414 font_size=font_size_0) 

415 return axes