Coverage for moptipyapps/utils/sampling.py: 87%

277 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-05-28 09:42 +0000

1""" 

2Some utilities for random sampling. 

3 

4The goal that we follow with class :class:`Distribution` is to have 

5clearly defined integer-producing random distributions. We want to be 

6able to say exactly how to generate some random numbers. 

7 

8A distribution can be sampled using the method :meth:`~Distribution.sample`. 

9Each distribution has a mean value, which either may be an exact value or 

10an approximate result, that can be obtained via :meth:`~Distribution.mean`. 

11Sometimes, distributions can be simplified, which is supported by 

12:meth:`~Distribution.simplify`. 

13 

14>>> from moptipy.utils.nputils import rand_generator 

15>>> from statistics import mean 

16>>> rnd = rand_generator(0) 

17 

18>>> const = Const(12.3) 

19>>> const 

20Const(v=12.3) 

21>>> const.mean() 

2212.3 

23>>> const.sample(rnd) 

2412.3 

25 

26>>> normal = Normal(1, 2.0) 

27>>> normal 

28Normal(mu=1, sd=2) 

29>>> x = [normal.sample(rnd) for _ in range(200)] 

30>>> x[:20] 

31[1.2514604421867865, 0.7357902734173962, 2.280845300886564,\ 

32 1.2098002343060794, -0.07133874632222192, 1.7231901098189695,\ 

33 3.6080000902602745, 2.8941619262584846, -0.4074704716139852,\ 

34 -1.530842942092105, -0.24654892507470438, 1.0826519586944872,\ 

35 -3.6500615492776687, 0.5624166721349085, -1.4918218945061303,\ 

36 -0.4645347094069032, -0.08851796571461978, 0.3673996872616909,\ 

37 1.8232610727482657, 3.085026738885355] 

38>>> mean(x) / normal.mean() 

391.0305262793198813 

40 

41>>> exponential = Exponential(3) 

42>>> x = [exponential.sample(rnd) for _ in range(200)] 

43>>> x[:20] 

44[1.3203895772033505, 1.152983246827425, 4.527545171626064,\ 

45 5.080441711712409, 0.6498200242245252, 3.408652958826374,\ 

46 2.842677620245357, 2.1132194336022, 3.5838508479713393,\ 

47 1.2223825336486978, 2.8454397976498504, 0.8653980789905962,\ 

48 0.19572166348792452, 3.2506725793229854, 0.9426446058446336,\ 

49 3.246902386754473, 10.294412603282666, 5.275683923067543,\ 

50 0.49517363091492944, 0.2982336402218482] 

51>>> mean(x) / exponential.mean() 

521.088857332669081 

53 

54>>> gamma = Gamma(3.0, 0.26) 

55>>> gamma 

56Gamma(k=3, theta=0.26) 

57>>> x = [gamma.sample(rnd) for _ in range(200)] 

58>>> x[:20] 

59[1.1617763743292708, 1.305755109480284, 0.7627948403389954,\ 

60 1.2735522637897285, 0.7742951665621697, 1.074233520618276,\ 

61 0.6324661100546898, 1.4627037699922791, 0.5739033567160827,\ 

62 0.5555065636904546, 0.629236234283296, 0.3666171387296996,\ 

63 0.3780976936750937, 0.9511433672028858, 1.2607313263258062,\ 

64 1.4442096466925938, 0.48758642085808085, 1.247724803721524,\ 

65 1.9359140456080306, 1.3935246884396764] 

66>>> mean(x) / gamma.mean() 

670.962254253258444 

68 

69>>> Gamma.from_alpha_beta(2, 0.5) 

70Erlang(k=2, theta=2) 

71 

72>>> Gamma.from_alpha_beta(2.5, 0.5) 

73Gamma(k=2.5, theta=2) 

74 

75>>> Erlang.from_alpha_beta(2, 0.5) 

76Erlang(k=2, theta=2) 

77 

78>>> erlang = Erlang(1.0, 0.26) 

79>>> erlang 

80Erlang(k=1, theta=0.26) 

81>>> x = [erlang.sample(rnd) for _ in range(200)] 

82>>> x[:20] 

83[0.29911329228981776, 0.3630768060267626, 0.14111385731543394,\ 

84 0.0745673280234536, 0.029950507989979877, 0.04741877104350835,\ 

85 0.38599089026561223, 0.047919114170390194, 0.06921557868837301,\ 

86 0.4066084140331242, 0.07170887998378667, 0.022061870233843223,\ 

87 0.04904717644388396, 0.32082097064821674, 0.001884448999141546,\ 

88 0.6687964040577958, 0.060598863807579915, 0.21491377996577304,\ 

89 0.23088301776258766, 0.23667780086315618] 

90>>> erlang.mean() 

910.26 

92>>> mean(x) 

930.23951704283832467 

94>>> mean(x) / erlang.mean() 

950.9212193955320179 

96 

97>>> uniform = Uniform(10.5, 17) 

98>>> uniform 

99Uniform(low=10.5, high=17) 

100>>> x = [uniform.sample(rnd) for _ in range(200)] 

101>>> x[:20] 

102[11.235929257903324, 13.89459856391694, 14.196833332199233,\ 

103 13.871456027849515, 14.485310406702393, 16.20468251006836,\ 

104 13.7773314921396, 12.964459911549145, 12.167722566772781,\ 

105 12.49450317595701, 14.145245846833722, 15.669920217581366,\ 

106 13.367286694558258, 10.764955127505994, 11.723004274328007,\ 

107 11.089239517262232, 12.666732191193558, 14.948448277461127,\ 

108 14.339645757381653, 14.803829334728565] 

109>>> mean(x) / uniform.mean() 

1100.9968578545459643 

111 

112>>> choice = Choice((Const(2), gamma, normal)) 

113>>> choice 

114Choice(ch=(Const(v=2), Gamma(k=3, theta=0.26), Normal(mu=1, sd=2))) 

115>>> x = [choice.sample(rnd) for _ in range(200)] 

116>>> x[:20] 

117[0.5499364190576, -0.10428920251005325, 2, 2, 0.44263544084840273,\ 

118 0.6088189450771303, 2, 2, 2, 2, -0.32003290715104904, 2,\ 

119 0.6165299227577784, 1.0445083345086352, 2, 2, 2, -0.5970322857539738,\ 

120 -1.6672705710198277, 2] 

121>>> mean(x) / choice.mean() 

1220.9849730453309948 

123 

124>>> lower = AtLeast(2, normal) 

125>>> lower 

126AtLeast(lb=2, d=Normal(mu=1, sd=2)) 

127>>> x = [lower.sample(rnd) for _ in range(200)] 

128>>> x[:20] 

129[2.9715155848474772, 3.407503493034132, 2.59634461519561,\ 

130 2.0472458472897714, 3.9865670827840334, 2.015117730344058,\ 

131 2.0316584714999935, 5.36625737470408, 3.042848158343226,\ 

132 3.390407444638451, 4.5503184215686, 3.7682459882073007,\ 

133 2.0539760253651305, 2.134886147372958, 2.500182239395479,\ 

134 2.891402111997337, 2.7393826524907228, 2.3449577842364766,\ 

135 2.9043074694017195, 4.7173482582723825] 

136>>> mean(x) / lower.mean() 

1371.0178179732075965 

138 

139>>> interval = In(1, 10, gamma) 

140>>> interval 

141In(lb=1, ub=10, d=Gamma(k=3, theta=0.26)) 

142>>> x = [interval.sample(rnd) for _ in range(200)] 

143>>> x[:20] 

144[1.3488423972739083, 1.0929631399361601, 1.681162621901135,\ 

145 1.655614926246918, 1.041948891842002, 2.0773395958990175,\ 

146 1.3338891374921853, 1.2478964188175743, 1.9417070894505217,\ 

147 1.3990572178987266, 1.1216118870312337, 1.0641160239253207,\ 

148 2.131253600219639, 1.0337453883221577, 1.396499618416345,\ 

149 1.9865175145136038, 1.2555473269031396, 1.4412027583435465,\ 

150 1.320740351247919, 1.0407411942999665] 

151>>> mean(x) / interval.mean() 

1521.0256329849020747 

153 

154>>> erlang2 = Gamma.from_k_and_mean(3, 10) 

155>>> erlang2 

156Erlang(k=3, theta=3.3333333333333335) 

157>>> erlang2.mean() 

15810 

159>>> x = [erlang2.sample(rnd) for _ in range(200)] 

160>>> x[:20] 

161[10.087506399523226, 12.928131914870168, 12.330250639007767,\ 

162 5.305123692562998, 21.085037136404374, 6.6603691824173135,\ 

163 2.961302890492059, 9.810557147180853, 8.051620919921454,\ 

164 8.750329405836668, 3.9511445189935763, 5.570300668751883,\ 

165 16.70132947692463, 7.831425379483914, 11.154757962484842,\ 

166 8.78943102381046, 8.395847820234795, 16.42251602814587,\ 

167 17.1628992966332, 9.684008648356015] 

168>>> mean(x) / erlang2.mean() 

1691.022444624140316 

170 

171>>> AtLeast.greater_than_zero(Gamma(1, 0.5)) 

172AtLeast(lb=5e-324, d=Exponential(eta=0.5)) 

173 

174>>> AtLeast.greater_than_zero(Gamma(2, 0.5)) 

175Erlang(k=2, theta=0.5) 

176 

177>>> AtLeast.greater_than_zero(Gamma(2.5, 0.5)) 

178Gamma(k=2.5, theta=0.5) 

179 

180>>> expo = erlang.simplify() 

181>>> expo 

182Exponential(eta=0.26) 

183>>> x = [expo.sample(rnd) for _ in range(200)] 

184>>> x[:20] 

185[0.07166050776665364, 0.16168119850191573, 0.22421276308859897, \ 

1860.0159144585889376, 0.2814946917726687, 0.18627887507033897, \ 

1870.023154315537313133, 0.04580416404670023, 0.6022196871728142, \ 

1880.025107668179499498, 0.8833625673263216, 0.15071056549661394, \ 

1890.00971431623116772, 0.024629860397399155, 0.5601666963985541, \ 

1900.14266231845989083, 0.05929349838310316, 0.34738399195888336, \ 

1910.048260650616228924, 0.28230843185818594] 

192>>> mean(x) 

1930.23545459359353754 

194>>> expo.mean() 

1950.26 

196>>> mean(x) / expo.mean() 

1970.9055945907443751 

198""" 

