# hub.solver.ray_rllib.ray_rllib
Domain specification
# RayRLlib
This class wraps a Ray RLlib solver (ray[rllib]) as a scikit-decide solver.
WARNING
Using this class requires Ray RLlib to be installed.
# Constructor RayRLlib
RayRLlib(
domain_factory: Callable[[], Domain],
algo_class: type[Algorithm],
train_iterations: int,
config: Optional[AlgorithmConfig] = None,
policy_configs: Optional[dict[str, dict]] = None,
policy_mapping_fn: Optional[Callable[[str, Optional['EpisodeV2'], Optional['RolloutWorker']], str]] = None,
action_embed_sizes: Optional[dict[str, int]] = None,
callback: Callable[[RayRLlib], bool] = <lambda function>,
graph_feature_extractors_kwargs: Optional[dict[str, Any]] = None,
**kwargs
) -> None
Initialize Ray RLlib.
# Parameters
- domain_factory: A callable with no argument returning the domain to solve (can be a mere domain class). The resulting domain will be auto-cast to the level expected by the solver.
- algo_class: The class of Ray RLlib trainer/agent to wrap.
- train_iterations: The number of iterations to call the trainer's train() method.
- config: The configuration dictionary for the trainer.
- policy_configs: The mapping from policy id (str) to additional config (dict) (leave default for single policy).
- policy_mapping_fn: The function mapping agent ids to policy ids (leave default for single policy).
- action_embed_sizes: The mapping from policy id (str) to action embedding size (only used with domains filtering allowed actions per state, default to 2)
- callback: function called at each solver iteration. If returning true, the solve process stops and exit the current train iteration. However, if train_iterations > 1, another train loop will be entered after that. (One can code its callback in such a way that further training loop are stopped directly after that.) graph_feature_extractors_kwargs: in case of graph observations, these are the kwargs to the GraphFreaturesExtractor model used to extract features. (See skdecide.hub.solver.ray_rllib.gnn.models.torch.gnn.GraphFeaturesExtractor) **kwargs: used to update the algo config with kwargs automatically filled by optuna.
# autocast Solver
autocast(
self,
domain_cls: Optional[type[Domain]] = None
) -> None
Autocast itself to the level corresponding to the given domain class.
# Parameters
- domain_cls: the domain class to which level the solver needs to autocast itself. By default, use the original domain factory passed to its constructor.
# check_domain Solver
check_domain(
domain: Domain
) -> bool
Check whether a domain is compliant with this solver type.
By default, Solver.check_domain()
provides some boilerplate code and internally
calls Solver._check_domain_additional()
(which returns True by default but can be overridden to define
specific checks in addition to the "domain requirements"). The boilerplate code automatically checks whether all
domain requirements are met.
# Parameters
- domain: The domain to check.
# Returns
True if the domain is compliant with the solver type (False otherwise).
# complete_with_default_hyperparameters Hyperparametrizable
complete_with_default_hyperparameters(
kwargs: dict[str, Any],
names: Optional[list[str]] = None
)
Add missing hyperparameters to kwargs by using default values
Args:
kwargs: keyword arguments to complete (e.g. for __init__
, init_model
, or solve
)
names: names of the hyperparameters to add if missing.
By default, all available hyperparameters.
Returns: a new dictionary, completion of kwargs
# copy_and_update_hyperparameters Hyperparametrizable
copy_and_update_hyperparameters(
names: Optional[list[str]] = None,
**kwargs_by_name: dict[str, Any]
) -> list[Hyperparameter]
Copy hyperparameters definition of this class and update them with specified kwargs.
This is useful to define hyperparameters for a child class for which only choices of the hyperparameter change for instance.
Args: names: names of hyperparameters to copy. Default to all. **kwargs_by_name: for each hyperparameter specified by its name, the attributes to update. If a given hyperparameter name is not specified, the hyperparameter is copied without further update.
Returns:
# forget_callback RayRLlib
forget_callback(
self
)
Forget about actual callback to avoid serializing issues.
# get_action_mask Maskable
get_action_mask(
self
) -> Optional[StrDict[Mask]]
Retrieve stored action masks.
To be used by self.sample_action()
.
Returns None if self.set_action_mask()
was not called.
# get_default_hyperparameters Hyperparametrizable
get_default_hyperparameters(
names: Optional[list[str]] = None
) -> dict[str, Any]
Get hyperparameters default values.
Args: names: names of the hyperparameters to choose. By default, all available hyperparameters will be suggested.
Returns: a mapping between hyperparameter's name_in_kwargs and its default value (None if not specified)
# get_domain_requirements Solver
get_domain_requirements(
) -> list[type]
Get domain requirements for this solver class to be applicable.
Domain requirements are classes from the skdecide.builders.domain
package that the domain needs to inherit from.
# Returns
A list of classes to inherit from.
# get_hyperparameter Hyperparametrizable
get_hyperparameter(
name: str
) -> Hyperparameter
Get hyperparameter from given name.
# get_hyperparameters_by_name Hyperparametrizable
get_hyperparameters_by_name(
) -> dict[str, Hyperparameter]
Mapping from name to corresponding hyperparameter.
# get_hyperparameters_names Hyperparametrizable
get_hyperparameters_names(
) -> list[str]
List of hyperparameters names.
# get_policy RayRLlib
get_policy(
self
) -> dict[str, Policy]
Return the computed policy.
# is_policy_defined_for Policies
is_policy_defined_for(
self,
observation: StrDict[D.T_observation]
) -> bool
Check whether the solver's current policy is defined for the given observation.
# Parameters
- observation: The observation to consider.
# Returns
True if the policy is defined for the given observation memory (False otherwise).
# load Restorable
load(
self,
path: str
) -> None
Restore the solver state from given path.
After calling self._load(), autocast itself so that rollout methods apply to the domain original characteristics.
# Parameters
- path: The path where the solver state was saved.
# reset Solver
reset(
self
) -> None
Reset whatever is needed on this solver before running a new episode.
This function does nothing by default but can be overridden if needed (e.g. to reset the hidden state of a LSTM policy network, which carries information about past observations seen in the previous episode).
# retrieve_applicable_actions ApplicableActions
retrieve_applicable_actions(
self,
domain: Domain
) -> None
Retrieve applicable actions and use it for future call to self.step()
.
To be called during rollout to get the actual applicable actions from the actual domain used in rollout. Transform applicable actions into an action_mask to be use when sampling action.
# sample_action Policies
sample_action(
self,
observation: StrDict[D.T_observation]
) -> StrDict[list[D.T_event]]
Sample an action for the given observation (from the solver's current policy).
# Parameters
- observation: The observation for which an action must be sampled.
# Returns
The sampled action.
# save Restorable
save(
self,
path: str
) -> None
Save the solver state to given path.
# Parameters
- path: The path to store the saved state.
# set_action_mask Maskable
set_action_mask(
self,
action_mask: Optional[StrDict[Mask]]
) -> None
Set the action mask.
To be called during rollout before self.sample_action()
, assuming that
self.sample_action()
knows what to do with it.
Autocastable so that it can use action_mask from original domain during rollout.
# set_callback RayRLlib
set_callback(
self
)
Set back callback.
Useful to do it after serializing/deserializing because of potential issues with
- lambda functions
- dynamic classes
# solve FromInitialState
solve(
self
) -> None
Run the solving process.
After solving by calling self._solve(), autocast itself so that rollout methods apply to the domain original characteristics.
TIP
The nature of the solutions produced here depends on other solver's characteristics like
policy
and assessibility
.
# suggest_hyperparameter_with_optuna Hyperparametrizable
suggest_hyperparameter_with_optuna(
trial: optuna.trial.Trial,
name: str,
prefix: str,
**kwargs
) -> Any
Suggest hyperparameter value during an Optuna trial.
This can be used during Optuna hyperparameters tuning.
Args: trial: optuna trial during hyperparameters tuning name: name of the hyperparameter to choose prefix: prefix to add to optuna corresponding parameter name (useful for disambiguating hyperparameters from subsolvers in case of meta-solvers) **kwargs: options for optuna hyperparameter suggestions
Returns:
kwargs can be used to pass relevant arguments to
- trial.suggest_float()
- trial.suggest_int()
- trial.suggest_categorical()
For instance it can
- add a low/high value if not existing for the hyperparameter or override it to narrow the search. (for float or int hyperparameters)
- add a step or log argument (for float or int hyperparameters, see optuna.trial.Trial.suggest_float())
- override choices for categorical or enum parameters to narrow the search
# suggest_hyperparameters_with_optuna Hyperparametrizable
suggest_hyperparameters_with_optuna(
trial: optuna.trial.Trial,
names: Optional[list[str]] = None,
kwargs_by_name: Optional[dict[str, dict[str, Any]]] = None,
fixed_hyperparameters: Optional[dict[str, Any]] = None,
prefix: str
) -> dict[str, Any]
Suggest hyperparameters values during an Optuna trial.
Args:
trial: optuna trial during hyperparameters tuning
names: names of the hyperparameters to choose.
By default, all available hyperparameters will be suggested.
If fixed_hyperparameters
is provided, the corresponding names are removed from names
.
kwargs_by_name: options for optuna hyperparameter suggestions, by hyperparameter name
fixed_hyperparameters: values of fixed hyperparameters, useful for suggesting subbrick hyperparameters,
if the subbrick class is not suggested by this method, but already fixed.
Will be added to the suggested hyperparameters.
prefix: prefix to add to optuna corresponding parameters
(useful for disambiguating hyperparameters from subsolvers in case of meta-solvers)
Returns:
mapping between the hyperparameter name and its suggested value.
If the hyperparameter has an attribute name_in_kwargs
, this is used as the key in the mapping
instead of the actual hyperparameter name.
the mapping is updated with fixed_hyperparameters
.
kwargs_by_name[some_name] will be passed as **kwargs to suggest_hyperparameter_with_optuna(name=some_name)
# using_applicable_actions ApplicableActions
using_applicable_actions(
self
)
Tell if the solver is able to use applicable actions information.
For instance, action masking could be possible only if considered domain action space is enumerable for each agent.
The default implementation returns always True.
# _check_domain_additional Solver
_check_domain_additional(
domain: Domain
) -> bool
Check whether the given domain is compliant with the specific requirements of this solver type (i.e. the ones in addition to "domain requirements").
This is a helper function called by default from Solver.check_domain()
. It focuses on specific checks, as
opposed to taking also into account the domain requirements for the latter.
# Parameters
- domain: The domain to check.
# Returns
True if the domain is compliant with the specific requirements of this solver type (False otherwise).
# _cleanup Solver
_cleanup(
self
)
Runs cleanup code here, or code to be executed at the exit of a 'with' context statement.
# _initialize Solver
_initialize(
self
)
Runs long-lasting initialization code here.
# _is_policy_defined_for Policies
_is_policy_defined_for(
self,
observation: StrDict[D.T_observation]
) -> bool
Check whether the solver's current policy is defined for the given observation.
# Parameters
- observation: The observation to consider.
# Returns
True if the policy is defined for the given observation memory (False otherwise).
# _load Restorable
_load(
self,
path: str
)
Restore the solver state from given path.
# Parameters
- path: The path where the solver state was saved.
# _reset Solver
_reset(
self
) -> None
Reset whatever is needed on this solver before running a new episode.
This function does nothing by default but can be overridden if needed (e.g. to reset the hidden state of a LSTM policy network, which carries information about past observations seen in the previous episode).
# _sample_action Policies
_sample_action(
self,
observation: StrDict[D.T_observation]
) -> StrDict[list[D.T_event]]
Sample an action for the given observation (from the solver's current policy).
# Parameters
- observation: The observation for which an action must be sampled.
# Returns
The sampled action.
# _save Restorable
_save(
self,
path: str
) -> None
Save the solver state to given path.
# Parameters
- path: The path to store the saved state.
# _set_action_mask Maskable
_set_action_mask(
self,
action_mask: Optional[StrDict[Mask]]
) -> None
Set the action mask.
To be called during rollout before self.sample_action()
, assuming that
self.sample_action()
knows what to do with it.
# _solve FromInitialState
_solve(
self
) -> None
Run the solving process.
TIP
The nature of the solutions produced here depends on other solver's characteristics like
policy
and assessibility
.
# _CallbackWrapper
Wrapper to avoid surprises with lambda functions
# Constructor _CallbackWrapper
_CallbackWrapper(
callback: Callable[[RayRLlib], bool]
)
Initialize self. See help(type(self)) for accurate signature.
# SolveEarlyStop
Exception raised if a callback tells to stop the solve process.