ramjet.trial¶
Boilerplate code for running trials.
Module Contents¶
Classes¶
LoggingToolName |
Enum where members are also (and must be) strings |
Functions¶
infer(model: tf.keras.Model, dataset: tf.data.Dataset, infer_results_path: Path, number_of_top_predictions_to_keep: int = None) |
Performs inference of a model on a dataset saving the results to a file. |
save_results(confidences_data_frame: pd.DataFrame, infer_results_path: Path, number_of_top_predictions_to_keep: int = None) → pd.DataFrame |
Saves a predictions data frame to a file. |
create_logging_metrics() → List[tf.metrics.Metric] |
Creates the standard metrics to be used in logging. |
create_logging_callbacks(logs_directory: Path, trial_name: str, database: StandardAndInjectedLightcurveDatabase, logging_tool_name: LoggingToolName = LoggingToolName.WANDB, wandb_entity: Optional[str] = None, wandb_project: Optional[str] = None, light_curve_logging: bool = False) → List[callbacks.Callback] |
Creates the callbacks to perform the logging. |
-
class
LoggingToolName[source]¶ Bases:
enum.StrEnumEnum where members are also (and must be) strings
-
WANDB= wandb¶
-
TENSORBOARD= tensorboard¶
-
-
infer(model: tf.keras.Model, dataset: tf.data.Dataset, infer_results_path: Path, number_of_top_predictions_to_keep: int = None)[source]¶ Performs inference of a model on a dataset saving the results to a file.
Parameters: - model – The model to infer with.
- dataset – The dataset to infer on.
- infer_results_path – The path to save the resulting predictions to.
- number_of_top_predictions_to_keep – The number of top results to keep. None will save all results.
-
save_results(confidences_data_frame: pd.DataFrame, infer_results_path: Path, number_of_top_predictions_to_keep: int = None) → pd.DataFrame[source]¶ Saves a predictions data frame to a file.
Parameters: - confidences_data_frame – The data frame of predictions to save.
- infer_results_path – The path to save the resulting predictions to.
- number_of_top_predictions_to_keep – The number of top results to keep. None will save all results.
Returns: The updated data frame.
-
create_logging_metrics() → List[tf.metrics.Metric][source]¶ Creates the standard metrics to be used in logging.
Returns: The list of metrics.
-
create_logging_callbacks(logs_directory: Path, trial_name: str, database: StandardAndInjectedLightcurveDatabase, logging_tool_name: LoggingToolName = LoggingToolName.WANDB, wandb_entity: Optional[str] = None, wandb_project: Optional[str] = None, light_curve_logging: bool = False) → List[callbacks.Callback][source]¶ Creates the callbacks to perform the logging.
Parameters: - logs_directory – The directory to log to.
- trial_name – The name of the trial.
Returns: The callbacks to perform the logging.