199 

200from dataclasses import dataclass 

201from math import fsum, isfinite, nextafter 

202from typing import Callable, Final, cast 

203 

204from moptipy.utils.nputils import rand_generator, rand_seeds_from_str 

205from numpy.random import Generator 

206from pycommons.math.int_math import try_int 

207from pycommons.types import check_int_range, type_error 

208 

209#: the maximum number of trials during a sampling process 

210_MAX_TRIALS: int = 1_000_000 

211 

212#: the smallest positive number 

213_SMALLEST_POSITIVE_NUMBER: Final[float] = nextafter(0.0, 1.0) 

214 

215 

216class Distribution: 

217 """A base class for distributions.""" 

218 

219 def sample(self, random: Generator) -> int | float: 

220 """ 

221 Sample a random number following this distribution generator. 

222 

223 Each call to this function returns exactly one number. 

224 

225 :param random: the random number generator 

226 :return: the number 

227 """ 

228 raise NotImplementedError 

229 

230 def simplify(self) -> "Distribution": 

231 """ 

232 Try to simplify this distribution. 

233 

234 Some distributions can trivially be simplified. For example, if you 

235 have applied a range limit (:class:`In`) to a constant distribution 

236 (class:`Const`), then this can be simplified to just the constant. 

237 If such simplification is possible, this method returns the simplified 

238 distribution. Otherwise, it just returns the distribution itself. 

239 

240 :returns: a simplified version of this distribution 

241 """ 

