Source code for alts.modules.queried_data_pool
from __future__ import annotations
from typing import TYPE_CHECKING
from random import choice
import numpy as np
from alts.core.data.constrains import QueryConstrain, ResultConstrain
from alts.core.data.queried_data_pool import QueriedDataPool
if TYPE_CHECKING:
from typing import Dict
[docs]
class FlatQueriedDataPool(QueriedDataPool):
"""
implements a pool of already labeled data
"""
def __init__(self):
super().init(FlatQueriedDataPool)
self.query_index: Dict = {}
[docs]
def query(self, queries):
result_list = []
for query in queries:
result_candidate = self.query_index.get(tuple(query), [])
result = choice(result_candidate)
result_list.append(result)
results: np.ndarray = np.asarray(result_list)
return queries, results
[docs]
def add(self, data_points):
queries, results = data_points
for query, result in zip(queries, results):
results = self.query_index.get(tuple(query), [])
self.query_index[tuple(query)] = results + [result]
super().add(data_points)
[docs]
def query_constrain(self) -> QueryConstrain:
return QueryConstrain(count = self.queries.shape[0], shape = self._query_constrain().shape, ranges = self.queries)
[docs]
def result_constrain(self) -> ResultConstrain:
return self._result_constrain()