RAMjET

ramjet.trial

Boilerplate code for running trials.

Module Contents

Classes

LoggingToolName Enum where members are also (and must be) strings

Functions

infer(model: tf.keras.Model, dataset: tf.data.Dataset, infer_results_path: Path, number_of_top_predictions_to_keep: int = None) Performs inference of a model on a dataset saving the results to a file.
save_results(confidences_data_frame: pd.DataFrame, infer_results_path: Path, number_of_top_predictions_to_keep: int = None) → pd.DataFrame Saves a predictions data frame to a file.
create_logging_metrics() → List[tf.metrics.Metric] Creates the standard metrics to be used in logging.
create_logging_callbacks(logs_directory: Path, trial_name: str, database: StandardAndInjectedLightCurveDatabase, logging_tool_name: LoggingToolName = LoggingToolName.WANDB, wandb_entity: Optional[str] = None, wandb_project: Optional[str] = None, light_curve_logging: bool = False) → List[callbacks.Callback] Creates the callbacks to perform the logging.
class LoggingToolName[source]

Bases: enum.StrEnum

Enum where members are also (and must be) strings

WANDB = wandb
TENSORBOARD = tensorboard
infer(model: tf.keras.Model, dataset: tf.data.Dataset, infer_results_path: Path, number_of_top_predictions_to_keep: int = None)[source]

Performs inference of a model on a dataset saving the results to a file.

Parameters:
  • model – The model to infer with.
  • dataset – The dataset to infer on.
  • infer_results_path – The path to save the resulting predictions to.
  • number_of_top_predictions_to_keep – The number of top results to keep. None will save all results.
save_results(confidences_data_frame: pd.DataFrame, infer_results_path: Path, number_of_top_predictions_to_keep: int = None) → pd.DataFrame[source]

Saves a predictions data frame to a file.

Parameters:
  • confidences_data_frame – The data frame of predictions to save.
  • infer_results_path – The path to save the resulting predictions to.
  • number_of_top_predictions_to_keep – The number of top results to keep. None will save all results.
Returns:

The updated data frame.

create_logging_metrics() → List[tf.metrics.Metric][source]

Creates the standard metrics to be used in logging.

Returns:The list of metrics.
create_logging_callbacks(logs_directory: Path, trial_name: str, database: StandardAndInjectedLightCurveDatabase, logging_tool_name: LoggingToolName = LoggingToolName.WANDB, wandb_entity: Optional[str] = None, wandb_project: Optional[str] = None, light_curve_logging: bool = False) → List[callbacks.Callback][source]

Creates the callbacks to perform the logging.

Parameters:
  • logs_directory – The directory to log to.
  • trial_name – The name of the trial.
Returns:

The callbacks to perform the logging.