# hub.solver.stable_baselines.stable_baselines
Domain specification
# StableBaseline
This class wraps a stable OpenAI Baselines solver (stable_baselines3) as a scikit-decide solver.
WARNING
Using this class requires Stable Baselines 3 to be installed.
# Constructor StableBaseline
StableBaseline(
domain_factory: Callable[[], Domain],
algo_class: type[BaseAlgorithm],
baselines_policy: Union[str, type[BasePolicy]],
learn_config: Optional[dict[str, Any]] = None,
callback: Callable[[StableBaseline], bool] = <lambda function>,
use_action_masking: bool = False,
**kwargs: Any
) -> None
Initialize StableBaselines.
# Parameters
- domain_factory: A callable with no argument returning the domain to solve (can be a mere domain class). The resulting domain will be auto-cast to the level expected by the solver.
- algo_class: The class of Baselines solver (stable_baselines3) to wrap.
- baselines_policy: The class of Baselines policy network (stable_baselines3.common.policies or str) to use.
- learn_config: the kwargs passed to sb3 algo's
learn()
method - callback: function called at each solver iteration. If returning true, the solve process stops.
- use_action_masking: if True,
- the domain will be wrapped in a gymnasium environment exposing
action_masks()
, self.sample_action()
will pass action masks to underlying sb3 algo'spredict()
(e.g. MaskablePPO or MaskableGraphPPO),self.using_applicable_actions()
will return True so that rollout knows to retrieve action masks before sampling actions. kwargs: keyword arguments passed to the algo_class constructor.
- the domain will be wrapped in a gymnasium environment exposing
# autocast Solver
autocast(
self,
domain_cls: Optional[type[Domain]] = None
) -> None
Autocast itself to the level corresponding to the given domain class.
# Parameters
- domain_cls: the domain class to which level the solver needs to autocast itself. By default, use the original domain factory passed to its constructor.
# check_domain Solver
check_domain(
domain: Domain
) -> bool
Check whether a domain is compliant with this solver type.
By default, Solver.check_domain()
provides some boilerplate code and internally
calls Solver._check_domain_additional()
(which returns True by default but can be overridden to define
specific checks in addition to the "domain requirements"). The boilerplate code automatically checks whether all
domain requirements are met.
# Parameters
- domain: The domain to check.
# Returns
True if the domain is compliant with the solver type (False otherwise).
# complete_with_default_hyperparameters Hyperparametrizable
complete_with_default_hyperparameters(
kwargs: dict[str, Any],
names: Optional[list[str]] = None
)
Add missing hyperparameters to kwargs by using default values
Args:
kwargs: keyword arguments to complete (e.g. for __init__
, init_model
, or solve
)
names: names of the hyperparameters to add if missing.
By default, all available hyperparameters.
Returns: a new dictionary, completion of kwargs
# copy_and_update_hyperparameters Hyperparametrizable
copy_and_update_hyperparameters(
names: Optional[list[str]] = None,
**kwargs_by_name: dict[str, Any]
) -> list[Hyperparameter]
Copy hyperparameters definition of this class and update them with specified kwargs.
This is useful to define hyperparameters for a child class for which only choices of the hyperparameter change for instance.
Args: names: names of hyperparameters to copy. Default to all. **kwargs_by_name: for each hyperparameter specified by its name, the attributes to update. If a given hyperparameter name is not specified, the hyperparameter is copied without further update.
Returns:
# get_action_mask Maskable
get_action_mask(
self
) -> Optional[StrDict[Mask]]
Retrieve stored action masks.
To be used by self.sample_action()
.
Returns None if self.set_action_mask()
was not called.
# get_default_hyperparameters Hyperparametrizable
get_default_hyperparameters(
names: Optional[list[str]] = None
) -> dict[str, Any]
Get hyperparameters default values.
Args: names: names of the hyperparameters to choose. By default, all available hyperparameters will be suggested.
Returns: a mapping between hyperparameter's name_in_kwargs and its default value (None if not specified)
# get_domain_requirements Solver
get_domain_requirements(
) -> list[type]
Get domain requirements for this solver class to be applicable.
Domain requirements are classes from the skdecide.builders.domain
package that the domain needs to inherit from.
# Returns
A list of classes to inherit from.
# get_hyperparameter Hyperparametrizable
get_hyperparameter(
name: str
) -> Hyperparameter
Get hyperparameter from given name.
# get_hyperparameters_by_name Hyperparametrizable
get_hyperparameters_by_name(
) -> dict[str, Hyperparameter]
Mapping from name to corresponding hyperparameter.
# get_hyperparameters_names Hyperparametrizable
get_hyperparameters_names(
) -> list[str]
List of hyperparameters names.
# get_policy StableBaseline
get_policy(
self
) -> BasePolicy
Return the computed policy.
# is_policy_defined_for Policies
is_policy_defined_for(
self,
observation: StrDict[D.T_observation]
) -> bool
Check whether the solver's current policy is defined for the given observation.
# Parameters
- observation: The observation to consider.
# Returns
True if the policy is defined for the given observation memory (False otherwise).
# load Restorable
load(
self,
path: str
) -> None
Restore the solver state from given path.
After calling self._load(), autocast itself so that rollout methods apply to the domain original characteristics.
# Parameters
- path: The path where the solver state was saved.
# reset Solver
reset(
self
) -> None
Reset whatever is needed on this solver before running a new episode.
This function does nothing by default but can be overridden if needed (e.g. to reset the hidden state of a LSTM policy network, which carries information about past observations seen in the previous episode).
# retrieve_applicable_actions ApplicableActions
retrieve_applicable_actions(
self,
domain: Domain
) -> None
Retrieve applicable actions and use it for future call to self.step()
.
To be called during rollout to get the actual applicable actions from the actual domain used in rollout. Transform applicable actions into an action_mask to be use when sampling action.
# sample_action Policies
sample_action(
self,
observation: StrDict[D.T_observation]
) -> StrDict[list[D.T_event]]
Sample an action for the given observation (from the solver's current policy).
# Parameters
- observation: The observation for which an action must be sampled.
# Returns
The sampled action.
# save Restorable
save(
self,
path: str
) -> None
Save the solver state to given path.
# Parameters
- path: The path to store the saved state.
# set_action_mask Maskable
set_action_mask(
self,
action_mask: Optional[StrDict[Mask]]
) -> None
Set the action mask.
To be called during rollout before self.sample_action()
, assuming that
self.sample_action()
knows what to do with it.
Autocastable so that it can use action_mask from original domain during rollout.
# solve FromInitialState
solve(
self
) -> None
Run the solving process.
After solving by calling self._solve(), autocast itself so that rollout methods apply to the domain original characteristics.
TIP
The nature of the solutions produced here depends on other solver's characteristics like
policy
and assessibility
.
# suggest_hyperparameter_with_optuna Hyperparametrizable
suggest_hyperparameter_with_optuna(
trial: optuna.trial.Trial,
name: str,
prefix: str,
**kwargs
) -> Any
Suggest hyperparameter value during an Optuna trial.
This can be used during Optuna hyperparameters tuning.
Args: trial: optuna trial during hyperparameters tuning name: name of the hyperparameter to choose prefix: prefix to add to optuna corresponding parameter name (useful for disambiguating hyperparameters from subsolvers in case of meta-solvers) **kwargs: options for optuna hyperparameter suggestions
Returns:
kwargs can be used to pass relevant arguments to
- trial.suggest_float()
- trial.suggest_int()
- trial.suggest_categorical()
For instance it can
- add a low/high value if not existing for the hyperparameter or override it to narrow the search. (for float or int hyperparameters)
- add a step or log argument (for float or int hyperparameters, see optuna.trial.Trial.suggest_float())
- override choices for categorical or enum parameters to narrow the search
# suggest_hyperparameters_with_optuna Hyperparametrizable
suggest_hyperparameters_with_optuna(
trial: optuna.trial.Trial,
names: Optional[list[str]] = None,
kwargs_by_name: Optional[dict[str, dict[str, Any]]] = None,
fixed_hyperparameters: Optional[dict[str, Any]] = None,
prefix: str
) -> dict[str, Any]
Suggest hyperparameters values during an Optuna trial.
Args:
trial: optuna trial during hyperparameters tuning
names: names of the hyperparameters to choose.
By default, all available hyperparameters will be suggested.
If fixed_hyperparameters
is provided, the corresponding names are removed from names
.
kwargs_by_name: options for optuna hyperparameter suggestions, by hyperparameter name
fixed_hyperparameters: values of fixed hyperparameters, useful for suggesting subbrick hyperparameters,
if the subbrick class is not suggested by this method, but already fixed.
Will be added to the suggested hyperparameters.
prefix: prefix to add to optuna corresponding parameters
(useful for disambiguating hyperparameters from subsolvers in case of meta-solvers)
Returns:
mapping between the hyperparameter name and its suggested value.
If the hyperparameter has an attribute name_in_kwargs
, this is used as the key in the mapping
instead of the actual hyperparameter name.
the mapping is updated with fixed_hyperparameters
.
kwargs_by_name[some_name] will be passed as **kwargs to suggest_hyperparameter_with_optuna(name=some_name)
# using_applicable_actions ApplicableActions
using_applicable_actions(
self
)
Tell if the solver is able to use applicable actions information.
# _check_domain_additional Solver
_check_domain_additional(
domain: Domain
) -> bool
Check whether the given domain is compliant with the specific requirements of this solver type (i.e. the ones in addition to "domain requirements").
This is a helper function called by default from Solver.check_domain()
. It focuses on specific checks, as
opposed to taking also into account the domain requirements for the latter.
# Parameters
- domain: The domain to check.
# Returns
True if the domain is compliant with the specific requirements of this solver type (False otherwise).
# _cleanup Solver
_cleanup(
self
)
Runs cleanup code here, or code to be executed at the exit of a 'with' context statement.
# _initialize Solver
_initialize(
self
)
Runs long-lasting initialization code here.
# _is_policy_defined_for Policies
_is_policy_defined_for(
self,
observation: StrDict[D.T_observation]
) -> bool
Check whether the solver's current policy is defined for the given observation.
# Parameters
- observation: The observation to consider.
# Returns
True if the policy is defined for the given observation memory (False otherwise).
# _load Restorable
_load(
self,
path: str
)
Restore the solver state from given path.
# Parameters
- path: The path where the solver state was saved.
# _reset Solver
_reset(
self
) -> None
Reset whatever is needed on this solver before running a new episode.
This function does nothing by default but can be overridden if needed (e.g. to reset the hidden state of a LSTM policy network, which carries information about past observations seen in the previous episode).
# _sample_action Policies
_sample_action(
self,
observation: StrDict[D.T_observation]
) -> StrDict[list[D.T_event]]
Sample an action for the given observation (from the solver's current policy).
# Parameters
- observation: The observation for which an action must be sampled.
# Returns
The sampled action.
# _save Restorable
_save(
self,
path: str
) -> None
Save the solver state to given path.
# Parameters
- path: The path to store the saved state.
# _set_action_mask Maskable
_set_action_mask(
self,
action_mask: Optional[StrDict[Mask]]
) -> None
Set the action mask.
To be called during rollout before self.sample_action()
, assuming that
self.sample_action()
knows what to do with it.
# _solve FromInitialState
_solve(
self
) -> None
Run the solving process.
TIP
The nature of the solutions produced here depends on other solver's characteristics like
policy
and assessibility
.
# as_gymnasium_env
as_gymnasium_env(
domain: Domain
) -> gym.Env
Wraps the domain into a gymnasium env.
To be fed to sb3 algorithms.
# as_masked_gymnasium_env
as_masked_gymnasium_env(
domain: Domain
) -> gym.Env
Wraps the domain into an action-masked gymnasium env.
This means that it exposes a method self.action_masks()
as expected by algorithms like
sb3_contrib.MaskablePPO
.