Coverage for moptipyapps / utils / sampling.py: 87%

277 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-11 04:40 +0000

1""" 

2Some utilities for random sampling. 

3 

4The goal that we follow with class :class:`Distribution` is to have 

5clearly defined integer-producing random distributions. We want to be 

6able to say exactly how to generate some random numbers. 

7 

8A distribution can be sampled using the method :meth:`~Distribution.sample`. 

9Each distribution has a mean value, which either may be an exact value or 

10an approximate result, that can be obtained via :meth:`~Distribution.mean`. 

11Sometimes, distributions can be simplified, which is supported by 

12:meth:`~Distribution.simplify`. 

13 

14>>> from moptipy.utils.nputils import rand_generator 

15>>> from statistics import mean 

16>>> rnd = rand_generator(0) 

17 

18>>> const = Const(12.3) 

19>>> const 

20Const(v=12.3) 

21>>> const.mean() 

2212.3 

23>>> const.sample(rnd) 

2412.3 

25 

26>>> normal = Normal(1, 2.0) 

27>>> normal 

28Normal(mu=1, sd=2) 

29>>> x = [normal.sample(rnd) for _ in range(200)] 

30>>> x[:20] 

31[1.2514604421867865, 0.7357902734173962, 2.280845300886564,\ 

32 1.2098002343060794, -0.07133874632222192, 1.7231901098189695,\ 

33 3.6080000902602745, 2.8941619262584846, -0.4074704716139852,\ 

34 -1.530842942092105, -0.24654892507470438, 1.0826519586944872,\ 

35 -3.6500615492776687, 0.5624166721349085, -1.4918218945061303,\ 

36 -0.4645347094069032, -0.08851796571461978, 0.3673996872616909,\ 

37 1.8232610727482657, 3.085026738885355] 

38>>> mean(x) / normal.mean() 

391.0305262793198813 

40 

41>>> exponential = Exponential(3) 

42>>> x = [exponential.sample(rnd) for _ in range(200)] 

43>>> x[:20] 

44[1.3203895772033505, 1.152983246827425, 4.527545171626064,\ 

45 5.080441711712409, 0.6498200242245252, 3.408652958826374,\ 

46 2.842677620245357, 2.1132194336022, 3.5838508479713393,\ 

47 1.2223825336486978, 2.8454397976498504, 0.8653980789905962,\ 

48 0.19572166348792452, 3.2506725793229854, 0.9426446058446336,\ 

49 3.246902386754473, 10.294412603282666, 5.275683923067543,\ 

50 0.49517363091492944, 0.2982336402218482] 

51>>> mean(x) / exponential.mean() 

521.088857332669081 

53 

54>>> gamma = Gamma(3.0, 0.26) 

55>>> gamma 

56Gamma(k=3, theta=0.26) 

57>>> x = [gamma.sample(rnd) for _ in range(200)] 

58>>> x[:20] 

59[1.1617763743292708, 1.305755109480284, 0.7627948403389954,\ 

60 1.2735522637897285, 0.7742951665621697, 1.074233520618276,\ 

61 0.6324661100546898, 1.4627037699922791, 0.5739033567160827,\ 

62 0.5555065636904546, 0.629236234283296, 0.3666171387296996,\ 

63 0.3780976936750937, 0.9511433672028858, 1.2607313263258062,\ 

64 1.4442096466925938, 0.48758642085808085, 1.247724803721524,\ 

65 1.9359140456080306, 1.3935246884396764] 

66>>> mean(x) / gamma.mean() 

670.962254253258444 

68 

69>>> Gamma.from_alpha_beta(2, 0.5) 

70Erlang(k=2, theta=2) 

71 

72>>> Gamma.from_alpha_beta(2.5, 0.5) 

73Gamma(k=2.5, theta=2) 

74 

75>>> Erlang.from_alpha_beta(2, 0.5) 

76Erlang(k=2, theta=2) 

77 

78>>> erlang = Erlang(1.0, 0.26) 

79>>> erlang 

80Erlang(k=1, theta=0.26) 

81>>> x = [erlang.sample(rnd) for _ in range(200)] 

82>>> x[:20] 

83[0.29911329228981776, 0.3630768060267626, 0.14111385731543394,\ 

84 0.0745673280234536, 0.029950507989979877, 0.04741877104350835,\ 

85 0.38599089026561223, 0.047919114170390194, 0.06921557868837301,\ 

86 0.4066084140331242, 0.07170887998378667, 0.022061870233843223,\ 

87 0.04904717644388396, 0.32082097064821674, 0.001884448999141546,\ 

88 0.6687964040577958, 0.060598863807579915, 0.21491377996577304,\ 

89 0.23088301776258766, 0.23667780086315618] 

90>>> mean(x) / erlang.mean() 

910.9212193955320179 

92 

93>>> uniform = Uniform(10.5, 17) 

94>>> uniform 

95Uniform(low=10.5, high=17) 

96>>> x = [uniform.sample(rnd) for _ in range(200)] 

97>>> x[:20] 

98[11.235929257903324, 13.89459856391694, 14.196833332199233,\ 

99 13.871456027849515, 14.485310406702393, 16.20468251006836,\ 

100 13.7773314921396, 12.964459911549145, 12.167722566772781,\ 

101 12.49450317595701, 14.145245846833722, 15.669920217581366,\ 

102 13.367286694558258, 10.764955127505994, 11.723004274328007,\ 

103 11.089239517262232, 12.666732191193558, 14.948448277461127,\ 

104 14.339645757381653, 14.803829334728565] 

105>>> mean(x) / uniform.mean() 

1060.9968578545459643 

107 

108>>> choice = Choice((Const(2), gamma, normal)) 

109>>> choice 

110Choice(ch=(Const(v=2), Gamma(k=3, theta=0.26), Normal(mu=1, sd=2))) 

111>>> x = [choice.sample(rnd) for _ in range(200)] 

112>>> x[:20] 

113[0.5499364190576, -0.10428920251005325, 2, 2, 0.44263544084840273,\ 

114 0.6088189450771303, 2, 2, 2, 2, -0.32003290715104904, 2,\ 

115 0.6165299227577784, 1.0445083345086352, 2, 2, 2, -0.5970322857539738,\ 

116 -1.6672705710198277, 2] 

117>>> mean(x) / choice.mean() 

1180.9849730453309948 

119 

120>>> lower = AtLeast(2, normal) 

121>>> lower 

122AtLeast(lb=2, d=Normal(mu=1, sd=2)) 

123>>> x = [lower.sample(rnd) for _ in range(200)] 

124>>> x[:20] 

125[2.9715155848474772, 3.407503493034132, 2.59634461519561,\ 

126 2.0472458472897714, 3.9865670827840334, 2.015117730344058,\ 

127 2.0316584714999935, 5.36625737470408, 3.042848158343226,\ 

128 3.390407444638451, 4.5503184215686, 3.7682459882073007,\ 

129 2.0539760253651305, 2.134886147372958, 2.500182239395479,\ 

130 2.891402111997337, 2.7393826524907228, 2.3449577842364766,\ 

131 2.9043074694017195, 4.7173482582723825] 

132>>> mean(x) / lower.mean() 

1331.0178179732075965 

134 

135>>> interval = In(1, 10, gamma) 

136>>> interval 

137In(lb=1, ub=10, d=Gamma(k=3, theta=0.26)) 

138>>> x = [interval.sample(rnd) for _ in range(200)] 

139>>> x[:20] 

140[1.3488423972739083, 1.0929631399361601, 1.681162621901135,\ 

141 1.655614926246918, 1.041948891842002, 2.0773395958990175,\ 

142 1.3338891374921853, 1.2478964188175743, 1.9417070894505217,\ 

143 1.3990572178987266, 1.1216118870312337, 1.0641160239253207,\ 

144 2.131253600219639, 1.0337453883221577, 1.396499618416345,\ 

145 1.9865175145136038, 1.2555473269031396, 1.4412027583435465,\ 

146 1.320740351247919, 1.0407411942999665] 

147>>> mean(x) / interval.mean() 

1481.0256329849020747 

149 

150>>> erlang2 = Gamma.from_k_and_mean(3, 10) 

151>>> erlang2 

152Erlang(k=3, theta=3.3333333333333335) 

153>>> erlang2.mean() 

15410 

155>>> x = [erlang2.sample(rnd) for _ in range(200)] 

156>>> x[:20] 

157[10.087506399523226, 12.928131914870168, 12.330250639007767,\ 

158 5.305123692562998, 21.085037136404374, 6.6603691824173135,\ 

159 2.961302890492059, 9.810557147180853, 8.051620919921454,\ 

160 8.750329405836668, 3.9511445189935763, 5.570300668751883,\ 

161 16.70132947692463, 7.831425379483914, 11.154757962484842,\ 

162 8.78943102381046, 8.395847820234795, 16.42251602814587,\ 

163 17.1628992966332, 9.684008648356015] 

164>>> mean(x) / erlang2.mean() 

1651.022444624140316 

166 

167>>> AtLeast.greater_than_zero(Gamma(1, 0.5)) 

168AtLeast(lb=5e-324, d=Exponential(eta=1)) 

169 

170>>> AtLeast.greater_than_zero(Gamma(2, 0.5)) 

171Erlang(k=2, theta=0.5) 

172 

173>>> AtLeast.greater_than_zero(Gamma(2.5, 0.5)) 

174Gamma(k=2.5, theta=0.5) 

175""" 

