Coverage for moptipy / evaluation / plot_progress.py: 74%
237 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-11-24 08:49 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-11-24 08:49 +0000
1"""Plot a set of `Progress` or `StatRun` objects into one figure."""
2from math import isfinite
3from typing import Any, Callable, Final, Iterable
5from matplotlib.artist import Artist # type: ignore
6from matplotlib.axes import Axes # type: ignore
7from matplotlib.figure import Figure # type: ignore
8from pycommons.types import type_error
10import moptipy.utils.plot_defaults as pd
11import moptipy.utils.plot_utils as pu
12from moptipy.evaluation.axis_ranger import AxisRanger
13from moptipy.evaluation.base import get_algorithm, get_instance, sort_key
14from moptipy.evaluation.progress import Progress
15from moptipy.evaluation.stat_run import StatRun, get_statistic
16from moptipy.evaluation.styler import Styler
17from moptipy.utils.lang import Lang
20def plot_progress(
21 progresses: Iterable[Progress | StatRun],
22 figure: Axes | Figure,
23 x_axis: AxisRanger | Callable[[str], AxisRanger] =
24 AxisRanger.for_axis,
25 y_axis: AxisRanger | Callable[[str], AxisRanger] =
26 AxisRanger.for_axis,
27 legend: bool = True,
28 distinct_colors_func: Callable[[int], Any] = pd.distinct_colors,
29 distinct_line_dashes_func: Callable[[int], Any] =
30 pd.distinct_line_dashes,
31 importance_to_line_width_func: Callable[[int], float] =
32 pd.importance_to_line_width,
33 importance_to_alpha_func: Callable[[int], float] =
34 pd.importance_to_alpha,
35 importance_to_font_size_func: Callable[[int], float] =
36 pd.importance_to_font_size,
37 x_grid: bool = True,
38 y_grid: bool = True,
39 x_label: str | Callable[[str], str] | None = Lang.translate,
40 x_label_inside: bool = True,
41 x_label_location: float = 0.5,
42 y_label: str | Callable[[str], str] | None = Lang.translate,
43 y_label_inside: bool = True,
44 y_label_location: float = 1.0,
45 instance_priority: float = 0.666,
46 algorithm_priority: float = 0.333,
47 stat_priority: float = 0.0,
48 instance_sort_key: Callable[[str], Any] = lambda x: x,
49 algorithm_sort_key: Callable[[str], Any] = lambda x: x,
50 stat_sort_key: Callable[[str], Any] = lambda x: x,
51 color_algorithms_as_fallback_group: bool = True,
52 instance_namer: Callable[[str], str] = lambda x: x,
53 algorithm_namer: Callable[[str], str] = lambda x: x) -> Axes:
54 """
55 Plot a set of progress or statistical run lines into one chart.
57 :param progresses: the iterable of progresses and statistical runs
58 :param figure: the figure to plot in
59 :param x_axis: the x_axis ranger
60 :param y_axis: the y_axis ranger
61 :param legend: should we plot the legend?
62 :param distinct_colors_func: the function returning the palette
63 :param distinct_line_dashes_func: the function returning the line styles
64 :param importance_to_line_width_func: the function converting importance
65 values to line widths
66 :param importance_to_alpha_func: the function converting importance
67 values to alphas
68 :param importance_to_font_size_func: the function converting importance
69 values to font sizes
70 :param x_grid: should we have a grid along the x-axis?
71 :param y_grid: should we have a grid along the y-axis?
72 :param x_label: a callable returning the label for the x-axis, a label
73 string, or `None` if no label should be put
74 :param x_label_inside: put the x-axis label inside the plot (so that
75 it does not consume additional vertical space)
76 :param x_label_location: the location of the x-axis label
77 :param y_label: a callable returning the label for the y-axis, a label
78 string, or `None` if no label should be put
79 :param y_label_inside: put the y-axis label inside the plot (so that
80 it does not consume additional horizontal space)
81 :param y_label_location: the location of the y-axis label
82 :param instance_priority: the style priority for instances
83 :param algorithm_priority: the style priority for algorithms
84 :param stat_priority: the style priority for statistics
85 :param instance_sort_key: the sort key function for instances
86 :param algorithm_sort_key: the sort key function for algorithms
87 :param stat_sort_key: the sort key function for statistics
88 :param color_algorithms_as_fallback_group: if only a single group of data
89 was found, use algorithms as group and put them in the legend
90 :param instance_namer: the name function for instances receives an
91 instance ID and returns an instance name; default=identity function
92 :param algorithm_namer: the name function for algorithms receives an
93 algorithm ID and returns an algorithm name; default=identity function
94 :returns: the axes object to allow you to add further plot elements
95 """
96 # Before doing anything, let's do some type checking on the parameters.
97 # I want to ensure that this function is called correctly before we begin
98 # to actually process the data. It is better to fail early than to deliver
99 # some incorrect results.
100 if not isinstance(progresses, Iterable):
101 raise type_error(progresses, "progresses", Iterable)
102 if not isinstance(figure, Axes | Figure):
103 raise type_error(figure, "figure", (Axes, Figure))
104 if not isinstance(legend, bool):
105 raise type_error(legend, "legend", bool)
106 if not callable(distinct_colors_func):
107 raise type_error(
108 distinct_colors_func, "distinct_colors_func", call=True)
109 if not callable(distinct_colors_func):
110 raise type_error(
111 distinct_colors_func, "distinct_colors_func", call=True)
112 if not callable(distinct_line_dashes_func):
113 raise type_error(
114 distinct_line_dashes_func, "distinct_line_dashes_func", call=True)
115 if not callable(importance_to_line_width_func):
116 raise type_error(importance_to_line_width_func,
117 "importance_to_line_width_func", call=True)
118 if not callable(importance_to_alpha_func):
119 raise type_error(
120 importance_to_alpha_func, "importance_to_alpha_func", call=True)
121 if not callable(importance_to_font_size_func):
122 raise type_error(importance_to_font_size_func,
123 "importance_to_font_size_func", call=True)
124 if not isinstance(x_grid, bool):
125 raise type_error(x_grid, "x_grid", bool)
126 if not isinstance(y_grid, bool):
127 raise type_error(y_grid, "y_grid", bool)
128 if not ((x_label is None) or callable(x_label)
129 or isinstance(x_label, str)):
130 raise type_error(x_label, "x_label", (str, None), call=True)
131 if not isinstance(x_label_inside, bool):
132 raise type_error(x_label_inside, "x_label_inside", bool)
133 if not isinstance(x_label_location, float):
134 raise type_error(x_label_location, "x_label_location", float)
135 if not ((y_label is None) or callable(y_label)
136 or isinstance(y_label, str)):
137 raise type_error(y_label, "y_label", (str, None), call=True)
138 if not isinstance(y_label_inside, bool):
139 raise type_error(y_label_inside, "y_label_inside", bool)
140 if not isinstance(y_label_location, float):
141 raise type_error(y_label_location, "y_label_location", float)
142 if not isinstance(instance_priority, float):
143 raise type_error(instance_priority, "instance_priority", float)
144 if not isfinite(instance_priority):
145 raise ValueError(f"instance_priority cannot be {instance_priority}.")
146 if not isinstance(algorithm_priority, float):
147 raise type_error(algorithm_priority, "algorithm_priority", float)
148 if not isfinite(algorithm_priority):
149 raise ValueError(f"algorithm_priority cannot be {algorithm_priority}.")
150 if not isinstance(stat_priority, float):
151 raise type_error(stat_priority, "stat_priority", float)
152 if not isfinite(stat_priority):
153 raise ValueError(f"stat_priority cannot be {stat_priority}.")
154 if not callable(instance_sort_key):
155 raise type_error(instance_sort_key, "instance_sort_key", call=True)
156 if not callable(algorithm_sort_key):
157 raise type_error(algorithm_sort_key, "algorithm_sort_key", call=True)
158 if not callable(stat_sort_key):
159 raise type_error(stat_sort_key, "stat_sort_key", call=True)
160 if not callable(instance_namer):
161 raise type_error(instance_namer, "instance_namer", call=True)
162 if not callable(algorithm_namer):
163 raise type_error(algorithm_namer, "algorithm_namer", call=True)
164 if not isinstance(color_algorithms_as_fallback_group, bool):
165 raise type_error(color_algorithms_as_fallback_group,
166 "color_algorithms_as_fallback_group", bool)
168 # First, we try to find groups of data to plot together in the same
169 # color/style. We distinguish progress objects from statistical runs.
170 instances: Final[Styler] = Styler(key_func=get_instance,
171 namer=instance_namer,
172 none_name=Lang.translate("all_insts"),
173 priority=instance_priority,
174 name_sort_function=instance_sort_key)
175 algorithms: Final[Styler] = Styler(key_func=get_algorithm,
176 namer=algorithm_namer,
177 none_name=Lang.translate("all_algos"),
178 priority=algorithm_priority,
179 name_sort_function=algorithm_sort_key)
180 statistics: Final[Styler] = Styler(key_func=get_statistic,
181 none_name=Lang.translate("single_run"),
182 priority=stat_priority,
183 name_sort_function=stat_sort_key)
184 x_dim: str | None = None
185 y_dim: str | None = None
186 progress_list: list[Progress] = []
187 statrun_list: list[StatRun] = []
189 # First pass: find out the statistics, instances, algorithms, and types
190 for prg in progresses:
191 instances.add(prg)
192 algorithms.add(prg)
193 statistics.add(prg)
194 if isinstance(prg, Progress):
195 progress_list.append(prg)
196 elif isinstance(prg, StatRun):
197 statrun_list.append(prg)
198 else:
199 raise type_error(prg, "progress plot element",
200 (Progress, StatRun))
202 # Validate that we have consistent time and objective units.
203 if x_dim is None:
204 x_dim = prg.time_unit
205 elif x_dim != prg.time_unit:
206 raise ValueError(
207 f"Time units {x_dim} and {prg.time_unit} do not fit!")
209 if y_dim is None:
210 y_dim = prg.f_name
211 elif y_dim != prg.f_name:
212 raise ValueError(
213 f"F-units {y_dim} and {prg.f_name} do not fit!")
214 del progresses
216 if (len(progress_list) + len(statrun_list)) <= 0:
217 raise ValueError("Empty input data?")
219 if (x_dim is None) or (y_dim is None):
220 raise ValueError("Illegal state?")
222 instances.finalize()
223 algorithms.finalize()
224 statistics.finalize()
226 # pick the right sorting order
227 sf: Callable[[StatRun | Progress], Any] = sort_key
228 if (instances.count > 1) and (algorithms.count == 1) \
229 and (statistics.count == 1):
230 def __x1(r: StatRun | Progress, ssf=instance_sort_key) -> Any:
231 return ssf(r.instance)
232 sf = __x1
233 elif (instances.count == 1) and (algorithms.count > 1) \
234 and (statistics.count == 1):
235 def __x2(r: StatRun | Progress, ssf=algorithm_sort_key) -> Any:
236 return ssf(r.algorithm)
237 sf = __x2
238 elif (instances.count == 1) and (algorithms.count == 1) \
239 and (statistics.count > 1):
240 def __x3(r: StatRun | Progress, ssf=stat_sort_key) -> Any:
241 return ssf(r.instance)
242 sf = __x3
243 elif (instances.count > 1) and (algorithms.count > 1):
244 def __x4(r: StatRun | Progress, sas=algorithm_sort_key,
245 ias=instance_sort_key,
246 ag=algorithm_priority > instance_priority) \
247 -> tuple[Any, Any]:
248 k1 = ias(r.instance)
249 k2 = sas(r.algorithm)
250 return (k2, k1) if ag else (k1, k2)
251 sf = __x4
253 statrun_list.sort(key=sf)
254 progress_list.sort()
256 def __set_importance(st: Styler) -> None:
257 if st is statistics:
258 none = -1
259 not_none = 1
260 else:
261 none = 1
262 not_none = 0
263 none_lw = importance_to_line_width_func(none)
264 not_none_lw = importance_to_line_width_func(not_none)
265 st.set_line_width(lambda x: [none_lw if i <= 0 else not_none_lw
266 for i in range(x)])
267 none_a = importance_to_alpha_func(none)
268 not_none_a = importance_to_alpha_func(not_none)
269 st.set_line_alpha(lambda x: [none_a if i <= 0 else not_none_a
270 for i in range(x)])
272 # determine the style groups
273 groups: list[Styler] = []
275 no_importance = True
276 if instances.count > 1:
277 groups.append(instances)
278 if algorithms.count > 1:
279 groups.append(algorithms)
280 add_stat_to_groups = False
281 if statistics.count > 1:
282 if statistics.has_none and (statistics.count == 2):
283 __set_importance(statistics)
284 no_importance = False
285 add_stat_to_groups = True
286 else:
287 groups.append(statistics)
289 if len(groups) > 0:
290 groups.sort()
291 groups[0].set_line_color(distinct_colors_func)
293 if len(groups) > 1:
294 groups[1].set_line_dash(distinct_line_dashes_func)
296 if (len(groups) > 2) and no_importance:
297 g = groups[2]
298 if g.count > 2:
299 raise ValueError(
300 f"Cannot have {g.count} importance values.")
301 __set_importance(g)
302 no_importance = False
303 elif color_algorithms_as_fallback_group:
304 algorithms.set_line_color(distinct_colors_func)
305 groups.append(algorithms)
307 if add_stat_to_groups:
308 groups.append(statistics)
310 # If we only have <= 2 groups, we can mark None and not-None values with
311 # different importance.
312 if no_importance and statistics.has_none and (statistics.count > 1):
313 __set_importance(statistics)
314 no_importance = False
315 if no_importance and instances.has_none and (instances.count > 1):
316 __set_importance(instances)
317 no_importance = False
318 if no_importance and algorithms.has_none and (algorithms.count > 1):
319 __set_importance(algorithms)
321 # we will collect all lines to plot in plot_list
322 plot_list: list[dict] = []
324 # first we collect all progress object
325 for prgs in progress_list:
326 style = pd.create_line_style()
327 for g in groups:
328 g.add_line_style(prgs, style)
329 style["x"] = prgs.time
330 style["y"] = prgs.f
331 plot_list.append(style)
332 del progress_list
334 # now collect the plot data for the statistics
335 for sn in statistics.keys:
336 if sn is None:
337 continue
338 for sr in statrun_list:
339 if statistics.key_func(sr) != sn:
340 continue
342 style = pd.create_line_style()
343 for g in groups:
344 g.add_line_style(sr, style)
345 style["x"] = sr.stat[:, 0]
346 style["y"] = sr.stat[:, 1]
347 plot_list.append(style)
348 del statrun_list
350 font_size_0: Final[float] = importance_to_font_size_func(0)
352 # set up the graphics area
353 axes: Final[Axes] = pu.get_axes(figure)
354 axes.tick_params(axis="x", labelsize=font_size_0)
355 axes.tick_params(axis="y", labelsize=font_size_0)
357 # draw the grid
358 if x_grid or y_grid:
359 grid_lwd = importance_to_line_width_func(-1)
360 if x_grid:
361 axes.grid(axis="x", color=pd.GRID_COLOR, linewidth=grid_lwd)
362 if y_grid:
363 axes.grid(axis="y", color=pd.GRID_COLOR, linewidth=grid_lwd)
365 # set up the axis rangers
366 if callable(x_axis):
367 x_axis = x_axis(x_dim)
368 if not isinstance(x_axis, AxisRanger):
369 raise type_error(x_axis, "x_axis", AxisRanger)
371 if callable(y_axis):
372 y_axis = y_axis(y_dim)
373 if not isinstance(y_axis, AxisRanger):
374 raise type_error(y_axis, "y_axis", AxisRanger)
376 # plot the lines
377 for line in plot_list:
378 axes.step(where="post", **line)
379 x_axis.register_array(line["x"])
380 y_axis.register_array(line["y"])
381 del plot_list
383 x_axis.apply(axes, "x")
384 y_axis.apply(axes, "y")
386 if legend:
387 handles: list[Artist] = []
389 for g in groups:
390 g.add_to_legend(handles.append)
391 g.has_style = False
393 if instances.has_style:
394 instances.add_to_legend(handles.append)
395 if algorithms.has_style:
396 algorithms.add_to_legend(handles.append)
397 if statistics.has_style:
398 statistics.add_to_legend(handles.append)
400 if len(handles) > 0:
401 axes.legend(loc="upper right",
402 handles=handles,
403 labelcolor=[art.color if hasattr(art, "color")
404 else pd.COLOR_BLACK for art in handles],
405 fontsize=font_size_0)
407 pu.label_axes(axes=axes,
408 x_label=x_label(x_dim) if callable(x_label) else x_label,
409 x_label_inside=x_label_inside,
410 x_label_location=x_label_location,
411 y_label=y_label(y_dim) if callable(y_label) else y_label,
412 y_label_inside=y_label_inside,
413 y_label_location=y_label_location,
414 font_size=font_size_0)
415 return axes