Source code for alts.modules.data_sampler
from __future__ import annotations
from typing import TYPE_CHECKING
from dataclasses import dataclass, field
import numpy as np
from sklearn.neighbors import NearestNeighbors
from alts.core.data.data_sampler import ResultDataSampler
from alts.core.data.constrains import QueryConstrain, ResultConstrain
from alts.core.configuration import init
if TYPE_CHECKING:
from typing import Tuple
from alts.core.subscribable import Subscribable
[docs]
@dataclass
class KDTreeKNNDataSampler(ResultDataSampler):
sample_size_max: int = init(default=80)
sample_size_min: int = init(default=5)
sample_size_data_fraction: int = init(default=6)
[docs]
def post_init(self):
super().post_init()
self._knn = NearestNeighbors(n_neighbors=self.sample_size_max)
[docs]
def result_update(self, subscription: Subscribable):
super().result_update(subscription)
self._knn.fit(self.data_pools.result.queries, self.data_pools.result.results)
[docs]
def query(self, queries, size = None):
if size is None: size = self.sample_size_max
if self.data_pools.result.query_constrain().count // self.sample_size_data_fraction < size: size = np.ceil(self.data_pools.result.query_constrain().count / self.sample_size_data_fraction)
if size < self.sample_size_min: size = self.sample_size_min
kneighbor_indexes = self._knn.kneighbors(queries, n_neighbors=int(size), return_distance=False)
neighbor_queries = self.data_pools.result.queries[kneighbor_indexes]
kneighbors = self.data_pools.result.results[kneighbor_indexes]
return (neighbor_queries, kneighbors)
[docs]
def query_constrain(self):
query_shape = self.data_pools.result.query_constrain().shape
query_ranges = self.data_pools.result.query_constrain().ranges
query_count = self.data_pools.result.query_constrain().count
query_constrain = QueryConstrain(count=query_count,shape=query_shape,ranges=query_ranges)
queries = self.data_pools.result.queries
query_constrain.ranges = queries[:,None]
query_constrain._last_queries = self.data_pools.result.last_queries
return query_constrain
[docs]
def result_constrain(self):
return self.data_pools.result.result_constrain()
[docs]
@dataclass
class KDTreeRegionDataSampler(ResultDataSampler):
region_size: float = init(default=0.1)
[docs]
def post_init(self):
super().post_init()
self._knn = NearestNeighbors()
[docs]
def result_update(self, subscription: Subscribable):
super().result_update(subscription)
self._knn.fit(self.data_pools.result.queries, self.data_pools.result.results)
[docs]
def query(self, queries, size = None):
kneighbor_indexes = self._knn.radius_neighbors(queries, radius=self.region_size ,return_distance=False)
neighbor_queries = np.asarray([self.data_pools.result.queries[kneighbor_indexe] for kneighbor_indexe in kneighbor_indexes], dtype=object)
kneighbors = np.asarray([self.data_pools.result.results[kneighbor_indexe] for kneighbor_indexe in kneighbor_indexes], dtype=object)
return (neighbor_queries, kneighbors)
[docs]
def query_constrain(self):
query_shape = self.data_pools.result.query_constrain().shape
query_ranges = self.data_pools.result.query_constrain().ranges
query_count = self.data_pools.result.query_constrain().count
query_constrain = QueryConstrain(count=query_count,shape=query_shape,ranges=query_ranges)
queries = self.data_pools.result.query_constrain().all_queries()
query_constrain.ranges = queries[:,None]
query_constrain._last_queries = self.data_pools.result.query_constrain().last_queries()
return query_constrain
[docs]
def result_constrain(self):
return self.data_pools.result.result_constrain()