src.rl.a2c¶
Module Contents¶
Classes¶
Introduction¶A class to store and manage the memory required for reinforcement learning algorithms. It keeps track of actions, states, log probabilities, and rewards during an episode and provides functionality to clear the stored memory. |
|
Introduction¶The |
API¶
- class src.rl.a2c.Memory[source]¶
Introduction¶
A class to store and manage the memory required for reinforcement learning algorithms. It keeps track of actions, states, log probabilities, and rewards during an episode and provides functionality to clear the stored memory.
Methods:¶
init(): Initializes the memory by creating empty lists for actions, states, log probabilities, and rewards.
clear_memory(): Clears the stored memory by deleting the lists of actions, states, log probabilities, and rewards.
Initialization
Initializes the memory by creating empty lists for actions, states, log probabilities, and rewards.
- class src.rl.a2c.A2C_Agent(config, networks: dict, learning_rates: float)[source]¶
Bases:
src.rl.basic_agent.Basic_AgentIntroduction¶
The
A2C_Agentclass implements an Advantage Actor-Critic (A2C) agent for reinforcement learning. This agent uses actor and critic networks to optimize policies and guide the low_level optimizer to optimize.Original paper¶
“Actor-Critic Algorithms.” Advances in Neural Information Processing Systems (NIPS), 1999
Args¶
config: Configuration object containing all necessary parameters for experiment.For details you can visit config.py.networks(dict): A dictionary of neural networks used by the agent, with keys as network names (e.g., ‘actor’, ‘critic’) and values as the corresponding network instances.learning_rates(float): Learning rate for the optimizer.
Attributes¶
gamma(float): Discount factor for future rewards.n_step(int): Number of steps for multi-step returns.max_grad_norm(float): Maximum gradient norm for gradient clipping.device(str): Device to run the computations on (e.g., ‘cpu’ or ‘cuda’).network(list): List of network names used by the agent.optimizer(torch.optim.Optimizer): Optimizer for training the networks.learning_time(int): Counter for the number of training steps completed.cur_checkpoint(int): Counter for the current checkpoint index.
Methods¶
set_network(networks, learning_rates): Initializes the networks and optimizer for the agent.get_step(): Returns the current training step count.update_setting(config): Updates the agent’s configuration and resets training-related attributes.train_episode(envs, para_mode, compute_resource, tb_logger, required_info): Trains the agent for one episode in a parallelized environment.log_to_tb_train(tb_logger, mini_step, grad_norms, reinforce_loss, baseline_loss, Return, Reward, memory_reward, critic_output, logprobs, entropy, approx_kl_divergence, extra_info): Logs training metrics to TensorBoard.rollout_episode(env, seed, required_info): Executes a single rollout episode in the environment and returns the results.rollout_batch_episode(envs, seeds, para_mode, compute_resource, required_info): Executes batch rollout episodes in parallelized environments and returns the results.
Initialization
Initializes the A2C_Agent with the given configuration, networks, and learning rates.
Args:¶
config: Configuration object containing all necessary parameters for the experiment.
networks (dict): A dictionary of neural networks used by the agent.
learning_rates (float): Learning rate for the optimizer.
- set_network(networks: dict, learning_rates: float)[source]¶
Initializes the networks and optimizer for the agent.
Args:¶
networks (dict): A dictionary of neural networks used by the agent.
learning_rates (float): Learning rate for the optimizer.
Raises:¶
AssertionError: If required network attributes (e.g., ‘actor’, ‘critic’) are not set.
ValueError: If the length of the learning rates list does not match the number of networks.
AttributeError: If the specified optimizer in the configuration is not available in
torch.optim.
- get_step()[source]¶
Returns the current training step count.
Returns:¶
int: The current training step count.
- update_setting(config)[source]¶
Updates the agent’s configuration and resets training-related attributes.Store the initial agent in the checkpoint directory.
Args:¶
config: Configuration object containing updated parameters.
- train_episode(envs, para_mode: Literal[dummy, subproc, ray, ray - subproc] = 'dummy', compute_resource={}, tb_logger=None, required_info={})[source]¶
Trains the agent for one episode in a parallelized environment.
Args:¶
envs: List of environments for training.
para_mode (str): Parallelization mode for the environments.
compute_resource (dict): Resources for computation (e.g., CPUs, GPUs).
tb_logger: TensorBoard logger for logging training metrics.
required_info (dict): Additional information required from the environment.
Returns:¶
tuple: A boolean indicating whether training has ended and a dictionary with training metrics.
- log_to_tb_train(tb_logger, mini_step, grad_norms, reinforce_loss, baseline_loss, Return, Reward, memory_reward, critic_output, logprobs, entropy, approx_kl_divergence, extra_info={})[source]¶
Logs training metrics to TensorBoard.
Args:¶
tb_logger: TensorBoard logger for logging training metrics.
mini_step (int): Current mini-batch step.
grad_norms (tuple): Gradient norms for the networks.
reinforce_loss (torch.Tensor): Actor loss.
baseline_loss (torch.Tensor): Critic loss.
Return (torch.Tensor): Episode return.
Reward (torch.Tensor): Target reward.
memory_reward (list): List of rewards from memory.
critic_output (torch.Tensor): Critic network output.
logprobs (torch.Tensor): Log probabilities of actions.
entropy (torch.Tensor): Entropy of the policy.
approx_kl_divergence (torch.Tensor): Approximate KL divergence.
extra_info (dict): Additional information to log.
- rollout_episode(env, seed=None, required_info={})[source]¶
Executes a single rollout episode in the environment and returns the results.
Args:¶
env: The environment for the rollout.
seed (int, optional): Seed for reproducibility.
required_info (dict): Additional information required from the environment.
Returns:¶
dict: A dictionary containing the total return and additional requested information.