开个新坑, pytorch源码阅读…从python代码开始读起.
继承自torch._C._TensorBase
, 包括各种操作,TODO:随后看cpp代码
__abs__, __iter__
之类的内建方法
requires_grad
属性是否需要求导
backward(self, gradient=None, retain_graph=None, create_graph=False)
retain_graph表示是否在backward之后free内存
register_hook(self, hook) 每次gradients
被计算的时候,这个hook
都被调用。返回的handle提供remove hook的能力
v = Variable(torch.Tensor([0, 0, 0]), requires_grad=True)
h = v.register_hook(lambda grad: grad * 2) # double the gradient
v.backward(torch.Tensor([1, 1, 1]))
#先计算原始梯度,再进hook,获得一个新梯度。
print(v.grad.data) #output [2,2,2]
h.remove() # removes the hook, 返回的句柄
一系列优化方法的集合, 基类是optimizer.py, 其余op都是继承这个类, 基础上实现op.step(), 初始化默认参数由__init__
提供. 包括SGD, Adam, RMSProp等, 以SGD为例:
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
optimizer.zero_grad() #初始化
loss_fn(model(input), target).backward()
optimizer.step()
state_dict() & load_state_dict()
更新state, param两个成员, 提供serialize的方法. 理解是可以训练到某个过程中进行op参数的存储, 下次可以继续, 避免训练失败重新训练
add_param_group()
transfer learning中将freeze固定层的参数加入训练时, 可以用该方法.
lr_scheduler
用来进行lr的调整, 动态decay
scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
for epoch in range(100):
scheduler.step()
train(...)
validate(...)
tips:
06ef05bc-004d-4561-b9ba-842076c9884b
手机扫一扫
移动阅读更方便
你可能感兴趣的文章