242 if not isinstance(self, Distribution): 

243 raise type_error(self, "self", Distribution) 

244 return self 

245 

246 def mean(self) -> int | float: 

247 """ 

248 Get the mean or approximate mean of the distribution. 

249 

250 Some distribution overwrite this method to produce an exact computed 

251 expected value or mean. This default implementation just computes the 

252 arithmetic mean of 10'000 samples of the distribution. This serves as 

253 baseline approximation for any case where a closed form mathematical 

254 definition of the expected value is not available. 

255 

256 :return: the mean or approximated mean 

257 """ 

258 sample: Callable[[Generator], int | float] = self.sample 

259 gen: Final[Generator] = rand_generator(rand_seeds_from_str( 

260 repr(self), 1)[0]) 

261 return try_int(fsum(sample(gen) for _ in range(10_000)) / 10_000) 

262 

263 

264@dataclass(order=True, frozen=True) 

265class Const(Distribution): 

266 """A constant value.""" 

267 

268 #: the constant value 

269 v: int | float 

270 

271 def __post_init__(self) -> None: 

272 """Perform some basic sanity checks and cleanup.""" 

273 object.__setattr__(self, "v", try_int(self.v)) 

274 

275 def sample(self, random: Generator) -> int | float: 

276 """ 

277 Sample the constant integer. 

278 

279 :param random: the random number generator 

280 :return: the integer 

281 """ 

