Source code for alts.modules.oracle.augmentation
#Version 1.1.1 conform as of 29.11.2024
from __future__ import annotations
"""
| *alts.modules.oracle.augmentation*
| :doc:`Core Module </core/oracle/augmentation>`
"""
from typing import TYPE_CHECKING
from dataclasses import dataclass
from alts.core.oracle.augmentation import Augmentation
from alts.core.configuration import init
import numpy as np
if TYPE_CHECKING:
from typing_extensions import Self
from typing import Tuple
from nptyping import NDArray, Number, Shape
[docs]
@dataclass
class NoiseAugmentation(Augmentation):
"""
NoiseAugmentation(noise_ratio)
| **Description**
| Adds noise to the results of the augmented :doc:`DataSource </core/oracle/data_source>`.
:param noise_ratio: Standard deviation from actual result (default = 0.01)
:type noise_ratio: float
"""
noise_ratio: float = init(default=0.01)
rng = np.random.default_rng()
[docs]
def query(self, queries: NDArray[ Shape["query_nr, ... query_dim"], Number]) -> Tuple[NDArray[Shape["query_nr, ... query_dim"], Number], NDArray[Shape["query_nr, ... result_dim"], Number]]: # type: ignore
"""
query(self, queries) -> data_points
| **Description**
| Applies random noise with the given standard deviation ``noise_ratio`` to the result.
:param queries: Requested Query
:type queries: `NDArray <https://numpy.org/doc/stable/reference/arrays.ndarray.html>`_
:return: Processed Query, Result
:rtype: A tuple of two `NDArray <https://numpy.org/doc/stable/reference/arrays.ndarray.html>`_
"""
queries, results = self.data_source.query(queries)
augmented = self.rng.normal(results, self.noise_ratio) # type: ignore
return queries, augmented