src.rl.vdn¶
Module Contents¶
Classes¶
Introduction¶The |
API¶
- class src.rl.vdn.VDN_Agent(config, networks, learning_rates)[source]¶
Bases:
src.rl.basic_agent.Basic_AgentIntroduction¶
The
VDN_Agentclass implements a Value Decomposition Network (VDN) agent for multi-agent reinforcement learning. This agent is designed to handle cooperative multi-agent environments by decomposing the joint action-value function into individual agent value functions. It supports experience replay, target networks, epsilon-greedy exploration, and parallelized environments. The class provides methods for training, action selection, and evaluation.Original paper¶
“Value-Decomposition Networks For Cooperative Multi-Agent Learning.”
Args¶
config: Configuration object containing all necessary parameters for experiment.For details you can visit config.py.network(dict): A dictionary of neural networks used by the agent, where keys are network names and values are the corresponding network objects.learning_rates(float): Learning rate or a list of learning rates for the optimizer(s).
Attributes¶
n_agent(int): Number of agents in the environment.(default: 4)n_act(int): Number of actions available to each agent.(default: 4)available_action(list): List of available actions for each agent.(default: 4)memory_size(int): Size of the replay buffer.(default: 10000)warm_up_size(int): Number of experiences required in the replay buffer before training starts.(default: 1000)gamma(float): Discount factor for future rewards.(default: 0.99)epsilon(float): Epsilon value for epsilon-greedy exploration.(default: 0.5)epsilon_start(float): Initial epsilon value for exploration.(default: 1)epsilon_end(float): Final epsilon value for exploration.(default: 0.1)epsilon_decay_steps(int): Number of steps for epsilon decay.(default: 10000)max_grad_norm(float): Maximum gradient norm for gradient clipping.(default: 10.0)batch_size(int): Batch size for training.(default: 64)chunk_size(int): Chunk size for sampling trajectories from the replay buffer.(default: 1)update_iter(int): Number of update iterations per training step.(default: 10)device(str): Device used for computation (e.g., ‘cpu’ or ‘cuda’).replay_buffer(MultiAgent_ReplayBuffer): Replay buffer for storing experiences.network(list): List of network names used by the agent.optimizer(torch.optim.Optimizer): Optimizer for training the networks.(default: Adam)criterion(torch.nn.Module): Loss function used for training.(default: MSELoss)learning_time(int): Counter for the number of training steps.cur_checkpoint(int): Counter for the current checkpoint index.
Methods¶
set_network(networks: dict, learning_rates: float): Sets up the networks, optimizer, and loss function for the agent.get_step() -> int: Returns the current training step.update_setting(config): Updates the agent’s configuration and resets training-related attributes.get_action(state, epsilon_greedy=False) -> np.ndarray: Selects an action based on the current state and exploration strategy.train_episode(...): Trains the agent for one episode in a parallelized environment.rollout_episode(env, seed=None, required_info={}) -> dict: Executes a single episode in the environment and returns the results.log_to_tb_train(...): Logs training metrics and information to TensorBoard.
Initialization
Initializes the VDN agent with the given configuration, networks, and learning rates.Store the initial agent in the checkpoint directory.
Args:¶
config: Configuration object containing all necessary parameters for the experiment.
networks (dict): A dictionary of neural networks used by the agent.
learning_rates (float): Learning rate for the optimizer.
- set_network(networks: dict, learning_rates: float)[source]¶
Sets up the networks, optimizer, and loss function for the agent.
Args:¶
networks (dict): A dictionary of neural networks used by the agent.
learning_rates (float): Learning rate for the optimizer.
Raises:¶
ValueError: If the length of the learning rates list does not match the number of networks.
- update_setting(config)[source]¶
Updates the agent’s configuration and resets training-related attributes.
Args:¶
config: Configuration object containing updated parameters.
- get_action(state, epsilon_greedy=False)[source]¶
Selects an action based on the current state and exploration strategy.
Args:¶
state (torch.Tensor): The current state.
epsilon_greedy (bool): Whether to use epsilon-greedy exploration.
Returns:¶
np.ndarray: The selected action(s).
- train_episode(envs, seeds: Optional[Union[int, List[int], src.rl.utils.np.ndarray]], para_mode: Literal[dummy, subproc, ray, ray - subproc] = 'dummy', compute_resource={}, tb_logger=None, required_info={})[source]¶
Trains the agent for one episode in a parallelized environment.
Args:¶
envs: List of environments for training.
seeds: Seeds for reproducibility.
para_mode (str): Parallelization mode for the environments.
compute_resource (dict): Resources for computation (e.g., CPUs, GPUs).
tb_logger: TensorBoard logger for logging training metrics.
required_info (dict): Additional information required from the environment.
Returns:¶
tuple: A boolean indicating whether training has ended and a dictionary with training metrics.
- rollout_episode(env, seed=None, required_info={})[source]¶
Executes a single episode in the environment without training.
Args:¶
env: The environment for the rollout.
seed (int, optional): Seed for reproducibility.
required_info (dict): Additional information required from the environment.
Returns:¶
dict: A dictionary containing episode results such as return, cost, and metadata.
- log_to_tb_train(tb_logger, mini_step, grad_norms, loss, Return, Reward, predict_Q, target_Q, extra_info={})[source]¶
Logs training metrics to TensorBoard.
Args:¶
tb_logger: TensorBoard logger for logging training metrics.
mini_step (int): Current mini-batch step.
grad_norms (tuple): Gradient norms for the networks.
loss (torch.Tensor): Training loss.
Return (torch.Tensor): Episode return.
Reward (torch.Tensor): Target reward.
predict_Q (torch.Tensor): Predicted Q-values.
target_Q (torch.Tensor): Target Q-values.
extra_info (dict): Additional information to log.