282 return self.v 

283 

284 def mean(self) -> int | float: 

285 """ 

286 Get the mean of this distribution. 

287 

288 :return: the mean 

289 """ 

290 return self.v 

291 

292 def simplify(self) -> "Distribution": 

293 """ 

294 Simplify this constat. 

295 

296 :return: the simplified constant 

297 

298 >>> Const(1.5).simplify() 

299 Const(v=1.5) 

300 >>> Const(1.0).simplify() 

301 IntConst(v=1) 

302 """ 

303 return IntConst(self.v) if isinstance(self.v, int) else self 

304 

305 

306@dataclass(order=True, frozen=True) 

307class Normal(Distribution): 

308 """A class representing a normal distribution.""" 

309 

310 #: the expected value and center of the distribution 

311 mu: int | float 

312 #: the standard deviation of the distribution 

313 sd: int | float 

314 

315 def __post_init__(self) -> None: 

316 """Perform some basic sanity checks and cleanup.""" 

317 if not (isfinite(self.mu) and isfinite(self.sd) and (self.sd > 0)): 

318 raise ValueError(f"Invalid parameters {self}.") 

319 object.__setattr__(self, "mu", try_int(self.mu)) 

320 object.__setattr__(self, "sd", try_int(self.sd)) 

321 

322 def sample(self, random: Generator) -> float: 

323 """ 

324 Sample from the normal distribution. 

325 

326 :param random: the random number generator 

327 :return: the result 

328 """ 

329 return random.normal(self.mu, self.sd) 

330 

331 def mean(self) -> int | float: 

332 """ 

333 Get the mean of this distribution. 

334 

335 :return: the mean 

336 """ 

337 return self.mu 

338 

339 

340@dataclass(order=True, frozen=True) 

341class Exponential(Distribution): 

342 """A class representing an exponential distribution.""" 

343 

344 #: the exponential distribution parameter 

345 eta: int | float 

346 

347 def __post_init__(self) -> None: 

348 """Perform some basic sanity checks and cleanup.""" 

349 object.__setattr__(self, "eta", try_int(self.eta)) 

350 if self.eta <= 0: 

351 raise ValueError(f"Invalid setup {self!r}.") 

352 

353 def sample(self, random: Generator) -> float: 

354 """ 

355 Sample from the Exponential distribution. 

356 

357 :param random: the random number generator 

358 :return: the result 

359 """ 

360 return random.exponential(self.eta) 

361 

362 def mean(self) -> int | float: 

363 """ 

364 Get the mean of this distribution. 

365 

366 :return: the mean 

367 """ 

368 return try_int(self.eta) 

369 

370 

371@dataclass(order=True, frozen=True) 

372class Gamma(Distribution): 

373 """ 

374 A class representing a Gamma distribution. 

375 

376 Here, `k` is the shape and `theta` is the scale parameter. 

377 If you use a parameterization with `alpha` and `beta`, you need to create 

378 the distribution using :meth:`~Gamma.from_alpha_beta` instead. The reason 

379 is that `shape = 1/beta`, see 

380 https://www.statlect.com/probability-distributions/gamma-distribution. 

381 """ 

382 

383 #: the shape parameter 

384 k: int | float 

385 #: the scale parameter 

386 theta: int | float 

387 

388 def __post_init__(self) -> None: 

389 """Perform some basic sanity checks and cleanup.""" 

390 object.__setattr__(self, "k", try_int(self.k)) 

391 object.__setattr__(self, "theta", try_int(self.theta)) 

392 if not (self.k > 0) and (self.theta > 0): 

393 raise ValueError(f"Invalid parameters {self}.") 

394 

395 def sample(self, random: Generator) -> float: 

396 """ 

397 Sample from the Gamma distribution. 

398 

399 :param random: the random number generator 

400 :return: the result 

401 """ 

402 return random.gamma(self.k, self.theta) 

403 

404 def simplify(self) -> "Distribution": 

405 """ 

406 Try to simplify this distribution. 

407 

408 A Gamma distribution may simplify to either an :class:`Erlang` or an 

409 :class:`Exponential` distribution, depending on its parameters. 

410 If the :attr:`~Gamma.k` is `1`, then it is actually an 

411 :class:`Exponential` distribution. If :attr:`~Gamma.k` is an integer, 

412 then the distribution is an :class:`Erlang` distribution. 

413 

414 1. https://www.statisticshowto.com/gamma-distribution 

415 2. https://www.statisticshowto.com/erlang-distribution 

416 

417 :returns: a simplified version of this distribution 

418 """ 

