ramjet.photometric_database.py_mapper¶
Code for TensorFlow’s Dataset class which allows for multiprocessing in CPU map functions.
Module Contents¶
-
class
PyMapper(map_function: Callable, number_of_parallel_calls: int)[source]¶ A class which allows for mapping a py_function to a TensorFlow dataset in parallel on CPU.
-
send_to_map_pool(self, element_tensor)[source]¶ Sends the tensor element to the pool for processing.
Parameters: element_tensor – The element to be processed by the pool. Returns: The output of the map function on the element.
-
map_to_dataset(self, dataset: tf.data.Dataset, output_types: Union[List[tf.dtypes.DType], tf.dtypes.DType] = tf.float32)[source]¶ Maps the map function to the passed dataset.
Parameters: - dataset – The dataset to apply the map function to.
- output_types – The TensorFlow output types of the function to convert to.
Returns: The mapped dataset.
-
-
map_py_function_to_dataset(dataset: tf.data.Dataset, map_function: Callable, number_of_parallel_calls: int, output_types: Union[List[tf.dtypes.DType], tf.dtypes.DType] = tf.float32) → tf.data.Dataset[source]¶ A one line wrapper to allow mapping a parallel py function to a dataset.
Parameters: - dataset – The dataset whose elements the mapping function will be applied to.
- map_function – The function to map to the dataset.
- number_of_parallel_calls – The number of parallel calls of the mapping function.
- output_types – The TensorFlow output types of the function to convert to.
Returns: The mapped dataset.