Coverage for moptipyapps/utils/sampling.py: 87%
277 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-05-28 09:42 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-05-28 09:42 +0000
1"""
2Some utilities for random sampling.
4The goal that we follow with class :class:`Distribution` is to have
5clearly defined integer-producing random distributions. We want to be
6able to say exactly how to generate some random numbers.
8A distribution can be sampled using the method :meth:`~Distribution.sample`.
9Each distribution has a mean value, which either may be an exact value or
10an approximate result, that can be obtained via :meth:`~Distribution.mean`.
11Sometimes, distributions can be simplified, which is supported by
12:meth:`~Distribution.simplify`.
14>>> from moptipy.utils.nputils import rand_generator
15>>> from statistics import mean
16>>> rnd = rand_generator(0)
18>>> const = Const(12.3)
19>>> const
20Const(v=12.3)
21>>> const.mean()
2212.3
23>>> const.sample(rnd)
2412.3
26>>> normal = Normal(1, 2.0)
27>>> normal
28Normal(mu=1, sd=2)
29>>> x = [normal.sample(rnd) for _ in range(200)]
30>>> x[:20]
31[1.2514604421867865, 0.7357902734173962, 2.280845300886564,\
32 1.2098002343060794, -0.07133874632222192, 1.7231901098189695,\
33 3.6080000902602745, 2.8941619262584846, -0.4074704716139852,\
34 -1.530842942092105, -0.24654892507470438, 1.0826519586944872,\
35 -3.6500615492776687, 0.5624166721349085, -1.4918218945061303,\
36 -0.4645347094069032, -0.08851796571461978, 0.3673996872616909,\
37 1.8232610727482657, 3.085026738885355]
38>>> mean(x) / normal.mean()
391.0305262793198813
41>>> exponential = Exponential(3)
42>>> x = [exponential.sample(rnd) for _ in range(200)]
43>>> x[:20]
44[1.3203895772033505, 1.152983246827425, 4.527545171626064,\
45 5.080441711712409, 0.6498200242245252, 3.408652958826374,\
46 2.842677620245357, 2.1132194336022, 3.5838508479713393,\
47 1.2223825336486978, 2.8454397976498504, 0.8653980789905962,\
48 0.19572166348792452, 3.2506725793229854, 0.9426446058446336,\
49 3.246902386754473, 10.294412603282666, 5.275683923067543,\
50 0.49517363091492944, 0.2982336402218482]
51>>> mean(x) / exponential.mean()
521.088857332669081
54>>> gamma = Gamma(3.0, 0.26)
55>>> gamma
56Gamma(k=3, theta=0.26)
57>>> x = [gamma.sample(rnd) for _ in range(200)]
58>>> x[:20]
59[1.1617763743292708, 1.305755109480284, 0.7627948403389954,\
60 1.2735522637897285, 0.7742951665621697, 1.074233520618276,\
61 0.6324661100546898, 1.4627037699922791, 0.5739033567160827,\
62 0.5555065636904546, 0.629236234283296, 0.3666171387296996,\
63 0.3780976936750937, 0.9511433672028858, 1.2607313263258062,\
64 1.4442096466925938, 0.48758642085808085, 1.247724803721524,\
65 1.9359140456080306, 1.3935246884396764]
66>>> mean(x) / gamma.mean()
670.962254253258444
69>>> Gamma.from_alpha_beta(2, 0.5)
70Erlang(k=2, theta=2)
72>>> Gamma.from_alpha_beta(2.5, 0.5)
73Gamma(k=2.5, theta=2)
75>>> Erlang.from_alpha_beta(2, 0.5)
76Erlang(k=2, theta=2)
78>>> erlang = Erlang(1.0, 0.26)
79>>> erlang
80Erlang(k=1, theta=0.26)
81>>> x = [erlang.sample(rnd) for _ in range(200)]
82>>> x[:20]
83[0.29911329228981776, 0.3630768060267626, 0.14111385731543394,\
84 0.0745673280234536, 0.029950507989979877, 0.04741877104350835,\
85 0.38599089026561223, 0.047919114170390194, 0.06921557868837301,\
86 0.4066084140331242, 0.07170887998378667, 0.022061870233843223,\
87 0.04904717644388396, 0.32082097064821674, 0.001884448999141546,\
88 0.6687964040577958, 0.060598863807579915, 0.21491377996577304,\
89 0.23088301776258766, 0.23667780086315618]
90>>> erlang.mean()
910.26
92>>> mean(x)
930.23951704283832467
94>>> mean(x) / erlang.mean()
950.9212193955320179
97>>> uniform = Uniform(10.5, 17)
98>>> uniform
99Uniform(low=10.5, high=17)
100>>> x = [uniform.sample(rnd) for _ in range(200)]
101>>> x[:20]
102[11.235929257903324, 13.89459856391694, 14.196833332199233,\
103 13.871456027849515, 14.485310406702393, 16.20468251006836,\
104 13.7773314921396, 12.964459911549145, 12.167722566772781,\
105 12.49450317595701, 14.145245846833722, 15.669920217581366,\
106 13.367286694558258, 10.764955127505994, 11.723004274328007,\
107 11.089239517262232, 12.666732191193558, 14.948448277461127,\
108 14.339645757381653, 14.803829334728565]
109>>> mean(x) / uniform.mean()
1100.9968578545459643
112>>> choice = Choice((Const(2), gamma, normal))
113>>> choice
114Choice(ch=(Const(v=2), Gamma(k=3, theta=0.26), Normal(mu=1, sd=2)))
115>>> x = [choice.sample(rnd) for _ in range(200)]
116>>> x[:20]
117[0.5499364190576, -0.10428920251005325, 2, 2, 0.44263544084840273,\
118 0.6088189450771303, 2, 2, 2, 2, -0.32003290715104904, 2,\
119 0.6165299227577784, 1.0445083345086352, 2, 2, 2, -0.5970322857539738,\
120 -1.6672705710198277, 2]
121>>> mean(x) / choice.mean()
1220.9849730453309948
124>>> lower = AtLeast(2, normal)
125>>> lower
126AtLeast(lb=2, d=Normal(mu=1, sd=2))
127>>> x = [lower.sample(rnd) for _ in range(200)]
128>>> x[:20]
129[2.9715155848474772, 3.407503493034132, 2.59634461519561,\
130 2.0472458472897714, 3.9865670827840334, 2.015117730344058,\
131 2.0316584714999935, 5.36625737470408, 3.042848158343226,\
132 3.390407444638451, 4.5503184215686, 3.7682459882073007,\
133 2.0539760253651305, 2.134886147372958, 2.500182239395479,\
134 2.891402111997337, 2.7393826524907228, 2.3449577842364766,\
135 2.9043074694017195, 4.7173482582723825]
136>>> mean(x) / lower.mean()
1371.0178179732075965
139>>> interval = In(1, 10, gamma)
140>>> interval
141In(lb=1, ub=10, d=Gamma(k=3, theta=0.26))
142>>> x = [interval.sample(rnd) for _ in range(200)]
143>>> x[:20]
144[1.3488423972739083, 1.0929631399361601, 1.681162621901135,\
145 1.655614926246918, 1.041948891842002, 2.0773395958990175,\
146 1.3338891374921853, 1.2478964188175743, 1.9417070894505217,\
147 1.3990572178987266, 1.1216118870312337, 1.0641160239253207,\
148 2.131253600219639, 1.0337453883221577, 1.396499618416345,\
149 1.9865175145136038, 1.2555473269031396, 1.4412027583435465,\
150 1.320740351247919, 1.0407411942999665]
151>>> mean(x) / interval.mean()
1521.0256329849020747
154>>> erlang2 = Gamma.from_k_and_mean(3, 10)
155>>> erlang2
156Erlang(k=3, theta=3.3333333333333335)
157>>> erlang2.mean()
15810
159>>> x = [erlang2.sample(rnd) for _ in range(200)]
160>>> x[:20]
161[10.087506399523226, 12.928131914870168, 12.330250639007767,\
162 5.305123692562998, 21.085037136404374, 6.6603691824173135,\
163 2.961302890492059, 9.810557147180853, 8.051620919921454,\
164 8.750329405836668, 3.9511445189935763, 5.570300668751883,\
165 16.70132947692463, 7.831425379483914, 11.154757962484842,\
166 8.78943102381046, 8.395847820234795, 16.42251602814587,\
167 17.1628992966332, 9.684008648356015]
168>>> mean(x) / erlang2.mean()
1691.022444624140316
171>>> AtLeast.greater_than_zero(Gamma(1, 0.5))
172AtLeast(lb=5e-324, d=Exponential(eta=0.5))
174>>> AtLeast.greater_than_zero(Gamma(2, 0.5))
175Erlang(k=2, theta=0.5)
177>>> AtLeast.greater_than_zero(Gamma(2.5, 0.5))
178Gamma(k=2.5, theta=0.5)
180>>> expo = erlang.simplify()
181>>> expo
182Exponential(eta=0.26)
183>>> x = [expo.sample(rnd) for _ in range(200)]
184>>> x[:20]
185[0.07166050776665364, 0.16168119850191573, 0.22421276308859897, \
1860.0159144585889376, 0.2814946917726687, 0.18627887507033897, \
1870.023154315537313133, 0.04580416404670023, 0.6022196871728142, \
1880.025107668179499498, 0.8833625673263216, 0.15071056549661394, \
1890.00971431623116772, 0.024629860397399155, 0.5601666963985541, \
1900.14266231845989083, 0.05929349838310316, 0.34738399195888336, \
1910.048260650616228924, 0.28230843185818594]
192>>> mean(x)
1930.23545459359353754
194>>> expo.mean()
1950.26
196>>> mean(x) / expo.mean()
1970.9055945907443751
198"""
200from dataclasses import dataclass
201from math import fsum, isfinite, nextafter
202from typing import Callable, Final, cast
204from moptipy.utils.nputils import rand_generator, rand_seeds_from_str
205from numpy.random import Generator
206from pycommons.math.int_math import try_int
207from pycommons.types import check_int_range, type_error
209#: the maximum number of trials during a sampling process
210_MAX_TRIALS: int = 1_000_000
212#: the smallest positive number
213_SMALLEST_POSITIVE_NUMBER: Final[float] = nextafter(0.0, 1.0)
216class Distribution:
217 """A base class for distributions."""
219 def sample(self, random: Generator) -> int | float:
220 """
221 Sample a random number following this distribution generator.
223 Each call to this function returns exactly one number.
225 :param random: the random number generator
226 :return: the number
227 """
228 raise NotImplementedError
230 def simplify(self) -> "Distribution":
231 """
232 Try to simplify this distribution.
234 Some distributions can trivially be simplified. For example, if you
235 have applied a range limit (:class:`In`) to a constant distribution
236 (class:`Const`), then this can be simplified to just the constant.
237 If such simplification is possible, this method returns the simplified
238 distribution. Otherwise, it just returns the distribution itself.
240 :returns: a simplified version of this distribution
241 """
242 if not isinstance(self, Distribution):
243 raise type_error(self, "self", Distribution)
244 return self
246 def mean(self) -> int | float:
247 """
248 Get the mean or approximate mean of the distribution.
250 Some distribution overwrite this method to produce an exact computed
251 expected value or mean. This default implementation just computes the
252 arithmetic mean of 10'000 samples of the distribution. This serves as
253 baseline approximation for any case where a closed form mathematical
254 definition of the expected value is not available.
256 :return: the mean or approximated mean
257 """
258 sample: Callable[[Generator], int | float] = self.sample
259 gen: Final[Generator] = rand_generator(rand_seeds_from_str(
260 repr(self), 1)[0])
261 return try_int(fsum(sample(gen) for _ in range(10_000)) / 10_000)
264@dataclass(order=True, frozen=True)
265class Const(Distribution):
266 """A constant value."""
268 #: the constant value
269 v: int | float
271 def __post_init__(self) -> None:
272 """Perform some basic sanity checks and cleanup."""
273 object.__setattr__(self, "v", try_int(self.v))
275 def sample(self, random: Generator) -> int | float:
276 """
277 Sample the constant integer.
279 :param random: the random number generator
280 :return: the integer
281 """
282 return self.v
284 def mean(self) -> int | float:
285 """
286 Get the mean of this distribution.
288 :return: the mean
289 """
290 return self.v
292 def simplify(self) -> "Distribution":
293 """
294 Simplify this constat.
296 :return: the simplified constant
298 >>> Const(1.5).simplify()
299 Const(v=1.5)
300 >>> Const(1.0).simplify()
301 IntConst(v=1)
302 """
303 return IntConst(self.v) if isinstance(self.v, int) else self
306@dataclass(order=True, frozen=True)
307class Normal(Distribution):
308 """A class representing a normal distribution."""
310 #: the expected value and center of the distribution
311 mu: int | float
312 #: the standard deviation of the distribution
313 sd: int | float
315 def __post_init__(self) -> None:
316 """Perform some basic sanity checks and cleanup."""
317 if not (isfinite(self.mu) and isfinite(self.sd) and (self.sd > 0)):
318 raise ValueError(f"Invalid parameters {self}.")
319 object.__setattr__(self, "mu", try_int(self.mu))
320 object.__setattr__(self, "sd", try_int(self.sd))
322 def sample(self, random: Generator) -> float:
323 """
324 Sample from the normal distribution.
326 :param random: the random number generator
327 :return: the result
328 """
329 return random.normal(self.mu, self.sd)
331 def mean(self) -> int | float:
332 """
333 Get the mean of this distribution.
335 :return: the mean
336 """
337 return self.mu
340@dataclass(order=True, frozen=True)
341class Exponential(Distribution):
342 """A class representing an exponential distribution."""
344 #: the exponential distribution parameter
345 eta: int | float
347 def __post_init__(self) -> None:
348 """Perform some basic sanity checks and cleanup."""
349 object.__setattr__(self, "eta", try_int(self.eta))
350 if self.eta <= 0:
351 raise ValueError(f"Invalid setup {self!r}.")
353 def sample(self, random: Generator) -> float:
354 """
355 Sample from the Exponential distribution.
357 :param random: the random number generator
358 :return: the result
359 """
360 return random.exponential(self.eta)
362 def mean(self) -> int | float:
363 """
364 Get the mean of this distribution.
366 :return: the mean
367 """
368 return try_int(self.eta)
371@dataclass(order=True, frozen=True)
372class Gamma(Distribution):
373 """
374 A class representing a Gamma distribution.
376 Here, `k` is the shape and `theta` is the scale parameter.
377 If you use a parameterization with `alpha` and `beta`, you need to create
378 the distribution using :meth:`~Gamma.from_alpha_beta` instead. The reason
379 is that `shape = 1/beta`, see
380 https://www.statlect.com/probability-distributions/gamma-distribution.
381 """
383 #: the shape parameter
384 k: int | float
385 #: the scale parameter
386 theta: int | float
388 def __post_init__(self) -> None:
389 """Perform some basic sanity checks and cleanup."""
390 object.__setattr__(self, "k", try_int(self.k))
391 object.__setattr__(self, "theta", try_int(self.theta))
392 if not (self.k > 0) and (self.theta > 0):
393 raise ValueError(f"Invalid parameters {self}.")
395 def sample(self, random: Generator) -> float:
396 """
397 Sample from the Gamma distribution.
399 :param random: the random number generator
400 :return: the result
401 """
402 return random.gamma(self.k, self.theta)
404 def simplify(self) -> "Distribution":
405 """
406 Try to simplify this distribution.
408 A Gamma distribution may simplify to either an :class:`Erlang` or an
409 :class:`Exponential` distribution, depending on its parameters.
410 If the :attr:`~Gamma.k` is `1`, then it is actually an
411 :class:`Exponential` distribution. If :attr:`~Gamma.k` is an integer,
412 then the distribution is an :class:`Erlang` distribution.
414 1. https://www.statisticshowto.com/gamma-distribution
415 2. https://www.statisticshowto.com/erlang-distribution
417 :returns: a simplified version of this distribution
418 """
419 return Exponential(self.theta) if self.k == 1 else (
420 Erlang(self.k, self.theta) if isinstance(
421 self.k, int) and not isinstance(self, Erlang) else self)
423 def mean(self) -> int | float:
424 """
425 Get the mean of this distribution.
427 :return: the mean
428 """
429 return try_int(self.k * self.theta)
431 @classmethod
432 def from_alpha_beta(cls, alpha: int | float, beta: int | float) \
433 -> "Distribution":
434 """
435 Create a Gamma distribution from `alpha` and `beta`.
437 :param alpha: the alpha parameter
438 :param beta: the beta parameter
439 :return: the distribution
441 >>> Gamma.from_alpha_beta(1, 1)
442 Exponential(eta=1)
444 >>> Gamma.from_alpha_beta(2, 1)
445 Erlang(k=2, theta=1)
447 >>> Gamma.from_alpha_beta(1, 2)
448 Exponential(eta=0.5)
450 >>> Gamma.from_alpha_beta(2, 2)
451 Erlang(k=2, theta=0.5)
453 >>> Gamma.from_alpha_beta(1.5, 1)
454 Gamma(k=1.5, theta=1)
456 >>> Gamma.from_alpha_beta(1, 1.5)
457 Exponential(eta=0.6666666666666666)
459 >>> Gamma.from_alpha_beta(1.5, 1.5)
460 Gamma(k=1.5, theta=0.6666666666666666)
461 """
462 beta = try_int(beta)
463 if beta == 0:
464 raise ValueError(f"beta cannot be {beta}.")
465 return cls(alpha, 1 / beta).simplify()
467 @classmethod
468 def from_k_and_mean(cls, k: int | float, mean: int | float) \
469 -> "Distribution":
470 """
471 Create the Gamma distribution from the value of `k` and a mean.
473 :param k: the shape parameter
474 :param mean: the mean
475 :return: the distribution
477 >>> Gamma.from_k_and_mean(1, 1)
478 Exponential(eta=1)
480 >>> Gamma.from_k_and_mean(1, 2)
481 Exponential(eta=2)
483 >>> Gamma.from_k_and_mean(2, 1)
484 Erlang(k=2, theta=0.5)
486 >>> Gamma.from_k_and_mean(2, 2)
487 Erlang(k=2, theta=1)
489 >>> Gamma.from_k_and_mean(1.5, 1)
490 Gamma(k=1.5, theta=0.6666666666666666)
492 >>> Gamma.from_k_and_mean(1, 1.5)
493 Exponential(eta=1.5)
494 """
495 k = try_int(k)
496 mean = try_int(mean)
497 if (mean <= 0) or (k <= 0):
498 raise ValueError(f"Invalid values k={k}, mean={mean}.")
499 return Gamma(k, mean / k).simplify()
502class Erlang(Gamma):
503 """The Erlang distribution."""
505 def __post_init__(self) -> None:
506 """Perform some basic sanity checks and cleanup."""
507 super().__post_init__()
508 if not isinstance(self.k, int):
509 raise type_error(self.k, "k", int)
512@dataclass(order=True, frozen=True)
513class Uniform(Distribution):
514 """A class representing a uniform distribution."""
516 #: the lowest permitted value
517 low: int | float
518 #: the highest permitted value
519 high: int | float
521 def __post_init__(self) -> None:
522 """Perform some basic sanity checks and cleanup."""
523 object.__setattr__(self, "low", try_int(self.low))
524 object.__setattr__(self, "high", try_int(self.high))
525 if self.high <= self.low:
526 raise ValueError(f"Invalid parameters {self}.")
528 def sample(self, random: Generator) -> float:
529 """
530 Sample from the uniform distribution.
532 :param random: the random number generator
533 :return: the result
534 """
535 return random.uniform(self.low, self.high)
537 def mean(self) -> int | float:
538 """
539 Get the mean of this distribution.
541 :return: the mean
542 """
543 return try_int((self.high + self.low) / 2)
546@dataclass(order=True, frozen=True)
547class Choice(Distribution):
548 """
549 A class representing a uniform choice.
551 >>> Choice((Uniform(1, 2), Uniform(3, 4))).simplify()
552 Choice(ch=(Uniform(low=1, high=2), Uniform(low=3, high=4)))
554 >>> Choice((Uniform(1, 2), Uniform(1.0, 2))).simplify()
555 Uniform(low=1, high=2)
557 >>> Choice((Uniform(1, 2), Choice(
558 ... (Const(1), Uniform(1, 2))))).simplify()
559 Choice(ch=(Uniform(low=1, high=2), IntConst(v=1), Uniform(low=1, high=2)))
560 """
562 #: the choices
563 ch: tuple[Distribution, ...]
565 def __post_init__(self) -> None:
566 """Perform some basic sanity checks and cleanup."""
567 if not isinstance(self.ch, tuple):
568 raise TypeError(f"Invalid types {self}.")
569 for v in self.ch:
570 if not isinstance(v, Distribution):
571 raise type_error(v, "choice", Distribution)
573 def mean(self) -> int | float:
574 """
575 Get the mean of this distribution.
577 :return: the mean
578 """
579 return try_int(fsum(d.mean() for d in self.ch) / tuple.__len__(
580 self.ch))
582 def sample(self, random: Generator) -> int | float:
583 """
584 Sample from the uniform distribution.
586 :param random: the random number generator
587 :return: the result
588 """
589 return self.ch[random.integers(tuple.__len__(self.ch))].sample(random)
591 def simplify(self) -> Distribution:
592 """
593 Try to simplify this distribution.
595 :returns: a simplified version of this distribution
596 """
597 ch: Final[tuple[Distribution, ...]] = self.ch
598 if tuple.__len__(ch) <= 1:
599 return ch[0].simplify()
600 done: list[Distribution] = []
601 needs: bool = False
602 for dist in ch:
603 use: Distribution = dist.simplify()
604 if use != dist:
605 needs = True
606 if isinstance(use, Choice):
607 needs = True
608 done.extend(use.ch)
609 else:
610 done.append(use)
612 total: int = list.__len__(done)
613 dc: Distribution = done[0]
614 if total <= 1:
615 return dc
617 all_same: bool = True
618 for oth in done:
619 if oth != dc:
620 all_same = False
621 break
622 return dc if all_same else (Choice(tuple(done)) if needs else self)
625@dataclass(order=True, frozen=True)
626class AtLeast(Distribution):
627 """
628 A distribution that is lower-bounded.
630 >>> AtLeast(5, Const(7))
631 AtLeast(lb=5, d=Const(v=7))
633 >>> AtLeast(5, AtLeast(8, Const(17)))
634 AtLeast(lb=8, d=Const(v=17))
636 >>> AtLeast(8, AtLeast(5, Const(17)))
637 AtLeast(lb=8, d=Const(v=17))
638 """
640 #: the inclusive lower bound
641 lb: int | float
642 #: the inner distribution to sample from
643 d: Distribution
645 def __post_init__(self) -> None:
646 """Perform some basic sanity checks and cleanup."""
647 object.__setattr__(self, "lb", try_int(self.lb))
648 dd: Final[Distribution] = self.d
649 if isinstance(dd, Const):
650 if cast("Const", dd).v < self.lb:
651 raise ValueError(f"Invalid distribution {self!r}.")
652 elif isinstance(dd, AtLeast):
653 dlb: AtLeast = cast("AtLeast", dd)
654 object.__setattr__(self, "lb", max(self.lb, dlb.lb))
655 object.__setattr__(self, "d", dlb.d)
656 elif isinstance(dd, In):
657 idd: In = cast("In", dd)
658 ulb: int | float = max(idd.lb, self.lb)
659 if ulb >= idd.ub:
660 raise ValueError(f"Invalid distribution {self!r}.")
661 elif isinstance(dd, Uniform):
662 udd: Uniform = cast("Uniform", dd)
663 ulb = max(udd.low, self.lb)
664 if ulb >= udd.high:
665 raise ValueError(f"Invalid distribution {self!r}.")
667 def simplify(self) -> "Distribution":
668 """
669 Try to simplify this distribution.
671 :returns: a simplified version of this distribution
673 >>> AtLeast(1, Uniform(3, 4)).simplify()
674 Uniform(low=3, high=4)
676 >>> AtLeast(3.5, Uniform(3, 4)).simplify()
677 Uniform(low=3.5, high=4)
678 """
679 dd: Final[Distribution] = self.d
680 if isinstance(dd, Const):
681 return dd
682 if isinstance(dd, In):
683 idd: In = cast("In", dd)
684 return In(max(idd.lb, self.lb), idd.ub, idd.d).simplify()
685 if isinstance(dd, Uniform):
686 udd: Uniform = cast("Uniform", dd)
687 if udd.low >= self.lb:
688 return udd
689 return Uniform(self.lb, udd.high).simplify()
690 if (self.lb <= 0) and (isinstance(dd, Exponential | Gamma | Erlang)):
691 return dd
692 if (self.lb <= _SMALLEST_POSITIVE_NUMBER) and isinstance(
693 dd, Gamma | Erlang) and (
694 cast("Gamma", dd).k > 1): # pylint: disable=E1101
695 return dd
696 return self
698 def sample(self, random: Generator) -> int | float:
699 """
700 Sample from the lower-bounded distribution.
702 :param random: the random number generator
703 :return: the result
704 """
705 s: Final[Callable[[Generator], int | float]] = self.d.sample
706 lb: Final[int | float] = self.lb
707 for _ in range(_MAX_TRIALS):
708 v = s(random)
709 if lb <= v:
710 return v
711 raise ValueError(f"Failed to sample from {self!r}.")
713 @classmethod
714 def greater_than_zero(cls, d: int | float | Distribution) -> Distribution:
715 """
716 Ensure that all samples are greater than zero.
718 :param d: the original distribution
719 :return: a distribution which is always greater than zero
720 """
721 return cls(_SMALLEST_POSITIVE_NUMBER, distribution(d)).simplify()
724@dataclass(order=True, frozen=True)
725class In(Distribution):
726 """
727 A distribution that is lower and upper-bounded.
729 >>> In(1, 10, Const(6))
730 In(lb=1, ub=10, d=Const(v=6))
732 >>> In(1, 10, In(5, 12, Const(6)))
733 In(lb=5, ub=10, d=Const(v=6))
735 >>> In(1, 10, AtLeast(6, Const(6)))
736 In(lb=6, ub=10, d=Const(v=6))
737 """
739 #: the inclusive lower bound
740 lb: int | float
741 #: the exclusive upper bound
742 ub: int | float
743 #: the inner distribution to sample from
744 d: Distribution
746 def __post_init__(self) -> None:
747 """Perform some basic sanity checks and cleanup."""
748 object.__setattr__(self, "lb", try_int(self.lb))
749 object.__setattr__(self, "ub", try_int(self.ub))
750 if not isinstance(self.d, Distribution):
751 raise TypeError(f"Invalid types {self}.")
752 if self.ub <= self.lb:
753 raise ValueError(f"Invalid range {self}.")
754 dd: Distribution = self.d
755 lb: Final[int | float] = self.lb
756 ub: Final[int | float] = self.ub
757 if isinstance(dd, In):
758 idd: In = cast("In", dd)
759 ulb: int | float = max(idd.lb, lb)
760 uub: int | float = min(idd.ub, ub)
761 if ulb >= uub:
762 raise ValueError(f"Invalid distribution {self!r}.")
763 object.__setattr__(self, "lb", ulb)
764 object.__setattr__(self, "ub", uub)
765 dd = idd.d
766 object.__setattr__(self, "d", dd)
767 elif isinstance(dd, AtLeast):
768 ldd: AtLeast = cast("AtLeast", dd)
769 ulb = max(ldd.lb, lb)
770 if ulb >= ub:
771 raise ValueError(f"Invalid distribution {self!r}.")
772 object.__setattr__(self, "lb", ulb)
773 dd = ldd.d
774 object.__setattr__(self, "d", dd)
775 elif isinstance(dd, Uniform): # fix a uniform distribution
776 udd: Uniform = cast("Uniform", dd)
777 ulb = max(udd.low, lb)
778 uub = min(udd.high, ub)
779 if ulb >= uub:
780 raise ValueError(f"Invalid distribution {self!r}.")
781 if (ulb != udd.low) or (uub != udd.high):
782 object.__setattr__(self, "d", Uniform(ulb, uub))
783 object.__setattr__(self, "lb", ulb)
784 object.__setattr__(self, "ub", uub)
785 if isinstance(dd, Const) and not lb <= cast("Const", dd).v < ub:
786 raise ValueError(f"Invalid distribution {self!r}.")
788 def sample(self, random: Generator) -> int | float:
789 """
790 Sample from the lower-bounded distribution.
792 :param random: the random number generator
793 :return: the result
794 """
795 s: Final[Callable[[Generator], int | float]] = self.d.sample
796 lb: Final[int | float] = self.lb
797 ub: Final[int | float] = self.ub
798 for _ in range(_MAX_TRIALS):
799 v = s(random)
800 if lb <= v < ub:
801 return v
802 raise ValueError(f"Failed to sample from {self!r}.")
804 def simplify(self) -> Distribution:
805 """
806 Simplify this distribution.
808 :return: the simplified distribution
810 >>> In(-3, 4, Const(3)).simplify()
811 Const(v=3)
812 >>> ii = In(-10, 10, Uniform(3, 20))
813 >>> ii
814 In(lb=3, ub=10, d=Uniform(low=3, high=10))
815 >>> ii.simplify()
816 Uniform(low=3, high=10)
817 """
818 return self.d if isinstance(self.d, Const | Uniform) else self
820 def mean(self) -> int | float:
821 """
822 Get the mean of this distribution.
824 :return: the mean
825 """
826 return self.d.mean() if isinstance(self.d, Const | Uniform) \
827 else super().mean()
830class IntDistribution(Distribution):
831 """A base class for integer distributions."""
833 def sample(self, random: Generator) -> int:
834 """
835 Sample a random number following this integer distribution generator.
837 Each call to this function returns exactly one number.
839 :param random: the random number generator
840 :return: the number, which always will be integer
841 """
842 raise NotImplementedError
844 def simplify(self) -> "IntDistribution":
845 """
846 Try to simplify this integer distribution.
848 :returns: a simplified version of this integer distribution
849 """
850 if not isinstance(self, IntDistribution):
851 raise type_error(self, "self", IntDistribution)
852 return self
855class IntConst(IntDistribution, Const):
856 """An integer constant."""
858 def __post_init__(self) -> None:
859 """Perform some basic sanity checks and cleanup."""
860 super().__post_init__()
861 check_int_range(self.v, "v", -1_000_000_000_000_000_000,
862 1_000_000_000_000_000_000)
864 def sample(self, random: Generator) -> int:
865 """Get the integer constant value."""
866 return cast("int", self.v)
868 def mean(self) -> int:
869 """Get the arithmetic mean."""
870 return cast("int", self.v)
873def distribution(d: int | float | Distribution) -> Distribution:
874 """
875 Get the distribution from the parameter.
877 :param d: the integer value or distribution
878 :return: the canonicalized distribution
880 >>> distribution(7)
881 IntConst(v=7)
883 >>> distribution(3.4)
884 Const(v=3.4)
886 >>> distribution(Choice((Const(4.0), )))
887 IntConst(v=4)
888 """
889 if isinstance(d, int):
890 return IntConst(d)
891 if isinstance(d, float):
892 return Const(d)
893 if isinstance(d, Distribution):
894 old_d: Distribution | None = None
895 while old_d is not d:
896 old_d = d
897 d = d.simplify()
898 if not isinstance(d, Distribution):
899 break
900 return d
901 raise type_error(d, "d", (Distribution, int))