Coverage for moptipyapps / utils / sampling.py: 87%
277 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-11 04:40 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-11 04:40 +0000
1"""
2Some utilities for random sampling.
4The goal that we follow with class :class:`Distribution` is to have
5clearly defined integer-producing random distributions. We want to be
6able to say exactly how to generate some random numbers.
8A distribution can be sampled using the method :meth:`~Distribution.sample`.
9Each distribution has a mean value, which either may be an exact value or
10an approximate result, that can be obtained via :meth:`~Distribution.mean`.
11Sometimes, distributions can be simplified, which is supported by
12:meth:`~Distribution.simplify`.
14>>> from moptipy.utils.nputils import rand_generator
15>>> from statistics import mean
16>>> rnd = rand_generator(0)
18>>> const = Const(12.3)
19>>> const
20Const(v=12.3)
21>>> const.mean()
2212.3
23>>> const.sample(rnd)
2412.3
26>>> normal = Normal(1, 2.0)
27>>> normal
28Normal(mu=1, sd=2)
29>>> x = [normal.sample(rnd) for _ in range(200)]
30>>> x[:20]
31[1.2514604421867865, 0.7357902734173962, 2.280845300886564,\
32 1.2098002343060794, -0.07133874632222192, 1.7231901098189695,\
33 3.6080000902602745, 2.8941619262584846, -0.4074704716139852,\
34 -1.530842942092105, -0.24654892507470438, 1.0826519586944872,\
35 -3.6500615492776687, 0.5624166721349085, -1.4918218945061303,\
36 -0.4645347094069032, -0.08851796571461978, 0.3673996872616909,\
37 1.8232610727482657, 3.085026738885355]
38>>> mean(x) / normal.mean()
391.0305262793198813
41>>> exponential = Exponential(3)
42>>> x = [exponential.sample(rnd) for _ in range(200)]
43>>> x[:20]
44[1.3203895772033505, 1.152983246827425, 4.527545171626064,\
45 5.080441711712409, 0.6498200242245252, 3.408652958826374,\
46 2.842677620245357, 2.1132194336022, 3.5838508479713393,\
47 1.2223825336486978, 2.8454397976498504, 0.8653980789905962,\
48 0.19572166348792452, 3.2506725793229854, 0.9426446058446336,\
49 3.246902386754473, 10.294412603282666, 5.275683923067543,\
50 0.49517363091492944, 0.2982336402218482]
51>>> mean(x) / exponential.mean()
521.088857332669081
54>>> gamma = Gamma(3.0, 0.26)
55>>> gamma
56Gamma(k=3, theta=0.26)
57>>> x = [gamma.sample(rnd) for _ in range(200)]
58>>> x[:20]
59[1.1617763743292708, 1.305755109480284, 0.7627948403389954,\
60 1.2735522637897285, 0.7742951665621697, 1.074233520618276,\
61 0.6324661100546898, 1.4627037699922791, 0.5739033567160827,\
62 0.5555065636904546, 0.629236234283296, 0.3666171387296996,\
63 0.3780976936750937, 0.9511433672028858, 1.2607313263258062,\
64 1.4442096466925938, 0.48758642085808085, 1.247724803721524,\
65 1.9359140456080306, 1.3935246884396764]
66>>> mean(x) / gamma.mean()
670.962254253258444
69>>> Gamma.from_alpha_beta(2, 0.5)
70Erlang(k=2, theta=2)
72>>> Gamma.from_alpha_beta(2.5, 0.5)
73Gamma(k=2.5, theta=2)
75>>> Erlang.from_alpha_beta(2, 0.5)
76Erlang(k=2, theta=2)
78>>> erlang = Erlang(1.0, 0.26)
79>>> erlang
80Erlang(k=1, theta=0.26)
81>>> x = [erlang.sample(rnd) for _ in range(200)]
82>>> x[:20]
83[0.29911329228981776, 0.3630768060267626, 0.14111385731543394,\
84 0.0745673280234536, 0.029950507989979877, 0.04741877104350835,\
85 0.38599089026561223, 0.047919114170390194, 0.06921557868837301,\
86 0.4066084140331242, 0.07170887998378667, 0.022061870233843223,\
87 0.04904717644388396, 0.32082097064821674, 0.001884448999141546,\
88 0.6687964040577958, 0.060598863807579915, 0.21491377996577304,\
89 0.23088301776258766, 0.23667780086315618]
90>>> mean(x) / erlang.mean()
910.9212193955320179
93>>> uniform = Uniform(10.5, 17)
94>>> uniform
95Uniform(low=10.5, high=17)
96>>> x = [uniform.sample(rnd) for _ in range(200)]
97>>> x[:20]
98[11.235929257903324, 13.89459856391694, 14.196833332199233,\
99 13.871456027849515, 14.485310406702393, 16.20468251006836,\
100 13.7773314921396, 12.964459911549145, 12.167722566772781,\
101 12.49450317595701, 14.145245846833722, 15.669920217581366,\
102 13.367286694558258, 10.764955127505994, 11.723004274328007,\
103 11.089239517262232, 12.666732191193558, 14.948448277461127,\
104 14.339645757381653, 14.803829334728565]
105>>> mean(x) / uniform.mean()
1060.9968578545459643
108>>> choice = Choice((Const(2), gamma, normal))
109>>> choice
110Choice(ch=(Const(v=2), Gamma(k=3, theta=0.26), Normal(mu=1, sd=2)))
111>>> x = [choice.sample(rnd) for _ in range(200)]
112>>> x[:20]
113[0.5499364190576, -0.10428920251005325, 2, 2, 0.44263544084840273,\
114 0.6088189450771303, 2, 2, 2, 2, -0.32003290715104904, 2,\
115 0.6165299227577784, 1.0445083345086352, 2, 2, 2, -0.5970322857539738,\
116 -1.6672705710198277, 2]
117>>> mean(x) / choice.mean()
1180.9849730453309948
120>>> lower = AtLeast(2, normal)
121>>> lower
122AtLeast(lb=2, d=Normal(mu=1, sd=2))
123>>> x = [lower.sample(rnd) for _ in range(200)]
124>>> x[:20]
125[2.9715155848474772, 3.407503493034132, 2.59634461519561,\
126 2.0472458472897714, 3.9865670827840334, 2.015117730344058,\
127 2.0316584714999935, 5.36625737470408, 3.042848158343226,\
128 3.390407444638451, 4.5503184215686, 3.7682459882073007,\
129 2.0539760253651305, 2.134886147372958, 2.500182239395479,\
130 2.891402111997337, 2.7393826524907228, 2.3449577842364766,\
131 2.9043074694017195, 4.7173482582723825]
132>>> mean(x) / lower.mean()
1331.0178179732075965
135>>> interval = In(1, 10, gamma)
136>>> interval
137In(lb=1, ub=10, d=Gamma(k=3, theta=0.26))
138>>> x = [interval.sample(rnd) for _ in range(200)]
139>>> x[:20]
140[1.3488423972739083, 1.0929631399361601, 1.681162621901135,\
141 1.655614926246918, 1.041948891842002, 2.0773395958990175,\
142 1.3338891374921853, 1.2478964188175743, 1.9417070894505217,\
143 1.3990572178987266, 1.1216118870312337, 1.0641160239253207,\
144 2.131253600219639, 1.0337453883221577, 1.396499618416345,\
145 1.9865175145136038, 1.2555473269031396, 1.4412027583435465,\
146 1.320740351247919, 1.0407411942999665]
147>>> mean(x) / interval.mean()
1481.0256329849020747
150>>> erlang2 = Gamma.from_k_and_mean(3, 10)
151>>> erlang2
152Erlang(k=3, theta=3.3333333333333335)
153>>> erlang2.mean()
15410
155>>> x = [erlang2.sample(rnd) for _ in range(200)]
156>>> x[:20]
157[10.087506399523226, 12.928131914870168, 12.330250639007767,\
158 5.305123692562998, 21.085037136404374, 6.6603691824173135,\
159 2.961302890492059, 9.810557147180853, 8.051620919921454,\
160 8.750329405836668, 3.9511445189935763, 5.570300668751883,\
161 16.70132947692463, 7.831425379483914, 11.154757962484842,\
162 8.78943102381046, 8.395847820234795, 16.42251602814587,\
163 17.1628992966332, 9.684008648356015]
164>>> mean(x) / erlang2.mean()
1651.022444624140316
167>>> AtLeast.greater_than_zero(Gamma(1, 0.5))
168AtLeast(lb=5e-324, d=Exponential(eta=1))
170>>> AtLeast.greater_than_zero(Gamma(2, 0.5))
171Erlang(k=2, theta=0.5)
173>>> AtLeast.greater_than_zero(Gamma(2.5, 0.5))
174Gamma(k=2.5, theta=0.5)
175"""
177from dataclasses import dataclass
178from math import fsum, isfinite, nextafter
179from typing import Callable, Final, cast
181from moptipy.utils.nputils import rand_generator, rand_seeds_from_str
182from numpy.random import Generator
183from pycommons.math.int_math import try_int
184from pycommons.types import check_int_range, type_error
186#: the maximum number of trials during a sampling process
187_MAX_TRIALS: int = 1_000_000
189#: the smallest positive number
190_SMALLEST_POSITIVE_NUMBER: Final[float] = nextafter(0.0, 1.0)
193class Distribution:
194 """A base class for distributions."""
196 def sample(self, random: Generator) -> int | float:
197 """
198 Sample a random number following this distribution generator.
200 Each call to this function returns exactly one number.
202 :param random: the random number generator
203 :return: the number
204 """
205 raise NotImplementedError
207 def simplify(self) -> "Distribution":
208 """
209 Try to simplify this distribution.
211 Some distributions can trivially be simplified. For example, if you
212 have applied a range limit (:class:`In`) to a constant distribution
213 (class:`Const`), then this can be simplified to just the constant.
214 If such simplification is possible, this method returns the simplified
215 distribution. Otherwise, it just returns the distribution itself.
217 :returns: a simplified version of this distribution
218 """
219 if not isinstance(self, Distribution):
220 raise type_error(self, "self", Distribution)
221 return self
223 def mean(self) -> int | float:
224 """
225 Get the mean or approximate mean of the distribution.
227 Some distribution overwrite this method to produce an exact computed
228 expected value or mean. This default implementation just computes the
229 arithmetic mean of 10'000 samples of the distribution. This serves as
230 baseline approximation for any case where a closed form mathematical
231 definition of the expected value is not available.
233 :return: the mean or approximated mean
234 """
235 sample: Callable[[Generator], int | float] = self.sample
236 gen: Final[Generator] = rand_generator(rand_seeds_from_str(
237 repr(self), 1)[0])
238 return try_int(fsum(sample(gen) for _ in range(10_000)) / 10_000)
241@dataclass(order=True, frozen=True)
242class Const(Distribution):
243 """A constant value."""
245 #: the constant value
246 v: int | float
248 def __post_init__(self) -> None:
249 """Perform some basic sanity checks and cleanup."""
250 object.__setattr__(self, "v", try_int(self.v))
252 def sample(self, random: Generator) -> int | float:
253 """
254 Sample the constant integer.
256 :param random: the random number generator
257 :return: the integer
258 """
259 return self.v
261 def mean(self) -> int | float:
262 """
263 Get the mean of this distribution.
265 :return: the mean
266 """
267 return self.v
269 def simplify(self) -> "Distribution":
270 """
271 Simplify this constat.
273 :return: the simplified constant
275 >>> Const(1.5).simplify()
276 Const(v=1.5)
277 >>> Const(1.0).simplify()
278 IntConst(v=1)
279 """
280 return IntConst(self.v) if isinstance(self.v, int) else self
283@dataclass(order=True, frozen=True)
284class Normal(Distribution):
285 """A class representing a normal distribution."""
287 #: the expected value and center of the distribution
288 mu: int | float
289 #: the standard deviation of the distribution
290 sd: int | float
292 def __post_init__(self) -> None:
293 """Perform some basic sanity checks and cleanup."""
294 if not (isfinite(self.mu) and isfinite(self.sd) and (self.sd > 0)):
295 raise ValueError(f"Invalid parameters {self}.")
296 object.__setattr__(self, "mu", try_int(self.mu))
297 object.__setattr__(self, "sd", try_int(self.sd))
299 def sample(self, random: Generator) -> float:
300 """
301 Sample from the normal distribution.
303 :param random: the random number generator
304 :return: the result
305 """
306 return random.normal(self.mu, self.sd)
308 def mean(self) -> int | float:
309 """
310 Get the mean of this distribution.
312 :return: the mean
313 """
314 return self.mu
317@dataclass(order=True, frozen=True)
318class Exponential(Distribution):
319 """A class representing an exponential distribution."""
321 #: the exponential distribution parameter
322 eta: int | float
324 def __post_init__(self) -> None:
325 """Perform some basic sanity checks and cleanup."""
326 object.__setattr__(self, "eta", try_int(self.eta))
327 if self.eta <= 0:
328 raise ValueError(f"Invalid setup {self!r}.")
330 def sample(self, random: Generator) -> float:
331 """
332 Sample from the Exponential distribution.
334 :param random: the random number generator
335 :return: the result
336 """
337 return random.exponential(self.eta)
339 def mean(self) -> int | float:
340 """
341 Get the mean of this distribution.
343 :return: the mean
344 """
345 return try_int(self.eta)
348@dataclass(order=True, frozen=True)
349class Gamma(Distribution):
350 """
351 A class representing a Gamma distribution.
353 Here, `k` is the shape and `theta` is the scale parameter.
354 If you use a parameterization with `alpha` and `beta`, you need to create
355 the distribution using :meth:`~Gamma.from_alpha_beta` instead. The reason
356 is that `shape = 1/beta`, see
357 https://www.statlect.com/probability-distributions/gamma-distribution.
358 """
360 #: the shape parameter
361 k: int | float
362 #: the scale parameter
363 theta: int | float
365 def __post_init__(self) -> None:
366 """Perform some basic sanity checks and cleanup."""
367 object.__setattr__(self, "k", try_int(self.k))
368 object.__setattr__(self, "theta", try_int(self.theta))
369 if not (self.k > 0) and (self.theta > 0):
370 raise ValueError(f"Invalid parameters {self}.")
372 def sample(self, random: Generator) -> float:
373 """
374 Sample from the Gamma distribution.
376 :param random: the random number generator
377 :return: the result
378 """
379 return random.gamma(self.k, self.theta)
381 def simplify(self) -> "Distribution":
382 """
383 Try to simplify this distribution.
385 A Gamma distribution may simplify to either an :class:`Erlang` or an
386 :class:`Exponential` distribution, depending on its parameters.
387 If the :attr:`~Gamma.k` is `1`, then it is actually an
388 :class:`Exponential` distribution. If :attr:`~Gamma.k` is an integer,
389 then the distribution is an :class:`Erlang` distribution.
391 1. https://www.statisticshowto.com/gamma-distribution
392 2. https://www.statisticshowto.com/erlang-distribution
394 :returns: a simplified version of this distribution
395 """
396 return Exponential(self.k) if self.k == 1 else (
397 Erlang(self.k, self.theta) if isinstance(
398 self.k, int) and not isinstance(self, Erlang) else self)
400 def mean(self) -> int | float:
401 """
402 Get the mean of this distribution.
404 :return: the mean
405 """
406 return try_int(self.k * self.theta)
408 @classmethod
409 def from_alpha_beta(cls, alpha: int | float, beta: int | float) \
410 -> "Distribution":
411 """
412 Create a Gamma distribution from `alpha` and `beta`.
414 :param alpha: the alpha parameter
415 :param beta: the beta parameter
416 :return: the distribution
417 """
418 beta = try_int(beta)
419 if beta == 0:
420 raise ValueError(f"beta cannot be {beta}.")
421 return cls(alpha, 1 / beta).simplify()
423 @classmethod
424 def from_k_and_mean(cls, k: int | float, mean: int | float) \
425 -> "Distribution":
426 """
427 Create the Gamma distribution from the value of `k` and a mean.
429 :param k: the shape parameter
430 :param mean: the mean
431 :return: the distribution
432 """
433 k = try_int(k)
434 mean = try_int(mean)
435 if (mean <= 0) or (k <= 0):
436 raise ValueError(f"Invalid values k={k}, mean={mean}.")
437 return Gamma(k, mean / k).simplify()
440class Erlang(Gamma):
441 """The Erlang distribution."""
443 def __post_init__(self) -> None:
444 """Perform some basic sanity checks and cleanup."""
445 super().__post_init__()
446 if not isinstance(self.k, int):
447 raise type_error(self.k, "k", int)
450@dataclass(order=True, frozen=True)
451class Uniform(Distribution):
452 """A class representing a uniform distribution."""
454 #: the lowest permitted value
455 low: int | float
456 #: the highest permitted value
457 high: int | float
459 def __post_init__(self) -> None:
460 """Perform some basic sanity checks and cleanup."""
461 object.__setattr__(self, "low", try_int(self.low))
462 object.__setattr__(self, "high", try_int(self.high))
463 if self.high <= self.low:
464 raise ValueError(f"Invalid parameters {self}.")
466 def sample(self, random: Generator) -> float:
467 """
468 Sample from the uniform distribution.
470 :param random: the random number generator
471 :return: the result
472 """
473 return random.uniform(self.low, self.high)
475 def mean(self) -> int | float:
476 """
477 Get the mean of this distribution.
479 :return: the mean
480 """
481 return try_int((self.high + self.low) / 2)
484@dataclass(order=True, frozen=True)
485class Choice(Distribution):
486 """
487 A class representing a uniform choice.
489 >>> Choice((Uniform(1, 2), Uniform(3, 4))).simplify()
490 Choice(ch=(Uniform(low=1, high=2), Uniform(low=3, high=4)))
492 >>> Choice((Uniform(1, 2), Uniform(1.0, 2))).simplify()
493 Uniform(low=1, high=2)
495 >>> Choice((Uniform(1, 2), Choice(
496 ... (Const(1), Uniform(1, 2))))).simplify()
497 Choice(ch=(Uniform(low=1, high=2), IntConst(v=1), Uniform(low=1, high=2)))
498 """
500 #: the choices
501 ch: tuple[Distribution, ...]
503 def __post_init__(self) -> None:
504 """Perform some basic sanity checks and cleanup."""
505 if not isinstance(self.ch, tuple):
506 raise TypeError(f"Invalid types {self}.")
507 for v in self.ch:
508 if not isinstance(v, Distribution):
509 raise type_error(v, "choice", Distribution)
511 def mean(self) -> int | float:
512 """
513 Get the mean of this distribution.
515 :return: the mean
516 """
517 return try_int(fsum(d.mean() for d in self.ch) / tuple.__len__(
518 self.ch))
520 def sample(self, random: Generator) -> int | float:
521 """
522 Sample from the uniform distribution.
524 :param random: the random number generator
525 :return: the result
526 """
527 return self.ch[random.integers(tuple.__len__(self.ch))].sample(random)
529 def simplify(self) -> Distribution:
530 """
531 Try to simplify this distribution.
533 :returns: a simplified version of this distribution
534 """
535 ch: Final[tuple[Distribution, ...]] = self.ch
536 if tuple.__len__(ch) <= 1:
537 return ch[0].simplify()
538 done: list[Distribution] = []
539 needs: bool = False
540 for dist in ch:
541 use: Distribution = dist.simplify()
542 if use != dist:
543 needs = True
544 if isinstance(use, Choice):
545 needs = True
546 done.extend(use.ch)
547 else:
548 done.append(use)
550 total: int = list.__len__(done)
551 dc: Distribution = done[0]
552 if total <= 1:
553 return dc
555 all_same: bool = True
556 for oth in done:
557 if oth != dc:
558 all_same = False
559 break
560 return dc if all_same else (Choice(tuple(done)) if needs else self)
563@dataclass(order=True, frozen=True)
564class AtLeast(Distribution):
565 """
566 A distribution that is lower-bounded.
568 >>> AtLeast(5, Const(7))
569 AtLeast(lb=5, d=Const(v=7))
571 >>> AtLeast(5, AtLeast(8, Const(17)))
572 AtLeast(lb=8, d=Const(v=17))
574 >>> AtLeast(8, AtLeast(5, Const(17)))
575 AtLeast(lb=8, d=Const(v=17))
576 """
578 #: the inclusive lower bound
579 lb: int | float
580 #: the inner distribution to sample from
581 d: Distribution
583 def __post_init__(self) -> None:
584 """Perform some basic sanity checks and cleanup."""
585 object.__setattr__(self, "lb", try_int(self.lb))
586 dd: Final[Distribution] = self.d
587 if isinstance(dd, Const):
588 if cast("Const", dd).v < self.lb:
589 raise ValueError(f"Invalid distribution {self!r}.")
590 elif isinstance(dd, AtLeast):
591 dlb: AtLeast = cast("AtLeast", dd)
592 object.__setattr__(self, "lb", max(self.lb, dlb.lb))
593 object.__setattr__(self, "d", dlb.d)
594 elif isinstance(dd, In):
595 idd: In = cast("In", dd)
596 ulb: int | float = max(idd.lb, self.lb)
597 if ulb >= idd.ub:
598 raise ValueError(f"Invalid distribution {self!r}.")
599 elif isinstance(dd, Uniform):
600 udd: Uniform = cast("Uniform", dd)
601 ulb = max(udd.low, self.lb)
602 if ulb >= udd.high:
603 raise ValueError(f"Invalid distribution {self!r}.")
605 def simplify(self) -> "Distribution":
606 """
607 Try to simplify this distribution.
609 :returns: a simplified version of this distribution
611 >>> AtLeast(1, Uniform(3, 4)).simplify()
612 Uniform(low=3, high=4)
614 >>> AtLeast(3.5, Uniform(3, 4)).simplify()
615 Uniform(low=3.5, high=4)
616 """
617 dd: Final[Distribution] = self.d
618 if isinstance(dd, Const):
619 return dd
620 if isinstance(dd, In):
621 idd: In = cast("In", dd)
622 return In(max(idd.lb, self.lb), idd.ub, idd.d).simplify()
623 if isinstance(dd, Uniform):
624 udd: Uniform = cast("Uniform", dd)
625 if udd.low >= self.lb:
626 return udd
627 return Uniform(self.lb, udd.high).simplify()
628 if (self.lb <= 0) and (isinstance(dd, Exponential | Gamma | Erlang)):
629 return dd
630 if (self.lb <= _SMALLEST_POSITIVE_NUMBER) and isinstance(
631 dd, Gamma | Erlang) and (
632 cast("Gamma", dd).k > 1): # pylint: disable=E1101
633 return dd
634 return self
636 def sample(self, random: Generator) -> int | float:
637 """
638 Sample from the lower-bounded distribution.
640 :param random: the random number generator
641 :return: the result
642 """
643 s: Final[Callable[[Generator], int | float]] = self.d.sample
644 lb: Final[int | float] = self.lb
645 for _ in range(_MAX_TRIALS):
646 v = s(random)
647 if lb <= v:
648 return v
649 raise ValueError(f"Failed to sample from {self!r}.")
651 @classmethod
652 def greater_than_zero(cls, d: int | float | Distribution) -> Distribution:
653 """
654 Ensure that all samples are greater than zero.
656 :param d: the original distribution
657 :return: a distribution which is always greater than zero
658 """
659 return cls(_SMALLEST_POSITIVE_NUMBER, distribution(d)).simplify()
662@dataclass(order=True, frozen=True)
663class In(Distribution):
664 """
665 A distribution that is lower and upper-bounded.
667 >>> In(1, 10, Const(6))
668 In(lb=1, ub=10, d=Const(v=6))
670 >>> In(1, 10, In(5, 12, Const(6)))
671 In(lb=5, ub=10, d=Const(v=6))
673 >>> In(1, 10, AtLeast(6, Const(6)))
674 In(lb=6, ub=10, d=Const(v=6))
675 """
677 #: the inclusive lower bound
678 lb: int | float
679 #: the exclusive upper bound
680 ub: int | float
681 #: the inner distribution to sample from
682 d: Distribution
684 def __post_init__(self) -> None:
685 """Perform some basic sanity checks and cleanup."""
686 object.__setattr__(self, "lb", try_int(self.lb))
687 object.__setattr__(self, "ub", try_int(self.ub))
688 if not isinstance(self.d, Distribution):
689 raise TypeError(f"Invalid types {self}.")
690 if self.ub <= self.lb:
691 raise ValueError(f"Invalid range {self}.")
692 dd: Distribution = self.d
693 lb: Final[int | float] = self.lb
694 ub: Final[int | float] = self.ub
695 if isinstance(dd, In):
696 idd: In = cast("In", dd)
697 ulb: int | float = max(idd.lb, lb)
698 uub: int | float = min(idd.ub, ub)
699 if ulb >= uub:
700 raise ValueError(f"Invalid distribution {self!r}.")
701 object.__setattr__(self, "lb", ulb)
702 object.__setattr__(self, "ub", uub)
703 dd = idd.d
704 object.__setattr__(self, "d", dd)
705 elif isinstance(dd, AtLeast):
706 ldd: AtLeast = cast("AtLeast", dd)
707 ulb = max(ldd.lb, lb)
708 if ulb >= ub:
709 raise ValueError(f"Invalid distribution {self!r}.")
710 object.__setattr__(self, "lb", ulb)
711 dd = ldd.d
712 object.__setattr__(self, "d", dd)
713 elif isinstance(dd, Uniform): # fix a uniform distribution
714 udd: Uniform = cast("Uniform", dd)
715 ulb = max(udd.low, lb)
716 uub = min(udd.high, ub)
717 if ulb >= uub:
718 raise ValueError(f"Invalid distribution {self!r}.")
719 if (ulb != udd.low) or (uub != udd.high):
720 object.__setattr__(self, "d", Uniform(ulb, uub))
721 object.__setattr__(self, "lb", ulb)
722 object.__setattr__(self, "ub", uub)
723 if isinstance(dd, Const) and not lb <= cast("Const", dd).v < ub:
724 raise ValueError(f"Invalid distribution {self!r}.")
726 def sample(self, random: Generator) -> int | float:
727 """
728 Sample from the lower-bounded distribution.
730 :param random: the random number generator
731 :return: the result
732 """
733 s: Final[Callable[[Generator], int | float]] = self.d.sample
734 lb: Final[int | float] = self.lb
735 ub: Final[int | float] = self.ub
736 for _ in range(_MAX_TRIALS):
737 v = s(random)
738 if lb <= v < ub:
739 return v
740 raise ValueError(f"Failed to sample from {self!r}.")
742 def simplify(self) -> Distribution:
743 """
744 Simplify this distribution.
746 :return: the simplified distribution
748 >>> In(-3, 4, Const(3)).simplify()
749 Const(v=3)
750 >>> ii = In(-10, 10, Uniform(3, 20))
751 >>> ii
752 In(lb=3, ub=10, d=Uniform(low=3, high=10))
753 >>> ii.simplify()
754 Uniform(low=3, high=10)
755 """
756 return self.d if isinstance(self.d, Const | Uniform) else self
758 def mean(self) -> int | float:
759 """
760 Get the mean of this distribution.
762 :return: the mean
763 """
764 return self.d.mean() if isinstance(self.d, Const | Uniform) \
765 else super().mean()
768class IntDistribution(Distribution):
769 """A base class for integer distributions."""
771 def sample(self, random: Generator) -> int:
772 """
773 Sample a random number following this integer distribution generator.
775 Each call to this function returns exactly one number.
777 :param random: the random number generator
778 :return: the number, which always will be integer
779 """
780 raise NotImplementedError
782 def simplify(self) -> "IntDistribution":
783 """
784 Try to simplify this integer distribution.
786 :returns: a simplified version of this integer distribution
787 """
788 if not isinstance(self, IntDistribution):
789 raise type_error(self, "self", IntDistribution)
790 return self
793class IntConst(IntDistribution, Const):
794 """An integer constant."""
796 def __post_init__(self) -> None:
797 """Perform some basic sanity checks and cleanup."""
798 super().__post_init__()
799 check_int_range(self.v, "v", -1_000_000_000_000_000_000,
800 1_000_000_000_000_000_000)
802 def sample(self, random: Generator) -> int:
803 """Get the integer constant value."""
804 return cast("int", self.v)
806 def mean(self) -> int:
807 """Get the arithmetic mean."""
808 return cast("int", self.v)
811def distribution(d: int | float | Distribution) -> Distribution:
812 """
813 Get the distribution from the parameter.
815 :param d: the integer value or distribution
816 :return: the canonicalized distribution
818 >>> distribution(7)
819 IntConst(v=7)
821 >>> distribution(3.4)
822 Const(v=3.4)
824 >>> distribution(Choice((Const(4.0), )))
825 IntConst(v=4)
826 """
827 if isinstance(d, int):
828 return IntConst(d)
829 if isinstance(d, float):
830 return Const(d)
831 if isinstance(d, Distribution):
832 old_d: Distribution | None = None
833 while old_d is not d:
834 old_d = d
835 d = d.simplify()
836 if not isinstance(d, Distribution):
837 break
838 return d
839 raise type_error(d, "d", (Distribution, int))