错误实例:
def init(self):
self.w1 = torch.nn.Parameter(torch.FloatTensor(1),requires_grad=True).cuda()
self.w2 = torch.nn.Parameter(torch.FloatTensor(1),requires_grad=True).cuda()
self.w1.data.fill_(0.3)
self.w2.data.fill_(0.3)
def forward(self, x):
out = self.w1 * out1 + self.w2 * out2
out = self.fc(out)
复制
正确实例:
def init(self):
self.w1 = torch.nn.Parameter(torch.FloatTensor(1),requires_grad=True)
self.w2 = torch.nn.Parameter(torch.FloatTensor(1),requires_grad=True)
self.w1.data.fill_(0.3)
self.w2.data.fill_(0.3)
def forward(self, x):
out = self.w1 * out1 + self.w2 * out2
out = self.fc(out)
复制
手机扫一扫
移动阅读更方便
你可能感兴趣的文章