Coverage for moptipy / evaluation / stat_run.py: 77%
222 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-11-24 08:49 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-11-24 08:49 +0000
1"""Statistic runs are time-depending statistics over several runs."""
2from dataclasses import dataclass
3from math import erf, sqrt
4from typing import Any, Callable, Final, Iterable
6import numba # type: ignore
7import numpy as np
8from pycommons.math.sample_statistics import (
9 KEY_MAXIMUM,
10 KEY_MEAN_ARITH,
11 KEY_MEAN_GEOM,
12 KEY_MEDIAN,
13 KEY_MINIMUM,
14 KEY_STDDEV,
15)
16from pycommons.types import type_error
18from moptipy.evaluation.base import MultiRun2DData, MultiRunData, PerRunData
19from moptipy.evaluation.progress import Progress
20from moptipy.utils.nputils import DEFAULT_FLOAT, DEFAULT_INT, is_np_float
22#: The value of the CDF of the standard normal distribution CDF at -1,
23#: which corresponds to "mean - 1 * sd".
24__Q159: Final[float] = (1.0 + erf(-1.0 / sqrt(2.0))) / 2.0
26#: The value of the CDF of the standard normal distribution CDF at +1,
27#: which corresponds to "mean + 1 * sd".
28__Q841: Final[float] = (1.0 + erf(1.0 / sqrt(2.0))) / 2.0
31def __unique_floats_1d(data: list[np.ndarray]) -> np.ndarray:
32 """
33 Get all unique values that are >= than the minimum of all arrays.
35 :param data: the data
36 :return: the `ndarray` with the sorted, unique values
37 """
38 res: np.ndarray = np.unique(np.concatenate(data).astype(DEFAULT_FLOAT))
39 mini = res[0] # old version: int = -9223372036854775808
40 for d in data:
41 mini2 = d[0]
42 if d[0] > mini:
43 mini = mini2
44 i: Final[int] = int(np.searchsorted(res, mini))
45 if i > 0:
46 return res[i:]
47 return res
50def __apply_fun(x_unique: np.ndarray,
51 x_raw: list[np.ndarray],
52 y_raw: list[np.ndarray],
53 stat_func: Callable,
54 out_len: int,
55 dest_y: np.ndarray,
56 stat_dim: int,
57 values_buf: np.ndarray,
58 pos_buf: np.ndarray) -> np.ndarray:
59 """
60 Perform the work of computing the time-depending statistic.
62 The unique x-values `x_unique` have separately been computed with
63 :func:`_unique_floats_1d` from `x_raw` so that they can be reused.
64 `x_raw` and `y_raw` are lists with the raw time and objective data,
65 respectively. `stat_fun` is the statistic function that will be applied to
66 the step-wise generated data filled into `values_buf`. `pos_buf` will be
67 used maintain the current indices into `x_raw` and `y_raw`. `dest_y` will
68 be filled with the computed statistic for each element of `x_unique`.
69 In a final step, we will remove all redundant elements of both arrays: If
70 `x_unique` increases but `dest_y` remains the same, then the corresponding
71 point is deleted if it is not the last point in the list. As a result,
72 a two-dimensional time/value array is returned.
74 :param x_unique: the unique time coordinates
75 :param x_raw: a tuple of several x-data arrays
76 :param y_raw: a tuple of several y-data arrays
77 :param Callable stat_func: a statistic function which must have been
78 jitted with numba
79 :param out_len: the length of `dest_y` and `x_unique`
80 :param dest_y: the destination array for the computed statistics
81 :param stat_dim: the dimension of the tuples `x_raw` and `y_raw`
82 :param values_buf: the buffer for the values to be passed to `stat_func`
83 :param pos_buf: the position buffer
84 :return: the two-dimensional `np.ndarray` where the first column is the
85 time and the second column is the statistic value
86 """
87 for i in range(out_len - 1, -1, -1): # reverse iteration
88 x = x_unique[i] # x_unique holds all unique x values
89 for j in range(stat_dim): # for all Progress datasets do
90 idx = pos_buf[j] # get the current position
91 if x < x_raw[j][idx]: # if x < then current time value
92 idx -= 1 # step back by one
93 pos_buf[j] = idx # now x >= x_raw[j][idx]
94 values_buf[j] = y_raw[j][idx]
95 dest_y[i] = stat_func(values_buf)
97 changes = 1 + np.flatnonzero(dest_y[1:] != dest_y[:-1])
98 dest_len = len(dest_y) - 1
99 changes_len = len(changes)
100 if changes_len < 2: # strange corner case: all values are the same
101 # if there is only one value, use only that value
102 # otherwise, use first and last value
103 indexes = np.array([0]) if dest_len <= 1 else np.array([0, dest_len])
104 elif changes[-1] != dest_len: # always put last point
105 indexes = np.concatenate((np.array([0]), changes,
106 np.array([dest_len])))
107 else:
108 indexes = np.concatenate((np.array([0]), changes))
109 return np.column_stack((x_unique[indexes], dest_y[indexes]))
112def __do_apply_fun(x_unique: np.ndarray, x_raw: list[np.ndarray],
113 y_raw: list[np.ndarray],
114 stat_func: Callable[[np.ndarray], float]) -> np.ndarray:
115 """
116 Compute a time-depending statistic.
118 The unique x-values `x_unique` have separate been computed with
119 `_unique_floats_1d` from `x_raw` so that they can be reused.
120 `x_raw` and `y_raw` are tuples with the raw time and objective data,
121 respectively. `stat_fun` is the statistic function that will be applied.
122 In a final step, we will remove all redundant elements of both arrays: If
123 `x_unique` increases but `dest_y` remains the same, then the corresponding
124 point is deleted if it is not the last point in the list. As a result,
125 a two-dimensional time/value array is returned. This function uses
126 :meth:`__apply_fun` as internal work horse.
128 :param x_unique: the unique time coordinates
129 :param x_raw: a tuple of several x-data arrays
130 :param y_raw: a tuple of several y-data arrays
131 :param stat_func: a statistic function which must have been jitted with
132 numba
133 :return: the two-dimensional `numpy.ndarray` where the first column is the
134 time and the second column is the statistic value
135 """
136 out_len: Final[int] = len(x_unique)
137 dest_y: Final[np.ndarray] = np.zeros(out_len, DEFAULT_FLOAT)
138 stat_dim: Final[int] = len(x_raw)
139 values: Final[np.ndarray] = np.zeros(stat_dim, DEFAULT_FLOAT)
140 pos: Final[np.ndarray] = np.array([len(x) - 1 for x in x_raw], DEFAULT_INT)
142 return __apply_fun(x_unique, x_raw, y_raw, stat_func, out_len,
143 dest_y, stat_dim, values, pos)
146@numba.njit(cache=True, inline="always", fastmath=False, boundscheck=False,
147 parallel=True)
148def __stat_arith_mean(data: np.ndarray) -> np.number:
149 """
150 Compute the arithmetic mean.
152 :param data: the data
153 :return: the arithmetic mean
154 """
155 return data.mean()
158@numba.njit(cache=True, inline="always", fastmath=False, boundscheck=False,
159 parallel=True)
160def __stat_geo_mean(data: np.ndarray) -> np.number:
161 """
162 Compute the geometric mean.
164 :param data: the data
165 :return: the geometric mean
166 """
167 return np.exp(np.mean(np.log(data)))
170@numba.njit(cache=True, inline="always", fastmath=False, boundscheck=False,
171 parallel=True)
172def __stat_min(data: np.ndarray) -> np.number:
173 """
174 Compute the minimum.
176 :param data: the data
177 :return: the minimum
178 """
179 return data.min()
182@numba.njit(cache=True, inline="always", fastmath=False, boundscheck=False,
183 parallel=True)
184def __stat_max(data: np.ndarray) -> np.number:
185 """
186 Compute the maximum.
188 :param data: the data
189 :return: the maximum
190 """
191 return data.max()
194@numba.njit(cache=True, inline="always", fastmath=False, boundscheck=False)
195def __stat_median(data: np.ndarray) -> np.ndarray:
196 """
197 Compute the median.
199 :param data: the data
200 :return: the median
201 """
202 return np.median(data)
205@numba.njit(cache=True, inline="always", fastmath=False, boundscheck=False,
206 parallel=True)
207def __stat_sd(data: np.ndarray) -> np.number:
208 """
209 Compute the standard deviation.
211 :param data: the data
212 :return: the standard deviation
213 """
214 return data.std()
217@numba.njit(cache=True, inline="always", fastmath=False, boundscheck=False,
218 parallel=True)
219def __stat_mean_minus_sd(data: np.ndarray) -> np.number:
220 """
221 Compute the arithmetic mean minus the standard deviation.
223 :param data: the data
224 :return: the arithmetic mean minus the standard deviation
225 """
226 return data.mean() - data.std()
229@numba.njit(cache=True, inline="always", fastmath=False, boundscheck=False,
230 parallel=True)
231def __stat_mean_plus_sd(data: np.ndarray) -> np.number:
232 """
233 Compute the arithmetic mean plus the standard deviation.
235 :param data: the data
236 :return: the arithmetic mean plus the standard deviation
237 """
238 return data.mean() + data.std()
241@numba.njit(cache=True, inline="always", fastmath=False, boundscheck=False)
242def __stat_quantile_10(data: np.ndarray) -> np.ndarray:
243 """
244 Compute the 10% quantile.
246 :param data: the data
247 :return: the 10% quantile
248 """
249 length: Final[int] = len(data)
250 if (length > 10) and ((length % 10) == 1):
251 data.sort()
252 return data[(length - 1) // 10]
253 return np.quantile(data, 0.1)
256@numba.njit(cache=True, inline="always", fastmath=False, boundscheck=False)
257def __stat_quantile_90(data: np.ndarray) -> np.ndarray:
258 """
259 Compute the 90% quantile.
261 :param data: the data
262 :return: the 90% quantile
263 """
264 length: Final[int] = len(data)
265 if (length > 10) and ((length % 10) == 1):
266 data.sort()
267 return data[(9 * (length - 1)) // 10]
268 return np.quantile(data, 0.9)
271@numba.njit(cache=True, inline="always", fastmath=False, boundscheck=False)
272def __stat_quantile_159(data: np.ndarray) -> np.ndarray:
273 """
274 Compute the 15.9% quantile, which equals mean-sd in normal distributions.
276 :param data: the data
277 :return: the 15.9% quantile
278 """
279 return np.quantile(data, __Q159)
282@numba.njit(cache=True, inline="always", fastmath=False, boundscheck=False)
283def __stat_quantile_841(data: np.ndarray) -> np.ndarray:
284 """
285 Compute the 84.1% quantile, which equals mean+sd in normal distributions.
287 :param data: the data
288 :return: the 84.1% quantile
289 """
290 return np.quantile(data, __Q841)
293#: The statistics key for the minimum
294STAT_MINIMUM: Final[str] = KEY_MINIMUM
295#: The statistics key for the median.
296STAT_MEDIAN: Final[str] = KEY_MEDIAN
297#: The statistics key for the arithmetic mean.
298STAT_MEAN_ARITH: Final[str] = KEY_MEAN_ARITH
299#: The statistics key for the geometric mean.
300STAT_MEAN_GEOM: Final[str] = KEY_MEAN_GEOM
301#: The statistics key for the maximum
302STAT_MAXIMUM: Final[str] = KEY_MAXIMUM
303#: The statistics key for the standard deviation
304STAT_STDDEV: Final[str] = KEY_STDDEV
305#: The key for the arithmetic mean minus the standard deviation.
306STAT_MEAN_MINUS_STDDEV: Final[str] = f"{STAT_MEAN_ARITH}-{STAT_STDDEV}"
307#: The key for the arithmetic mean plus the standard deviation.
308STAT_MEAN_PLUS_STDDEV: Final[str] = f"{STAT_MEAN_ARITH}+{STAT_STDDEV}"
309#: The key for the 10% quantile.
310STAT_Q10: Final[str] = "q10"
311#: The key for the 90% quantile.
312STAT_Q90: Final[str] = "q90"
313#: The key for the 15.9% quantile. In a normal distribution, this quantile
314#: is where "mean - standard deviation" is located-
315STAT_Q159: Final[str] = "q159"
316#: The key for the 84.1% quantile. In a normal distribution, this quantile
317#: is where "mean + standard deviation" is located-
318STAT_Q841: Final[str] = "q841"
320#: The internal function map.
321_FUNC_MAP: Final[dict[str, Callable[[np.ndarray], float]]] = {
322 STAT_MINIMUM: __stat_min,
323 STAT_MEDIAN: __stat_median,
324 STAT_MEAN_ARITH: __stat_arith_mean,
325 STAT_MEAN_GEOM: __stat_geo_mean,
326 STAT_MAXIMUM: __stat_max,
327 STAT_STDDEV: __stat_sd,
328 STAT_MEAN_MINUS_STDDEV: __stat_mean_minus_sd,
329 STAT_MEAN_PLUS_STDDEV: __stat_mean_plus_sd,
330 STAT_Q10: __stat_quantile_10,
331 STAT_Q90: __stat_quantile_90,
332 STAT_Q159: __stat_quantile_159,
333 STAT_Q841: __stat_quantile_841,
334}
337@dataclass(frozen=True, init=False, order=False, eq=False)
338class StatRun(MultiRun2DData):
339 """A time-value statistic over a set of runs."""
341 #: The name of this statistic.
342 stat_name: str
343 #: The time-dependent statistic.
344 stat: np.ndarray
346 def __init__(self,
347 algorithm: str | None,
348 instance: str | None,
349 objective: str | None,
350 encoding: str | None,
351 n: int,
352 time_unit: str,
353 f_name: str,
354 stat_name: str,
355 stat: np.ndarray):
356 """
357 Create the time-based statistics of an algorithm-setup combination.
359 :param algorithm: the algorithm name, if all runs are
360 with the same algorithm
361 :param instance: the instance name, if all runs are
362 on the same instance
363 :param objective: the objective name, if all runs are on the same
364 objective function, `None` otherwise
365 :param encoding: the encoding name, if all runs are on the same
366 encoding and an encoding was actually used, `None` otherwise
367 :param n: the total number of runs
368 :param time_unit: the time unit
369 :param f_name: the objective dimension name
370 :param stat_name: the name of the statistic
371 :param stat: the statistic itself
372 """
373 super().__init__(algorithm, instance, objective, encoding, n,
374 time_unit, f_name)
376 if not isinstance(stat_name, str):
377 raise type_error(stat_name, "stat_name", str)
378 object.__setattr__(self, "stat_name", stat_name)
379 if not isinstance(stat, np.ndarray):
380 raise type_error(stat, "statistic data", np.ndarray)
381 stat.flags.writeable = False
382 if (len(stat.shape) != 2) or (stat.shape[1] != 2) or \
383 (stat.shape[0] <= 0):
384 raise ValueError(
385 "time array must be two-dimensional and have two columns and "
386 f"at least one row, but has shape {stat.shape}.")
387 if not is_np_float(stat.dtype):
388 raise ValueError("statistics array must be float-typed, but has "
389 f"dtype {stat.dtype}.")
390 object.__setattr__(self, "stat", stat)
393def create(source: Iterable[Progress],
394 statistics: str | Iterable[str],
395 consumer: Callable[[StatRun], Any]) -> None:
396 """
397 Compute statistics from an iterable of `Progress` objects.
399 :param source: the progress data
400 :param statistics: the statistics to be computed
401 :param consumer: the consumer for the statistics
402 """
403 if not isinstance(source, Iterable):
404 raise type_error(source, "source", Iterable)
405 if isinstance(statistics, str):
406 statistics = [statistics]
407 if not isinstance(statistics, Iterable):
408 raise type_error(statistics, "statistics", Iterable)
409 if not callable(consumer):
410 raise type_error(consumer, "consumer", call=True)
412 algorithm: str | None = None
413 instance: str | None = None
414 objective: str | None = None
415 encoding: str | None = None
416 time_unit: str | None = None
417 f_name: str | None = None
418 time: list[np.ndarray] = []
419 f: list[np.ndarray] = []
420 n: int = 0
422 for progress in source:
423 if not isinstance(progress, Progress):
424 raise type_error(progress, "stat run data source", Progress)
425 if n <= 0:
426 algorithm = progress.algorithm
427 instance = progress.instance
428 objective = progress.objective
429 encoding = progress.encoding
430 time_unit = progress.time_unit
431 f_name = progress.f_name
432 else:
433 if algorithm != progress.algorithm:
434 algorithm = None
435 if instance != progress.instance:
436 instance = None
437 if objective != progress.objective:
438 objective = None
439 if encoding != progress.encoding:
440 encoding = None
441 if time_unit != progress.time_unit:
442 raise ValueError(
443 f"Cannot mix time units {time_unit} "
444 f"and {progress.time_unit}.")
445 if f_name != progress.f_name:
446 raise ValueError(f"Cannot mix f-names {f_name} "
447 f"and {progress.f_name}.")
448 n += 1
449 time.append(progress.time)
450 f.append(progress.f)
452 if n <= 0:
453 raise ValueError("Did not encounter any progress information.")
455 x_unique: Final[np.ndarray] = __unique_floats_1d(time)
456 if not isinstance(x_unique, np.ndarray):
457 raise type_error(x_unique, "x_unique", np.ndarray)
458 if not is_np_float(x_unique.dtype):
459 raise TypeError(
460 f"x_unique must be floats, but is {x_unique.dtype}.")
461 if (len(x_unique.shape) != 1) or (x_unique.shape[0] <= 0):
462 raise ValueError(
463 f"Invalid shape of unique values {x_unique.shape}.")
465 count = 0
466 for name in statistics:
467 if not isinstance(name, str):
468 raise type_error(name, "statistic name", str)
469 if name not in _FUNC_MAP:
470 raise ValueError(f"Unknown statistic name {name!r}.")
471 consumer(StatRun(algorithm, instance, objective, encoding, n,
472 time_unit, f_name, name,
473 __do_apply_fun(x_unique, time, f, _FUNC_MAP[name])))
474 count += 1
476 if count <= 0:
477 raise ValueError("No statistic names provided.")
480def from_progress(source: Iterable[Progress],
481 statistics: str | Iterable[str],
482 consumer: Callable[[StatRun], Any],
483 join_all_algorithms: bool = False,
484 join_all_instances: bool = False,
485 join_all_objectives: bool = False,
486 join_all_encodings: bool = False) -> None:
487 """
488 Aggregate statist runs over a stream of progress data.
490 :param source: the stream of progress data
491 :param statistics: the statistics that should be computed per group
492 :param consumer: the destination to which the new stat runs will be
493 passed, can be the `append` method of a :class:`list`
494 :param join_all_algorithms: should the statistics be aggregated
495 over all algorithms
496 :param join_all_instances: should the statistics be aggregated
497 over all algorithms
498 :param join_all_objectives: should the statistics be aggregated over
499 all objective functions?
500 :param join_all_encodings: should the statistics be aggregated over
501 all encodings?
502 """
503 if not isinstance(source, Iterable):
504 raise type_error(source, "source", Iterable)
505 if isinstance(statistics, str):
506 statistics = [statistics]
507 if not isinstance(statistics, Iterable):
508 raise type_error(statistics, "statistics", Iterable)
509 if not callable(consumer):
510 raise type_error(consumer, "consumer", call=True)
511 if not isinstance(join_all_algorithms, bool):
512 raise type_error(join_all_algorithms, "join_all_algorithms", bool)
513 if not isinstance(join_all_instances, bool):
514 raise type_error(join_all_instances, "join_all_instances", bool)
515 if not isinstance(join_all_objectives, bool):
516 raise type_error(join_all_objectives, "join_all_objectives", bool)
517 if not isinstance(join_all_encodings, bool):
518 raise type_error(join_all_encodings, "join_all_encodings", bool)
520 sorter: dict[tuple[str, str, str, str, str, str], list[Progress]] = {}
521 for prog in source:
522 if not isinstance(prog, Progress):
523 raise type_error(prog, "progress source", Progress)
524 key = ("" if join_all_algorithms else prog.algorithm,
525 "" if join_all_instances else prog.instance,
526 "" if join_all_objectives else prog.objective,
527 "" if join_all_encodings else (
528 "" if prog.encoding is None else prog.encoding),
529 prog.time_unit, prog.f_name)
531 if key in sorter:
532 lst = sorter[key]
533 else:
534 lst = []
535 sorter[key] = lst
536 lst.append(prog)
538 if len(sorter) <= 0:
539 raise ValueError("source must not be empty")
541 if len(sorter) > 1:
542 keys = list(sorter.keys())
543 keys.sort()
544 for key in keys:
545 create(sorter[key], statistics, consumer)
546 else:
547 create(next(iter(sorter.values())), statistics, consumer)
550def get_statistic(obj: PerRunData | MultiRunData) -> str | None:
551 """
552 Get the statistic of a given object.
554 :param obj: the object
555 :return: the statistic string, or `None` if no statistic is specified
556 """
557 return obj.stat_name if isinstance(obj, StatRun) else None