当前位置:网站首页>Pytorch | how to save and load pytorch models?
Pytorch | how to save and load pytorch models?
2022-06-25 15:13:00 【m0_ sixty-one million eight hundred and ninety-nine thousand on】
Preface
This article is reproduced in PyTorch Depth analysis : How to save and load PyTorch Model ?
Rounding out the PyTorch Model saving and loading method .
Catalog
1 Need to master 3 An important function
2.2 Save and load state_dict ( I've finished training , No more training )
2.3 Save and load the entire model ( I've finished training , No more training )
2.4 Save and load state_dict ( Not finished training , Will continue to train )
2.5 Save multiple models into one file
2.6 Warm up your own model with parameters from other models
2.7 Save in GPU, Load into CPU
2.8 Save in GPU, Load into GPU
2.9 Save in CPU, Load into GPU
1 Need to master 3 An important function
1) torch.save: Save a serialized object to disk . This function uses Python Of pickle Tool for serialization . Model (model)、 tensor (tensor) and A dictionary of various objects (dict) Can be saved with this function .
2) torch.load: take pickled Object file deserialized to memory , It is also convenient to load data into the device .
3) torch.nn.Module.load_state_dict(): Load the parameters of the model .
2 state_dict
2.1 state_dict Introduce
PyTorch in ,torch.nn.Module
Inside the learnable parameters (weights and biases) All put in model.parameters()
Inside . and state_dict It's a Python dictionary object, Map each layer to its parameter tensor On . Be careful : Only layers with learnable parameters (convolutional layers, linear layers), Or it contains registered buffers The layer (batchnorm's running_mean) Only then state_dict. Optimizer objects (torch.optim
) Also have state_dict, Stores the state of the optimizer and its super parameters .
because state_dict It's a Python dictionary object, So save , load , It's easier to update it .
Let's intuitively feel it through an example state_dict Usage of :
# Define model
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# Initialize model
model = TheModelClass()
# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
Output :
Model's state_dict:
conv1.weight torch.Size([6, 3, 5, 5])
conv1.bias torch.Size([6])
conv2.weight torch.Size([16, 6, 5, 5])
conv2.bias torch.Size([16])
fc1.weight torch.Size([120, 400])
fc1.bias torch.Size([120])
fc2.weight torch.Size([84, 120])
fc2.bias torch.Size([84])
fc3.weight torch.Size([10, 84])
fc3.bias torch.Size([10])
Optimizer's state_dict:
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]
2.2 Save and load state_dict ( I've finished training , No more training )
preservation :
torch.save(model.state_dict(), PATH)
load :
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
Generally saved as .pt
or .pth
File format .
Be careful :
have access to
model.eval()
take dropout and batch normalization Layer set to evaluation Pattern .
load_state_dict()
The function needs a dict Type input , Instead of saving the model PATH. So that's whymodel.load_state_dict(PATH)
It's wrong. , And shouldmodel.load_state_dict(torch.load(PATH))
.If you want to save the best performing model on the validator , Such a
best_model_state=model.state_dict()
It's wrong. . Because this is a shallow copy , That is to say, at this moment best_model_state It will be updated with the subsequent training process , The last thing saved is actually a overfit Model of . So the right thing to do isbest_model_state=deepcopy(model.state_dict())
.
2.3 Save and load the entire model ( I've finished training , No more training )
preservation :
torch.save(model, PATH)
load :
# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()
Generally saved as .pt
or .pth
File format .
Be careful :
have access to
model.eval()
take dropout and batch normalization Layer set to evaluation Pattern .
2.4 Save and load state_dict ( Not finished training , Will continue to train )
preservation :
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, PATH)
And 2.2 The difference is that apart from preserving model_state_dict outside , It needs to be preserved :optimizer_state_dict,epoch and loss, Because you need to know the status of the optimizer when you continue training ,epoch wait .
load :
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()
# - or -
model.train()
And 2.2 The difference is that in addition to loading model_state_dict outside , It also needs to be loaded :optimizer_state_dict,epoch and loss.
2.5 Save multiple models into one file
preservation :
torch.save({
'modelA_state_dict': modelA.state_dict(),
'modelB_state_dict': modelB.state_dict(),
'optimizerA_state_dict': optimizerA.state_dict(),
'optimizerB_state_dict': optimizerB.state_dict(),
...
}, PATH)
Put the model A and B Of state_dict and optimizer All in one file .
load :
modelA = TheModelAClass(*args, **kwargs)
modelB = TheModelBClass(*args, **kwargs)
optimizerA = TheOptimizerAClass(*args, **kwargs)
optimizerB = TheOptimizerBClass(*args, **kwargs)
checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])
modelA.eval()
modelB.eval()
# - or -
modelA.train()
modelB.train()
2.6 Warm up your own model with parameters from other models
Sometimes when training a new complex model , You need to load some of its pre training weights . Even if only a few parameters are available , It will also help warmstart Training process , Help the model reach convergence faster .
If you have this state_dict Lack some keys, Or more keys, Just set strict
Parameter is False, You can put state_dict Can match keys Load it in , And ignore those non-matching keys.
Save the model A Of state_dict :
torch.save(modelA.state_dict(), PATH)
Load into model B:
modelB = TheModelBClass(*args, **kwargs)
modelB.load_state_dict(torch.load(PATH), strict=False)
2.7 Save in GPU, Load into CPU
preservation :
torch.save(model.state_dict(), PATH)
load :
device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))
This situation model.state_dict() After preservation in GPU, direct torch.load(PATH) Will be loaded into GPU in . So if you want to load into CPU in , Need to add map_location=torch.device('cpu').
2.8 Save in GPU, Load into GPU
preservation :
torch.save(model.state_dict(), PATH)
load :
map_location="cuda:0"device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model
This situation model.state_dict() After preservation in GPU, direct torch.load(PATH) Will be loaded into GPU in . So if you want to load into GPU in , No need to add map_location=device. Because it will be loaded into GPU Inside ,model Is reinitialized ( stay CPU Inside ), So we need to model.to(device).
2.9 Save in CPU, Load into GPU
preservation :
torch.save(model.state_dict(), PATH)
load :
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model
This situation model.state_dict() After preservation in CPU, direct torch.load(PATH) Will be loaded into CPU in . So if you want to load into GPU in , Need to add map_location="cuda:0" . Because it will be loaded into GPU Inside ,model Is reinitialized ( stay CPU Inside ), So we need to model.to(device).
边栏推荐
- How to package rpm
- Qcodeeditor - QT based code editor
- About?: Notes for
- Judging the number of leap years from 1 to N years
- Qmake uses toplevel or topbuilddir
- Flexible layout (display:flex;) Attribute details
- One code per day - day one
- The robot is playing an old DOS based game
- Arithmetic operations and expressions
- 2.18 codeforces supplement
猜你喜欢
1090.Phone List
How to cut the size of a moving picture? Try this online photo cropping tool
QT loading third-party library basic operation
Data feature analysis skills - correlation test
[C language] implementation of magic square array (the most complete)
Esp8266 building smart home system
From 408 to independent proposition, 211 to postgraduate entrance examination of Guizhou University
Judging the number of leap years from 1 to N years
In 2022, the score line of Guangdong college entrance examination was released, and several families were happy and several worried
Common dynamic memory errors
随机推荐
Compile Caffe's project using cmake
User defined data type - structure
One question per day, punch in
Leetcode122 timing of buying and selling stocks II
Build a minimalist gb28181 gatekeeper and gateway server, establish AI reasoning and 3D service scenarios, and then open source code (I)
[untitled] PTA check password
Stack and queue
JS select all exercise
Mining procedure processing
(1) Introduction
Afterword of Parl intensive learning 7-day punch in camp
Business layer - upper and lower computer communication protocol
Character encoding minutes
AB string interchange
(translation) json-rpc 2.0 specification (Chinese version)
55 specific ways to improve program design (2)
How to combine multiple motion graphs into a GIF? Generate GIF animation pictures in three steps
55 specific ways to improve program design (1)
QT source code online view
Dynamic memory allocation