419 return Exponential(self.theta) if self.k == 1 else ( 

420 Erlang(self.k, self.theta) if isinstance( 

421 self.k, int) and not isinstance(self, Erlang) else self) 

422 

423 def mean(self) -> int | float: 

424 """ 

425 Get the mean of this distribution. 

426 

427 :return: the mean 

428 """ 

429 return try_int(self.k * self.theta) 

430 

431 @classmethod 

432 def from_alpha_beta(cls, alpha: int | float, beta: int | float) \ 

433 -> "Distribution": 

434 """ 

435 Create a Gamma distribution from `alpha` and `beta`. 

436 

437 :param alpha: the alpha parameter 

438 :param beta: the beta parameter 

439 :return: the distribution 

440 

441 >>> Gamma.from_alpha_beta(1, 1) 

442 Exponential(eta=1) 

443 

444 >>> Gamma.from_alpha_beta(2, 1) 

445 Erlang(k=2, theta=1) 

446 

447 >>> Gamma.from_alpha_beta(1, 2) 

448 Exponential(eta=0.5) 

449 

450 >>> Gamma.from_alpha_beta(2, 2) 

451 Erlang(k=2, theta=0.5) 

452 

453 >>> Gamma.from_alpha_beta(1.5, 1) 

454 Gamma(k=1.5, theta=1) 

455 

456 >>> Gamma.from_alpha_beta(1, 1.5) 

457 Exponential(eta=0.6666666666666666) 

458 

459 >>> Gamma.from_alpha_beta(1.5, 1.5) 

460 Gamma(k=1.5, theta=0.6666666666666666) 

461 """ 

462 beta = try_int(beta) 

463 if beta == 0: 

464 raise ValueError(f"beta cannot be {beta}.") 

465 return cls(alpha, 1 / beta).simplify() 

466 

467 @classmethod 

468 def from_k_and_mean(cls, k: int | float, mean: int | float) \ 

469 -> "Distribution": 

470 """ 

471 Create the Gamma distribution from the value of `k` and a mean. 

472 

473 :param k: the shape parameter 

474 :param mean: the mean 

475 :return: the distribution 

476 

477 >>> Gamma.from_k_and_mean(1, 1) 

478 Exponential(eta=1) 

479 

480 >>> Gamma.from_k_and_mean(1, 2) 

481 Exponential(eta=2) 

482 

483 >>> Gamma.from_k_and_mean(2, 1) 

484 Erlang(k=2, theta=0.5) 

485 

486 >>> Gamma.from_k_and_mean(2, 2) 

487 Erlang(k=2, theta=1) 

488 

489 >>> Gamma.from_k_and_mean(1.5, 1) 

490 Gamma(k=1.5, theta=0.6666666666666666) 

491 

492 >>> Gamma.from_k_and_mean(1, 1.5) 

493 Exponential(eta=1.5) 

494 """ 

495 k = try_int(k) 

496 mean = try_int(mean) 

497 if (mean <= 0) or (k <= 0): 

498 raise ValueError(f"Invalid values k={k}, mean={mean}.") 

499 return Gamma(k, mean / k).simplify() 

500 

501 

502class Erlang(Gamma): 

503 """The Erlang distribution.""" 

504 

505 def __post_init__(self) -> None: 

506 """Perform some basic sanity checks and cleanup.""" 

507 super().__post_init__() 

508 if not isinstance(self.k, int): 

509 raise type_error(self.k, "k", int) 

510 

511 

512@dataclass(order=True, frozen=True) 

513class Uniform(Distribution): 

514 """A class representing a uniform distribution.""" 

515 

516 #: the lowest permitted value 

517 low: int | float 

518 #: the highest permitted value 

519 high: int | float 

520 

521 def __post_init__(self) -> None: 

522 """Perform some basic sanity checks and cleanup.""" 

523 object.__setattr__(self, "low", try_int(self.low)) 

524 object.__setattr__(self, "high", try_int(self.high)) 

525 if self.high <= self.low: 

526 raise ValueError(f"Invalid parameters {self}.") 

527 

528 def sample(self, random: Generator) -> float: 

529 """ 

530 Sample from the uniform distribution. 

531 

532 :param random: the random number generator 

533 :return: the result 

534 """ 

535 return random.uniform(self.low, self.high) 

536 

537 def mean(self) -> int | float: 

538 """ 

539 Get the mean of this distribution. 

540 

541 :return: the mean 

542 """ 