176 

177from dataclasses import dataclass 

178from math import fsum, isfinite, nextafter 

179from typing import Callable, Final, cast 

180 

181from moptipy.utils.nputils import rand_generator, rand_seeds_from_str 

182from numpy.random import Generator 

183from pycommons.math.int_math import try_int 

184from pycommons.types import check_int_range, type_error 

185 

186#: the maximum number of trials during a sampling process 

187_MAX_TRIALS: int = 1_000_000 

188 

189#: the smallest positive number 

190_SMALLEST_POSITIVE_NUMBER: Final[float] = nextafter(0.0, 1.0) 

191 

192 

193class Distribution: 

194 """A base class for distributions.""" 

195 

196 def sample(self, random: Generator) -> int | float: 

197 """ 

198 Sample a random number following this distribution generator. 

199 

200 Each call to this function returns exactly one number. 

201 

202 :param random: the random number generator 

203 :return: the number 

204 """ 

205 raise NotImplementedError 

206 

207 def simplify(self) -> "Distribution": 

208 """ 

209 Try to simplify this distribution. 

210 

211 Some distributions can trivially be simplified. For example, if you 

212 have applied a range limit (:class:`In`) to a constant distribution 

213 (class:`Const`), then this can be simplified to just the constant. 

214 If such simplification is possible, this method returns the simplified 

215 distribution. Otherwise, it just returns the distribution itself. 

216 

217 :returns: a simplified version of this distribution 

218 """ 

