有時候我們在看別人的論文時會發現:常常會有一些「超參數」的出現,像是 ResNet shortcut 進入的權重值等等
這個時候就可以用 Pytorch 提供的 Parameter 和 buffer 來實作,想知道詳細差在哪裡就繼續往下看吧 ~
keywords: Parameter、buffer
Parameter 和 buffer
有時候我們想要在網路中新增一層或是一個參數時,就可以使用 Parameter 或是 buffer
- Parameter 在反向傳播時「會」隨著網路更新權重值
- Buffer 在反向傳播時「不會」隨著網硬更新權重值
建立方向:
1 2 3 4 5 6 7 8 9 10 11 12
| class MyModel(nn.Module): def __init__(self): super(MyModel, self).__init__() buffer = torch.randn(2, 3) self.register_buffer('my_buffer', buffer) self.param = nn.Parameter(torch.randn(3, 3)) self.register_parameter("param", param)
def forward(self, x): self.my_buffer(x) self.param(x)
|
兩者的共同點就是,在使用 model.state_dict() 的方法來保存、讀取網路模型時,都會被存入到 OrderDict 中
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
| torch.save(model.state_dict(), PATH)
model = MyModel(*args, **kwargs) model.load_state_dict(torch.load(PATH)) model.eval()
model = MyModel() for param in model.parameters(): print(param)
for buffer in model.buffers(): print(buffer)
|
在 ViT 的 Patch Embedding 中有使用到,用在 reletive positional encoding 上,因為相對位置編碼不會隨著網路而更新
1 2 3 4 5 6 7 8 9 10
| class Embeddings(nn.Module): def __init__(self, vocab_size, d_model, dropout=0.1, max_len=5000): super(Embeddings, self).__init__() self.embs = nn.Embedding(vocab_size, d_model) self.d_model = d_model self.dropout = nn.Dropout(dropout)
pe = self._build_position_encoding(max_len, d_model) self.register_buffer("pe", pe)
|
reference
https://zhuanlan.zhihu.com/p/89442276