src.baseline.metabbo.networks

Module Contents

Classes

API

class src.baseline.metabbo.networks.MLP(config)[source]

Bases: torch.nn.Module

Initialization

Parameters:

config – a list of dicts like [{‘in’:2,’out’:4,’drop_out’:0.5,’activation’:’ReLU’}, {‘in’:4,’out’:8,’drop_out’:0,’activation’:’Sigmoid’}, {‘in’:8,’out’:10,’drop_out’:0,’activation’:’None’}], and the number of dicts is customized.

forward(x)[source]
class src.baseline.metabbo.networks.SkipConnection(module)[source]

Bases: torch.nn.Module

Initialization

forward(input)[source]
class src.baseline.metabbo.networks.Normalization(embed_dim, normalization='batch')[source]

Bases: torch.nn.Module

Initialization

forward(input)[source]
class src.baseline.metabbo.networks.MultiHeadAttentionLayerforCritic(n_heads, embed_dim, feed_forward_hidden, normalization='layer')[source]

Bases: torch.nn.Sequential

Initialization

class src.baseline.metabbo.networks.MultiHeadAttention(n_heads, input_dim, embed_dim=None, val_dim=None, key_dim=None)[source]

Bases: torch.nn.Module

Initialization

forward(h, q=None)[source]
class src.baseline.metabbo.networks.MultiHeadCompat(n_heads, input_dim, embed_dim=None, val_dim=None, key_dim=None)[source]

Bases: torch.nn.Module

Initialization

forward(q, h=None, mask=None)[source]
class src.baseline.metabbo.networks.MultiHeadEncoder(n_heads, embed_dim, feed_forward_hidden, normalization='layer')[source]

Bases: torch.nn.Module

Initialization

forward(input, input_=None)[source]
class src.baseline.metabbo.networks.MultiHeadAttentionsubLayer(n_heads, embed_dim, feed_forward_hidden, normalization='layer')[source]

Bases: torch.nn.Module

Initialization

forward(input, input_=None)[source]
class src.baseline.metabbo.networks.FFandNormsubLayer(n_heads, embed_dim, feed_forward_hidden, normalization='layer')[source]

Bases: torch.nn.Module

Initialization

forward(input)[source]
class src.baseline.metabbo.networks.EmbeddingNet(node_dim, embedding_dim)[source]

Bases: torch.nn.Module

Initialization

forward(x)[source]
class src.baseline.metabbo.networks.PositionalEncoding(d_model, max_len)[source]

compute sinusoid encoding.

Initialization

constructor of sinusoid encoding class

Parameters:
  • d_model – dimension of model

  • max_len – max sequence length

get_PE(seq_len)[source]