219 if not isinstance(self, Distribution): 

220 raise type_error(self, "self", Distribution) 

221 return self 

222 

223 def mean(self) -> int | float: 

224 """ 

225 Get the mean or approximate mean of the distribution. 

226 

227 Some distribution overwrite this method to produce an exact computed 

228 expected value or mean. This default implementation just computes the 

229 arithmetic mean of 10'000 samples of the distribution. This serves as 

230 baseline approximation for any case where a closed form mathematical 

231 definition of the expected value is not available. 

232 

233 :return: the mean or approximated mean 

234 """ 

235 sample: Callable[[Generator], int | float] = self.sample 

236 gen: Final[Generator] = rand_generator(rand_seeds_from_str( 

237 repr(self), 1)[0]) 

238 return try_int(fsum(sample(gen) for _ in range(10_000)) / 10_000) 

239 

240 

241@dataclass(order=True, frozen=True) 

242class Const(Distribution): 

243 """A constant value.""" 

244 

245 #: the constant value 

246 v: int | float 

247 

248 def __post_init__(self) -> None: 

249 """Perform some basic sanity checks and cleanup.""" 

250 object.__setattr__(self, "v", try_int(self.v)) 

251 

252 def sample(self, random: Generator) -> int | float: 

253 """ 

254 Sample the constant integer. 

255 

256 :param random: the random number generator 

257 :return: the integer 

258 """ 

259 return self.v 

260 

261 def mean(self) -> int | float: 

