# utils
This module contains utility functions.
Domain specification
# get_data_home
get_data_home(
data_home: Optional[str] = None
) -> str
Return the path of the scikit-decide data directory.
This folder is used by some large dataset loaders to avoid downloading the data several times, as for instance the weather data used by the flight planning domain. By default the data dir is set to a folder named 'skdecide_data' in the user home folder. Alternatively, it can be set by the 'SKDECIDE_DATA' environment variable or programmatically by giving an explicit folder path. The '~' symbol is expanded to the user home folder. If the folder does not already exist, it is automatically created.
Params:
data_home : The path to scikit-decide data directory. If None
, the default path
is ~/skdecide_data
.
# ReplayOutOfActionMethod
An enumeration.
# ERROR ReplayOutOfActionMethod
# LAST ReplayOutOfActionMethod
# LOOP ReplayOutOfActionMethod
# ReplaySolver
Wrapper around a list of actions mimicking a computed policy.
The goal is to be able to replay a rollout from a previous episode.
# Attributes
- actions: list of actions to wrap
- out_of_action_method: method to use when we run out of actions
- LOOP: we loop on actions, beginning back with the first action,
- LAST: we keep returning the last action,
- ERROR: we raise a RuntimeError.
# Example
# rollout with actual solver
episodes = rollout(
domain,
solver,
return_episodes=True
)
# take the first episode
observations, actions, values = episodes[0]
# wrap the corresponding actions in a replay solver
replay_solver = ReplaySolver(actions)
# replay the rollout
replayed_episodes = rollout(
domain=domain,
solver=replay_solver,
return_episodes=True
)
# same outputs (for deterministic domain)
assert episodes == replayed_episodes
# Constructor ReplaySolver
ReplaySolver(
actions: list[StrDict[list[D.T_event]]],
out_of_action_method: ReplayOutOfActionMethod = ReplayOutOfActionMethod.LAST
)
Initialize self. See help(type(self)) for accurate signature.
# get_next_action DeterministicPolicies
get_next_action(
self,
observation: StrDict[D.T_observation]
) -> StrDict[list[D.T_event]]
Get the next deterministic action (from the solver's current policy).
# Parameters
- observation: The observation for which next action is requested.
# Returns
The next deterministic action.
# get_next_action_distribution UncertainPolicies
get_next_action_distribution(
self,
observation: StrDict[D.T_observation]
) -> Distribution[StrDict[list[D.T_event]]]
Get the probabilistic distribution of next action for the given observation (from the solver's current policy).
# Parameters
- observation: The observation to consider.
# Returns
The probabilistic distribution of next action.
# is_policy_defined_for Policies
is_policy_defined_for(
self,
observation: StrDict[D.T_observation]
) -> bool
Check whether the solver's current policy is defined for the given observation.
# Parameters
- observation: The observation to consider.
# Returns
True if the policy is defined for the given observation memory (False otherwise).
# sample_action Policies
sample_action(
self,
observation: StrDict[D.T_observation]
) -> StrDict[list[D.T_event]]
Sample an action for the given observation (from the solver's current policy).
# Parameters
- observation: The observation for which an action must be sampled.
# Returns
The sampled action.
# _get_next_action DeterministicPolicies
_get_next_action(
self,
observation: StrDict[D.T_observation]
) -> StrDict[list[D.T_event]]
Get the next deterministic action (from the solver's current policy).
# Parameters
- observation: The observation for which next action is requested.
# Returns
The next deterministic action.
# _get_next_action_distribution UncertainPolicies
_get_next_action_distribution(
self,
observation: StrDict[D.T_observation]
) -> Distribution[StrDict[list[D.T_event]]]
Get the probabilistic distribution of next action for the given observation (from the solver's current policy).
# Parameters
- observation: The observation to consider.
# Returns
The probabilistic distribution of next action.
# _is_policy_defined_for Policies
_is_policy_defined_for(
self,
observation: StrDict[D.T_observation]
) -> bool
Check whether the solver's current policy is defined for the given observation.
# Parameters
- observation: The observation to consider.
# Returns
True if the policy is defined for the given observation memory (False otherwise).
# _sample_action Policies
_sample_action(
self,
observation: StrDict[D.T_observation]
) -> StrDict[list[D.T_event]]
Sample an action for the given observation (from the solver's current policy).
# Parameters
- observation: The observation for which an action must be sampled.
# Returns
The sampled action.
# rollout
rollout(
domain: Domain,
solver: Optional[Union[Solver, Policies]] = None,
from_memory: Optional[Memory[D.T_state]] = None,
from_action: Optional[StrDict[list[D.T_event]]] = None,
num_episodes: int = 1,
max_steps: Optional[int] = None,
render: bool = True,
max_framerate: Optional[float] = None,
verbose: bool = True,
action_formatter: Optional[Callable[[D.T_event], str]] = <lambda function>,
outcome_formatter: Optional[Callable[[EnvironmentOutcome], str]] = <lambda function>,
return_episodes: bool = False,
goal_logging_level: int = 20,
rollout_callback: Optional[RolloutCallback] = None
) -> Optional[list[tuple[list[StrDict[D.T_observation]], list[StrDict[list[D.T_event]]], list[StrDict[Value[D.T_value]]]]]]
This method will run one or more episodes in a domain according to the policy of a solver.
# Parameters
- domain: The domain in which the episode(s) will be run.
- solver: The solver whose policy will select actions to take (if None, a random policy is used).
- from_memory: The memory or state to consider as rollout starting point (if None, the domain is reset first).
- from_action: The last applied action when from_memory is used (if necessary for initial observation computation).
- num_episodes: The number of episodes to run.
- max_steps: The maximum number of steps for each episode (if None, no limit is set).
- render: Whether to render the episode(s) during rollout if the domain is renderable.
- max_framerate: The maximum number of steps/renders per second (if None, steps/renders are never slowed down).
- verbose: Whether to print information to the console during rollout.
- action_formatter: The function transforming actions in the string to print (if None, no print).
- outcome_formatter: The function transforming EnvironmentOutcome objects in the string to print (if None, no print).
- return_episodes: if True, return the list of episodes, each episode as a tuple of observations, actions, and values. else return nothing.
- goal_logging_level: logging level at which we want to display if goal has been reached or not
# RolloutCallback
Callback used during rollout to add custom behaviour.
One should derives from this one in order to hook in different stages of the rollout.
# at_episode_end RolloutCallback
at_episode_end(
self
)
Called after each episode.
# at_episode_start RolloutCallback
at_episode_start(
self
)
Called before each episode.
# at_episode_step RolloutCallback
at_episode_step(
self,
i_episode: int,
step: int,
domain: Domain,
solver: Union[Solver, Policies],
action: StrDict[list[D.T_event]],
outcome: EnvironmentOutcome[StrDict[D.T_observation], StrDict[Value[D.T_value]], StrDict[D.T_predicate], StrDict[D.T_info]]
) -> bool
# Parameters
- i_episode: current episode number
- step: current step number within the episode
- domain: domain considered
- solver: solver considered (or randomwalk policy if solver was None in rollout)
- action: last action sampled
- outcome: outcome of the last action applied to the domain
# Returns
stopping: if True, the rollout for the current episode stops and the next episode starts.
# at_rollout_end RolloutCallback
at_rollout_end(
self
)
Called at rollout end.
# at_rollout_start RolloutCallback
at_rollout_start(
self
)
Called at rollout start.