Source code for alts.modules.query.query_sampler

from __future__ import annotations
from typing import TYPE_CHECKING

from math import ceil
import random
from dataclasses import dataclass
from alts.core.data.data_pools import ResultDataPools, StreamDataPools, ProcessDataPools
from alts.core.oracle.oracles import POracles


import numpy as np
from scipy.stats import qmc # type: ignore

from alts.core.query.query_sampler import QuerySampler
from alts.core.configuration import init
from alts.core.data.queried_data_pool import QueriedDataPool

if TYPE_CHECKING:
    from typing import Tuple, List, Union, Literal
    from nptyping import NDArray, Number, Shape

[docs] @dataclass class OptimalQuerySampler(QuerySampler): optimal_queries: Tuple[NDArray[Shape["query_nr, ... query_dims"], Number], ...] = init() # type: ignore
[docs] def post_init(self): super().post_init() for optimal_query in self.optimal_queries: if not self.oracles.query_constrain().constrains_met(optimal_query): raise ValueError("optimal_queries do not meet oracles.query_constrain")
[docs] def sample(self, num_queries = None): if num_queries is None: num_queries = self.num_queries query_nr = self.optimal_queries[0].shape[0] k = ceil(num_queries / query_nr) queries = random.choices(self.optimal_queries, k=k) queries = np.concatenate(queries) return queries[:num_queries]
[docs] @dataclass class FixedQuerySampler(QuerySampler): fixed_query: NDArray[Shape["... query_dims"], Number] = init() # type: ignore
[docs] def post_init(self): super().post_init() if not self.oracles.query_constrain().constrains_met(self.fixed_query): raise ValueError("fixed_query does not meet oracles.query_constrain")
[docs] def sample(self, num_queries = None): if num_queries is None: num_queries = self.num_queries queries = np.repeat(self.fixed_query[None, ...], num_queries, axis=0) return queries
[docs] @dataclass class UniformQuerySampler(QuerySampler):
[docs] def sample(self, num_queries = None): if num_queries is None: num_queries = self.num_queries if self.oracles.query_constrain().ranges is None: raise ValueError("Not for discrete Pools") else: a = self.oracles.query_constrain().queries_from_norm_pos(np.random.uniform(size=(num_queries, *self.oracles.query_constrain().shape))) return a
[docs] @dataclass class LatinHypercubeQuerySampler(QuerySampler):
[docs] def post_init(self): super().post_init() dim = 1 for size in self.oracles.query_constrain().shape: dim *= size self.sampler = qmc.LatinHypercube(d=dim)
[docs] def sample(self, num_queries = None): if num_queries is None: num_queries = self.num_queries if self.oracles.query_constrain().ranges is None: raise ValueError("Not for discrete Pools") else: sample = self.sampler.random(n=num_queries) sample = np.reshape(sample, (num_queries, *self.oracles.query_constrain().shape)) a = self.oracles.query_constrain().queries_from_norm_pos(sample) return a
[docs] class RandomChoiceQuerySampler(QuerySampler):
[docs] def sample(self, num_queries = None): if num_queries is None: num_queries = self.num_queries if self.oracles.query_constrain().count is None: raise ValueError("Not for continues pools") else: count = self.oracles.query_constrain().count if count == 0: return np.asarray([], dtype=np.int32) return self.oracles.query_constrain().queries_from_index(np.random.randint(low = 0, high = count, size=(num_queries,)))
[docs] class ProcessQuerySampler(QuerySampler):
[docs] def post_init(self): super().post_init() if not isinstance(super().oracles, POracles): raise TypeError("ProcessQuerySampler requires POracles")
@property def oracles(self) -> POracles: oracles: POracles = super().oracles return oracles
[docs] @dataclass class LastProcessQuerySampler(ProcessQuerySampler): num_queries: int = init(default=None)
[docs] def sample(self, num_queries = None): if num_queries is None: num_queries = self.num_queries return self.oracles.process.latest_add
[docs] @dataclass class ProcessQueueQuerySampler(ProcessQuerySampler): num_queries: int = init(default=None)
[docs] def sample(self, num_queries = None): if num_queries is None: num_queries = self.num_queries return self.oracles.process.queries
[docs] @dataclass class DataPoolQuerySampler(QuerySampler): num_queries: int = init(default=None)
[docs] def pool(self) -> QueriedDataPool: raise NotImplementedError("Please use a non abstract ...PoolQuerySampler.")
[docs] @dataclass class AllDataPoolQuerySampler(DataPoolQuerySampler): num_queries: int = init(default=None)
[docs] def sample(self, num_queries = None): if num_queries is None: num_queries = self.num_queries return self.pool().queries
[docs] @dataclass class AllResultPoolQuerySampler(AllDataPoolQuerySampler):
[docs] def post_init(self): super().post_init() if not isinstance(self.data_pools, ResultDataPools): raise TypeError("ResultPoolQuerySampler requires ResultDataPools")
@property def data_pools(self) -> ResultDataPools: return super().data_pools
[docs] def pool(self) -> QueriedDataPool: return self.data_pools.result
[docs] @dataclass class AllStreamPoolQuerySampler(AllDataPoolQuerySampler):
[docs] def post_init(self): super().post_init() if not isinstance(self.data_pools, StreamDataPools): raise TypeError("StreamPoolQuerySampler requires StreamDataPools")
@property def data_pools(self) -> StreamDataPools: return super().data_pools
[docs] def pool(self) -> QueriedDataPool: return self.data_pools.stream
[docs] @dataclass class AllProcessPoolQuerySampler(AllDataPoolQuerySampler):
[docs] def post_init(self): super().post_init() if not isinstance(self.data_pools, ProcessDataPools): raise TypeError("ProcessPoolQuerySampler requires ProcessDataPools")
@property def data_pools(self) -> ProcessDataPools: return super().data_pools
[docs] def pool(self) -> QueriedDataPool: return self.data_pools.process
[docs] @dataclass class LastDataPoolQuerySampler(DataPoolQuerySampler): num_queries: int = init(default=None)
[docs] def sample(self, num_queries = None): if num_queries is None: num_queries = self.num_queries return self.pool().last_queries
[docs] @dataclass class LastResultPoolQuerySampler(LastDataPoolQuerySampler):
[docs] def post_init(self): super().post_init() if not isinstance(self.data_pools, ResultDataPools): raise TypeError("ResultPoolQuerySampler requires ResultDataPools")
@property def data_pools(self) -> ResultDataPools: return super().data_pools
[docs] def pool(self) -> QueriedDataPool: return self.data_pools.result
[docs] @dataclass class LastStreamPoolQuerySampler(LastDataPoolQuerySampler):
[docs] def post_init(self): super().post_init() if not isinstance(self.data_pools, StreamDataPools): raise TypeError("StreamPoolQuerySampler requires StreamDataPools")
@property def data_pools(self) -> StreamDataPools: return super().data_pools
[docs] def pool(self) -> QueriedDataPool: return self.data_pools.stream
[docs] @dataclass class LastProcessPoolQuerySampler(LastDataPoolQuerySampler):
[docs] def post_init(self): super().post_init() if not isinstance(self.data_pools, ProcessDataPools): raise TypeError("ProcessPoolQuerySampler requires ProcessDataPools")
@property def data_pools(self) -> ProcessDataPools: return super().data_pools
[docs] def pool(self) -> QueriedDataPool: return self.data_pools.process