262 """ 

263 Get the mean of this distribution. 

264 

265 :return: the mean 

266 """ 

267 return self.v 

268 

269 def simplify(self) -> "Distribution": 

270 """ 

271 Simplify this constat. 

272 

273 :return: the simplified constant 

274 

275 >>> Const(1.5).simplify() 

276 Const(v=1.5) 

277 >>> Const(1.0).simplify() 

278 IntConst(v=1) 

279 """ 

280 return IntConst(self.v) if isinstance(self.v, int) else self 

281 

282 

283@dataclass(order=True, frozen=True) 

284class Normal(Distribution): 

285 """A class representing a normal distribution.""" 

286 

287 #: the expected value and center of the distribution 

288 mu: int | float 

289 #: the standard deviation of the distribution 

290 sd: int | float 

291 

292 def __post_init__(self) -> None: 

293 """Perform some basic sanity checks and cleanup.""" 

294 if not (isfinite(self.mu) and isfinite(self.sd) and (self.sd > 0)): 

295 raise ValueError(f"Invalid parameters {self}.") 

296 object.__setattr__(self, "mu", try_int(self.mu)) 

297 object.__setattr__(self, "sd", try_int(self.sd)) 

298 

299 def sample(self, random: Generator) -> float: 

300 """ 

301 Sample from the normal distribution. 

302 

303 :param random: the random number generator 

304 :return: the result 

305 """ 

306 return random.normal(self.mu, self.sd) 

307 

308 def mean(self) -> int | float: 

309 """ 

310 Get the mean of this distribution. 

311 

312 :return: the mean 

313 """ 

314 return self.mu 

315 

316 

317@dataclass(order=True, frozen=True) 

318class Exponential(Distribution): 

319 """A class representing an exponential distribution.""" 

320 

321 #: the exponential distribution parameter 

322 eta: int | float 

323 

324 def __post_init__(self) -> None: 

325 """Perform some basic sanity checks and cleanup.""" 

326 object.__setattr__(self, "eta", try_int(self.eta)) 

327 if self.eta <= 0: 

328 raise ValueError(f"Invalid setup {self!r}.") 

329 

330 def sample(self, random: Generator) -> float: 

331 """ 

332 Sample from the Exponential distribution. 

333 

334 :param random: the random number generator 

335 :return: the result 

336 """ 

337 return random.exponential(self.eta) 

338 

339 def mean(self) -> int | float: 

340 """ 

341 Get the mean of this distribution. 

342 

343 :return: the mean 

344 """ 

345 return try_int(self.eta) 

346 

347 

348@dataclass(order=True, frozen=True) 

349class Gamma(Distribution): 

350 """ 

351 A class representing a Gamma distribution. 

352 

353 Here, `k` is the shape and `theta` is the scale parameter. 

354 If you use a parameterization with `alpha` and `beta`, you need to create 

355 the distribution using :meth:`~Gamma.from_alpha_beta` instead. The reason 

356 is that `shape = 1/beta`, see 

357 https://www.statlect.com/probability-distributions/gamma-distribution. 

358 """ 

359 

360 #: the shape parameter 

361 k: int | float 

362 #: the scale parameter 

363 theta: int | float 

364 

365 def __post_init__(self) -> None: 

366 """Perform some basic sanity checks and cleanup.""" 

367 object.__setattr__(self, "k", try_int(self.k)) 

368 object.__setattr__(self, "theta", try_int(self.theta)) 

369 if not (self.k > 0) and (self.theta > 0): 

370 raise ValueError(f"Invalid parameters {self}.") 

371 

372 def sample(self, random: Generator) -> float: 

373 """ 

374 Sample from the Gamma distribution. 

375 

376 :param random: the random number generator 

377 :return: the result 

378 """ 

379 return random.gamma(self.k, self.theta) 

380 

381 def simplify(self) -> "Distribution": 

382 """ 

383 Try to simplify this distribution. 

384 

385 A Gamma distribution may simplify to either an :class:`Erlang` or an 

386 :class:`Exponential` distribution, depending on its parameters. 

387 If the :attr:`~Gamma.k` is `1`, then it is actually an 

388 :class:`Exponential` distribution. If :attr:`~Gamma.k` is an integer, 

389 then the distribution is an :class:`Erlang` distribution. 

390 

391 1. https://www.statisticshowto.com/gamma-distribution 

392 2. https://www.statisticshowto.com/erlang-distribution 

393 

394 :returns: a simplified version of this distribution 

395 """ 

