# Copyright (c) 2024 AIRBUS and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations # see annotations as str
import logging
import time
from typing import Optional
from discrete_optimization.generic_tools.callbacks.callback import Callback
from discrete_optimization.generic_tools.do_solver import SolverDO
from discrete_optimization.generic_tools.result_storage.result_storage import (
ResultStorage,
)
logger = logging.getLogger(__name__)
try:
import optuna
except ImportError:
logger.warning("You should install optuna to use callbacks for optuna.")
[docs]
class OptunaCallback(Callback):
"""Callback reporting intermediate values to prune unpromising trials during Optuna hyperparameters tuning.
Adapted to single objective optimization (res.fit is a float)
The callback report to optuna intermediate fitness with the corresponding step number
or elapsed time since starting time.
It also updates the user attribute used to store computing time,
so that pruned or failed trials will still have the user attribute updated.
If the optuna pruner see that the trial should be pruned, raise the appropriate TrialPruned exception.
Args:
trial:
A :class:`optuna.trial.Trial` corresponding to the current evaluation of the
objective function.
optuna_report_nb_steps: report intermediate result every `optuna_report_nb_steps` steps
when the number of iterations is high, setting this to 1 could slow too much run of a single trial
starting_time: float representing the start time of the solving process.
Should be the result of a call to `time.perf_counter()`.
Default to `time.perf_counter()` called by `on_solve_start()`.
Useful to be on par with a clock set outside the callback.
elapsed_time_attr: key of trial user attribute used to store the elapsed time at each step
report_time: if True, report to optuna intermediate fitness with elapsed time instead of step
report_time_unit: fraction of second used as unit when reporting time.
Second is represented by 1, millisecond by 0.001, etc.
This may be useful starting from optuna 4.0 as report steps are converted to integers.
pruning: if True, use the optuna pruner to decide if we the trial should be pruned. Else never try to prune.
"""
def __init__(
self,
trial: optuna.trial.Trial,
optuna_report_nb_steps: int = 1,
starting_time: Optional[float] = None,
elapsed_time_attr: str = "elapsed_time",
report_time: bool = False,
report_time_unit: float = 1.0,
pruning: bool = True,
**kwargs,
) -> None:
self.pruning = pruning
self.report_time = report_time
self.report_time_unit = report_time_unit
self.elapsed_time_attr = elapsed_time_attr
self.report_nb_steps = optuna_report_nb_steps
self.trial = trial
self.starting_time = starting_time
[docs]
def on_solve_start(self, solver: SolverDO):
if self.starting_time is None:
self.starting_time = time.perf_counter()
[docs]
def on_step_end(
self, step: int, res: ResultStorage, solver: SolverDO
) -> Optional[bool]:
"""Called at the end of an optimization step.
Args:
step: index of step
res: current result storage
solver: solvers using the callback
Returns:
If `True`, the optimization process is stopped, else it goes on.
"""
if step % self.report_nb_steps == 0:
_, fit = res.get_best_solution_fit()
step_time = (
time.perf_counter() - self.starting_time
) / self.report_time_unit
self.trial.set_user_attr(self.elapsed_time_attr, step_time)
# Report current score and step to Optuna's trial.
if self.report_time:
self.trial.report(float(fit), step=step_time)
else:
self.trial.report(float(fit), step=step)
# Prune trial if needed
if self.pruning:
if self.trial.should_prune():
message = "Trial was pruned at step {}.".format(step)
raise optuna.TrialPruned(message)