# Using code https://github.com/TUPLES-Trustworthy-AI/beluga-challenge-tools-internal/blob/97aeafcb75c3e570a8b4c1d38a13ef345d8cc636/generator/configurations/random_state.py
from typing import TypeVar
import numpy as np
from numpy.random import Generator
from scipy.stats import truncnorm, uniform
epsilon = 0.00001
[docs]
class RandomState:
def __init__(self, seed) -> None:
self.rng: Generator = np.random.default_rng(seed)
T = TypeVar("T")
[docs]
def get_random_element_prop(self, list: list[T], probs: list[float]) -> T:
assert len(list) > 0
assert len(list) == len(probs)
sum_props = sum(probs)
assert (
1 - epsilon <= sum_props <= 1 + epsilon
), f"Sum of probabilities must be 1 +-{epsilon} but is {sum(probs)}"
sample = uniform.rvs()
index = 0
ref = 0
for next_prob in probs:
ref += next_prob
if sample <= ref:
return list[index]
index += 1
assert False
[docs]
def get_discrete_truncated_normal_sample(
self, center: int, sigma: int, lower: int, upper: int
) -> int:
if lower == upper:
return lower
low = (lower - center) / sigma
up = (upper - center) / sigma
r = truncnorm.rvs(low, up, loc=center, scale=sigma, random_state=self.rng)
return int(round(r))