396 return Exponential(self.k) if self.k == 1 else ( 

397 Erlang(self.k, self.theta) if isinstance( 

398 self.k, int) and not isinstance(self, Erlang) else self) 

399 

400 def mean(self) -> int | float: 

401 """ 

402 Get the mean of this distribution. 

403 

404 :return: the mean 

405 """ 

406 return try_int(self.k * self.theta) 

407 

408 @classmethod 

409 def from_alpha_beta(cls, alpha: int | float, beta: int | float) \ 

410 -> "Distribution": 

411 """ 

412 Create a Gamma distribution from `alpha` and `beta`. 

413 

414 :param alpha: the alpha parameter 

415 :param beta: the beta parameter 

416 :return: the distribution 

417 """ 

418 beta = try_int(beta) 

419 if beta == 0: 

420 raise ValueError(f"beta cannot be {beta}.") 

421 return cls(alpha, 1 / beta).simplify() 

422 

423 @classmethod 

424 def from_k_and_mean(cls, k: int | float, mean: int | float) \ 

425 -> "Distribution": 

426 """ 

427 Create the Gamma distribution from the value of `k` and a mean. 

428 

429 :param k: the shape parameter 

430 :param mean: the mean 

431 :return: the distribution 

432 """ 

433 k = try_int(k) 

434 mean = try_int(mean) 

435 if (mean <= 0) or (k <= 0): 

436 raise ValueError(f"Invalid values k={k}, mean={mean}.") 

437 return Gamma(k, mean / k).simplify() 

438 

439 

440class Erlang(Gamma): 

441 """The Erlang distribution.""" 

442 

443 def __post_init__(self) -> None: 

444 """Perform some basic sanity checks and cleanup.""" 

445 super().__post_init__() 

446 if not isinstance(self.k, int): 

447 raise type_error(self.k, "k", int) 

448 

449 

450@dataclass(order=True, frozen=True) 

451class Uniform(Distribution): 

452 """A class representing a uniform distribution.""" 

453 

454 #: the lowest permitted value 

455 low: int | float 

456 #: the highest permitted value 

457 high: int | float 

458 

459 def __post_init__(self) -> None: 

460 """Perform some basic sanity checks and cleanup.""" 

461 object.__setattr__(self, "low", try_int(self.low)) 

462 object.__setattr__(self, "high", try_int(self.high)) 

463 if self.high <= self.low: 

464 raise ValueError(f"Invalid parameters {self}.") 

465 

466 def sample(self, random: Generator) -> float: 

467 """ 

468 Sample from the uniform distribution. 

469 

470 :param random: the random number generator 

471 :return: the result 

472 """ 

473 return random.uniform(self.low, self.high) 

474 

475 def mean(self) -> int | float: 

476 """ 

477 Get the mean of this distribution. 

478 

479 :return: the mean 

480 """ 

481 return try_int((self.high + self.low) / 2) 

482 

483 

484@dataclass(order=True, frozen=True) 

485class Choice(Distribution): 

486 """ 

487 A class representing a uniform choice. 

488 

489 >>> Choice((Uniform(1, 2), Uniform(3, 4))).simplify() 

490 Choice(ch=(Uniform(low=1, high=2), Uniform(low=3, high=4))) 

491 

492 >>> Choice((Uniform(1, 2), Uniform(1.0, 2))).simplify() 

493 Uniform(low=1, high=2) 

494 

495 >>> Choice((Uniform(1, 2), Choice( 

496 ... (Const(1), Uniform(1, 2))))).simplify() 

497 Choice(ch=(Uniform(low=1, high=2), IntConst(v=1), Uniform(low=1, high=2))) 

498 """ 

499 

500 #: the choices 

501 ch: tuple[Distribution, ...] 

502 

503 def __post_init__(self) -> None: 

504 """Perform some basic sanity checks and cleanup.""" 

