Pytorch-如何在模型中引入可学习参数
阅读原文时间:2023年08月09日阅读:1

错误实例:

def init(self):
    self.w1 = torch.nn.Parameter(torch.FloatTensor(1),requires_grad=True).cuda()
    self.w2 = torch.nn.Parameter(torch.FloatTensor(1),requires_grad=True).cuda()
    self.w1.data.fill_(0.3)
    self.w2.data.fill_(0.3)
def forward(self, x):
    out = self.w1 * out1 + self.w2 * out2
    out = self.fc(out)复制

正确实例:

def init(self):
    self.w1 = torch.nn.Parameter(torch.FloatTensor(1),requires_grad=True)
    self.w2 = torch.nn.Parameter(torch.FloatTensor(1),requires_grad=True)
    self.w1.data.fill_(0.3)
    self.w2.data.fill_(0.3)
def forward(self, x):
    out = self.w1 * out1 + self.w2 * out2
    out = self.fc(out)复制