Source code for alts.modules.query.query_sampler
from __future__ import annotations
from typing import TYPE_CHECKING
from math import ceil
import random
from dataclasses import dataclass
from alts.core.data.data_pools import ResultDataPools, StreamDataPools, ProcessDataPools
from alts.core.oracle.oracles import POracles
import numpy as np
from scipy.stats import qmc # type: ignore
from alts.core.query.query_sampler import QuerySampler
from alts.core.configuration import init
from alts.core.data.queried_data_pool import QueriedDataPool
if TYPE_CHECKING:
from typing import Tuple, List, Union, Literal
from nptyping import NDArray, Number, Shape
[docs]
@dataclass
class OptimalQuerySampler(QuerySampler):
optimal_queries: Tuple[NDArray[Shape["query_nr, ... query_dims"], Number], ...] = init() # type: ignore
[docs]
def post_init(self):
super().post_init()
for optimal_query in self.optimal_queries:
if not self.oracles.query_constrain().constrains_met(optimal_query):
raise ValueError("optimal_queries do not meet oracles.query_constrain")
[docs]
def sample(self, num_queries = None):
if num_queries is None: num_queries = self.num_queries
query_nr = self.optimal_queries[0].shape[0]
k = ceil(num_queries / query_nr)
queries = random.choices(self.optimal_queries, k=k)
queries = np.concatenate(queries)
return queries[:num_queries]
[docs]
@dataclass
class FixedQuerySampler(QuerySampler):
fixed_query: NDArray[Shape["... query_dims"], Number] = init() # type: ignore
[docs]
def post_init(self):
super().post_init()
if not self.oracles.query_constrain().constrains_met(self.fixed_query):
raise ValueError("fixed_query does not meet oracles.query_constrain")
[docs]
def sample(self, num_queries = None):
if num_queries is None: num_queries = self.num_queries
queries = np.repeat(self.fixed_query[None, ...], num_queries, axis=0)
return queries
[docs]
@dataclass
class LatinHypercubeQuerySampler(QuerySampler):
[docs]
def post_init(self):
super().post_init()
dim = 1
for size in self.oracles.query_constrain().shape:
dim *= size
self.sampler = qmc.LatinHypercube(d=dim)
[docs]
def sample(self, num_queries = None):
if num_queries is None: num_queries = self.num_queries
if self.oracles.query_constrain().ranges is None:
raise ValueError("Not for discrete Pools")
else:
sample = self.sampler.random(n=num_queries)
sample = np.reshape(sample, (num_queries, *self.oracles.query_constrain().shape))
a = self.oracles.query_constrain().queries_from_norm_pos(sample)
return a
[docs]
class RandomChoiceQuerySampler(QuerySampler):
[docs]
def sample(self, num_queries = None):
if num_queries is None: num_queries = self.num_queries
if self.oracles.query_constrain().count is None:
raise ValueError("Not for continues pools")
else:
count = self.oracles.query_constrain().count
if count == 0:
return np.asarray([], dtype=np.int32)
return self.oracles.query_constrain().queries_from_index(np.random.randint(low = 0, high = count, size=(num_queries,)))
[docs]
class ProcessQuerySampler(QuerySampler):
[docs]
def post_init(self):
super().post_init()
if not isinstance(super().oracles, POracles):
raise TypeError("ProcessQuerySampler requires POracles")
@property
def oracles(self) -> POracles:
oracles: POracles = super().oracles
return oracles
[docs]
@dataclass
class LastProcessQuerySampler(ProcessQuerySampler):
num_queries: int = init(default=None)
[docs]
def sample(self, num_queries = None):
if num_queries is None: num_queries = self.num_queries
return self.oracles.process.latest_add
[docs]
@dataclass
class ProcessQueueQuerySampler(ProcessQuerySampler):
num_queries: int = init(default=None)
[docs]
def sample(self, num_queries = None):
if num_queries is None: num_queries = self.num_queries
return self.oracles.process.queries
[docs]
@dataclass
class DataPoolQuerySampler(QuerySampler):
num_queries: int = init(default=None)
[docs]
def pool(self) -> QueriedDataPool:
raise NotImplementedError("Please use a non abstract ...PoolQuerySampler.")
[docs]
@dataclass
class AllDataPoolQuerySampler(DataPoolQuerySampler):
num_queries: int = init(default=None)
[docs]
def sample(self, num_queries = None):
if num_queries is None: num_queries = self.num_queries
return self.pool().queries
[docs]
@dataclass
class AllResultPoolQuerySampler(AllDataPoolQuerySampler):
[docs]
def post_init(self):
super().post_init()
if not isinstance(self.data_pools, ResultDataPools):
raise TypeError("ResultPoolQuerySampler requires ResultDataPools")
@property
def data_pools(self) -> ResultDataPools:
return super().data_pools
[docs]
def pool(self) -> QueriedDataPool:
return self.data_pools.result
[docs]
@dataclass
class AllStreamPoolQuerySampler(AllDataPoolQuerySampler):
[docs]
def post_init(self):
super().post_init()
if not isinstance(self.data_pools, StreamDataPools):
raise TypeError("StreamPoolQuerySampler requires StreamDataPools")
@property
def data_pools(self) -> StreamDataPools:
return super().data_pools
[docs]
def pool(self) -> QueriedDataPool:
return self.data_pools.stream
[docs]
@dataclass
class AllProcessPoolQuerySampler(AllDataPoolQuerySampler):
[docs]
def post_init(self):
super().post_init()
if not isinstance(self.data_pools, ProcessDataPools):
raise TypeError("ProcessPoolQuerySampler requires ProcessDataPools")
@property
def data_pools(self) -> ProcessDataPools:
return super().data_pools
[docs]
def pool(self) -> QueriedDataPool:
return self.data_pools.process
[docs]
@dataclass
class LastDataPoolQuerySampler(DataPoolQuerySampler):
num_queries: int = init(default=None)
[docs]
def sample(self, num_queries = None):
if num_queries is None: num_queries = self.num_queries
return self.pool().last_queries
[docs]
@dataclass
class LastResultPoolQuerySampler(LastDataPoolQuerySampler):
[docs]
def post_init(self):
super().post_init()
if not isinstance(self.data_pools, ResultDataPools):
raise TypeError("ResultPoolQuerySampler requires ResultDataPools")
@property
def data_pools(self) -> ResultDataPools:
return super().data_pools
[docs]
def pool(self) -> QueriedDataPool:
return self.data_pools.result
[docs]
@dataclass
class LastStreamPoolQuerySampler(LastDataPoolQuerySampler):
[docs]
def post_init(self):
super().post_init()
if not isinstance(self.data_pools, StreamDataPools):
raise TypeError("StreamPoolQuerySampler requires StreamDataPools")
@property
def data_pools(self) -> StreamDataPools:
return super().data_pools
[docs]
def pool(self) -> QueriedDataPool:
return self.data_pools.stream
[docs]
@dataclass
class LastProcessPoolQuerySampler(LastDataPoolQuerySampler):
[docs]
def post_init(self):
super().post_init()
if not isinstance(self.data_pools, ProcessDataPools):
raise TypeError("ProcessPoolQuerySampler requires ProcessDataPools")
@property
def data_pools(self) -> ProcessDataPools:
return super().data_pools
[docs]
def pool(self) -> QueriedDataPool:
return self.data_pools.process