Source code for discrete_optimization.generic_tools.dashboard.config

from __future__ import annotations

import logging
from collections import defaultdict
from collections.abc import Hashable
from typing import Any, Optional, Union

from discrete_optimization.generic_tools.study.experiment import NAME, SOLVER

logger = logging.getLogger(__name__)

# type aliases
DictConfig = Union[dict[str, "DictConfig"], Hashable]
HashableConfig = Union[tuple[str, "HashableConfig"], Hashable]


[docs] def get_config_name(config: DictConfig) -> str: if isinstance(config, dict): if NAME in config: return config[NAME] elif SOLVER in config: return f"{config[SOLVER]}-{get_config_name({k: v for k,v in config.items() if k != SOLVER})}" else: return "-".join(f"{k}-{get_config_name(v)}" for k, v in config.items()) else: return str(config)
[docs] def convert_config_dict2hashable(config: DictConfig) -> HashableConfig: if isinstance(config, dict): return tuple((k, convert_config_dict2hashable(v)) for k, v in config.items()) else: return config
[docs] def is_tupleddict(config: Any) -> bool: return isinstance(config, tuple) and all( isinstance(configitem, tuple) and len(configitem) == 2 and isinstance(configitem[0], str) for configitem in config )
[docs] def convert_config_hashable2dict(config: HashableConfig) -> DictConfig: if is_tupleddict(config): return {k: convert_config_hashable2dict(v) for k, v in config} else: return config
[docs] class ConfigStore: """Store experiments config and mapping to their names""" def __init__(self): self.map_name2configs: dict[str, set[HashableConfig]] = defaultdict(set) self.map_config2name: dict[HashableConfig, str] = {} self.map_config2hasusername: dict[HashableConfig, bool] = {}
[docs] def add(self, config: DictConfig) -> None: """Add a config to the store. Ensure bijection between names and configs. If name already given, use it. If not, construct it from solver parameters. If 2 names given in different occurences raise an error. If 2 different config share the same name, raise an error. """ config = dict(config) # copy dict to avoid inplace modification name: Optional[str] = None # name given by user? if NAME in config: name = config[NAME] del config[NAME] # hashable version of the config hashable_config = convert_config_dict2hashable(config) # get a name and check that previous similar config share the name (or had not already one defined) if ( hashable_config in self.map_config2name ): # config already added (with potentially no name or another one) if name is None: name = self.map_config2name[hashable_config] else: if self.map_config2hasusername[ hashable_config ]: # had already a name given by user assert name == self.map_config2name[hashable_config], ( "Same configs should share same names. " f"The names {name} and {self.map_config2name[hashable_config]} were used." ) else: # stored name was a name constructed by default => replace it with name given by user self.map_config2name[hashable_config] = name self.map_config2hasusername[hashable_config] = True else: if name is None: # no name given by user: construct one from parameters name = get_config_name(config) self.map_config2hasusername[hashable_config] = False else: # name given by user self.map_config2hasusername[hashable_config] = True self.map_config2name[hashable_config] = name # check that only one config corresponds to this name => only warning if more than 1 # it is probably an error on study-side but we allow it to avoid the dashboard crashing n_configs_with_name_before = len(self.map_name2configs[name]) self.map_name2configs[name].add(hashable_config) if ( n_configs_with_name_before > 0 and len(self.map_name2configs[name]) > n_configs_with_name_before ): logger.warning( f"More than one config share name {name}:\n" + "\n".join( repr(convert_config_hashable2dict(c)) for c in self.map_name2configs[name] ) )
[docs] def get_name(self, config: DictConfig) -> str: config = dict(config) # copy dict to avoid inplace modification if NAME in config: del config[NAME] hashable_config = convert_config_dict2hashable(config) if hashable_config not in self.map_config2name: raise RuntimeError( "Be sure to add all configs to the store before extracting names. " "That way, we can ensure same configs share names " "and that no name is used for several different configs." ) else: return self.map_config2name[hashable_config]
[docs] def get_configs(self, name: str) -> list[DictConfig]: return [convert_config_hashable2dict(c) for c in self.map_name2configs[name]]