543 return try_int((self.high + self.low) / 2) 

544 

545 

546@dataclass(order=True, frozen=True) 

547class Choice(Distribution): 

548 """ 

549 A class representing a uniform choice. 

550 

551 >>> Choice((Uniform(1, 2), Uniform(3, 4))).simplify() 

552 Choice(ch=(Uniform(low=1, high=2), Uniform(low=3, high=4))) 

553 

554 >>> Choice((Uniform(1, 2), Uniform(1.0, 2))).simplify() 

555 Uniform(low=1, high=2) 

556 

557 >>> Choice((Uniform(1, 2), Choice( 

558 ... (Const(1), Uniform(1, 2))))).simplify() 

559 Choice(ch=(Uniform(low=1, high=2), IntConst(v=1), Uniform(low=1, high=2))) 

560 """ 

561 

562 #: the choices 

563 ch: tuple[Distribution, ...] 

564 

565 def __post_init__(self) -> None: 

566 """Perform some basic sanity checks and cleanup.""" 

567 if not isinstance(self.ch, tuple): 

568 raise TypeError(f"Invalid types {self}.") 

569 for v in self.ch: 

570 if not isinstance(v, Distribution): 

571 raise type_error(v, "choice", Distribution) 

572 

573 def mean(self) -> int | float: 

574 """ 

575 Get the mean of this distribution. 

576 

577 :return: the mean 

578 """ 

579 return try_int(fsum(d.mean() for d in self.ch) / tuple.__len__( 

580 self.ch)) 

581 

582 def sample(self, random: Generator) -> int | float: 

583 """ 

584 Sample from the uniform distribution. 

585 

586 :param random: the random number generator 

587 :return: the result 

588 """ 

589 return self.ch[random.integers(tuple.__len__(self.ch))].sample(random) 

590 

591 def simplify(self) -> Distribution: 

592 """ 

593 Try to simplify this distribution. 

594 

595 :returns: a simplified version of this distribution 

596 """ 

597 ch: Final[tuple[Distribution, ...]] = self.ch 

598 if tuple.__len__(ch) <= 1: 

599 return ch[0].simplify() 

600 done: list[Distribution] = [] 

601 needs: bool = False 

602 for dist in ch: 

603 use: Distribution = dist.simplify() 

604 if use != dist: 

605 needs = True 

606 if isinstance(use, Choice): 

607 needs = True 

608 done.extend(use.ch) 

609 else: 

610 done.append(use) 

611 

612 total: int = list.__len__(done) 

613 dc: Distribution = done[0] 

614 if total <= 1: 

615 return dc 

616 

617 all_same: bool = True 

618 for oth in done: 

619 if oth != dc: 

620 all_same = False 

621 break 

622 return dc if all_same else (Choice(tuple(done)) if needs else self) 

623 

624 

625@dataclass(order=True, frozen=True) 

626class AtLeast(Distribution): 

627 """ 

628 A distribution that is lower-bounded. 

629 

630 >>> AtLeast(5, Const(7)) 

631 AtLeast(lb=5, d=Const(v=7)) 

632 

633 >>> AtLeast(5, AtLeast(8, Const(17))) 

634 AtLeast(lb=8, d=Const(v=17)) 

635 

636 >>> AtLeast(8, AtLeast(5, Const(17))) 

637 AtLeast(lb=8, d=Const(v=17)) 

638 """ 

639 

640 #: the inclusive lower bound 

641 lb: int | float 

642 #: the inner distribution to sample from 

643 d: Distribution 

644 

645 def __post_init__(self) -> None: 

646 """Perform some basic sanity checks and cleanup.""" 

647 object.__setattr__(self, "lb", try_int(self.lb)) 

648 dd: Final[Distribution] = self.d 

649 if isinstance(dd, Const): 

650 if cast("Const", dd).v < self.lb: 

651 raise ValueError(f"Invalid distribution {self!r}.") 

652 elif isinstance(dd, AtLeast): 

653 dlb: AtLeast = cast("AtLeast", dd) 

654 object.__setattr__(self, "lb", max(self.lb, dlb.lb)) 

655 object.__setattr__(self, "d", dlb.d) 

656 elif isinstance(dd, In): 

657 idd: In = cast("In", dd) 

658 ulb: int | float = max(idd.lb, self.lb) 

659 if ulb >= idd.ub: 

660 raise ValueError(f"Invalid distribution {self!r}.") 

661 elif isinstance(dd, Uniform): 

662 udd: Uniform = cast("Uniform", dd) 