505 if not isinstance(self.ch, tuple): 

506 raise TypeError(f"Invalid types {self}.") 

507 for v in self.ch: 

508 if not isinstance(v, Distribution): 

509 raise type_error(v, "choice", Distribution) 

510 

511 def mean(self) -> int | float: 

512 """ 

513 Get the mean of this distribution. 

514 

515 :return: the mean 

516 """ 

517 return try_int(fsum(d.mean() for d in self.ch) / tuple.__len__( 

518 self.ch)) 

519 

520 def sample(self, random: Generator) -> int | float: 

521 """ 

522 Sample from the uniform distribution. 

523 

524 :param random: the random number generator 

525 :return: the result 

526 """ 

527 return self.ch[random.integers(tuple.__len__(self.ch))].sample(random) 

528 

529 def simplify(self) -> Distribution: 

530 """ 

531 Try to simplify this distribution. 

532 

533 :returns: a simplified version of this distribution 

534 """ 

535 ch: Final[tuple[Distribution, ...]] = self.ch 

536 if tuple.__len__(ch) <= 1: 

537 return ch[0].simplify() 

538 done: list[Distribution] = [] 

539 needs: bool = False 

540 for dist in ch: 

541 use: Distribution = dist.simplify() 

542 if use != dist: 

543 needs = True 

544 if isinstance(use, Choice): 

545 needs = True 

546 done.extend(use.ch) 

547 else: 

548 done.append(use) 

549 

550 total: int = list.__len__(done) 

551 dc: Distribution = done[0] 

552 if total <= 1: 

553 return dc 

554 

555 all_same: bool = True 

556 for oth in done: 

557 if oth != dc: 

558 all_same = False 

559 break 

560 return dc if all_same else (Choice(tuple(done)) if needs else self) 

561 

562 

563@dataclass(order=True, frozen=True) 

564class AtLeast(Distribution): 

565 """ 

566 A distribution that is lower-bounded. 

567 

568 >>> AtLeast(5, Const(7)) 

569 AtLeast(lb=5, d=Const(v=7)) 

570 

571 >>> AtLeast(5, AtLeast(8, Const(17))) 

572 AtLeast(lb=8, d=Const(v=17)) 

573 

574 >>> AtLeast(8, AtLeast(5, Const(17))) 

575 AtLeast(lb=8, d=Const(v=17)) 

576 """ 

577 

578 #: the inclusive lower bound 

579 lb: int | float 

580 #: the inner distribution to sample from 

581 d: Distribution 

582 

583 def __post_init__(self) -> None: 

584 """Perform some basic sanity checks and cleanup.""" 

585 object.__setattr__(self, "lb", try_int(self.lb)) 

586 dd: Final[Distribution] = self.d 

587 if isinstance(dd, Const): 

588 if cast("Const", dd).v < self.lb: 

589 raise ValueError(f"Invalid distribution {self!r}.") 

590 elif isinstance(dd, AtLeast): 

591 dlb: AtLeast = cast("AtLeast", dd) 

592 object.__setattr__(self, "lb", max(self.lb, dlb.lb)) 

593 object.__setattr__(self, "d", dlb.d) 

594 elif isinstance(dd, In): 

595 idd: In = cast("In", dd) 

596 ulb: int | float = max(idd.lb, self.lb) 

597 if ulb >= idd.ub: 

598 raise ValueError(f"Invalid distribution {self!r}.") 

599 elif isinstance(dd, Uniform): 

600 udd: Uniform = cast("Uniform", dd) 

601 ulb = max(udd.low, self.lb) 

602 if ulb >= udd.high: 

603 raise ValueError(f"Invalid distribution {self!r}.") 

604 

605 def simplify(self) -> "Distribution": 

606 """ 

607 Try to simplify this distribution. 

608 

609 :returns: a simplified version of this distribution 

610 

611 >>> AtLeast(1, Uniform(3, 4)).simplify() 

612 Uniform(low=3, high=4) 

613 

614 >>> AtLeast(3.5, Uniform(3, 4)).simplify() 

615 Uniform(low=3.5, high=4) 

616 """ 

617 dd: Final[Distribution] = self.d 

618 if isinstance(dd, Const): 

