# Copyright (c) 2026 AIRBUS and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
from collections import defaultdict
from collections.abc import Callable, Iterable
from typing import Any, Optional
import pandas as pd
from discrete_optimization.generic_tools.do_problem import Problem
from discrete_optimization.generic_tools.do_solver import SolverDO, StatusSolver
from discrete_optimization.generic_tools.study import Hdf5Database
from discrete_optimization.generic_tools.study.config import ConfigStore
from discrete_optimization.generic_tools.study.database import (
IS_EMPTY,
is_empty_metrics,
)
from discrete_optimization.generic_tools.study.experiment import (
CONFIG,
INSTANCE,
STATUS,
Experiment,
SolverConfig,
SolverJsonableConfig,
)
logger = logging.getLogger(__name__)
[docs]
class Study(Iterable[tuple[Problem, SolverDO, dict[str, Any]]]):
"""Small wrapper to manage d-o studies
- loop over instance x config
- manage mapping instance_name <-> instance, config_name <-> config
- allow warmstart from previous run, with skip of successful experiments
- automatic retry of unsuccessful experiments
Manage database to store experiment.
For now, use Hdf5Database with default path beigin <study_name>.h5
To loop over a study, do:
```python
for problem, solver, solver_kwargs in study:
...
# solve
solver.solve(..., **solver_kwargs)
# retrieve metrics
metrics = ...
study.store_current_xp(metrics, ...)
```
"""
def __init__(
self,
name: str,
instances: list[str],
solver_configs: dict[str, SolverConfig],
problem_factory: Callable[[str], Problem],
overwrite: bool = False,
max_retry: int = 0,
database_filepath: Optional[str] = None,
solver_factory: Optional[Callable[[Problem, SolverConfig], SolverDO]] = None,
):
"""
Args:
name: study name
instances: list of instance names (as they are to appear in database/dashboard)
solver_configs: mapping a config name to an actual solver config
problem_factory: mapping an instance name to a d-o problem instance
overwrite: if True, the previous database is erased. Else we warmstart with previous database to avoid relaunching successful experiments
max_retry: when a experiment is unsuccessful (potentially from a previous stufy run) we retry it until reaching max_retry.
Default to 0, which means no retry.
database_filepath: hdf5 database file path, default to <name>.h5
solver_factory: mapping a problem and solver config to a solver.
Default to calling solver_config.cls __init__ + init_model with solver_config.kwargs.
"""
self.max_retry = max_retry
self.overwrite = overwrite
self.name = name
self.instances = instances
self.solver_configs = solver_configs
self.problem_factory = problem_factory
if database_filepath is None:
self.database_filepath = f"{self.name}.h5"
else:
self.database_filepath = database_filepath
if solver_factory is None:
def solver_factory(
problem: Problem, solver_config: SolverConfig
) -> SolverDO:
solver = solver_config.cls(problem, **solver_config.kwargs)
solver.init_model(**solver_config.kwargs)
return solver
self.solver_factory = solver_factory
self.config_store = ConfigStore()
self.config_instance_success: dict[tuple[str, str], bool] = defaultdict(
lambda: False
)
self.config_instance_n_attempts: dict[tuple[str, str], int] = defaultdict(
lambda: 0
)
if self.overwrite:
# erase previous database
try:
os.remove(self.database_filepath)
except FileNotFoundError:
pass
else:
# load success and n_attempts from previous database
self.load_metadatas()
self._in_the_loop = False
def __iter__(self):
self._in_the_loop = True
# loop over instances x configs
for i_instance, instance in enumerate(self.instances):
self._instance = instance
for i_config, (config_name, solver_config) in enumerate(
self.solver_configs.items()
):
self._solver_config = solver_config
logging.info(
f"###### Instance {i_instance + 1}/{len(self.instances)}: {instance}, config {i_config + 1}/{len(self.solver_configs)}: {config_name} ######"
)
if not self.overwrite:
# check if config x instance already done
config = SolverJsonableConfig.from_solver_config(
solver_config, name=config_name
).as_nested_dict()
self.config_store.add(config)
config_name_normalized = self.config_store.get_name(config)
key = (instance, config_name_normalized)
if self.config_instance_success[key]:
logging.info(
"config x instance already run successfully in a previous attempt. Skipping it."
)
continue
if self.config_instance_n_attempts[key] > self.max_retry:
logging.info(
f"config x instance already run (unsuccessfully) {self.config_instance_n_attempts[key]} times. Skipping it."
)
continue
n_attempts = self.config_instance_n_attempts[key]
self._success = self.config_instance_success[key]
self._config_name = config_name_normalized
else:
n_attempts = 0
self._success = False
self._config_name = config_name
# retry it if failing
while not self._success and n_attempts <= self.max_retry:
n_attempts += 1
if self.max_retry > 0:
logging.info(
f"###### Attempt {n_attempts} / {self.max_retry + 1}"
)
try:
problem = self.problem_factory(instance)
solver = self.solver_factory(problem, solver_config)
yield problem, solver, solver_config.kwargs
except Exception as e:
# failed experiment from problem/solver instantiation
metrics = pd.DataFrame([])
status = StatusSolver.ERROR
logging.error(e)
reason = f"{type(e)}: {str(e)}"
success = False
self.store_current_xp(
metrics=metrics,
status=status,
reason=reason,
success=success,
)
[docs]
def store_current_xp(
self,
metrics: pd.DataFrame,
status: str | StatusSolver,
reason: str = "",
success: Optional[bool] = None,
):
"""
Args:
metrics:
status:
reason:
success:
Returns:
"""
if not self._in_the_loop:
raise RuntimeError(
"This method should be called inside a loop over the study:\n"
"`for problem, solver, solver_kwargs in study: ...; study.store_current_xp(...)`"
)
if success is None:
success = is_empty_metrics(metrics)
with (
Hdf5Database(self.database_filepath) as database
): # ensure closing the database at the end of computation (even if error)
xp_id = database.get_new_experiment_id()
xp = Experiment.from_solver_config(
xp_id=xp_id,
instance=self._instance,
config_name=self._config_name,
solver_config=self._solver_config,
metrics=metrics,
status=status,
reason=reason,
)
database.store(xp)
self._success = success
[docs]
def get_current_instance(self) -> str:
return self._instance
[docs]
def get_current_config_name(self) -> str:
return self._config_name
I_RUN_LABEL = "attempt"
[docs]
def normalize_config(config: Any, config_store: ConfigStore) -> str:
# config -> config name
if isinstance(config, dict):
config_name = config_store.get_name(config)
elif isinstance(config, str):
config_name = config
else:
raise ValueError(
"For each result df, df.attrs['config'] must be either a dictionary "
"or a string representing its name."
)
return config_name