663 ulb = max(udd.low, self.lb) 

664 if ulb >= udd.high: 

665 raise ValueError(f"Invalid distribution {self!r}.") 

666 

667 def simplify(self) -> "Distribution": 

668 """ 

669 Try to simplify this distribution. 

670 

671 :returns: a simplified version of this distribution 

672 

673 >>> AtLeast(1, Uniform(3, 4)).simplify() 

674 Uniform(low=3, high=4) 

675 

676 >>> AtLeast(3.5, Uniform(3, 4)).simplify() 

677 Uniform(low=3.5, high=4) 

678 """ 

679 dd: Final[Distribution] = self.d 

680 if isinstance(dd, Const): 

681 return dd 

682 if isinstance(dd, In): 

683 idd: In = cast("In", dd) 

684 return In(max(idd.lb, self.lb), idd.ub, idd.d).simplify() 

685 if isinstance(dd, Uniform): 

686 udd: Uniform = cast("Uniform", dd) 

687 if udd.low >= self.lb: 

688 return udd 

689 return Uniform(self.lb, udd.high).simplify() 

690 if (self.lb <= 0) and (isinstance(dd, Exponential | Gamma | Erlang)): 

691 return dd 

692 if (self.lb <= _SMALLEST_POSITIVE_NUMBER) and isinstance( 

693 dd, Gamma | Erlang) and ( 

694 cast("Gamma", dd).k > 1): # pylint: disable=E1101 

695 return dd 

696 return self 

697 

698 def sample(self, random: Generator) -> int | float: 

699 """ 

700 Sample from the lower-bounded distribution. 

701 

702 :param random: the random number generator 

703 :return: the result 

704 """ 

705 s: Final[Callable[[Generator], int | float]] = self.d.sample 

706 lb: Final[int | float] = self.lb 

707 for _ in range(_MAX_TRIALS): 

708 v = s(random) 

709 if lb <= v: 

710 return v 

711 raise ValueError(f"Failed to sample from {self!r}.") 

712 

713 @classmethod 

714 def greater_than_zero(cls, d: int | float | Distribution) -> Distribution: 

715 """ 

716 Ensure that all samples are greater than zero. 

717 

718 :param d: the original distribution 

719 :return: a distribution which is always greater than zero 

720 """ 

721 return cls(_SMALLEST_POSITIVE_NUMBER, distribution(d)).simplify() 

722 

723 

724@dataclass(order=True, frozen=True) 

725class In(Distribution): 

726 """ 

727 A distribution that is lower and upper-bounded. 

728 

729 >>> In(1, 10, Const(6)) 

730 In(lb=1, ub=10, d=Const(v=6)) 

731 

732 >>> In(1, 10, In(5, 12, Const(6))) 

733 In(lb=5, ub=10, d=Const(v=6)) 

734 

735 >>> In(1, 10, AtLeast(6, Const(6))) 

736 In(lb=6, ub=10, d=Const(v=6)) 

737 """ 

738 

739 #: the inclusive lower bound 

740 lb: int | float 

741 #: the exclusive upper bound 

742 ub: int | float 

743 #: the inner distribution to sample from 

744 d: Distribution 

745 

746 def __post_init__(self) -> None: 

747 """Perform some basic sanity checks and cleanup.""" 

748 object.__setattr__(self, "lb", try_int(self.lb)) 

749 object.__setattr__(self, "ub", try_int(self.ub)) 

750 if not isinstance(self.d, Distribution): 

751 raise TypeError(f"Invalid types {self}.") 

752 if self.ub <= self.lb: 

753 raise ValueError(f"Invalid range {self}.") 

754 dd: Distribution = self.d 

755 lb: Final[int | float] = self.lb 

756 ub: Final[int | float] = self.ub 

757 if isinstance(dd, In): 

758 idd: In = cast("In", dd) 

759 ulb: int | float = max(idd.lb, lb) 

760 uub: int | float = min(idd.ub, ub) 

761 if ulb >= uub: 

762 raise ValueError(f"Invalid distribution {self!r}.") 

763 object.__setattr__(self, "lb", ulb) 

764 object.__setattr__(self, "ub", uub) 

765 dd = idd.d 

766 object.__setattr__(self, "d", dd) 

767 elif isinstance(dd, AtLeast): 

768 ldd: AtLeast = cast("AtLeast", dd) 

769 ulb = max(ldd.lb, lb) 

770 if ulb >= ub: 

771 raise ValueError(f"Invalid distribution {self!r}.") 