619 return dd 

620 if isinstance(dd, In): 

621 idd: In = cast("In", dd) 

622 return In(max(idd.lb, self.lb), idd.ub, idd.d).simplify() 

623 if isinstance(dd, Uniform): 

624 udd: Uniform = cast("Uniform", dd) 

625 if udd.low >= self.lb: 

626 return udd 

627 return Uniform(self.lb, udd.high).simplify() 

628 if (self.lb <= 0) and (isinstance(dd, Exponential | Gamma | Erlang)): 

629 return dd 

630 if (self.lb <= _SMALLEST_POSITIVE_NUMBER) and isinstance( 

631 dd, Gamma | Erlang) and ( 

632 cast("Gamma", dd).k > 1): # pylint: disable=E1101 

633 return dd 

634 return self 

635 

636 def sample(self, random: Generator) -> int | float: 

637 """ 

638 Sample from the lower-bounded distribution. 

639 

640 :param random: the random number generator 

641 :return: the result 

642 """ 

643 s: Final[Callable[[Generator], int | float]] = self.d.sample 

644 lb: Final[int | float] = self.lb 

645 for _ in range(_MAX_TRIALS): 

646 v = s(random) 

647 if lb <= v: 

648 return v 

649 raise ValueError(f"Failed to sample from {self!r}.") 

650 

651 @classmethod 

652 def greater_than_zero(cls, d: int | float | Distribution) -> Distribution: 

653 """ 

654 Ensure that all samples are greater than zero. 

655 

656 :param d: the original distribution 

657 :return: a distribution which is always greater than zero 

658 """ 

659 return cls(_SMALLEST_POSITIVE_NUMBER, distribution(d)).simplify() 

660 

661 

662@dataclass(order=True, frozen=True) 

663class In(Distribution): 

664 """ 

665 A distribution that is lower and upper-bounded. 

666 

667 >>> In(1, 10, Const(6)) 

668 In(lb=1, ub=10, d=Const(v=6)) 

669 

670 >>> In(1, 10, In(5, 12, Const(6))) 

671 In(lb=5, ub=10, d=Const(v=6)) 

672 

673 >>> In(1, 10, AtLeast(6, Const(6))) 

674 In(lb=6, ub=10, d=Const(v=6)) 

675 """ 

676 

677 #: the inclusive lower bound 

678 lb: int | float 

679 #: the exclusive upper bound 

680 ub: int | float 

681 #: the inner distribution to sample from 

682 d: Distribution 

683 

684 def __post_init__(self) -> None: 

685 """Perform some basic sanity checks and cleanup.""" 

686 object.__setattr__(self, "lb", try_int(self.lb)) 

687 object.__setattr__(self, "ub", try_int(self.ub)) 

688 if not isinstance(self.d, Distribution): 

689 raise TypeError(f"Invalid types {self}.") 

690 if self.ub <= self.lb: 

691 raise ValueError(f"Invalid range {self}.") 

692 dd: Distribution = self.d 

693 lb: Final[int | float] = self.lb 

694 ub: Final[int | float] = self.ub 

695 if isinstance(dd, In): 

696 idd: In = cast("In", dd) 

697 ulb: int | float = max(idd.lb, lb) 

698 uub: int | float = min(idd.ub, ub) 

699 if ulb >= uub: 

700 raise ValueError(f"Invalid distribution {self!r}.") 

701 object.__setattr__(self, "lb", ulb) 

702 object.__setattr__(self, "ub", uub) 

703 dd = idd.d 

704 object.__setattr__(self, "d", dd) 

705 elif isinstance(dd, AtLeast): 

706 ldd: AtLeast = cast("AtLeast", dd) 

707 ulb = max(ldd.lb, lb) 

708 if ulb >= ub: 

709 raise ValueError(f"Invalid distribution {self!r}.") 

710 object.__setattr__(self, "lb", ulb) 

711 dd = ldd.d 

712 object.__setattr__(self, "d", dd) 

713 elif isinstance(dd, Uniform): # fix a uniform distribution 

714 udd: Uniform = cast("Uniform", dd) 

715 ulb = max(udd.low, lb) 

716 uub = min(udd.high, ub) 

