forked from PyTorchKorea/tutorials-kr
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsaveloadrun_tutorial.py
70 lines (55 loc) ยท 3.61 KB
/
saveloadrun_tutorial.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""
`ํ์ดํ ์น(PyTorch) ๊ธฐ๋ณธ ์ตํ๊ธฐ <intro.html>`_ ||
`๋น ๋ฅธ ์์ <quickstart_tutorial.html>`_ ||
`ํ
์(Tensor) <tensorqs_tutorial.html>`_ ||
`Dataset๊ณผ Dataloader <data_tutorial.html>`_ ||
`๋ณํ(Transform) <transforms_tutorial.html>`_ ||
`์ ๊ฒฝ๋ง ๋ชจ๋ธ ๊ตฌ์ฑํ๊ธฐ <buildmodel_tutorial.html>`_ ||
`Autograd <autogradqs_tutorial.html>`_ ||
`์ต์ ํ(Optimization) <optimization_tutorial.html>`_ ||
**๋ชจ๋ธ ์ ์ฅํ๊ณ ๋ถ๋ฌ์ค๊ธฐ**
๋ชจ๋ธ ์ ์ฅํ๊ณ ๋ถ๋ฌ์ค๊ธฐ
==========================================================================
์ด๋ฒ ์ฅ์์๋ ์ ์ฅํ๊ธฐ๋ ๋ถ๋ฌ์ค๊ธฐ๋ฅผ ํตํด ๋ชจ๋ธ์ ์ํ๋ฅผ ์ ์ง(persist)ํ๊ณ ๋ชจ๋ธ์ ์์ธก์ ์คํํ๋ ๋ฐฉ๋ฒ์ ์์๋ณด๊ฒ ์ต๋๋ค.
"""
import torch
import torchvision.models as models
#######################################################################
# ๋ชจ๋ธ ๊ฐ์ค์น ์ ์ฅํ๊ณ ๋ถ๋ฌ์ค๊ธฐ
# ------------------------------------------------------------------------------------------
#
# PyTorch ๋ชจ๋ธ์ ํ์ตํ ๋งค๊ฐ๋ณ์๋ฅผ ``state_dict``\ ๋ผ๊ณ ๋ถ๋ฆฌ๋ ๋ด๋ถ ์ํ ์ฌ์ (internal state dictionary)์ ์ ์ฅํฉ๋๋ค.
# ์ด ์ํ ๊ฐ๋ค์ ``torch.save`` ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ ์ ์ฅ(persist)ํ ์ ์์ต๋๋ค:
model = models.vgg16(weights='IMAGENET1K_V1')
torch.save(model.state_dict(), 'model_weights.pth')
##########################
# ๋ชจ๋ธ ๊ฐ์ค์น๋ฅผ ๋ถ๋ฌ์ค๊ธฐ ์ํด์๋, ๋จผ์ ๋์ผํ ๋ชจ๋ธ์ ์ธ์คํด์ค(instance)๋ฅผ ์์ฑํ ๋ค์์ ``load_state_dict()`` ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ
# ๋งค๊ฐ๋ณ์๋ค์ ๋ถ๋ฌ์ต๋๋ค.
model = models.vgg16() # ์ฌ๊ธฐ์๋ ``weights`` ๋ฅผ ์ง์ ํ์ง ์์์ผ๋ฏ๋ก, ํ์ต๋์ง ์์ ๋ชจ๋ธ์ ์์ฑํฉ๋๋ค.
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()
###########################
# .. note:: ์ถ๋ก (inference)์ ํ๊ธฐ ์ ์ ``model.eval()`` ๋ฉ์๋๋ฅผ ํธ์ถํ์ฌ ๋๋กญ์์(dropout)๊ณผ ๋ฐฐ์น ์ ๊ทํ(batch normalization)๋ฅผ ํ๊ฐ ๋ชจ๋(evaluation mode)๋ก ์ค์ ํด์ผ ํฉ๋๋ค. ๊ทธ๋ ์ง ์์ผ๋ฉด ์ผ๊ด์ฑ ์๋ ์ถ๋ก ๊ฒฐ๊ณผ๊ฐ ์์ฑ๋ฉ๋๋ค.
#######################################################################
# ๋ชจ๋ธ์ ํํ๋ฅผ ํฌํจํ์ฌ ์ ์ฅํ๊ณ ๋ถ๋ฌ์ค๊ธฐ
# ------------------------------------------------------------------------------------------
#
# ๋ชจ๋ธ์ ๊ฐ์ค์น๋ฅผ ๋ถ๋ฌ์ฌ ๋, ์ ๊ฒฝ๋ง์ ๊ตฌ์กฐ๋ฅผ ์ ์ํ๊ธฐ ์ํด ๋ชจ๋ธ ํด๋์ค๋ฅผ ๋จผ์ ์์ฑ(instantiate)ํด์ผ ํ์ต๋๋ค.
# ์ด ํด๋์ค์ ๊ตฌ์กฐ๋ฅผ ๋ชจ๋ธ๊ณผ ํจ๊ป ์ ์ฅํ๊ณ ์ถ์ผ๋ฉด, (``model.state_dict()``\ ๊ฐ ์๋) ``model`` ์ ์ ์ฅ ํจ์์
# ์ ๋ฌํฉ๋๋ค:
torch.save(model, 'model.pth')
########################
# ๋ค์๊ณผ ๊ฐ์ด ๋ชจ๋ธ์ ๋ถ๋ฌ์ฌ ์ ์์ต๋๋ค:
#
# `Saving and loading torch.nn.Modules <pytorch.org/docs/main/notes/serialization.html#saving-and-loading-torch-nn-modules>`__์์ ์ค๋ช
ํ ๊ฒ์ฒ๋ผ,
# ``state_dict``๋ฅผ ์ ์ฅํ๋ ๊ฒ์ด ๊ฐ์ฅ ์ข์ ๋ฐฉ๋ฒ์ผ๋ก ๊ฐ์ฃผ๋ฉ๋๋ค.
# ํ์ง๋ง ์๋์์๋ ``weights_only=False``๋ฅผ ์ฌ์ฉํ๋๋ฐ,
# ์ด๋ ๋ชจ๋ธ์ ๋ก๋ํ๋ ๊ฒ์ ํฌํจํ๊ธฐ ๋๋ฌธ์ด๋ฉฐ, ``torch.save``์ ๋ ๊ฑฐ์ ์ฌ์ฉ ์ฌ๋ก์
๋๋ค.
model = torch.load('model.pth', weights_only=False),
########################
# .. note:: ์ด ์ ๊ทผ ๋ฐฉ์์ Python `pickle <https://docs.python.org/3/library/pickle.html>`_ ๋ชจ๋์ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ์ง๋ ฌํ(serialize)ํ๋ฏ๋ก, ๋ชจ๋ธ์ ๋ถ๋ฌ์ฌ ๋ ์ค์ ํด๋์ค ์ ์(definition)๋ฅผ ์ ์ฉ(rely on)ํฉ๋๋ค.
#######################
# ๊ด๋ จ ํํ ๋ฆฌ์ผ
# -----------------
# :doc:`/recipes/recipes/saving_and_loading_a_general_checkpoint`
# :doc:`/recipes/recipes/module_load_state_dict_tips`