PyTorch๋ Tensor์์ ๋์ํ๋ ๋๊ท๋ชจ ์ฐ์ฐ์ ๋ผ์ด๋ธ๋ฌ๋ฆฌ(์: torch.add
, torch.sum
๋ฑ)๋ฅผ ์ ๊ณตํฉ๋๋ค.
๊ทธ๋ฌ๋ PyTorch์ ์๋ก์ด ์ฌ์ฉ์ ์ง์ ์ฐ์ฐ์ ๋์
ํ์ฌ torch.compile
, autograd ๋ฐ torch.vmap
์ ๊ฐ์
์๋ธ์์คํ
์์ ๋์ํ๋๋ก ํ ์ ์์ต๋๋ค.
์ด๋ ๊ฒ ํ๋ ค๋ฉด, ํ์ด์ฌ torch.library docs ๋ฌธ์ ๋๋
C++ TORCH_LIBRARY
API๋ฅผ ํตํด PyTorch์ ์ฌ์ฉ์ ์ง์ ์ฐ์ฐ์ ๋ฑ๋กํด์ผ ํฉ๋๋ค.
:ref:`python-custom-ops-tutorial`๋ฅผ ์ฐธ์กฐํ์ญ์์ค.
๋ค์๊ณผ ๊ฐ์ ๊ฒฝ์ฐ ํ์ด์ฌ(C++์ ๋ฐ๋)์์ ์ฌ์ฉ์ ์ง์ ์ฐ์ฐ์๋ฅผ ์์ฑํ ์ ์์ต๋๋ค.
- PyTorch๊ฐ ๋ถํฌ๋ช
ํ ํธ์ถ์ด ๊ฐ๋ฅํ ๊ฒ์ผ๋ก ์ทจ๊ธํ๋ ค๋ ํ์ด์ฌ ํจ์๊ฐ ์๋๋ฐ,
ํนํ
torch.compile
๋ฐ ``torch.export``๊ณผ ๊ฐ์ ๊ฒฝ์ฐ์๋ ๋์ฑ ๊ทธ๋ ์ต๋๋ค. - C++/CUDA ์ปค๋์ ๋ํ ํ์ด์ฌ ๋ฐ์ธ๋ฉ์ด ์์ผ๋ฉฐ PyTorch ์๋ธ์์คํ
(์:
torch.compile
๋๋torch.autograd
)์ผ๋ก ๊ตฌ์ฑ๋๊ธฐ๋ฅผ ์ํฉ๋๋ค.
:ref:`cpp-custom-ops-tutorial`๋ฅผ ์ฐธ์กฐํ์ญ์์ค.
C++(ํ์ด์ฌ๊ณผ ๋ฐ๋)์์ ์ฌ์ฉ์ ์ง์ ์ฐ์ฐ์๋ฅผ ์์ฑํ ์ ์๋ ๊ฒฝ์ฐ: - ์ฌ์ฉ์ ์ง์ C++ ๋ฐ/๋๋ CUDA ์ฝ๋๋ฅผ ๊ฐ์ง๊ณ ์์ ๋ - ์ด ์ฝ๋๋ฅผ ``AOTInductor``์ ํจ๊ป ์ฌ์ฉํ์ฌ ํ์ด์ฌ ์์ด ์ถ๋ก ์ ์ํํ ๋
ํํ ๋ฆฌ์ผ๊ณผ ์ด ํ์ด์ง์์ ๋ค๋ฃจ์ง ์๋ ์ ๋ณด๋ ๋ค์์ ์ฐธ์กฐํ์๊ธฐ ๋ฐ๋๋๋ค. The Custom Operators Manual (์ ๋ณด๋ฅผ ๋ฌธ์ ์ฌ์ดํธ๋ก ์ฎ๊ธฐ๋ ์์ ์ ์งํ ์ค์ ๋๋ค.) ์์ ํํ ๋ฆฌ์ผ ์ค ํ๋๋ฅผ ๋จผ์ ์ฝ์ ๋ค์ ์ฌ์ฉ์ ์ง์ ์ฐ์ฐ์ ์ค๋ช ์๋ฅผ ์ฐธ์กฐ๋ก ์ฌ์ฉํ๋ ๊ฒ์ด ์ข์ต๋๋ค; ์ฒ์๋ถํฐ ๋๊น์ง ์ฝ์ด์๋ ์ ๋ฉ๋๋ค.
์ฐ์ฐ์ด ๋ด์ฅ๋ ํ์ดํ ์น ์ฐ์ฐ์์ ๊ตฌ์ฑ์ผ๋ก ํํํ ์ ์๋ ๊ฒฝ์ฐ ์ฌ์ฉ์ ์ง์ ์ฐ์ฐ์๋ฅผ ๋ง๋๋ ๋์ ํ์ด์ฌ ํจ์๋ก ์์ฑํ์ฌ ํธ์ถํ์ธ์. ๋ค์๊ณผ ๊ฐ์ ๊ฒฝ์ฐ ์ฐ์ฐ์ ๋ฑ๋ก API๋ฅผ ์ฌ์ฉํ์ฌ ์ฌ์ฉ์ ์ง์ ์ฐ์ฐ์๋ฅผ ์์ฑํฉ๋๋ค. PyTorch๊ฐ ์ดํดํ์ง ๋ชปํ๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ(์: ์ฌ์ฉ์ ์ง์ C/C++ ์ฝ๋๋ฅผ ํธ์ถํ๋ ๊ฒฝ์ฐ, ์ปค์คํ CUDA ์ปค๋ ๋๋ C/C++/CUDA ํ์ฅ์ ๋ํ ํ์ด์ฌ ๋ฐ์ธ๋ฉ).
Tensor์ ๋ฐ์ดํฐ ํฌ์ธํฐ๋ฅผ ์ก์ C/C++/CUDA ์ปค๋์ ์ฌ์ฉํ ์ ์์ต๋๋ค. ๊ทธ๋ฆฌ๊ณ ์ด๋ฅผ pybind ์ปค๋์ ์ ๋ฌํฉ๋๋ค. ๊ทธ๋ฌ๋ ์ด ์ ๊ทผ ๋ฐฉ์์ autograd, torch.compile, vmap ๋ฑ๊ณผ ๊ฐ์ PyTorch ์๋ธ์์คํ ๊ณผ๋ ๊ตฌ์ฑ๋์ง ์์ต๋๋ค. PyTorch ์๋ธ์์คํ ์ผ๋ก ์์ ์ ๊ตฌ์ฑํ๋ ค๋ฉด ์ฐ์ฐ์ ๋ฑ๋ก API๋ฅผ ํตํด ์์ ์ ๋ฑ๋กํด์ผ ํฉ๋๋ค.