Input dimension for pytorch transformer when the sequence length is one

In pytorch, when we want to provide a 2-dimensional input with shape (batch_size, embedding_size), we should be careful with the definition.
By default, the pytorch will treat the first dimension as sequence dimension, and second dimension as embedding dimension.
So we should do two things :
(1) Set batch_first to be True in TransformerEncoderLayer
(2) Modify the input data to be 3 dimensional to indicate it’s a sequence of size of 1

The code example is as the following:

class TransformerModel(nn.Module):
def __init__(self, input_dim, d_model, n_head, n_hid, n_layers, output_dim, dropout=0.5):
super(TransformerModel, self).__init__()
self.model_type = 'Transformer'
self.encoder_layer = nn.TransformerEncoderLayer(d_model, n_head, n_hid, dropout,batch_first = True)
self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, n_layers)
self.input_linear = nn.Linear(input_dim, d_model)
self.output_linear = nn.Linear(d_model, output_dim)

def forward(self, src):
src = self.input_linear(src)
src = src.unsqueeze(1) # Add a sequence dimension (batch_size, seq_len, d_model)
src = self.transformer_encoder(src)
output = self.output_linear(src[:, 0, :])
return output

Author: robot learner
Reprint policy: All articles in this blog are used except for special statements CC BY 4.0 reprint policy. If reproduced, please indicate source robot learner !