717 if ulb >= uub: 

718 raise ValueError(f"Invalid distribution {self!r}.") 

719 if (ulb != udd.low) or (uub != udd.high): 

720 object.__setattr__(self, "d", Uniform(ulb, uub)) 

721 object.__setattr__(self, "lb", ulb) 

722 object.__setattr__(self, "ub", uub) 

723 if isinstance(dd, Const) and not lb <= cast("Const", dd).v < ub: 

724 raise ValueError(f"Invalid distribution {self!r}.") 

725 

726 def sample(self, random: Generator) -> int | float: 

727 """ 

728 Sample from the lower-bounded distribution. 

729 

730 :param random: the random number generator 

731 :return: the result 

732 """ 

733 s: Final[Callable[[Generator], int | float]] = self.d.sample 

734 lb: Final[int | float] = self.lb 

735 ub: Final[int | float] = self.ub 

736 for _ in range(_MAX_TRIALS): 

737 v = s(random) 

738 if lb <= v < ub: 

739 return v 

740 raise ValueError(f"Failed to sample from {self!r}.") 

741 

742 def simplify(self) -> Distribution: 

743 """ 

744 Simplify this distribution. 

745 

746 :return: the simplified distribution 

747 

748 >>> In(-3, 4, Const(3)).simplify() 

749 Const(v=3) 

750 >>> ii = In(-10, 10, Uniform(3, 20)) 

751 >>> ii 

752 In(lb=3, ub=10, d=Uniform(low=3, high=10)) 

753 >>> ii.simplify() 

754 Uniform(low=3, high=10) 

755 """ 

756 return self.d if isinstance(self.d, Const | Uniform) else self 

757 

758 def mean(self) -> int | float: 

759 """ 

760 Get the mean of this distribution. 

761 

762 :return: the mean 

763 """ 

764 return self.d.mean() if isinstance(self.d, Const | Uniform) \ 

765 else super().mean() 

766 

767 

768class IntDistribution(Distribution): 

769 """A base class for integer distributions.""" 

770 

771 def sample(self, random: Generator) -> int: 

772 """ 

773 Sample a random number following this integer distribution generator. 

774 

775 Each call to this function returns exactly one number. 

776 

777 :param random: the random number generator 

778 :return: the number, which always will be integer 

779 """ 

780 raise NotImplementedError 

781 

782 def simplify(self) -> "IntDistribution": 

783 """ 

784 Try to simplify this integer distribution. 

785 

786 :returns: a simplified version of this integer distribution 

787 """ 

788 if not isinstance(self, IntDistribution): 

789 raise type_error(self, "self", IntDistribution) 

790 return self 

791 

792 

793class IntConst(IntDistribution, Const): 

794 """An integer constant.""" 

795 

796 def __post_init__(self) -> None: 

797 """Perform some basic sanity checks and cleanup.""" 

798 super().__post_init__() 

799 check_int_range(self.v, "v", -1_000_000_000_000_000_000, 

800 1_000_000_000_000_000_000) 

801 

802 def sample(self, random: Generator) -> int: 

803 """Get the integer constant value.""" 

804 return cast("int", self.v) 

805 

806 def mean(self) -> int: 

807 """Get the arithmetic mean.""" 

808 return cast("int", self.v) 

809 

810 

811def distribution(d: int | float | Distribution) -> Distribution: 

812 """ 

813 Get the distribution from the parameter. 

814 

815 :param d: the integer value or distribution 

816 :return: the canonicalized distribution 

817 

818 >>> distribution(7) 

819 IntConst(v=7) 

820 

821 >>> distribution(3.4) 

822 Const(v=3.4) 

823 

824 >>> distribution(Choice((Const(4.0), ))) 

825 IntConst(v=4) 

826 """ 

827 if isinstance(d, int): 

828 return IntConst(d) 

829 if isinstance(d, float): 

830 return Const(d) 

831 if isinstance(d, Distribution): 

832 old_d: Distribution | None = None 

833 while old_d is not d: 

834 old_d = d 

835 d = d.simplify() 

836 if not isinstance(d, Distribution): 

837 break 

838 return d 

839 raise type_error(d, "d", (Distribution, int))