772 object.__setattr__(self, "lb", ulb) 

773 dd = ldd.d 

774 object.__setattr__(self, "d", dd) 

775 elif isinstance(dd, Uniform): # fix a uniform distribution 

776 udd: Uniform = cast("Uniform", dd) 

777 ulb = max(udd.low, lb) 

778 uub = min(udd.high, ub) 

779 if ulb >= uub: 

780 raise ValueError(f"Invalid distribution {self!r}.") 

781 if (ulb != udd.low) or (uub != udd.high): 

782 object.__setattr__(self, "d", Uniform(ulb, uub)) 

783 object.__setattr__(self, "lb", ulb) 

784 object.__setattr__(self, "ub", uub) 

785 if isinstance(dd, Const) and not lb <= cast("Const", dd).v < ub: 

786 raise ValueError(f"Invalid distribution {self!r}.") 

787 

788 def sample(self, random: Generator) -> int | float: 

789 """ 

790 Sample from the lower-bounded distribution. 

791 

792 :param random: the random number generator 

793 :return: the result 

794 """ 

795 s: Final[Callable[[Generator], int | float]] = self.d.sample 

796 lb: Final[int | float] = self.lb 

797 ub: Final[int | float] = self.ub 

798 for _ in range(_MAX_TRIALS): 

799 v = s(random) 

800 if lb <= v < ub: 

801 return v 

802 raise ValueError(f"Failed to sample from {self!r}.") 

803 

804 def simplify(self) -> Distribution: 

805 """ 

806 Simplify this distribution. 

807 

808 :return: the simplified distribution 

809 

810 >>> In(-3, 4, Const(3)).simplify() 

811 Const(v=3) 

812 >>> ii = In(-10, 10, Uniform(3, 20)) 

813 >>> ii 

814 In(lb=3, ub=10, d=Uniform(low=3, high=10)) 

815 >>> ii.simplify() 

816 Uniform(low=3, high=10) 

817 """ 

818 return self.d if isinstance(self.d, Const | Uniform) else self 

819 

820 def mean(self) -> int | float: 

821 """ 

822 Get the mean of this distribution. 

823 

824 :return: the mean 

825 """ 

826 return self.d.mean() if isinstance(self.d, Const | Uniform) \ 

827 else super().mean() 

828 

829 

830class IntDistribution(Distribution): 

831 """A base class for integer distributions.""" 

832 

833 def sample(self, random: Generator) -> int: 

834 """ 

835 Sample a random number following this integer distribution generator. 

836 

837 Each call to this function returns exactly one number. 

838 

839 :param random: the random number generator 

840 :return: the number, which always will be integer 

841 """ 

842 raise NotImplementedError 

843 

844 def simplify(self) -> "IntDistribution": 

845 """ 

846 Try to simplify this integer distribution. 

847 

848 :returns: a simplified version of this integer distribution 

849 """ 

850 if not isinstance(self, IntDistribution): 

851 raise type_error(self, "self", IntDistribution) 

852 return self 

853 

854 

855class IntConst(IntDistribution, Const): 

856 """An integer constant.""" 

857 

858 def __post_init__(self) -> None: 

859 """Perform some basic sanity checks and cleanup.""" 

860 super().__post_init__() 

861 check_int_range(self.v, "v", -1_000_000_000_000_000_000, 

862 1_000_000_000_000_000_000) 

863 

864 def sample(self, random: Generator) -> int: 

865 """Get the integer constant value.""" 

866 return cast("int", self.v) 

867 

868 def mean(self) -> int: 

869 """Get the arithmetic mean.""" 

870 return cast("int", self.v) 

871 

872 

873def distribution(d: int | float | Distribution) -> Distribution: 

874 """ 

875 Get the distribution from the parameter. 

876 

877 :param d: the integer value or distribution 

878 :return: the canonicalized distribution 

879 

880 >>> distribution(7) 

881 IntConst(v=7) 

882 

883 >>> distribution(3.4) 

884 Const(v=3.4) 

885 

886 >>> distribution(Choice((Const(4.0), ))) 

887 IntConst(v=4) 

888 """ 

889 if isinstance(d, int): 

890 return IntConst(d) 

891 if isinstance(d, float): 

892 return Const(d) 

893 if isinstance(d, Distribution): 

894 old_d: Distribution | None = None 

895 while old_d is not d: 

896 old_d = d 

897 d = d.simplify() 

898 if not isinstance(d, Distribution): 

899 break 

900 return d 

901 raise type_error(d, "d", (Distribution, int))