requires_grad = Ложь
Если вы хотите заморозить часть вашей модели и обучить остальную, вы можете установить requires_grad
из параметров, которые вы хотите заморозить, на False
.
Например, если вы хотите сохранить фиксированной только сверточную часть VGG16:
model = torchvision.models.vgg16(pretrained=True)
for param in model.features.parameters():
param.requires_grad = False
При переключении флагов requires_grad
на False
промежуточные буферы не будут сохраняться, пока вычисление не дойдет до некоторой точки, где для одного из входов операции требуется градиент.
torch.no_grad ()
Использование диспетчера контекста torch.no_grad
- это другой способ достижения этой цели: в контексте no_grad
все результаты вычислений будут иметь requires_grad=False
, даже если входные данные имеют requires_grad=True
. Обратите внимание, что вы не сможете распространять градиент на слои до no_grad
. Например:
x = torch.randn(2, 2)
x.requires_grad = True
lin0 = nn.Linear(2, 2)
lin1 = nn.Linear(2, 2)
lin2 = nn.Linear(2, 2)
x1 = lin0(x)
with torch.no_grad():
x2 = lin1(x1)
x3 = lin2(x2)
x3.sum().backward()
print(lin0.weight.grad, lin1.weight.grad, lin2.weight.grad)
выходы:
(None, None, tensor([[-1.4481, -1.1789],
[-1.4481, -1.1789]]))
Здесь lin1.weight.requires_grad
было True, но градиент не был вычислен, потому что операция была выполнена в контексте no_grad
.
model.eval ()
Если ваша цель не в точной настройке, а в том, чтобы установить модель в режим вывода, наиболее удобный способ - использовать torch.no_grad
диспетчер контекста. В этом случае вам также необходимо установить модель в режим оценки, это достигается путем вызова eval()
на nn.Module
, например:
model = torchvision.models.vgg16(pretrained=True)
model.eval()
Эта операция устанавливает для атрибута self.training
слоев значение False
, на практике это изменит поведение таких операций, как Dropout
или BatchNorm
, которые должны вести себя по-разному во время обучения и тестирования.
08.08.2018
torch.no_grad()
диспетчером контекста иt.requires_grad=False
, особенно когда речь идет об эффективности памяти? Как вы упомянули ранее,t.requires_grad=False
, никакие промежуточные буферы не будут сохраняться, будет ли это более эффективным? 09.08.2018