diff --git a/.circleci/config.yml b/.circleci/config.yml new file mode 100644 index 00000000000..f00a39bb176 --- /dev/null +++ b/.circleci/config.yml @@ -0,0 +1,442 @@ +version: 2.1 + +executors: + windows-cpu: + machine: + resource_class: windows.xlarge + image: windows-server-2019-vs2019:stable + shell: bash.exe + + windows-gpu: + machine: + resource_class: windows.gpu.nvidia.medium + image: windows-server-2019-nvidia:stable + shell: bash.exe + + +commands: + checkout_merge: + description: "checkout merge branch" + steps: + - checkout + designate_upload_channel: + description: "inserts the correct upload channel into ${BASH_ENV}" + steps: + - run: + name: adding UPLOAD_CHANNEL to BASH_ENV + command: | + our_upload_channel=nightly + # On tags upload to test instead + if [[ -n "${CIRCLE_TAG}" ]] || [[ ${CIRCLE_BRANCH} =~ release/* ]]; then + our_upload_channel=test + fi + echo "export UPLOAD_CHANNEL=${our_upload_channel}" >> ${BASH_ENV} + apt_install: + parameters: + args: + type: string + descr: + type: string + default: "" + update: + type: boolean + default: true + steps: + - run: + name: > + <<^ parameters.descr >> apt install << parameters.args >> <> + <<# parameters.descr >> << parameters.descr >> <> + command: | + <<# parameters.update >> sudo apt update -qy <> + sudo apt install << parameters.args >> + pip_install: + parameters: + args: + type: string + descr: + type: string + default: "" + user: + type: boolean + default: true + steps: + - run: + name: > + <<^ parameters.descr >> pip install << parameters.args >> <> + <<# parameters.descr >> << parameters.descr >> <> + command: > + pip install + <<# parameters.user >> --user <> + --progress-bar=off + << parameters.args >> + + install_torchrl: + parameters: + editable: + type: boolean + default: true + steps: + - pip_install: + args: --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html + descr: Install PyTorch from nightly releases + - pip_install: + args: --no-build-isolation <<# parameters.editable >> --editable <> . + descr: Install torchrl <<# parameters.editable >> in editable mode <> + + +binary_common: &binary_common + parameters: + # Edit these defaults to do a release + build_version: + description: "version number of release binary; by default, build a nightly" + type: string + default: "" + pytorch_version: + description: "PyTorch version to build against; by default, use a nightly" + type: string + default: "" + # Don't edit these + python_version: + description: "Python version to build against (e.g., 3.7)" + type: string + cu_version: + description: "CUDA version to build against, in CU format (e.g., cpu or cu100)" + type: string + default: "cpu" + unicode_abi: + description: "Python 2.7 wheel only: whether or not we are cp27mu (default: no)" + type: string + default: "" + wheel_docker_image: + description: "Wheel only: what docker image to use" + type: string + default: "pytorch/manylinux-cuda102" + conda_docker_image: + description: "Conda only: what docker image to use" + type: string + default: "pytorch/conda-builder:cpu" + environment: + PYTHON_VERSION: << parameters.python_version >> + PYTORCH_VERSION: << parameters.pytorch_version >> + UNICODE_ABI: << parameters.unicode_abi >> + CU_VERSION: << parameters.cu_version >> + +smoke_test_common: &smoke_test_common + <<: *binary_common + docker: + - image: torchrl/smoke_test:latest + +jobs: +# circleci_consistency: +# docker: +# - image: circleci/python:3.7 +# steps: +# - checkout +# - pip_install: +# args: jinja2 pyyaml +# - run: +# name: Check CircleCI config consistency +# command: | +# python .circleci/regenerate.py +# git diff --exit-code || (echo ".circleci/config.yml not in sync with config.yml.in! Run .circleci/regenerate.py to update config"; exit 1) + + lint_python_and_config: + docker: + - image: circleci/python:3.7 + steps: + - checkout + - pip_install: + args: pre-commit + descr: Install lint utilities + - run: + name: Install pre-commit hooks + command: pre-commit install-hooks + - run: + name: Lint Python code and config files + command: pre-commit run --all-files + - run: + name: Required lint modifications + when: on_fail + command: git --no-pager diff + + lint_c: + docker: + - image: circleci/python:3.7 + steps: + - apt_install: + args: libtinfo5 + descr: Install additional system libraries + - checkout + - run: + name: Install lint utilities + command: | + curl https://oss-clang-format.s3.us-east-2.amazonaws.com/linux64/clang-format-linux64 -o clang-format + chmod +x clang-format + sudo mv clang-format /opt/clang-format + - run: + name: Lint C code + command: ./.circleci/unittest/linux/scripts/run-clang-format.py -r torchrl/csrc --clang-format-executable /opt/clang-format + - run: + name: Required lint modifications + when: on_fail + command: git --no-pager diff + + type_check_python: + docker: + - image: circleci/python:3.7 + steps: + - checkout + - pip_install: + args: cmake ninja + descr: Install CMake and Ninja + - install_torchrl: + editable: true + - pip_install: + args: mypy + descr: Install Python type check utilities + - run: + name: Check Python types statically + command: mypy --install-types --non-interactive --config-file mypy.ini + + binary_linux_wheel: + <<: *binary_common + docker: + - image: << parameters.wheel_docker_image >> + resource_class: 2xlarge+ + steps: + - checkout_merge + - designate_upload_channel + - run: packaging/build_wheel.sh + - store_artifacts: + path: dist + - persist_to_workspace: + root: dist + paths: + - "*" + + unittest_linux_cpu: + <<: *binary_common + + docker: + - image: "pytorch/manylinux-cuda102" + resource_class: 2xlarge+ + + environment: + TAR_OPTIONS: --no-same-owner + PYTHON_VERSION: << parameters.python_version >> + + steps: + - checkout + - designate_upload_channel + - run: + name: Generate cache key + # This will refresh cache on Sundays, nightly build should generate new cache. + command: echo "$(date +"%Y-%U")" > .circleci-weekly + - restore_cache: + + keys: + - env-v2-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }} + + - run: + name: Setup + command: .circleci/unittest/linux/scripts/setup_env.sh + + - save_cache: + + key: env-v2-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }} + + paths: + - conda + - env + - run: + name: Install torchrl + command: .circleci/unittest/linux/scripts/install.sh + - run: + name: Run tests + command: .circleci/unittest/linux/scripts/run_test.sh + - run: + name: Post process + command: .circleci/unittest/linux/scripts/post_process.sh + - store_test_results: + path: test-results + + unittest_linux_gpu: + <<: *binary_common + machine: + image: ubuntu-1604-cuda-10.2:202012-01 + resource_class: gpu.nvidia.medium + environment: + image_name: "pytorch/manylinux-cuda102" + TAR_OPTIONS: --no-same-owner + PYTHON_VERSION: << parameters.python_version >> + + steps: + - checkout + - designate_upload_channel + - run: + name: Generate cache key + # This will refresh cache on Sundays, nightly build should generate new cache. + command: echo "$(date +"%Y-%U")" > .circleci-weekly + - restore_cache: + + keys: + - env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }} + + - run: + name: Setup + command: docker run -e PYTHON_VERSION -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/setup_env.sh + - save_cache: + + key: env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }} + + paths: + - conda + - env + - run: + # Here we create an envlist file that contains some env variables that we want the docker container to be aware of. + # Normally, the CIRCLECI variable is set and available on all CI workflows: https://circleci.com/docs/2.0/env-vars/#built-in-environment-variables. + # They're availble in all the other workflows (OSX and Windows). + # But here, we're running the unittest_linux_gpu workflows in a docker container, where those variables aren't accessible. + # So instead we dump the variables we need in env.list and we pass that file when invoking "docker run". + name: export CIRCLECI env var + command: echo "CIRCLECI=true" >> ./env.list + - run: + name: Install torchrl + command: docker run -t --gpus all -v $PWD:$PWD -w $PWD -e UPLOAD_CHANNEL -e CU_VERSION "${image_name}" .circleci/unittest/linux/scripts/install.sh + - run: + name: Run tests + command: docker run --env-file ./env.list -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/run_test.sh + - run: + name: Post Process + command: docker run -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/post_process.sh + - store_test_results: + path: test-results + + + unittest_linux_stable_cpu: + <<: *binary_common + + docker: + - image: "pytorch/manylinux-cuda102" + resource_class: 2xlarge+ + + environment: + TAR_OPTIONS: --no-same-owner + PYTHON_VERSION: << parameters.python_version >> + + steps: + - checkout + - designate_upload_channel + - run: + name: Generate cache key + # This will refresh cache on Sundays, nightly build should generate new cache. + command: echo "$(date +"%Y-%U")" > .circleci-weekly + - restore_cache: + + keys: + - env-v2-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux_stable/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }} + + - run: + name: Setup + command: .circleci/unittest/linux_stable/scripts/setup_env.sh + + - save_cache: + + key: env-v2-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux_stable/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }} + + paths: + - conda + - env + - run: + name: Install torchrl + command: .circleci/unittest/linux_stable/scripts/install.sh + - run: + name: Run tests + command: .circleci/unittest/linux_stable/scripts/run_test.sh + - run: + name: Post process + command: .circleci/unittest/linux_stable/scripts/post_process.sh + - store_test_results: + path: test-results + + unittest_linux_stable_gpu: + <<: *binary_common + machine: + image: ubuntu-1604-cuda-10.2:202012-01 + resource_class: gpu.nvidia.medium + environment: + image_name: "pytorch/manylinux-cuda102" + TAR_OPTIONS: --no-same-owner + PYTHON_VERSION: << parameters.python_version >> + + steps: + - checkout + - designate_upload_channel + - run: + name: Generate cache key + # This will refresh cache on Sundays, nightly build should generate new cache. + command: echo "$(date +"%Y-%U")" > .circleci-weekly + - restore_cache: + + keys: + - env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux_stable/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }} + + - run: + name: Setup + command: docker run -e PYTHON_VERSION -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux_stable/scripts/setup_env.sh + - save_cache: + + key: env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux_stable/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }} + + paths: + - conda + - env + - run: + # Here we create an envlist file that contains some env variables that we want the docker container to be aware of. + # Normally, the CIRCLECI variable is set and available on all CI workflows: https://circleci.com/docs/2.0/env-vars/#built-in-environment-variables. + # They're availble in all the other workflows (OSX and Windows). + # But here, we're running the unittest_linux_gpu workflows in a docker container, where those variables aren't accessible. + # So instead we dump the variables we need in env.list and we pass that file when invoking "docker run". + name: export CIRCLECI env var + command: echo "CIRCLECI=true" >> ./env.list + - run: + name: Install torchrl + command: docker run -t --gpus all -v $PWD:$PWD -w $PWD -e UPLOAD_CHANNEL -e CU_VERSION "${image_name}" .circleci/unittest/linux_stable/scripts/install.sh + - run: + name: Run tests + command: docker run --env-file ./env.list -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux_stable/scripts/run_test.sh + - run: + name: Post Process + command: docker run -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux_stable/scripts/post_process.sh + - store_test_results: + path: test-results + +workflows: + lint: + jobs: +# - circleci_consistency + - lint_python_and_config + - lint_c +# - type_check_python + + unittest: + jobs: + - unittest_linux_cpu: + cu_version: cpu + name: unittest_linux_cpu_py3.8 + python_version: '3.8' + + - unittest_linux_gpu: + cu_version: cu102 + name: unittest_linux_gpu_py3.8 + python_version: '3.8' + + - unittest_linux_stable_cpu: + cu_version: cpu + name: unittest_linux_stable_cpu_py3.8 + python_version: '3.8' + + - unittest_linux_stable_gpu: + cu_version: cu102 + name: unittest_linux_stable_gpu_py3.8 + python_version: '3.8' diff --git a/.circleci/unittest/linux/scripts/environment.yml b/.circleci/unittest/linux/scripts/environment.yml new file mode 100644 index 00000000000..0af0ab2ca28 --- /dev/null +++ b/.circleci/unittest/linux/scripts/environment.yml @@ -0,0 +1,26 @@ +channels: + - pytorch + - defaults +dependencies: + - pip + - cmake >= 3.18 + - pip: + - hypothesis + - protobuf + - future + - cloudpickle + - gym_retro + - gym + - pygame + - gym[accept-rom-license] + - gym[atari] + - moviepy + - tqdm + - pytest + - pytest-cov + - pytest-mock + - expecttest + - pyyaml + - scipy + - dm_control + - mujoco_py diff --git a/.circleci/unittest/linux/scripts/install.sh b/.circleci/unittest/linux/scripts/install.sh new file mode 100755 index 00000000000..ea1cad8dd3c --- /dev/null +++ b/.circleci/unittest/linux/scripts/install.sh @@ -0,0 +1,62 @@ +#!/usr/bin/env bash + +unset PYTORCH_VERSION +# For unittest, nightly PyTorch is used as the following section, +# so no need to set PYTORCH_VERSION. +# In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config. + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +if [ "${CU_VERSION:-}" == cpu ] ; then + cudatoolkit="cpuonly" + version="cpu" +else + if [[ ${#CU_VERSION} -eq 4 ]]; then + CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}" + elif [[ ${#CU_VERSION} -eq 5 ]]; then + CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}" + fi + echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)" + version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" + cudatoolkit="cudatoolkit=${version}" +fi + +case "$(uname -s)" in + Darwin*) os=MacOSX;; + *) os=Linux +esac + +# submodules +git submodule sync && git submodule update --init --recursive + +#printf "Installing PyTorch with %s\n" "${cudatoolkit}" +#if [ "${os}" == "MacOSX" ]; then +# conda install -y -c "pytorch-${UPLOAD_CHANNEL}" "pytorch-${UPLOAD_CHANNEL}"::pytorch "${cudatoolkit}" pytest +#else +# conda install -y -c "pytorch-${UPLOAD_CHANNEL}" "pytorch-${UPLOAD_CHANNEL}"::pytorch[build="*${version}*"] "${cudatoolkit}" pytest +#fi + +#printf "Installing PyTorch with %s\n" "${CU_VERSION}" +if [ "${CU_VERSION:-}" == cpu ] ; then + # conda install -y pytorch torchvision cpuonly -c pytorch-nightly + # use pip to install pytorch as conda can frequently pick older release + pip install torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre +else + conda install -y pytorch torchvision cudatoolkit=10.2 -c pytorch-nightly +fi + +printf "Installing functorch\n" +pip install ninja # Makes the build go faster +pip install "git+https://github.com/pytorch/functorch.git" + +# smoke test +python -c "import functorch" + +printf "* Installing torchrl\n" +python setup.py develop + +# smoke test +python -c "import torchrl" diff --git a/.circleci/unittest/linux/scripts/post_process.sh b/.circleci/unittest/linux/scripts/post_process.sh new file mode 100755 index 00000000000..e97bf2a7b1b --- /dev/null +++ b/.circleci/unittest/linux/scripts/post_process.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env diff --git a/.circleci/unittest/linux/scripts/run-clang-format.py b/.circleci/unittest/linux/scripts/run-clang-format.py new file mode 100755 index 00000000000..5783a885d86 --- /dev/null +++ b/.circleci/unittest/linux/scripts/run-clang-format.py @@ -0,0 +1,356 @@ +#!/usr/bin/env python +""" +MIT License + +Copyright (c) 2017 Guillaume Papin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +A wrapper script around clang-format, suitable for linting multiple files +and to use for continuous integration. + +This is an alternative API for the clang-format command line. +It runs over multiple files and directories in parallel. +A diff output is produced and a sensible exit code is returned. + +""" + +import argparse +import difflib +import fnmatch +import multiprocessing +import os +import signal +import subprocess +import sys +import traceback +from functools import partial + +try: + from subprocess import DEVNULL # py3k +except ImportError: + DEVNULL = open(os.devnull, "wb") + + +DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu" + + +class ExitStatus: + SUCCESS = 0 + DIFF = 1 + TROUBLE = 2 + + +def list_files(files, recursive=False, extensions=None, exclude=None): + if extensions is None: + extensions = [] + if exclude is None: + exclude = [] + + out = [] + for file in files: + if recursive and os.path.isdir(file): + for dirpath, dnames, fnames in os.walk(file): + fpaths = [os.path.join(dirpath, fname) for fname in fnames] + for pattern in exclude: + # os.walk() supports trimming down the dnames list + # by modifying it in-place, + # to avoid unnecessary directory listings. + dnames[:] = [ + x + for x in dnames + if not fnmatch.fnmatch(os.path.join(dirpath, x), pattern) + ] + fpaths = [x for x in fpaths if not fnmatch.fnmatch(x, pattern)] + for f in fpaths: + ext = os.path.splitext(f)[1][1:] + if ext in extensions: + out.append(f) + else: + out.append(file) + return out + + +def make_diff(file, original, reformatted): + return list( + difflib.unified_diff( + original, + reformatted, + fromfile=f"{file}\t(original)", + tofile=f"{file}\t(reformatted)", + n=3, + ) + ) + + +class DiffError(Exception): + def __init__(self, message, errs=None): + super().__init__(message) + self.errs = errs or [] + + +class UnexpectedError(Exception): + def __init__(self, message, exc=None): + super().__init__(message) + self.formatted_traceback = traceback.format_exc() + self.exc = exc + + +def run_clang_format_diff_wrapper(args, file): + try: + ret = run_clang_format_diff(args, file) + return ret + except DiffError: + raise + except Exception as e: + raise UnexpectedError(f"{file}: {e.__class__.__name__}: {e}", e) + + +def run_clang_format_diff(args, file): + try: + with open(file, encoding="utf-8") as f: + original = f.readlines() + except OSError as exc: + raise DiffError(str(exc)) + invocation = [args.clang_format_executable, file] + + # Use of utf-8 to decode the process output. + # + # Hopefully, this is the correct thing to do. + # + # It's done due to the following assumptions (which may be incorrect): + # - clang-format will returns the bytes read from the files as-is, + # without conversion, and it is already assumed that the files use utf-8. + # - if the diagnostics were internationalized, they would use utf-8: + # > Adding Translations to Clang + # > + # > Not possible yet! + # > Diagnostic strings should be written in UTF-8, + # > the client can translate to the relevant code page if needed. + # > Each translation completely replaces the format string + # > for the diagnostic. + # > -- http://clang.llvm.org/docs/InternalsManual.html#internals-diag-translation + + try: + proc = subprocess.Popen( + invocation, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + encoding="utf-8", + ) + except OSError as exc: + raise DiffError( + f"Command '{subprocess.list2cmdline(invocation)}' failed to start: {exc}" + ) + proc_stdout = proc.stdout + proc_stderr = proc.stderr + + # hopefully the stderr pipe won't get full and block the process + outs = list(proc_stdout.readlines()) + errs = list(proc_stderr.readlines()) + proc.wait() + if proc.returncode: + raise DiffError( + "Command '{}' returned non-zero exit status {}".format( + subprocess.list2cmdline(invocation), proc.returncode + ), + errs, + ) + return make_diff(file, original, outs), errs + + +def bold_red(s): + return "\x1b[1m\x1b[31m" + s + "\x1b[0m" + + +def colorize(diff_lines): + def bold(s): + return "\x1b[1m" + s + "\x1b[0m" + + def cyan(s): + return "\x1b[36m" + s + "\x1b[0m" + + def green(s): + return "\x1b[32m" + s + "\x1b[0m" + + def red(s): + return "\x1b[31m" + s + "\x1b[0m" + + for line in diff_lines: + if line[:4] in ["--- ", "+++ "]: + yield bold(line) + elif line.startswith("@@ "): + yield cyan(line) + elif line.startswith("+"): + yield green(line) + elif line.startswith("-"): + yield red(line) + else: + yield line + + +def print_diff(diff_lines, use_color): + if use_color: + diff_lines = colorize(diff_lines) + sys.stdout.writelines(diff_lines) + + +def print_trouble(prog, message, use_colors): + error_text = "error:" + if use_colors: + error_text = bold_red(error_text) + print(f"{prog}: {error_text} {message}", file=sys.stderr) + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--clang-format-executable", + metavar="EXECUTABLE", + help="path to the clang-format executable", + default="clang-format", + ) + parser.add_argument( + "--extensions", + help=f"comma separated list of file extensions (default: {DEFAULT_EXTENSIONS})", + default=DEFAULT_EXTENSIONS, + ) + parser.add_argument( + "-r", + "--recursive", + action="store_true", + help="run recursively over directories", + ) + parser.add_argument("files", metavar="file", nargs="+") + parser.add_argument("-q", "--quiet", action="store_true") + parser.add_argument( + "-j", + metavar="N", + type=int, + default=0, + help="run N clang-format jobs in parallel (default number of cpus + 1)", + ) + parser.add_argument( + "--color", + default="auto", + choices=["auto", "always", "never"], + help="show colored diff (default: auto)", + ) + parser.add_argument( + "-e", + "--exclude", + metavar="PATTERN", + action="append", + default=[], + help="exclude paths matching the given glob-like pattern(s) from recursive search", + ) + + args = parser.parse_args() + + # use default signal handling, like diff return SIGINT value on ^C + # https://bugs.python.org/issue14229#msg156446 + signal.signal(signal.SIGINT, signal.SIG_DFL) + try: + signal.SIGPIPE + except AttributeError: + # compatibility, SIGPIPE does not exist on Windows + pass + else: + signal.signal(signal.SIGPIPE, signal.SIG_DFL) + + colored_stdout = False + colored_stderr = False + if args.color == "always": + colored_stdout = True + colored_stderr = True + elif args.color == "auto": + colored_stdout = sys.stdout.isatty() + colored_stderr = sys.stderr.isatty() + + version_invocation = [args.clang_format_executable, "--version"] + try: + subprocess.check_call(version_invocation, stdout=DEVNULL) + except subprocess.CalledProcessError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + return ExitStatus.TROUBLE + except OSError as e: + print_trouble( + parser.prog, + f"Command '{subprocess.list2cmdline(version_invocation)}' failed to start: {e}", + use_colors=colored_stderr, + ) + return ExitStatus.TROUBLE + + retcode = ExitStatus.SUCCESS + files = list_files( + args.files, + recursive=args.recursive, + exclude=args.exclude, + extensions=args.extensions.split(","), + ) + + if not files: + return + + njobs = args.j + if njobs == 0: + njobs = multiprocessing.cpu_count() + 1 + njobs = min(len(files), njobs) + + if njobs == 1: + # execute directly instead of in a pool, + # less overhead, simpler stacktraces + it = (run_clang_format_diff_wrapper(args, file) for file in files) + pool = None + else: + pool = multiprocessing.Pool(njobs) + it = pool.imap_unordered(partial(run_clang_format_diff_wrapper, args), files) + while True: + try: + outs, errs = next(it) + except StopIteration: + break + except DiffError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + retcode = ExitStatus.TROUBLE + sys.stderr.writelines(e.errs) + except UnexpectedError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + sys.stderr.write(e.formatted_traceback) + retcode = ExitStatus.TROUBLE + # stop at the first unexpected error, + # something could be very wrong, + # don't process all files unnecessarily + if pool: + pool.terminate() + break + else: + sys.stderr.writelines(errs) + if outs == []: + continue + if not args.quiet: + print_diff(outs, use_color=colored_stdout) + if retcode == ExitStatus.SUCCESS: + retcode = ExitStatus.DIFF + return retcode + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/.circleci/unittest/linux/scripts/run_test.sh b/.circleci/unittest/linux/scripts/run_test.sh new file mode 100755 index 00000000000..cfa9862f060 --- /dev/null +++ b/.circleci/unittest/linux/scripts/run_test.sh @@ -0,0 +1,16 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +export PYTORCH_TEST_WITH_SLOW='1' +python -m torch.utils.collect_env +root_dir="$(git rev-parse --show-toplevel)" +export MJLIB_PATH=$root_dir/.mujoco/mujoco-2.1.1/lib/libmujoco.so.2.1.1 +export MUJOCO_PY_MUJOCO_PATH=$root_dir/.mujoco/mujoco210 +export DISPLAY=unix:0.0 +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/root/project/.mujoco/mujoco210/bin +#MUJOCO_GL=glfw pytest --cov=torchrl --junitxml=test-results/junit.xml -v --durations 20 +MUJOCO_GL=glfw pytest -v --durations 20 diff --git a/.circleci/unittest/linux/scripts/setup_env.sh b/.circleci/unittest/linux/scripts/setup_env.sh new file mode 100755 index 00000000000..18f4592ae42 --- /dev/null +++ b/.circleci/unittest/linux/scripts/setup_env.sh @@ -0,0 +1,57 @@ +#!/usr/bin/env bash + +# This script is for setting up environment in which unit test is ran. +# To speed up the CI time, the resulting environment is cached. +# +# Do not install PyTorch and torchvision here, otherwise they also get cached. + +set -e + +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +root_dir="$(git rev-parse --show-toplevel)" +conda_dir="${root_dir}/conda" +env_dir="${root_dir}/env" + +cd "${root_dir}" + +case "$(uname -s)" in + Darwin*) os=MacOSX;; + *) os=Linux +esac + +# 1. Install conda at ./conda +if [ ! -d "${conda_dir}" ]; then + printf "* Installing conda\n" + wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh" + bash ./miniconda.sh -b -f -p "${conda_dir}" +fi +eval "$(${conda_dir}/bin/conda shell.bash hook)" + +# 2. Create test environment at ./env +printf "python: ${PYTHON_VERSION}\n" +if [ ! -d "${env_dir}" ]; then + printf "* Creating a test environment\n" + conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION" +fi +conda activate "${env_dir}" + +# 3. Install mujoco +printf "* Installing mujoco and related\n" +mkdir $root_dir/.mujoco +cd $root_dir/.mujoco/ +wget https://github.com/deepmind/mujoco/releases/download/2.1.1/mujoco-2.1.1-linux-x86_64.tar.gz +tar -xf mujoco-2.1.1-linux-x86_64.tar.gz +wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz +tar -xf mujoco210-linux-x86_64.tar.gz +cd $this_dir + +# 4. Install Conda dependencies +printf "* Installing dependencies (except PyTorch)\n" +echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" +cat "${this_dir}/environment.yml" +conda env update --file "${this_dir}/environment.yml" --prune + +#yum makecache +#yum -y install glfw-devel +#yum -y install libGLEW +#yum -y install gcc-c++ diff --git a/.circleci/unittest/linux_stable/scripts/environment.yml b/.circleci/unittest/linux_stable/scripts/environment.yml new file mode 100644 index 00000000000..b64ed32f221 --- /dev/null +++ b/.circleci/unittest/linux_stable/scripts/environment.yml @@ -0,0 +1,27 @@ +channels: + - pytorch + - defaults +dependencies: + - pip + - ninja + - cmake >= 3.18 + - pip: + - hypothesis + - protobuf + - future + - cloudpickle + - gym_retro + - gym + - pygame + - gym[accept-rom-license] + - gym[atari] + - moviepy + - tqdm + - pytest + - pytest-cov + - pytest-mock + - expecttest + - pyyaml + - scipy + - dm_control + - mujoco_py diff --git a/.circleci/unittest/linux_stable/scripts/install.sh b/.circleci/unittest/linux_stable/scripts/install.sh new file mode 100755 index 00000000000..bbca6096b30 --- /dev/null +++ b/.circleci/unittest/linux_stable/scripts/install.sh @@ -0,0 +1,54 @@ +#!/usr/bin/env bash + +unset PYTORCH_VERSION +# For unittest, nightly PyTorch is used as the following section, +# so no need to set PYTORCH_VERSION. +# In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config. + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +if [ "${CU_VERSION:-}" == cpu ] ; then + cudatoolkit="cpuonly" + version="cpu" +else + if [[ ${#CU_VERSION} -eq 4 ]]; then + CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}" + elif [[ ${#CU_VERSION} -eq 5 ]]; then + CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}" + fi + echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)" + version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" + cudatoolkit="cudatoolkit=${version}" +fi + +case "$(uname -s)" in + Darwin*) os=MacOSX;; + *) os=Linux +esac + +# submodules +git submodule sync && git submodule update --init --recursive + +printf "Installing PyTorch with %s\n" "${CU_VERSION}" +if [ "${CU_VERSION:-}" == cpu ] ; then + # conda install -y pytorch torchvision cpuonly -c pytorch-nightly + # use pip to install pytorch as conda can frequently pick older release + pip install torch torchvision +else + conda install -y pytorch torchvision cudatoolkit=10.2 -c pytorch +fi + +printf "Installing functorch\n" +pip install functorch + +# smoke test +python -c "import functorch" + +printf "* Installing torchrl\n" +printf "g++ version: " +gcc --version + +python setup.py install diff --git a/.circleci/unittest/linux_stable/scripts/post_process.sh b/.circleci/unittest/linux_stable/scripts/post_process.sh new file mode 100755 index 00000000000..e97bf2a7b1b --- /dev/null +++ b/.circleci/unittest/linux_stable/scripts/post_process.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env diff --git a/.circleci/unittest/linux_stable/scripts/run-clang-format.py b/.circleci/unittest/linux_stable/scripts/run-clang-format.py new file mode 100755 index 00000000000..5783a885d86 --- /dev/null +++ b/.circleci/unittest/linux_stable/scripts/run-clang-format.py @@ -0,0 +1,356 @@ +#!/usr/bin/env python +""" +MIT License + +Copyright (c) 2017 Guillaume Papin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +A wrapper script around clang-format, suitable for linting multiple files +and to use for continuous integration. + +This is an alternative API for the clang-format command line. +It runs over multiple files and directories in parallel. +A diff output is produced and a sensible exit code is returned. + +""" + +import argparse +import difflib +import fnmatch +import multiprocessing +import os +import signal +import subprocess +import sys +import traceback +from functools import partial + +try: + from subprocess import DEVNULL # py3k +except ImportError: + DEVNULL = open(os.devnull, "wb") + + +DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu" + + +class ExitStatus: + SUCCESS = 0 + DIFF = 1 + TROUBLE = 2 + + +def list_files(files, recursive=False, extensions=None, exclude=None): + if extensions is None: + extensions = [] + if exclude is None: + exclude = [] + + out = [] + for file in files: + if recursive and os.path.isdir(file): + for dirpath, dnames, fnames in os.walk(file): + fpaths = [os.path.join(dirpath, fname) for fname in fnames] + for pattern in exclude: + # os.walk() supports trimming down the dnames list + # by modifying it in-place, + # to avoid unnecessary directory listings. + dnames[:] = [ + x + for x in dnames + if not fnmatch.fnmatch(os.path.join(dirpath, x), pattern) + ] + fpaths = [x for x in fpaths if not fnmatch.fnmatch(x, pattern)] + for f in fpaths: + ext = os.path.splitext(f)[1][1:] + if ext in extensions: + out.append(f) + else: + out.append(file) + return out + + +def make_diff(file, original, reformatted): + return list( + difflib.unified_diff( + original, + reformatted, + fromfile=f"{file}\t(original)", + tofile=f"{file}\t(reformatted)", + n=3, + ) + ) + + +class DiffError(Exception): + def __init__(self, message, errs=None): + super().__init__(message) + self.errs = errs or [] + + +class UnexpectedError(Exception): + def __init__(self, message, exc=None): + super().__init__(message) + self.formatted_traceback = traceback.format_exc() + self.exc = exc + + +def run_clang_format_diff_wrapper(args, file): + try: + ret = run_clang_format_diff(args, file) + return ret + except DiffError: + raise + except Exception as e: + raise UnexpectedError(f"{file}: {e.__class__.__name__}: {e}", e) + + +def run_clang_format_diff(args, file): + try: + with open(file, encoding="utf-8") as f: + original = f.readlines() + except OSError as exc: + raise DiffError(str(exc)) + invocation = [args.clang_format_executable, file] + + # Use of utf-8 to decode the process output. + # + # Hopefully, this is the correct thing to do. + # + # It's done due to the following assumptions (which may be incorrect): + # - clang-format will returns the bytes read from the files as-is, + # without conversion, and it is already assumed that the files use utf-8. + # - if the diagnostics were internationalized, they would use utf-8: + # > Adding Translations to Clang + # > + # > Not possible yet! + # > Diagnostic strings should be written in UTF-8, + # > the client can translate to the relevant code page if needed. + # > Each translation completely replaces the format string + # > for the diagnostic. + # > -- http://clang.llvm.org/docs/InternalsManual.html#internals-diag-translation + + try: + proc = subprocess.Popen( + invocation, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + encoding="utf-8", + ) + except OSError as exc: + raise DiffError( + f"Command '{subprocess.list2cmdline(invocation)}' failed to start: {exc}" + ) + proc_stdout = proc.stdout + proc_stderr = proc.stderr + + # hopefully the stderr pipe won't get full and block the process + outs = list(proc_stdout.readlines()) + errs = list(proc_stderr.readlines()) + proc.wait() + if proc.returncode: + raise DiffError( + "Command '{}' returned non-zero exit status {}".format( + subprocess.list2cmdline(invocation), proc.returncode + ), + errs, + ) + return make_diff(file, original, outs), errs + + +def bold_red(s): + return "\x1b[1m\x1b[31m" + s + "\x1b[0m" + + +def colorize(diff_lines): + def bold(s): + return "\x1b[1m" + s + "\x1b[0m" + + def cyan(s): + return "\x1b[36m" + s + "\x1b[0m" + + def green(s): + return "\x1b[32m" + s + "\x1b[0m" + + def red(s): + return "\x1b[31m" + s + "\x1b[0m" + + for line in diff_lines: + if line[:4] in ["--- ", "+++ "]: + yield bold(line) + elif line.startswith("@@ "): + yield cyan(line) + elif line.startswith("+"): + yield green(line) + elif line.startswith("-"): + yield red(line) + else: + yield line + + +def print_diff(diff_lines, use_color): + if use_color: + diff_lines = colorize(diff_lines) + sys.stdout.writelines(diff_lines) + + +def print_trouble(prog, message, use_colors): + error_text = "error:" + if use_colors: + error_text = bold_red(error_text) + print(f"{prog}: {error_text} {message}", file=sys.stderr) + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--clang-format-executable", + metavar="EXECUTABLE", + help="path to the clang-format executable", + default="clang-format", + ) + parser.add_argument( + "--extensions", + help=f"comma separated list of file extensions (default: {DEFAULT_EXTENSIONS})", + default=DEFAULT_EXTENSIONS, + ) + parser.add_argument( + "-r", + "--recursive", + action="store_true", + help="run recursively over directories", + ) + parser.add_argument("files", metavar="file", nargs="+") + parser.add_argument("-q", "--quiet", action="store_true") + parser.add_argument( + "-j", + metavar="N", + type=int, + default=0, + help="run N clang-format jobs in parallel (default number of cpus + 1)", + ) + parser.add_argument( + "--color", + default="auto", + choices=["auto", "always", "never"], + help="show colored diff (default: auto)", + ) + parser.add_argument( + "-e", + "--exclude", + metavar="PATTERN", + action="append", + default=[], + help="exclude paths matching the given glob-like pattern(s) from recursive search", + ) + + args = parser.parse_args() + + # use default signal handling, like diff return SIGINT value on ^C + # https://bugs.python.org/issue14229#msg156446 + signal.signal(signal.SIGINT, signal.SIG_DFL) + try: + signal.SIGPIPE + except AttributeError: + # compatibility, SIGPIPE does not exist on Windows + pass + else: + signal.signal(signal.SIGPIPE, signal.SIG_DFL) + + colored_stdout = False + colored_stderr = False + if args.color == "always": + colored_stdout = True + colored_stderr = True + elif args.color == "auto": + colored_stdout = sys.stdout.isatty() + colored_stderr = sys.stderr.isatty() + + version_invocation = [args.clang_format_executable, "--version"] + try: + subprocess.check_call(version_invocation, stdout=DEVNULL) + except subprocess.CalledProcessError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + return ExitStatus.TROUBLE + except OSError as e: + print_trouble( + parser.prog, + f"Command '{subprocess.list2cmdline(version_invocation)}' failed to start: {e}", + use_colors=colored_stderr, + ) + return ExitStatus.TROUBLE + + retcode = ExitStatus.SUCCESS + files = list_files( + args.files, + recursive=args.recursive, + exclude=args.exclude, + extensions=args.extensions.split(","), + ) + + if not files: + return + + njobs = args.j + if njobs == 0: + njobs = multiprocessing.cpu_count() + 1 + njobs = min(len(files), njobs) + + if njobs == 1: + # execute directly instead of in a pool, + # less overhead, simpler stacktraces + it = (run_clang_format_diff_wrapper(args, file) for file in files) + pool = None + else: + pool = multiprocessing.Pool(njobs) + it = pool.imap_unordered(partial(run_clang_format_diff_wrapper, args), files) + while True: + try: + outs, errs = next(it) + except StopIteration: + break + except DiffError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + retcode = ExitStatus.TROUBLE + sys.stderr.writelines(e.errs) + except UnexpectedError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + sys.stderr.write(e.formatted_traceback) + retcode = ExitStatus.TROUBLE + # stop at the first unexpected error, + # something could be very wrong, + # don't process all files unnecessarily + if pool: + pool.terminate() + break + else: + sys.stderr.writelines(errs) + if outs == []: + continue + if not args.quiet: + print_diff(outs, use_color=colored_stdout) + if retcode == ExitStatus.SUCCESS: + retcode = ExitStatus.DIFF + return retcode + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/.circleci/unittest/linux_stable/scripts/run_test.sh b/.circleci/unittest/linux_stable/scripts/run_test.sh new file mode 100755 index 00000000000..cfa9862f060 --- /dev/null +++ b/.circleci/unittest/linux_stable/scripts/run_test.sh @@ -0,0 +1,16 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +export PYTORCH_TEST_WITH_SLOW='1' +python -m torch.utils.collect_env +root_dir="$(git rev-parse --show-toplevel)" +export MJLIB_PATH=$root_dir/.mujoco/mujoco-2.1.1/lib/libmujoco.so.2.1.1 +export MUJOCO_PY_MUJOCO_PATH=$root_dir/.mujoco/mujoco210 +export DISPLAY=unix:0.0 +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/root/project/.mujoco/mujoco210/bin +#MUJOCO_GL=glfw pytest --cov=torchrl --junitxml=test-results/junit.xml -v --durations 20 +MUJOCO_GL=glfw pytest -v --durations 20 diff --git a/.circleci/unittest/linux_stable/scripts/setup_env.sh b/.circleci/unittest/linux_stable/scripts/setup_env.sh new file mode 100755 index 00000000000..18f4592ae42 --- /dev/null +++ b/.circleci/unittest/linux_stable/scripts/setup_env.sh @@ -0,0 +1,57 @@ +#!/usr/bin/env bash + +# This script is for setting up environment in which unit test is ran. +# To speed up the CI time, the resulting environment is cached. +# +# Do not install PyTorch and torchvision here, otherwise they also get cached. + +set -e + +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +root_dir="$(git rev-parse --show-toplevel)" +conda_dir="${root_dir}/conda" +env_dir="${root_dir}/env" + +cd "${root_dir}" + +case "$(uname -s)" in + Darwin*) os=MacOSX;; + *) os=Linux +esac + +# 1. Install conda at ./conda +if [ ! -d "${conda_dir}" ]; then + printf "* Installing conda\n" + wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh" + bash ./miniconda.sh -b -f -p "${conda_dir}" +fi +eval "$(${conda_dir}/bin/conda shell.bash hook)" + +# 2. Create test environment at ./env +printf "python: ${PYTHON_VERSION}\n" +if [ ! -d "${env_dir}" ]; then + printf "* Creating a test environment\n" + conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION" +fi +conda activate "${env_dir}" + +# 3. Install mujoco +printf "* Installing mujoco and related\n" +mkdir $root_dir/.mujoco +cd $root_dir/.mujoco/ +wget https://github.com/deepmind/mujoco/releases/download/2.1.1/mujoco-2.1.1-linux-x86_64.tar.gz +tar -xf mujoco-2.1.1-linux-x86_64.tar.gz +wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz +tar -xf mujoco210-linux-x86_64.tar.gz +cd $this_dir + +# 4. Install Conda dependencies +printf "* Installing dependencies (except PyTorch)\n" +echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" +cat "${this_dir}/environment.yml" +conda env update --file "${this_dir}/environment.yml" --prune + +#yum makecache +#yum -y install glfw-devel +#yum -y install libGLEW +#yum -y install gcc-c++ diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000000..388d28d7a76 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,6 @@ +[submodule "third_party/vision"] + path = third_party/vision + url = https://github.com/pytorch/vision +[submodule "third_party/functorch"] + path = third_party/functorch + url = https://github.com/pytorch/functorch diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000000..dccc1c1a9b2 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,30 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.0.1 + hooks: + - id: check-docstring-first + - id: check-toml + - id: check-yaml + exclude: packaging/.* + - id: mixed-line-ending + args: [--fix=lf] + - id: end-of-file-fixer + + - repo: https://github.com/omnilib/ufmt + rev: v1.3.2 + hooks: + - id: ufmt + additional_dependencies: + - black == 21.9b0 + - usort == 0.6.4 + + - repo: https://gitlab.com/pycqa/flake8 + rev: 3.9.2 + hooks: + - id: flake8 + args: [--config=setup.cfg] + + - repo: https://github.com/PyCQA/pydocstyle + rev: 6.1.1 + hooks: + - id: pydocstyle diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 00000000000..a443fb972e2 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,109 @@ +cmake_minimum_required(VERSION 3.18 FATAL_ERROR) + +# Most of the configurations are taken from PyTorch +# https://github.com/pytorch/pytorch/blob/0c9fb4aff0d60eaadb04e4d5d099fb1e1d5701a9/CMakeLists.txt + +# Use compiler ID "AppleClang" instead of "Clang" for XCode. +# Not setting this sometimes makes XCode C compiler gets detected as "Clang", +# even when the C++ one is detected as "AppleClang". +cmake_policy(SET CMP0010 NEW) +cmake_policy(SET CMP0025 NEW) + +# Suppress warning flags in default MSVC configuration. It's not +# mandatory that we do this (and we don't if cmake is old), but it's +# nice when it's possible, and it's possible on our Windows configs. +if(NOT CMAKE_VERSION VERSION_LESS 3.15.0) + cmake_policy(SET CMP0092 NEW) +endif() + +project(torchrl) + +# check and set CMAKE_CXX_STANDARD +string(FIND "${CMAKE_CXX_FLAGS}" "-std=c++" env_cxx_standard) +if(env_cxx_standard GREATER -1) + message( + WARNING "C++ standard version definition detected in environment variable." + "PyTorch requires -std=c++14. Please remove -std=c++ settings in your environment.") +endif() + +set(CMAKE_CXX_STANDARD 14) +set(CMAKE_C_STANDARD 11) + +# https://developercommunity.visualstudio.com/t/VS-16100-isnt-compatible-with-CUDA-11/1433342 +if(MSVC) + if(USE_CUDA) + set(CMAKE_CXX_STANDARD 17) + endif() +endif() + + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +set(CMAKE_POSITION_INDEPENDENT_CODE ON) + +# Apple specific +if(APPLE) + # Get clang version on macOS + execute_process( COMMAND ${CMAKE_CXX_COMPILER} --version OUTPUT_VARIABLE clang_full_version_string ) + string(REGEX REPLACE "Apple LLVM version ([0-9]+\\.[0-9]+).*" "\\1" CLANG_VERSION_STRING ${clang_full_version_string}) + message( STATUS "CLANG_VERSION_STRING: " ${CLANG_VERSION_STRING} ) + + # RPATH stuff + set(CMAKE_MACOSX_RPATH ON) + + set(CMAKE_SHARED_LIBRARY_SUFFIX ".so") +endif() + + +# Options +option(BUILD_TORCHRL_PYTHON_EXTENSION "Build Python extension" OFF) +option(USE_CUDA "Enable CUDA support" OFF) + +if(USE_CUDA) + enable_language(CUDA) +endif() + +find_package(Torch REQUIRED) + +# https://github.com/pytorch/pytorch/issues/54174 +function(CUDA_CONVERT_FLAGS EXISTING_TARGET) + get_property(old_flags TARGET ${EXISTING_TARGET} PROPERTY INTERFACE_COMPILE_OPTIONS) + if(NOT "${old_flags}" STREQUAL "") + string(REPLACE ";" "," CUDA_flags "${old_flags}") + set_property(TARGET ${EXISTING_TARGET} PROPERTY INTERFACE_COMPILE_OPTIONS + "$<$>:${old_flags}>$<$>:-Xcompiler=${CUDA_flags}>" + ) + endif() +endfunction() + +if(MSVC) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4819") + if(USE_CUDA) + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler=/wd4819") + foreach(diag cc_clobber_ignored integer_sign_change useless_using_declaration + set_but_not_used field_without_dll_interface + base_class_has_different_dll_interface + dll_interface_conflict_none_assumed + dll_interface_conflict_dllexport_assumed + implicit_return_from_non_void_function + unsigned_compare_with_zero + declared_but_not_referenced + bad_friend_decl) + string(APPEND CMAKE_CUDA_FLAGS " -Xcudafe --diag_suppress=${diag}") + endforeach() + CUDA_CONVERT_FLAGS(torch_cpu) + if(TARGET torch_cuda) + CUDA_CONVERT_FLAGS(torch_cuda) + endif() + if(TARGET torch_cuda_cu) + CUDA_CONVERT_FLAGS(torch_cuda_cu) + endif() + if(TARGET torch_cuda_cpp) + CUDA_CONVERT_FLAGS(torch_cuda_cpp) + endif() + endif() +endif() + +# TORCH_CXX_FLAGS contains the same -D_GLIBCXX_USE_CXX11_ABI value as PyTorch +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall ${TORCH_CXX_FLAGS}") + +add_subdirectory(torchrl/csrc) diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 00000000000..08b500a2218 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,80 @@ +# Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or + advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic + address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +This Code of Conduct also applies outside the project spaces when there is a +reasonable belief that an individual's behavior may have a negative impact on +the project or its community. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at . All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000000..2e1c3435820 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,31 @@ +# Contributing to rl +We want to make contributing to this project as easy and transparent as +possible. + +## Pull Requests +We actively welcome your pull requests. + +1. Fork the repo and create your branch from `main`. +2. If you've added code that should be tested, add tests. +3. If you've changed APIs, update the documentation. +4. Ensure the test suite passes. +5. Make sure your code lints. +6. If you haven't already, complete the Contributor License Agreement ("CLA"). + +## Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Facebook's open source projects. + +Complete your CLA here: + +## Issues +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe +disclosure of security bugs. In those cases, please go through the process +outlined on that page and do not file a public issue. + +## License +By contributing to rl, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000000..426ba7a3fc3 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) Meta Platforms, Inc. and affiliates. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 00000000000..f64a76fb30c --- /dev/null +++ b/README.md @@ -0,0 +1,127 @@ +# TorchRL + +TorchRL is an open-source Reinforcement Learning (RL) library for PyTorch. + +It provides pytorch and python-first, low and high level abstractions for RL that are intended to be efficient, docummented and properly tested. +The code is aimed at supporting research in RL. Most of it is written in python in a highly modular way, such that researchers can easily swap components, transform them or write new ones with little effort. + +This repo attempts to align with the existing pytorch ecosystem libraries in that it has a dataset pillar ([torchrl/envs](torchrl/envs)), [transforms](torchrl/envs/transforms), [models](torchrl/modules), data utilities (e.g. collectors and containers)... +TorchRL aims at having as few dependencies as possible (python standard library, numpy and pytorch). Common environment libraries (e.g. OpenAI gym) are only optional. + +On the low-level end, torchrl comes with a set of highly re-usable functionals for [cost functions](torchrl/objectives/costs), [returns](torchrl/objectives/returns) and data processing. + +On the high-level end, it provides: +- multiprocess [data collectors](torchrl/collectors/collectors.py); +- a generic [agent class](torchrl/agents/agents.py); +- efficient and generic [replay buffers](torchrl/data/replay_buffers/replay_buffers.py); +- [TensorDict](torchrl/data/tensordict/tensordict.py), a convenient data structure to pass data from one object to another without friction; +- An associated [`TDModule` class](torchrl/modules/td_module/common.py) which is [functorch](https://github.com/pytorch/functorch)-compatible! +- [interfaces for environments](torchrl/envs) from common libraries (OpenAI gym, deepmind control lab, etc.) and [wrappers for parallel execution](torchrl/envs/vec_env.py), as well as a new pytorch-first class of [tensor-specification class](torchrl/data/tensor_specs.py); +- [environment transforms](torchrl/envs/transforms/transforms.py), which process and prepare the data coming out of the environments to be used by the agent; +- various tools for distributed learning (e.g. [memory mapped tensors](torchrl/data/tensordict/memmap.py)); +- various [architectures](torchrl/modules/models/) and models (e.g. [actor-critic](torchrl/modules/td_module/actors.py)); +- [exploration wrappers](torchrl/modules/td_module/exploration.py); +- various [recipes](torchrl/agents/helpers/models.py) to build models that correspond to the environment being deployed. + +A series of [examples](examples/) are provided with an illustrative purpose: +- [DQN (and add-ons up to Rainbow)](examples/dqn/dqn.py) +- [DDPG](examples/ddpg/ddpg.py) +- [PPO](examples/ppo/ppo.py) +- [SAC](examples/sac/sac.py) +- [REDQ](examples/redq/redq.py) + +and many more to come! + +## Installation +Create a conda environment where the packages will be installed. +Before installing anything, make sure you have the latest version of `cmake` and `ninja` libraries: + +``` +conda create --name torch_rl python=3.9 +conda activate torch_rl +conda install cmake -c conda-forge +pip install ninja +``` + +Depending on the use of functorch that you want to make, you may want to install the latest (nightly) pytorch release or the latest stable version of pytorch: + +**Stable** + +``` +conda install pytorch torchvision cudatoolkit=10.2 -c pytorch # refer to pytorch official website for cudatoolkit installation +pip install functorch +``` + +**Nightly** +``` +# For CUDA 10.2 +pip install --pre torch -f https://download.pytorch.org/whl/nightly/cu102/torch_nightly.html --upgrade +# For CUDA 11.1 +pip install --pre torch -f https://download.pytorch.org/whl/nightly/cu111/torch_nightly.html --upgrade +# For CPU-only build +pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --upgrade +``` + +and functorch +``` +pip install --user "git+https://github.com/pytorch/functorch.git" +``` + +**Torchrl** + +Go to the directory where you have cloned the torchrl repo and install it +``` +cd /path/to/torchrl/ +python setup.py install +``` +To run a quick sanity check, leave that directory and try to import the library. +``` +python -c "import torchrl" +``` + +**Optional dependencies** + +The following libraries can be installed depending on the usage one wants to make of torchrl: +``` +# diverse +pip install tqdm pyyaml configargparse + +# rendering +pip install moviepy + +# deepmind control suite +pip install dm_control + +# gym, atari games +pip install gym gym[accept-rom-license] pygame gym_retro + +# tests +pip install pytest +``` + +## Running examples +Examples are coded in a very similar way but the configuration may change from one algorithm to the other (e.g. async/sync data collection, hyperparameters, ratio of model updates / frame etc.) +To train an algorithm it is therefore advised to do use the predefined configurations that are found in the `configs` sub-folder in each algorithm directory: +``` +python examples/ppo/ppo.py --config=examples/ppo/configs/humanoid.txt +``` +Note that using the config files requires the [configargparse](https://pypi.org/project/ConfigArgParse/) library. + +One can also overwrite the config parameters using flags, e.g. +``` +python examples/ppo/ppo.py --config=examples/ppo/configs/humanoid.txt --frame_skip=2 --collection_devices=cuda:1 +``` + +Each example will write a tensorboard log in a dedicated folder, e.g. `ppo_logging/...`. + +## Contributing +Internal collaborations to torchrl are welcome! Feel free to fork, submit issues and PRs. + +## Upcoming features +In the near future, we plan to: +- provide tutorials on how to design new actors or environment wrappers; +- implement IMPALA (as a distributed RL example) and Meta-RL algorithms; +- improve the tests, documentation and nomenclature. + +# License +TorchRL is licensed under the MIT License. See [LICENSE](LICENSE) for details. diff --git a/build_tools/__init__.py b/build_tools/__init__.py new file mode 100644 index 00000000000..7bec24cb17b --- /dev/null +++ b/build_tools/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/build_tools/setup_helpers/__init__.py b/build_tools/setup_helpers/__init__.py new file mode 100644 index 00000000000..167a9787362 --- /dev/null +++ b/build_tools/setup_helpers/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .extension import * # noqa diff --git a/build_tools/setup_helpers/extension.py b/build_tools/setup_helpers/extension.py new file mode 100644 index 00000000000..482691c8819 --- /dev/null +++ b/build_tools/setup_helpers/extension.py @@ -0,0 +1,166 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import distutils.sysconfig +import os +import platform +import subprocess +from pathlib import Path +from subprocess import check_output, STDOUT, CalledProcessError + +import torch +from setuptools import Extension +from setuptools.command.build_ext import build_ext + +__all__ = [ + "get_ext_modules", + "CMakeBuild", +] + +_THIS_DIR = Path(__file__).parent.resolve() +_ROOT_DIR = _THIS_DIR.parent.parent.resolve() +_TORCHRL_DIR = _ROOT_DIR / "torchrl" + + +def _get_build(var, default=False): + if var not in os.environ: + return default + + val = os.environ.get(var, "0") + trues = ["1", "true", "TRUE", "on", "ON", "yes", "YES"] + falses = ["0", "false", "FALSE", "off", "OFF", "no", "NO"] + if val in trues: + return True + if val not in falses: + print( + f"WARNING: Unexpected environment variable value `{var}={val}`. " + f"Expected one of {trues + falses}" + ) + return False + + +_BUILD_SOX = False if platform.system() == "Windows" else _get_build("BUILD_SOX", True) +_BUILD_KALDI = ( + False if platform.system() == "Windows" else _get_build("BUILD_KALDI", True) +) +_BUILD_RNNT = _get_build("BUILD_RNNT", True) +_USE_ROCM = _get_build( + "USE_ROCM", torch.cuda.is_available() and torch.version.hip is not None +) +_USE_CUDA = _get_build( + "USE_CUDA", torch.cuda.is_available() and torch.version.hip is None +) +_USE_OPENMP = ( + _get_build("USE_OPENMP", True) + and "ATen parallel backend: OpenMP" in torch.__config__.parallel_info() +) +_TORCH_CUDA_ARCH_LIST = os.environ.get("TORCH_CUDA_ARCH_LIST", None) + + +def get_ext_modules(): + return [ + Extension(name="torchrl._torchrl", sources=[]), + ] + + +# Based off of +# https://github.com/pybind/cmake_example/blob/580c5fd29d4651db99d8874714b07c0c49a53f8a/setup.py +class CMakeBuild(build_ext): + def run(self): + try: + subprocess.check_output(["cmake", "--version"]) + except OSError: + raise RuntimeError("CMake is not available.") from None + super().run() + + def build_extension(self, ext): + # Since two library files (libtorchrl and _torchrl) need to be + # recognized by setuptools, we instantiate `Extension` twice. (see `get_ext_modules`) + # This leads to the situation where this `build_extension` method is called twice. + # However, the following `cmake` command will build all of them at the same time, + # so, we do not need to perform `cmake` twice. + # Therefore we call `cmake` only for `torchrl._torchrl`. + if ext.name != "torchrl._torchrl": + return + + extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name))) + + # required for auto-detection of auxiliary "native" libs + if not extdir.endswith(os.path.sep): + extdir += os.path.sep + + cfg = "Debug" if self.debug else "Release" + + cmake_args = [ + f"-DCMAKE_BUILD_TYPE={cfg}", + f"-DCMAKE_PREFIX_PATH={torch.utils.cmake_prefix_path}", + f"-DCMAKE_INSTALL_PREFIX={extdir}", + "-DCMAKE_VERBOSE_MAKEFILE=ON", + f"-DPython_INCLUDE_DIR={distutils.sysconfig.get_python_inc()}", + "-DBUILD_TORCHRL_PYTHON_EXTENSION:BOOL=ON", + f"-DUSE_CUDA:BOOL={'ON' if _USE_CUDA else 'OFF'}", + ] + build_args = ["--target", "install"] + # Pass CUDA architecture to cmake + if _TORCH_CUDA_ARCH_LIST is not None: + # Convert MAJOR.MINOR[+PTX] list to new style one + # defined at https://cmake.org/cmake/help/latest/prop_tgt/CUDA_ARCHITECTURES.html + _arches = _TORCH_CUDA_ARCH_LIST.replace(".", "").split(";") + _arches = [ + arch[:-4] if arch.endswith("+PTX") else f"{arch}-real" + for arch in _arches + ] + cmake_args += [f"-DCMAKE_CUDA_ARCHITECTURES={';'.join(_arches)}"] + + # Default to Ninja + if "CMAKE_GENERATOR" not in os.environ or platform.system() == "Windows": + cmake_args += ["-GNinja"] + if platform.system() == "Windows": + import sys + + python_version = sys.version_info + cmake_args += [ + "-DCMAKE_C_COMPILER=cl", + "-DCMAKE_CXX_COMPILER=cl", + f"-DPYTHON_VERSION={python_version.major}.{python_version.minor}", + ] + + # Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level + # across all generators. + if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ: + # self.parallel is a Python 3 only way to set parallel jobs by hand + # using -j in the build_ext call, not supported by pip or PyPA-build. + if hasattr(self, "parallel") and self.parallel: + # CMake 3.12+ only. + build_args += ["-j{}".format(self.parallel)] + + if not os.path.exists(self.build_temp): + os.makedirs(self.build_temp) + + print(" ".join(["cmake", str(_ROOT_DIR)] + cmake_args)) + try: + check_output( + ["cmake", str(_ROOT_DIR)] + cmake_args, + cwd=self.build_temp, + stderr=STDOUT, + ) + except CalledProcessError as exc: + print(exc.output) + + try: + check_output( + ["cmake", "--build", "."] + build_args, + cwd=self.build_temp, + stderr=STDOUT, + ) + except CalledProcessError as exc: + print(exc.output) + + def get_ext_filename(self, fullname): + ext_filename = super().get_ext_filename(fullname) + ext_filename_parts = ext_filename.split(".") + without_abi = ext_filename_parts[:-2] + ext_filename_parts[-1:] + ext_filename = ".".join(without_abi) + return ext_filename diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 00000000000..d0c3cbf1020 --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 00000000000..061f32f91b9 --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 00000000000..3d1939261ac --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,6 @@ +matplotlib +numpy +sphinx-copybutton>=0.3.1 +sphinx-gallery>=0.9.0 +sphinx==3.5.4 +-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme diff --git a/docs/source/_templates/class.rst b/docs/source/_templates/class.rst new file mode 100644 index 00000000000..64f573535f9 --- /dev/null +++ b/docs/source/_templates/class.rst @@ -0,0 +1,7 @@ +.. currentmodule:: {{ module }} + + +{{ name | underline}} + +.. autoclass:: {{ name }} + :members: diff --git a/docs/source/_templates/function.rst b/docs/source/_templates/function.rst new file mode 100644 index 00000000000..819aa05601a --- /dev/null +++ b/docs/source/_templates/function.rst @@ -0,0 +1,6 @@ +.. currentmodule:: {{ module }} + + +{{ name | underline}} + +.. autofunction:: {{ name }} diff --git a/docs/source/_templates/layout.html b/docs/source/_templates/layout.html new file mode 100644 index 00000000000..4d45c2103dd --- /dev/null +++ b/docs/source/_templates/layout.html @@ -0,0 +1,5 @@ +{% extends "!layout.html" %} + +{% block sidebartitle %} + {% include "searchbox.html" %} +{% endblock %} diff --git a/docs/source/_templates/rl_template.rst b/docs/source/_templates/rl_template.rst new file mode 100644 index 00000000000..9e15d3a90ad --- /dev/null +++ b/docs/source/_templates/rl_template.rst @@ -0,0 +1,8 @@ +.. currentmodule:: {{ module }} + + +{{ name | underline}} + +.. autoclass:: {{ name }} + :members: + :inherited-members: diff --git a/docs/source/_templates/rl_template_fun.rst b/docs/source/_templates/rl_template_fun.rst new file mode 100644 index 00000000000..819aa05601a --- /dev/null +++ b/docs/source/_templates/rl_template_fun.rst @@ -0,0 +1,6 @@ +.. currentmodule:: {{ module }} + + +{{ name | underline}} + +.. autofunction:: {{ name }} diff --git a/docs/source/_templates/rl_template_noinherit.rst b/docs/source/_templates/rl_template_noinherit.rst new file mode 100644 index 00000000000..64f573535f9 --- /dev/null +++ b/docs/source/_templates/rl_template_noinherit.rst @@ -0,0 +1,7 @@ +.. currentmodule:: {{ module }} + + +{{ name | underline}} + +.. autoclass:: {{ name }} + :members: diff --git a/docs/source/conf.py b/docs/source/conf.py new file mode 100644 index 00000000000..7e3640f8ea0 --- /dev/null +++ b/docs/source/conf.py @@ -0,0 +1,239 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# Configuration file for the Sphinx documentation builder. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +# import os +# import sys +# sys.path.insert(0, os.path.abspath('.')) + + +# -- Project information ----------------------------------------------------- +import pytorch_sphinx_theme +import torchrl + +project = "torchrl" +copyright = "2022-presennt, Torch Contributors" +author = "Torch Contributors" + +# The version info for the project you're documenting, acts as replacement for +# |version| and |release|, also used in various other places throughout the +# built documents. +# +# The short X.Y version. +version = "main (" + torchrl.__version__ + " )" +# The full version, including alpha/beta/rc tags. +# TODO: verify this works as expected +release = "main" + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = None + +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.doctest", + "sphinx.ext.intersphinx", + "sphinx.ext.todo", + "sphinx.ext.mathjax", + "sphinx.ext.napoleon", + "sphinx.ext.viewcode", + "sphinx.ext.duration", + "sphinx_gallery.gen_gallery", + "sphinx_autodoc_typehints", + "sphinxcontrib.aafig", +] + +sphinx_gallery_conf = { + "examples_dirs": "../../gallery/", # path to your example scripts + "gallery_dirs": "auto_examples", # path to where to save gallery generated output + "backreferences_dir": "gen_modules/backreferences", + "doc_module": ("torchrl",), +} + +napoleon_use_ivar = True +napoleon_numpy_docstring = False +napoleon_google_docstring = True + +# Add any paths that contain templates here, relative to this directory. +templates_path = ["_templates"] + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +# +source_suffix = { + ".rst": "restructuredtext", +} + +# The master toctree document. +master_doc = "index" + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = [] + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = "pytorch_sphinx_theme" +html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()] + +html_theme_options = { + "collapse_navigation": False, + "display_version": True, + "logo_only": True, + "pytorch_project": "docs", + "navigation_with_keys": True, + "analytics_id": "UA-117752657-2", +} + +# Output file base name for HTML help builder. +htmlhelp_basename = "PyTorchdoc" + +autosummary_generate = True + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ["_static"] + +# -- Options for LaTeX output --------------------------------------------- +latex_elements = { + # The paper size ('letterpaper' or 'a4paper'). + # + # 'papersize': 'letterpaper', + # The font size ('10pt', '11pt' or '12pt'). + # + # 'pointsize': '10pt', + # Additional stuff for the LaTeX preamble. + # + # 'preamble': '', + # Latex figure (float) alignment + # + # 'figure_align': 'htbp', +} + + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, +# author, documentclass [howto, manual, or own class]). +# latex_documents = [ +# (master_doc, "pytorch.tex", "torchrl Documentation", "Torch Contributors", "manual"), +# ] + + +# -- Options for manual page output --------------------------------------- + +# One entry per manual page. List of tuples +# (source start file, name, description, authors, manual section). +man_pages = [(master_doc, "torchvision", "torchrl Documentation", [author], 1)] + + +# -- Options for Texinfo output ------------------------------------------- + +# Grouping the document tree into Texinfo files. List of tuples +# (source start file, target name, title, author, +# dir menu entry, description, category) +texinfo_documents = [ + ( + master_doc, + "torchrl", + "torchrl Documentation", + author, + "torchrl", + "TorchRL doc.", + "Miscellaneous", + ), +] + + +# Example configuration for intersphinx: refer to the Python standard library. +intersphinx_mapping = { + "python": ("https://docs.python.org/3/", None), + "torch": ("https://pytorch.org/docs/stable/", None), + "numpy": ("https://numpy.org/doc/stable/", None), +} + + +from docutils import nodes +from sphinx import addnodes +from sphinx.util.docfields import TypedField + + +def patched_make_field(self, types, domain, items, **kw): + # `kw` catches `env=None` needed for newer sphinx while maintaining + # backwards compatibility when passed along further down! + + # type: (list, unicode, tuple) -> nodes.field # noqa: F821 + def handle_item(fieldarg, content): + par = nodes.paragraph() + par += addnodes.literal_strong("", fieldarg) # Patch: this line added + # par.extend(self.make_xrefs(self.rolename, domain, fieldarg, + # addnodes.literal_strong)) + if fieldarg in types: + par += nodes.Text(" (") + # NOTE: using .pop() here to prevent a single type node to be + # inserted twice into the doctree, which leads to + # inconsistencies later when references are resolved + fieldtype = types.pop(fieldarg) + if len(fieldtype) == 1 and isinstance(fieldtype[0], nodes.Text): + typename = "".join(n.astext() for n in fieldtype) + typename = typename.replace("int", "python:int") + typename = typename.replace("long", "python:long") + typename = typename.replace("float", "python:float") + typename = typename.replace("type", "python:type") + par.extend( + self.make_xrefs( + self.typerolename, + domain, + typename, + addnodes.literal_emphasis, + **kw + ) + ) + else: + par += fieldtype + par += nodes.Text(")") + par += nodes.Text(" -- ") + par += content + return par + + fieldname = nodes.field_name("", self.label) + if len(items) == 1 and self.can_collapse: + fieldarg, content = items[0] + bodynode = handle_item(fieldarg, content) + else: + bodynode = self.list_type() + for fieldarg, content in items: + bodynode += nodes.list_item("", handle_item(fieldarg, content)) + fieldbody = nodes.field_body("", bodynode) + return nodes.field("", fieldname, fieldbody) + + +TypedField.make_field = patched_make_field + +aafig_default_options = dict(scale=1.5, aspect=1.0, proportional=True) diff --git a/docs/source/docutils.conf b/docs/source/docutils.conf new file mode 100644 index 00000000000..00b6db82694 --- /dev/null +++ b/docs/source/docutils.conf @@ -0,0 +1,2 @@ +[html writers] +table_style: colwidths-auto # Necessary for the table generated by autosummary to look decent diff --git a/docs/source/index.rst b/docs/source/index.rst new file mode 100644 index 00000000000..daf5b972b1f --- /dev/null +++ b/docs/source/index.rst @@ -0,0 +1,20 @@ +.. torchrl documentation master file, created by + sphinx-quickstart on Mon Mar 7 13:23:20 2022. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +Welcome to torchrl's documentation! +=================================== + +.. toctree:: + :maxdepth: 2 + :caption: Contents: + + reference/index + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` diff --git a/docs/source/reference/agents.rst b/docs/source/reference/agents.rst new file mode 100644 index 00000000000..bff96b9d6b6 --- /dev/null +++ b/docs/source/reference/agents.rst @@ -0,0 +1,74 @@ +.. currentmodule:: torchrl.agents + +torchrl.agents package +====================== + +Agents +------- + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + Agent + EnvCreator + + +Builders +-------- + +.. currentmodule:: torchrl.agents.helpers + +.. autosummary:: + :toctree: generated/ + :template: rl_template_fun.rst + + make_agent + sync_sync_collector + sync_async_collector + make_collector_offpolicy + make_collector_onpolicy + make_sac_loss + make_dqn_loss + make_ddpg_loss + make_target_updater + make_ppo_loss + make_redq_loss + make_dqn_actor + make_ddpg_actor + make_ppo_model + make_sac_model + make_redq_model + make_replay_buffer + transformed_env_constructor + parallel_env_constructor + +Utils +----- + + +.. autosummary:: + :toctree: generated/ + :template: rl_template_fun.rst + + correct_for_frame_skip + get_stats_random_rollout + +Argument parser +--------------- + + +.. autosummary:: + :toctree: generated/ + :template: rl_template_fun.rst + + parser_agent_args + parser_collector_args_offpolicy + parser_collector_args_onpolicy + parser_env_args + parser_loss_args + parser_loss_args_ppo + parser_model_args_continuous + parser_model_args_discrete + parser_recorder_args + parser_replay_args diff --git a/docs/source/reference/collectors.rst b/docs/source/reference/collectors.rst new file mode 100644 index 00000000000..afbded37ba0 --- /dev/null +++ b/docs/source/reference/collectors.rst @@ -0,0 +1,29 @@ +.. currentmodule:: torchrl.collectors + +torchrl.collectors package +========================== + +Data collectors +--------------- + + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + MultiSyncDataCollector + MultiaSyncDataCollector + SyncDataCollector + aSyncDataCollector + + +Helper functions +---------------- + +.. currentmodule:: torchrl.collectors.utils + +.. autosummary:: + :toctree: generated/ + :template: rl_template_fun.rst + + split_trajectories diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst new file mode 100644 index 00000000000..c8efa18c5fe --- /dev/null +++ b/docs/source/reference/data.rst @@ -0,0 +1,91 @@ +.. currentmodule:: torchrl.data + +torchrl.data package +==================== + +Replay Buffers +-------------- + +Replay buffers are a central part of off-policy RL algorithms. TorchRL provides an efficient implementation of a few, +widely used replay buffers: + + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + ReplayBuffer + PrioritizedReplayBuffer + TensorDictReplayBuffer + TensorDictPrioritizedReplayBuffer + + +TensorDict +---------- + +Passing data across objects can become a burdensome task when designing high-level classes: for instance it can be +hard to design an actor class that can take an arbitrary number of inputs and return an arbitrary number of inputs. The +`TensorDict` class simplifies this process by packing together a bag of tensors in a dictionary-like object. This +class supports a set of basic operations on tensors to facilitate the manipulation of entire batch of data (e.g. +`torch.cat`, `torch.stack`, `.to(device)` etc.). + + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + TensorDict + SubTensorDict + LazyStackedTensorDict + +TensorSpec +---------- + +The `TensorSpec` parent class and subclasses define the basic properties of observations and actions in TorchRL, such +as shape, device, dtype and domain. + + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + TensorSpec + BoundedTensorSpec + OneHotDiscreteTensorSpec + UnboundedContinuousTensorSpec + NdBoundedTensorSpec + NdUnboundedContinuousTensorSpec + BinaryDiscreteTensorSpec + MultOneHotDiscreteTensorSpec + CompositeSpec + +Transforms +---------- + +In most cases, the raw output of an environment must be treated before being passed to another object (such as a +policy or a value operator). To do this, TorchRL provides a set of transforms that aim at reproducing the transform +logic of `torch.distributions.Transform` and `torchvision.transforms`. + + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + Transform + TransformedEnv + Compose + CatTensors + CatFrames + RewardClipping + Resize + GrayScale + ToTensorImage + ObservationNorm + RewardScaling + ObservationTransform + FiniteTensorDictCheck + DoubleToFloat + NoopResetEnv + BinerizeReward + PinMemoryTransform + VecNorm diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst new file mode 100644 index 00000000000..7b3a56017f9 --- /dev/null +++ b/docs/source/reference/envs.rst @@ -0,0 +1,29 @@ +.. currentmodule:: torchrl.envs + +torchrl.envs package +==================== + + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + GymLikeEnv + GymEnv + DMControlEnv + SerialEnv + ParallelEnv + +Helpers +------- +.. currentmodule:: torchrl.envs.utils + +.. autosummary:: + :toctree: generated/ + :template: rl_template_fun.rst + + step_tensor_dict + get_available_libraries + set_exploration_mode + exploration_mode + make_tensor_dict diff --git a/docs/source/reference/index.rst b/docs/source/reference/index.rst new file mode 100644 index 00000000000..15facebbbfe --- /dev/null +++ b/docs/source/reference/index.rst @@ -0,0 +1,12 @@ +API Reference +============= + +.. toctree:: + :maxdepth: 1 + + agents + collectors + data + envs + modules + objectives diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst new file mode 100644 index 00000000000..42f6a7e30bf --- /dev/null +++ b/docs/source/reference/modules.rst @@ -0,0 +1,66 @@ +.. currentmodule:: torchrl.modules + +torchrl.modules package +======================= + +TensorDict modules +------------------ + + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + TDModule + ProbabilisticTDModule + TDSequence + TDModuleWrapper + Actor + ProbabilisticActor + ValueOperator + QValueActor + DistributionalQValueActor + ActorValueOperator + ActorCriticOperator + ActorCriticWrapper + +Hooks +----- +.. currentmodule:: torchrl.modules.td_module.actors + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + QValueHook + DistributionalQValueHook + +Models +------ +.. currentmodule:: torchrl.modules + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + MLP + ConvNet + DuelingCnnDQNet + DistributionalDQNnet + DdpgCnnActor + DdpgCnnQNet + DdpgMlpActor + DdpgMlpQNet + LSTMNet + +Distributions +------------- +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + Delta + TanhNormal + TruncatedNormal + TanhDelta + OneHotCategorical diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst new file mode 100644 index 00000000000..d648ff23d52 --- /dev/null +++ b/docs/source/reference/objectives.rst @@ -0,0 +1,71 @@ +.. currentmodule:: torchrl.objectives + +torchrl.objectives package +========================== + +DQN +--- + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + DQNLoss + DoubleDQNLoss + DistributionalDQNLoss + DoubleDistributionalDQNLoss + +DDPG +---- + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + DDPGLoss + DoubleDDPGLoss + +SAC +---- + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + SACLoss + DoubleSACLoss + +REDQ +---- + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + REDQLoss + DoubleREDQLoss + +PPO +---- + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + PPOLoss + ClipPPOLoss + KLPENPPOLoss + +Utils +----- + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + distance_loss + hold_out_net + hold_out_params + next_state_value + SoftUpdate + HardUpdate diff --git a/examples/ddpg/configs/cheetah.txt b/examples/ddpg/configs/cheetah.txt new file mode 100644 index 00000000000..8719009ecca --- /dev/null +++ b/examples/ddpg/configs/cheetah.txt @@ -0,0 +1,6 @@ +env_name=cheetah +env_task=run +env_library=dm_control +async_collection +record_video +normalize_rewards_online diff --git a/examples/ddpg/configs/halfcheetah.txt b/examples/ddpg/configs/halfcheetah.txt new file mode 100644 index 00000000000..ca629abfd1d --- /dev/null +++ b/examples/ddpg/configs/halfcheetah.txt @@ -0,0 +1,6 @@ +env_name=HalfCheetah-v2 +env_task= +env_library=gym +async_collection +record_video +normalize_rewards_online diff --git a/examples/ddpg/configs/humanoid.txt b/examples/ddpg/configs/humanoid.txt new file mode 100644 index 00000000000..e206430ada7 --- /dev/null +++ b/examples/ddpg/configs/humanoid.txt @@ -0,0 +1,10 @@ +env_name=humanoid +env_task=walk +env_library=dm_control +async_collection +record_video +normalize_rewards_online +frame_skip=2 +prb +multi_step +exp_name=humanoid_stats_wd diff --git a/examples/ddpg/ddpg.py b/examples/ddpg/ddpg.py new file mode 100644 index 00000000000..97a273cd835 --- /dev/null +++ b/examples/ddpg/ddpg.py @@ -0,0 +1,166 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import uuid +from datetime import datetime + +try: + import configargparse as argparse + + _configargparse = True +except ImportError: + import argparse + + _configargparse = False +import torch.cuda +from torch.utils.tensorboard import SummaryWriter +from torchrl.agents.helpers.agents import make_agent, parser_agent_args +from torchrl.agents.helpers.collectors import ( + make_collector_offpolicy, + parser_collector_args_offpolicy, +) +from torchrl.agents.helpers.envs import ( + correct_for_frame_skip, + get_stats_random_rollout, + parallel_env_constructor, + parser_env_args, + transformed_env_constructor, +) +from torchrl.agents.helpers.losses import make_ddpg_loss, parser_loss_args +from torchrl.agents.helpers.models import ( + make_ddpg_actor, + parser_model_args_continuous, +) +from torchrl.agents.helpers.recorder import parser_recorder_args +from torchrl.agents.helpers.replay_buffer import ( + make_replay_buffer, + parser_replay_args, +) +from torchrl.envs.transforms import RewardScaling, TransformedEnv +from torchrl.modules import OrnsteinUhlenbeckProcessWrapper + + +def make_args(): + parser = argparse.ArgumentParser() + if _configargparse: + parser.add_argument( + "-c", + "--config", + required=True, + is_config_file=True, + help="config file path", + ) + parser_agent_args(parser) + parser_collector_args_offpolicy(parser) + parser_env_args(parser) + parser_loss_args(parser) + parser_model_args_continuous(parser, "DDPG") + parser_recorder_args(parser) + parser_replay_args(parser) + return parser + + +parser = make_args() + +DEFAULT_REWARD_SCALING = { + "Hopper-v1": 5, + "Walker2d-v1": 5, + "HalfCheetah-v1": 5, + "cheetah": 5, + "Ant-v2": 5, + "Humanoid-v2": 20, + "humanoid": 100, +} + +if __name__ == "__main__": + args = parser.parse_args() + + args = correct_for_frame_skip(args) + + if not isinstance(args.reward_scaling, float): + args.reward_scaling = DEFAULT_REWARD_SCALING.get(args.env_name, 5.0) + + device = ( + torch.device("cpu") + if torch.cuda.device_count() == 0 + else torch.device("cuda:0") + ) + + exp_name = "_".join( + [ + "DDPG", + args.exp_name, + str(uuid.uuid4())[:8], + datetime.now().strftime("%y_%m_%d-%H_%M_%S"), + ] + ) + writer = SummaryWriter(f"ddpg_logging/{exp_name}") + video_tag = exp_name if args.record_video else "" + + proof_env = transformed_env_constructor(args=args, use_env_creator=False)() + model = make_ddpg_actor( + proof_env, + args.from_pixels, + noisy=args.noisy, + device=device, + ) + loss_module, target_net_updater = make_ddpg_loss(model, args) + actor_model_explore = model[0] + if args.ou_exploration: + actor_model_explore = OrnsteinUhlenbeckProcessWrapper( + actor_model_explore, annealing_num_steps=args.annealing_frames + ).to(device) + if device == torch.device("cpu"): + # mostly for debugging + actor_model_explore.share_memory() + + stats = None + if not args.vecnorm: + stats = get_stats_random_rollout(args, proof_env) + # make sure proof_env is closed + proof_env.close() + + create_env_fn = parallel_env_constructor(args=args, stats=stats) + + collector = make_collector_offpolicy( + make_env=create_env_fn, + actor_model_explore=actor_model_explore, + args=args, + ) + + replay_buffer = make_replay_buffer(device, args) + + recorder = transformed_env_constructor( + args, + video_tag=video_tag, + norm_obs_only=True, + stats=stats, + writer=writer, + use_env_creator=False, + )() + + # remove video recorder from recorder to have matching state_dict keys + if args.record_video: + recorder_rm = TransformedEnv(recorder.env, recorder.transform[1:]) + else: + recorder_rm = recorder + recorder_rm.load_state_dict(create_env_fn.state_dict()["worker0"]) + # reset reward scaling + for t in recorder.transform: + if isinstance(t, RewardScaling): + t.scale.fill_(1.0) + + agent = make_agent( + collector, + loss_module, + recorder, + target_net_updater, + actor_model_explore, + replay_buffer, + writer, + args, + ) + + agent.train() diff --git a/examples/dqn/configs/pong.txt b/examples/dqn/configs/pong.txt new file mode 100644 index 00000000000..a8c4f8aaced --- /dev/null +++ b/examples/dqn/configs/pong.txt @@ -0,0 +1,19 @@ +frames_per_batch=500 +frame_skip=4 +optim_steps_per_collection=125 +env_library=gym +env_name=PongNoFrameskip-v4 +noops=30 +max_frames_per_traj=-1 +exp_name=pong +record_interval=10000 +batch_size=32 +async_collection +distributional +prb +multi_step +annealing_frames=50000000 +record_frames=50000 +normalize_rewards_online +from_pixels +record_video diff --git a/examples/dqn/dqn.py b/examples/dqn/dqn.py new file mode 100644 index 00000000000..7c6f60f37a0 --- /dev/null +++ b/examples/dqn/dqn.py @@ -0,0 +1,151 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import uuid +from datetime import datetime + +try: + import configargparse as argparse + + _configargparse = True +except ImportError: + import argparse + + _configargparse = False +import torch.cuda +from torch.utils.tensorboard import SummaryWriter +from torchrl.agents.helpers.agents import make_agent, parser_agent_args +from torchrl.agents.helpers.collectors import ( + make_collector_offpolicy, + parser_collector_args_offpolicy, +) +from torchrl.agents.helpers.envs import ( + correct_for_frame_skip, + get_stats_random_rollout, + parallel_env_constructor, + parser_env_args, + transformed_env_constructor, +) +from torchrl.agents.helpers.losses import make_dqn_loss, parser_loss_args +from torchrl.agents.helpers.models import ( + make_dqn_actor, + parser_model_args_discrete, +) +from torchrl.agents.helpers.recorder import parser_recorder_args +from torchrl.agents.helpers.replay_buffer import ( + make_replay_buffer, + parser_replay_args, +) +from torchrl.envs.transforms import RewardScaling, TransformedEnv +from torchrl.modules import EGreedyWrapper + + +def make_args(): + parser = argparse.ArgumentParser() + if _configargparse: + parser.add_argument( + "-c", + "--config", + required=True, + is_config_file=True, + help="config file path", + ) + parser_agent_args(parser) + parser_collector_args_offpolicy(parser) + parser_env_args(parser) + parser_loss_args(parser) + parser_model_args_discrete(parser) + parser_recorder_args(parser) + parser_replay_args(parser) + return parser + + +parser = make_args() + +if __name__ == "__main__": + args = parser.parse_args() + + args = correct_for_frame_skip(args) + + if not isinstance(args.reward_scaling, float): + args.reward_scaling = 1.0 + + device = ( + torch.device("cpu") + if torch.cuda.device_count() == 0 + else torch.device("cuda:0") + ) + + exp_name = "_".join( + [ + "DQN", + args.exp_name, + str(uuid.uuid4())[:8], + datetime.now().strftime("%y_%m_%d-%H_%M_%S"), + ] + ) + writer = SummaryWriter(f"dqn_logging/{exp_name}") + video_tag = exp_name if args.record_video else "" + + proof_env = transformed_env_constructor(args=args, use_env_creator=False)() + model = make_dqn_actor( + proof_environment=proof_env, + args=args, + device=device, + ) + + loss_module, target_net_updater = make_dqn_loss(model, args) + model_explore = EGreedyWrapper(model, annealing_num_steps=args.annealing_frames).to( + device + ) + + stats = None + if not args.vecnorm: + stats = get_stats_random_rollout(args, proof_env) + # make sure proof_env is closed + proof_env.close() + + create_env_fn = parallel_env_constructor(args=args, stats=stats) + + collector = make_collector_offpolicy( + make_env=create_env_fn, + actor_model_explore=model_explore, + args=args, + ) + + replay_buffer = make_replay_buffer(device, args) + + recorder = transformed_env_constructor( + args, + video_tag=video_tag, + norm_obs_only=True, + stats=stats, + writer=writer, + )() + + # remove video recorder from recorder to have matching state_dict keys + if args.record_video: + recorder_rm = TransformedEnv(recorder.env, recorder.transform[1:]) + else: + recorder_rm = recorder + + recorder_rm.load_state_dict(create_env_fn.state_dict()["worker0"]) + # reset reward scaling + for t in recorder.transform: + if isinstance(t, RewardScaling): + t.scale.fill_(1.0) + + agent = make_agent( + collector, + loss_module, + recorder, + target_net_updater, + model_explore, + replay_buffer, + writer, + args, + ) + + agent.train() diff --git a/examples/ppo/configs/cheetah.txt b/examples/ppo/configs/cheetah.txt new file mode 100644 index 00000000000..e9fc6e208d5 --- /dev/null +++ b/examples/ppo/configs/cheetah.txt @@ -0,0 +1,12 @@ +env_name=cheetah +env_task=run +env_library=dm_control +optim_steps_per_collection=10 +lamda=0.95 +normalize_rewards_online +record_video +max_frames_per_traj=1000 +record_interval=200 +lr=3e-4 +tanh_loc +init_with_lag diff --git a/examples/ppo/configs/cheetah_pixels.txt b/examples/ppo/configs/cheetah_pixels.txt new file mode 100644 index 00000000000..3749c7236bb --- /dev/null +++ b/examples/ppo/configs/cheetah_pixels.txt @@ -0,0 +1,11 @@ +env_name=cheetah +env_task=run +env_library=dm_control +optim_steps_per_collection=10 +lamda=0.95 +normalize_rewards_online +record_video +max_frames_per_traj=1000 +record_interval=200 +tanh_loc +init_with_lag diff --git a/examples/ppo/configs/humanoid.txt b/examples/ppo/configs/humanoid.txt new file mode 100644 index 00000000000..95499ed740e --- /dev/null +++ b/examples/ppo/configs/humanoid.txt @@ -0,0 +1,14 @@ +env_name=humanoid +env_task=walk +env_library=dm_control +optim_steps_per_collection=10 +lamda=0.95 +normalize_rewards_online +record_video +max_frames_per_traj=1000 +record_interval=200 +lr=3e-4 +entropy_factor=1e-4 +frame_skip=4 +tanh_loc +init_with_lag diff --git a/examples/ppo/ppo.py b/examples/ppo/ppo.py new file mode 100644 index 00000000000..9df4533a084 --- /dev/null +++ b/examples/ppo/ppo.py @@ -0,0 +1,131 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import uuid +from datetime import datetime + +try: + import configargparse as argparse + + _configargparse = True +except ImportError: + import argparse + + _configargparse = False +import torch.cuda +from torch.utils.tensorboard import SummaryWriter +from torchrl.agents.helpers.agents import make_agent, parser_agent_args +from torchrl.agents.helpers.collectors import ( + make_collector_onpolicy, + parser_collector_args_onpolicy, +) +from torchrl.agents.helpers.envs import ( + correct_for_frame_skip, + get_stats_random_rollout, + parallel_env_constructor, + parser_env_args, + transformed_env_constructor, +) +from torchrl.agents.helpers.losses import make_ppo_loss, parser_loss_args_ppo +from torchrl.agents.helpers.models import ( + make_ppo_model, + parser_model_args_continuous, +) +from torchrl.agents.helpers.recorder import parser_recorder_args +from torchrl.envs.transforms import RewardScaling, TransformedEnv + + +def make_args(): + parser = argparse.ArgumentParser() + if _configargparse: + parser.add_argument( + "-c", + "--config", + required=True, + is_config_file=True, + help="config file path", + ) + parser_agent_args(parser) + parser_collector_args_onpolicy(parser) + parser_env_args(parser) + parser_loss_args_ppo(parser) + parser_model_args_continuous(parser, "PPO") + + parser_recorder_args(parser) + return parser + + +parser = make_args() + +if __name__ == "__main__": + args = parser.parse_args() + + args = correct_for_frame_skip(args) + + if not isinstance(args.reward_scaling, float): + args.reward_scaling = 1.0 + + device = ( + torch.device("cpu") + if torch.cuda.device_count() == 0 + else torch.device("cuda:0") + ) + + exp_name = "_".join( + [ + "PPO", + args.exp_name, + str(uuid.uuid4())[:8], + datetime.now().strftime("%y_%m_%d-%H_%M_%S"), + ] + ) + writer = SummaryWriter(f"ppo_logging/{exp_name}") + video_tag = exp_name if args.record_video else "" + + proof_env = transformed_env_constructor(args=args, use_env_creator=False)() + model = make_ppo_model(proof_env, args, device) + actor_model = model.get_policy_operator() + + loss_module = make_ppo_loss(model, args) + + stats = None + if not args.vecnorm: + stats = get_stats_random_rollout(args, proof_env) + # make sure proof_env is closed + proof_env.close() + + create_env_fn = parallel_env_constructor(args=args, stats=stats) + + collector = make_collector_onpolicy( + make_env=create_env_fn, + actor_model_explore=actor_model, + args=args, + ) + + recorder = transformed_env_constructor( + args, + video_tag=video_tag, + norm_obs_only=True, + stats=stats, + writer=writer, + )() + + # remove video recorder from recorder to have matching state_dict keys + if args.record_video: + recorder_rm = TransformedEnv(recorder.env, recorder.transform[1:]) + else: + recorder_rm = recorder + + recorder_rm.load_state_dict(create_env_fn.state_dict()["worker0"]) + # reset reward scaling + for t in recorder.transform: + if isinstance(t, RewardScaling): + t.scale.fill_(1.0) + + agent = make_agent( + collector, loss_module, recorder, None, actor_model, None, writer, args + ) + + agent.train() diff --git a/examples/redq/configs/humanoid.txt b/examples/redq/configs/humanoid.txt new file mode 100644 index 00000000000..4e150d8b71b --- /dev/null +++ b/examples/redq/configs/humanoid.txt @@ -0,0 +1,11 @@ +env_name=humanoid +env_task=walk +env_library=dm_control +async_collection +record_video +normalize_rewards_online +frame_skip=2 +prb +multi_step +exp_name=humanoid_stats +tanh_loc diff --git a/examples/redq/redq.py b/examples/redq/redq.py new file mode 100644 index 00000000000..cea20504e88 --- /dev/null +++ b/examples/redq/redq.py @@ -0,0 +1,168 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import uuid +from datetime import datetime + +try: + import configargparse as argparse + + _configargparse = True +except ImportError: + import argparse + + _configargparse = False + +import torch.cuda +from torch.utils.tensorboard import SummaryWriter +from torchrl.agents.helpers.agents import make_agent, parser_agent_args +from torchrl.agents.helpers.collectors import ( + make_collector_offpolicy, + parser_collector_args_offpolicy, +) +from torchrl.agents.helpers.envs import ( + correct_for_frame_skip, + get_stats_random_rollout, + parallel_env_constructor, + parser_env_args, + transformed_env_constructor, +) +from torchrl.agents.helpers.losses import make_redq_loss, parser_loss_args +from torchrl.agents.helpers.models import ( + make_redq_model, + parser_model_args_continuous, +) +from torchrl.agents.helpers.recorder import parser_recorder_args +from torchrl.agents.helpers.replay_buffer import ( + make_replay_buffer, + parser_replay_args, +) +from torchrl.envs.transforms import RewardScaling, TransformedEnv +from torchrl.modules import OrnsteinUhlenbeckProcessWrapper + + +def make_args(): + parser = argparse.ArgumentParser() + if _configargparse: + parser.add_argument( + "-c", + "--config", + required=True, + is_config_file=True, + help="config file path", + ) + parser_agent_args(parser) + parser_collector_args_offpolicy(parser) + parser_env_args(parser) + parser_loss_args(parser, algorithm="REDQ") + parser_model_args_continuous(parser, "REDQ") + parser_recorder_args(parser) + parser_replay_args(parser) + return parser + + +parser = make_args() + +DEFAULT_REWARD_SCALING = { + "Hopper-v1": 5, + "Walker2d-v1": 5, + "HalfCheetah-v1": 5, + "cheetah": 5, + "Ant-v2": 5, + "Humanoid-v2": 20, + "humanoid": 100, +} + +if __name__ == "__main__": + args = parser.parse_args() + + args = correct_for_frame_skip(args) + + if not isinstance(args.reward_scaling, float): + args.reward_scaling = DEFAULT_REWARD_SCALING.get(args.env_name, 5.0) + + device = ( + torch.device("cpu") + if torch.cuda.device_count() == 0 + else torch.device("cuda:0") + ) + + exp_name = "_".join( + [ + "REDQ", + args.exp_name, + str(uuid.uuid4())[:8], + datetime.now().strftime("%y_%m_%d-%H_%M_%S"), + ] + ) + writer = SummaryWriter(f"redq_logging/{exp_name}") + video_tag = exp_name if args.record_video else "" + + proof_env = transformed_env_constructor(args=args, use_env_creator=False)() + model = make_redq_model( + proof_env, + device=device, + tanh_loc=args.tanh_loc, + default_policy_scale=args.default_policy_scale, + ) + loss_module, target_net_updater = make_redq_loss(model, args) + + actor_model_explore = model[0] + if args.ou_exploration: + actor_model_explore = OrnsteinUhlenbeckProcessWrapper( + actor_model_explore, annealing_num_steps=args.annealing_frames + ).to(device) + if device == torch.device("cpu"): + # mostly for debugging + actor_model_explore.share_memory() + + stats = None + if not args.vecnorm: + stats = get_stats_random_rollout(args, proof_env) + # make sure proof_env is closed + proof_env.close() + + create_env_fn = parallel_env_constructor(args=args, stats=stats) + + collector = make_collector_offpolicy( + make_env=create_env_fn, + actor_model_explore=actor_model_explore, + args=args, + ) + + replay_buffer = make_replay_buffer(device, args) + + recorder = transformed_env_constructor( + args, + video_tag=video_tag, + norm_obs_only=True, + stats=stats, + writer=writer, + )() + + # remove video recorder from recorder to have matching state_dict keys + if args.record_video: + recorder_rm = TransformedEnv(recorder.env, recorder.transform[1:]) + else: + recorder_rm = recorder + + recorder_rm.load_state_dict(create_env_fn.state_dict()["worker0"]) + # reset reward scaling + for t in recorder.transform: + if isinstance(t, RewardScaling): + t.scale.fill_(1.0) + + agent = make_agent( + collector, + loss_module, + recorder, + target_net_updater, + actor_model_explore, + replay_buffer, + writer, + args, + ) + + agent.train() diff --git a/examples/sac/configs/humanoid.txt b/examples/sac/configs/humanoid.txt new file mode 100644 index 00000000000..4e150d8b71b --- /dev/null +++ b/examples/sac/configs/humanoid.txt @@ -0,0 +1,11 @@ +env_name=humanoid +env_task=walk +env_library=dm_control +async_collection +record_video +normalize_rewards_online +frame_skip=2 +prb +multi_step +exp_name=humanoid_stats +tanh_loc diff --git a/examples/sac/sac.py b/examples/sac/sac.py new file mode 100644 index 00000000000..24006e1b8d1 --- /dev/null +++ b/examples/sac/sac.py @@ -0,0 +1,165 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import uuid +from datetime import datetime + +try: + import configargparse as argparse + + _configargparse = True +except ImportError: + import argparse + + _configargparse = False +import torch.cuda +from torch.utils.tensorboard import SummaryWriter +from torchrl.agents.helpers.agents import make_agent, parser_agent_args +from torchrl.agents.helpers.collectors import ( + make_collector_offpolicy, + parser_collector_args_offpolicy, +) +from torchrl.agents.helpers.envs import ( + correct_for_frame_skip, + get_stats_random_rollout, + parallel_env_constructor, + parser_env_args, + transformed_env_constructor, +) +from torchrl.agents.helpers.losses import make_sac_loss, parser_loss_args +from torchrl.agents.helpers.models import ( + make_sac_model, + parser_model_args_continuous, +) +from torchrl.agents.helpers.recorder import parser_recorder_args +from torchrl.agents.helpers.replay_buffer import ( + make_replay_buffer, + parser_replay_args, +) +from torchrl.envs.transforms import RewardScaling, TransformedEnv +from torchrl.modules import OrnsteinUhlenbeckProcessWrapper + + +def make_args(): + parser = argparse.ArgumentParser() + if _configargparse: + parser.add_argument( + "-c", + "--config", + required=True, + is_config_file=True, + help="config file path", + ) + parser_agent_args(parser) + parser_collector_args_offpolicy(parser) + parser_env_args(parser) + parser_loss_args(parser, algorithm="SAC") + parser_model_args_continuous(parser, "SAC") + parser_recorder_args(parser) + parser_replay_args(parser) + return parser + + +parser = make_args() + +DEFAULT_REWARD_SCALING = { + "Hopper-v1": 5, + "Walker2d-v1": 5, + "HalfCheetah-v1": 5, + "cheetah": 5, + "Ant-v2": 5, + "Humanoid-v2": 20, + "humanoid": 100, +} + +if __name__ == "__main__": + args = parser.parse_args() + + args = correct_for_frame_skip(args) + + if not isinstance(args.reward_scaling, float): + args.reward_scaling = DEFAULT_REWARD_SCALING.get(args.env_name, 5.0) + + device = ( + torch.device("cpu") + if torch.cuda.device_count() == 0 + else torch.device("cuda:0") + ) + + exp_name = "_".join( + [ + "SAC", + args.exp_name, + str(uuid.uuid4())[:8], + datetime.now().strftime("%y_%m_%d-%H_%M_%S"), + ] + ) + writer = SummaryWriter(f"sac_logging/{exp_name}") + video_tag = exp_name if args.record_video else "" + + proof_env = transformed_env_constructor(args=args, use_env_creator=False)() + model = make_sac_model( + proof_env, + device=device, + ) + loss_module, target_net_updater = make_sac_loss(model, args) + + actor_model_explore = model[0] + if args.ou_exploration: + actor_model_explore = OrnsteinUhlenbeckProcessWrapper( + actor_model_explore, annealing_num_steps=args.annealing_frames + ).to(device) + if device == torch.device("cpu"): + # mostly for debugging + actor_model_explore.share_memory() + + stats = None + if not args.vecnorm: + stats = get_stats_random_rollout(args, proof_env) + # make sure proof_env is closed + proof_env.close() + + create_env_fn = parallel_env_constructor(args=args, stats=stats) + + collector = make_collector_offpolicy( + make_env=create_env_fn, + actor_model_explore=actor_model_explore, + args=args, + ) + + replay_buffer = make_replay_buffer(device, args) + + recorder = transformed_env_constructor( + args, + video_tag=video_tag, + norm_obs_only=True, + stats=stats, + writer=writer, + )() + + # remove video recorder from recorder to have matching state_dict keys + if args.record_video: + recorder_rm = TransformedEnv(recorder.env, recorder.transform[1:]) + else: + recorder_rm = recorder + + recorder_rm.load_state_dict(create_env_fn.state_dict()["worker0"]) + # reset reward scaling + for t in recorder.transform: + if isinstance(t, RewardScaling): + t.scale.fill_(1.0) + + agent = make_agent( + collector, + loss_module, + recorder, + target_net_updater, + actor_model_explore, + replay_buffer, + writer, + args, + ) + + agent.train() diff --git a/examples/torchrl_features/memmap_speed_distributed.py b/examples/torchrl_features/memmap_speed_distributed.py new file mode 100644 index 00000000000..5d5cf58ab5d --- /dev/null +++ b/examples/torchrl_features/memmap_speed_distributed.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import time + +import configargparse +import torch +import torch.distributed.rpc as rpc +from torchrl.data.tensordict import MemmapTensor + +parser = configargparse.ArgumentParser() +parser.add_argument("--rank", default=-1, type=int) +parser.add_argument("--world_size", default=2, type=int) +parser.add_argument("--tensortype", default="memmap", type=str) + +AGENT_NAME = "main" +OBSERVER_NAME = "worker{}" + +str_init_method = "tcp://localhost:10000" +options = rpc.TensorPipeRpcBackendOptions( + _transports=["uv"], num_worker_threads=16, init_method=str_init_method +) + +global tensor + + +def send_tensor(t): + global tensor + tensor = t + print(tensor) + + +def op_on_tensor(idx): + tensor[idx] += 1 + if isinstance(tensor, torch.Tensor): + return tensor + + +if __name__ == "__main__": + args = parser.parse_args() + rank = args.rank + world_size = args.world_size + tensortype = args.tensortype + + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "29500" + + if rank == 0: + rpc.init_rpc( + AGENT_NAME, + rank=rank, + world_size=world_size, + backend=rpc.BackendType.TENSORPIPE, + rpc_backend_options=options, + ) + # create tensor + tensor = torch.zeros(10000, 10000) + if tensortype == "memmap": + tensor = MemmapTensor(tensor) + elif tensortype == "tensor": + pass + else: + raise NotImplementedError + + #  send tensor + w = 1 + fut0 = rpc.remote(f"worker{w}", send_tensor, args=(tensor,)) + fut0.to_here() + + #  execute + t0 = time.time() + idx = 10 + for i in range(100): + fut1 = rpc.remote(f"worker{w}", op_on_tensor, args=(idx,)) + tensor_out = fut1.to_here() + + if tensortype == "memmap": + assert (tensor[idx] == i + 1).all() + else: + assert (tensor_out[idx] == i + 1).all() + print(f"{tensortype}, time spent: {time.time() - t0: 4.4f}") + + else: + rpc.init_rpc( + OBSERVER_NAME.format(rank), + rank=rank, + world_size=world_size, + backend=rpc.BackendType.TENSORPIPE, + rpc_backend_options=options, + ) + + rpc.shutdown() diff --git a/examples/torchrl_features/memmap_td_distributed.py b/examples/torchrl_features/memmap_td_distributed.py new file mode 100644 index 00000000000..31878b4efa6 --- /dev/null +++ b/examples/torchrl_features/memmap_td_distributed.py @@ -0,0 +1,207 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import time + +import configargparse +import torch +import torch.distributed.rpc as rpc +from torchrl.data import TensorDict +from torchrl.data.tensordict.memmap import set_transfer_ownership + +parser = configargparse.ArgumentParser() +parser.add_argument("--world_size", default=2, type=int) +parser.add_argument("--rank", default=-1, type=int) +parser.add_argument("--task", default=1, type=int) +parser.add_argument("--rank_var", default="SLURM_JOB_ID", type=str) +parser.add_argument( + "--master_addr", + type=str, + default="localhost", + help="""Address of master, will default to localhost if not provided. + Master must be able to accept network traffic on the address + port.""", +) +parser.add_argument( + "--master_port", + type=str, + default="29500", + help="""Port that master is listening on, will default to 29500 if not + provided. Master must be able to accept network traffic on the host and port.""", +) +parser.add_argument("--memmap", action="store_true") +parser.add_argument("--cuda", action="store_true") +parser.add_argument("--shared_mem", action="store_true") + +AGENT_NAME = "main" +OBSERVER_NAME = "worker{}" + + +def get_tensordict(): + return tensordict + + +def tensordict_add(): + tensordict.set_("a", tensordict.get("a") + 1) + tensordict.set("b", torch.zeros(*SIZE)) + if tensordict.is_memmap(): + td = tensordict.clone().apply_(set_transfer_ownership) + return td + return tensordict + + +def tensordict_add_noreturn(): + tensordict.set_("a", tensordict.get("a") + 1) + tensordict.set("b", torch.zeros(*SIZE)) + + +SIZE = (32, 50, 3, 84, 84) + +if __name__ == "__main__": + args = parser.parse_args() + rank = args.rank + if rank < 0: + rank = int(os.environ[args.rank_var]) + print("rank: ", rank) + world_size = args.world_size + + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "29500" + + str_init_method = "tcp://localhost:10000" + options = rpc.TensorPipeRpcBackendOptions( + _transports=["uv"], num_worker_threads=16, init_method=str_init_method + ) + + if rank == 0: + # rank0 is the agent + rpc.init_rpc( + AGENT_NAME, + rank=rank, + world_size=world_size, + backend=rpc.BackendType.TENSORPIPE, + rpc_backend_options=options, + ) + + if args.task == 0: + time.sleep(1) + t0 = time.time() + for w in range(1, args.world_size): + fut0 = rpc.rpc_async(f"worker{w}", get_tensordict, args=tuple()) + fut0.wait() + fut1 = rpc.rpc_async(f"worker{w}", tensordict_add, args=tuple()) + tensordict2 = fut1.wait() + tensordict2.clone() + print("time: ", time.time() - t0) + elif args.task == 1: + time.sleep(1) + t0 = time.time() + waiters = [ + rpc.remote(f"worker{w}", get_tensordict, args=tuple()) + for w in range(1, args.world_size) + ] + td = torch.stack([waiter.to_here() for waiter in waiters], 0).contiguous() + print("time: ", time.time() - t0) + + t0 = time.time() + waiters = [ + rpc.remote(f"worker{w}", tensordict_add, args=tuple()) + for w in range(1, args.world_size) + ] + td = torch.stack([waiter.to_here() for waiter in waiters], 0).contiguous() + print("time: ", time.time() - t0) + assert (td[:, 3].get("a") == 1).all() + assert (td[:, 3].get("b") == 0).all() + + elif args.task == 2: + time.sleep(1) + t0 = time.time() + # waiters = [rpc.rpc_async(f"worker{w}", get_tensordict, args=tuple()) for w in range(1, args.world_size)] + waiters = [ + rpc.remote(f"worker{w}", get_tensordict, args=tuple()) + for w in range(1, args.world_size) + ] + # td = torch.stack([waiter.wait() for waiter in waiters], 0).clone() + td = torch.stack([waiter.to_here() for waiter in waiters], 0) + print("time to receive objs: ", time.time() - t0) + t0 = time.time() + if args.memmap: + waiters = [ + rpc.remote(f"worker{w}", tensordict_add_noreturn, args=tuple()) + for w in range(1, args.world_size) + ] + print("temp t: ", time.time() - t0) + [ + waiter.to_here() for waiter in waiters + ] # the previous stack will track the original files + print("temp t: ", time.time() - t0) + else: + waiters = [ + rpc.remote(f"worker{w}", tensordict_add, args=tuple()) + for w in range(1, args.world_size) + ] + print("temp t: ", time.time() - t0) + td = torch.stack([waiter.to_here() for waiter in waiters], 0) + print("temp t: ", time.time() - t0) + assert (td[:, 3].get("a") == 1).all() + assert (td[:, 3].get("b") == 0).all() + print("time to receive updates: ", time.time() - t0) + + elif args.task == 3: + time.sleep(1) + t0 = time.time() + waiters = [ + rpc.remote(f"worker{w}", get_tensordict, args=tuple()) + for w in range(1, args.world_size) + ] + td = torch.stack([waiter.to_here() for waiter in waiters], 0) + print("time to receive objs: ", time.time() - t0) + t0 = time.time() + waiters = [ + rpc.remote(f"worker{w}", tensordict_add, args=tuple()) + for w in range(1, args.world_size) + ] + print("temp t: ", time.time() - t0) + td = torch.stack([waiter.to_here() for waiter in waiters], 0) + print("temp t: ", time.time() - t0) + if args.memmap: + print(td[0].get("a").filename) + print(td[0].get("a").file) + print(td[0].get("a")._has_ownership) + + print("time to receive updates: ", time.time() - t0) + assert (td[:, 3].get("a") == 1).all() + assert (td[:, 3].get("b") == 0).all() + print("time to read one update: ", time.time() - t0) + + else: + + global tensordict + # other ranks are the observer + tensordict = TensorDict( + { + "a": torch.zeros(*SIZE), + "b": torch.randn(*SIZE), + }, + batch_size=SIZE[:1], + ) + if args.memmap: + tensordict.memmap_() + if rank == 1: + print(tensordict.get("a").filename) + print(tensordict.get("a").file) + if args.shared_mem: + tensordict.share_memory_() + elif args.cuda: + tensordict = tensordict.cuda() + rpc.init_rpc( + OBSERVER_NAME.format(rank), + rank=rank, + world_size=world_size, + backend=rpc.BackendType.TENSORPIPE, + rpc_backend_options=options, + ) + + rpc.shutdown() diff --git a/gallery/README.rst b/gallery/README.rst new file mode 100644 index 00000000000..868afe74351 --- /dev/null +++ b/gallery/README.rst @@ -0,0 +1,4 @@ +Example gallery +=============== + +Below is a gallery of examples diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000000..e52974fbfae --- /dev/null +++ b/mypy.ini @@ -0,0 +1,56 @@ +[mypy] + +files = torchrl +show_error_codes = True +pretty = True +allow_redefinition = True +warn_redundant_casts = True + +[mypy-torchvision.*] + +ignore_errors = True +ignore_missing_imports = True + +[mypy-numpy.*] + +ignore_missing_imports = True + +[mypy-scipy.*] + +ignore_missing_imports = True + +[mypy-pycocotools.*] + +ignore_missing_imports = True + +[mypy-lmdb.*] + +ignore_missing_imports = True + +[mypy-tqdm.*] + +ignore_missing_imports = True + +[mypy-moviepy.*] + +ignore_missing_imports = True + +[mypy-dm_control.*] + +ignore_missing_imports = True + +[mypy-dm_env.*] + +ignore_missing_imports = True + +[mypy-retro.*] + +ignore_missing_imports = True + +[mypy-gym.*] + +ignore_missing_imports = True + +[mypy-torchrl._torchrl.*] + +ignore_missing_imports = True diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000000..36d047d3055 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,9 @@ +[pytest] +addopts = + # show summary of all tests that did not pass + -ra + # Make tracebacks shorter + --tb=native +testpaths = + test +xfail_strict = True diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 00000000000..83a605a2a0b --- /dev/null +++ b/setup.cfg @@ -0,0 +1,22 @@ +[bdist_wheel] +universal=1 + +[metadata] +license_file = LICENSE + +[pep8] +max-line-length = 120 + +[flake8] +# note: we ignore all 501s (line too long) anyway as they're taken care of by black +max-line-length = 79 +ignore = E203, E402, W503, W504, F821, E501 +per-file-ignores = + __init__.py: F401, F403, F405 + ./hubconf.py: F401 + test/smoke_test.py: F401 + test_*.py: F841, E731, E266 +exclude = venv + +[pydocstyle] +select = D417 # Missing argument descriptions in the docstring diff --git a/setup.py b/setup.py new file mode 100644 index 00000000000..3747a43fea2 --- /dev/null +++ b/setup.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import distutils.command.clean +import os +import shutil +import subprocess +from pathlib import Path + +from build_tools import setup_helpers +from setuptools import setup, find_packages + + +def _get_pytorch_version(): + if "PYTORCH_VERSION" in os.environ: + return f"torch=={os.environ['PYTORCH_VERSION']}" + return "torch" + + +def _get_packages(): + exclude = [ + "build*", + "test*", + "torchrl.csrc*", + "third_party*", + "tools*", + ] + return find_packages(exclude=exclude) + + +ROOT_DIR = Path(__file__).parent.resolve() + + +class clean(distutils.command.clean.clean): + def run(self): + # Run default behavior first + distutils.command.clean.clean.run(self) + + # Remove torchrl extension + for path in (ROOT_DIR / "torchrl").glob("**/*.so"): + print(f"removing '{path}'") + path.unlink() + # Remove build directory + build_dirs = [ + ROOT_DIR / "build", + ] + for path in build_dirs: + if path.exists(): + print(f"removing '{path}' (and everything under it)") + shutil.rmtree(str(path), ignore_errors=True) + + +def _run_cmd(cmd): + try: + return subprocess.check_output(cmd, cwd=ROOT_DIR).decode("ascii").strip() + except Exception: + return None + + +def _main(): + pytorch_package_dep = _get_pytorch_version() + print("-- PyTorch dependency:", pytorch_package_dep) + # branch = _run_cmd(["git", "rev-parse", "--abbrev-ref", "HEAD"]) + # tag = _run_cmd(["git", "describe", "--tags", "--exact-match", "@"]) + + setup( + name="torchrl", + version="0.1", + author="torchrl contributors", + author_email="vmoens@fb.com", + packages=_get_packages(), + ext_modules=setup_helpers.get_ext_modules(), + cmdclass={ + "build_ext": setup_helpers.CMakeBuild, + "clean": clean, + }, + install_requires=[pytorch_package_dep], + ) + + +if __name__ == "__main__": + _main() diff --git a/test/_utils_internal.py b/test/_utils_internal.py new file mode 100644 index 00000000000..7833c527451 --- /dev/null +++ b/test/_utils_internal.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os + +# Get relative file path +# this returns relative path from current file. +import torch.cuda + + +def get_relative_path(curr_file, *path_components): + return os.path.join(os.path.dirname(curr_file), *path_components) + + +def get_available_devices(): + devices = [torch.device("cpu")] + n_cuda = torch.cuda.device_count() + if n_cuda > 0: + for i in range(n_cuda): + devices += [torch.device(f"cuda:{i}")] + return devices diff --git a/test/mocking_classes.py b/test/mocking_classes.py new file mode 100644 index 00000000000..bfbe9b9803a --- /dev/null +++ b/test/mocking_classes.py @@ -0,0 +1,249 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torchrl.data.tensor_specs import ( + NdUnboundedContinuousTensorSpec, + NdBoundedTensorSpec, + CompositeSpec, + MultOneHotDiscreteTensorSpec, + BinaryDiscreteTensorSpec, + BoundedTensorSpec, + UnboundedContinuousTensorSpec, + OneHotDiscreteTensorSpec, +) +from torchrl.data.tensordict.tensordict import _TensorDict +from torchrl.envs.common import _EnvClass + +spec_dict = { + "bounded": BoundedTensorSpec, + "one_hot": OneHotDiscreteTensorSpec, + "unbounded": UnboundedContinuousTensorSpec, + "ndbounded": NdBoundedTensorSpec, + "ndunbounded": NdUnboundedContinuousTensorSpec, + "binary": BinaryDiscreteTensorSpec, + "mult_one_hot": MultOneHotDiscreteTensorSpec, + "composite": CompositeSpec, +} + +default_spec_kwargs = { + BoundedTensorSpec: {"minimum": -1.0, "maximum": 1.0}, + OneHotDiscreteTensorSpec: {"n": 7}, + UnboundedContinuousTensorSpec: {}, + NdBoundedTensorSpec: {"minimum": -torch.ones(4), "maxmimum": torch.ones(4)}, + NdUnboundedContinuousTensorSpec: { + "shape": [ + 7, + ] + }, + BinaryDiscreteTensorSpec: {"n": 7}, + MultOneHotDiscreteTensorSpec: {"nvec": [7, 3, 5]}, + CompositeSpec: {}, +} + + +def make_spec(spec_str): + target_class = spec_dict[spec_str] + return target_class(**default_spec_kwargs[target_class]) + + +class _MockEnv(_EnvClass): + def __init__(self, seed: int = 100): + super().__init__( + device="cpu", + dtype=torch.float, + ) + self.set_seed(seed) + + @property + def maxstep(self): + return self.counter + + def set_seed(self, seed: int) -> int: + self.seed = seed + self.counter = seed - 1 + return seed + + +class DiscreteActionVecMockEnv(_MockEnv): + size = 7 + observation_spec = NdUnboundedContinuousTensorSpec(shape=torch.Size([size])) + action_spec = OneHotDiscreteTensorSpec(7) + reward_spec = UnboundedContinuousTensorSpec() + from_pixels = False + + out_key = "observation" + + def _get_in_obs(self, obs): + return obs + + def _get_out_obs(self, obs): + return obs + + def _reset(self, tensor_dict: _TensorDict) -> _TensorDict: + self.counter += 1 + state = torch.zeros(self.size) + self.counter + tensor_dict = tensor_dict.select().set(self.out_key, self._get_out_obs(state)) + tensor_dict.set("done", torch.zeros(*tensor_dict.shape, 1, dtype=torch.bool)) + return tensor_dict + + def _step( + self, + tensor_dict: _TensorDict, + ) -> _TensorDict: + tensor_dict = tensor_dict.to(self.device) + a = tensor_dict.get("action") + assert (a.sum(-1) == 1).all() + assert not self.is_done, "trying to execute step in done env" + + obs = ( + self._get_in_obs(self.current_tensordict.get(self.out_key)) + + a / self.maxstep + ) + tensor_dict = tensor_dict.select() # empty tensordict + tensor_dict.set("next_" + self.out_key, self._get_out_obs(obs)) + done = torch.isclose(obs, torch.ones_like(obs) * (self.counter + 1)) + reward = done.any(-1).unsqueeze(-1) + done = done.all(-1).unsqueeze(-1) + tensor_dict.set("reward", reward.to(torch.float)) + tensor_dict.set("done", done) + return tensor_dict + + +class ContinuousActionVecMockEnv(_MockEnv): + size = 7 + observation_spec = NdUnboundedContinuousTensorSpec(shape=torch.Size([size])) + action_spec = NdBoundedTensorSpec(-1, 1, (7,)) + reward_spec = UnboundedContinuousTensorSpec() + from_pixels = False + + out_key = "observation" + + def _get_in_obs(self, obs): + return obs + + def _get_out_obs(self, obs): + return obs + + def _reset(self, tensor_dict: _TensorDict) -> _TensorDict: + self.counter += 1 + state = torch.zeros(self.size) + self.counter + tensor_dict = tensor_dict.select().set(self.out_key, self._get_out_obs(state)) + tensor_dict.set("done", torch.zeros(*tensor_dict.shape, 1, dtype=torch.bool)) + return tensor_dict + + def _step( + self, + tensor_dict: _TensorDict, + ) -> _TensorDict: + tensor_dict = tensor_dict.to(self.device) + a = tensor_dict.get("action") + assert not self.is_done, "trying to execute step in done env" + + obs = self._obs_step( + self._get_in_obs(self.current_tensordict.get(self.out_key)), a + ) + tensor_dict = tensor_dict.select() # empty tensordict + tensor_dict.set("next_" + self.out_key, self._get_out_obs(obs)) + done = torch.isclose(obs, torch.ones_like(obs) * (self.counter + 1)) + reward = done.any(-1).unsqueeze(-1) + done = done.all(-1).unsqueeze(-1) + tensor_dict.set("reward", reward.to(torch.float)) + tensor_dict.set("done", done) + return tensor_dict + + def _obs_step(self, obs, a): + return obs + a / self.maxstep + + +class DiscreteActionVecPolicy: + in_keys = ["observation"] + out_keys = ["action"] + + def _get_in_obs(self, tensor_dict): + obs = tensor_dict.get(*self.in_keys) + return obs + + def __call__(self, tensor_dict): + obs = self._get_in_obs(tensor_dict) + max_obs = (obs == obs.max(dim=-1, keepdim=True)[0]).cumsum(-1).argmax(-1) + k = tensor_dict.get(*self.in_keys).shape[-1] + max_obs = (max_obs + 1) % k + action = torch.nn.functional.one_hot(max_obs, k) + tensor_dict.set(*self.out_keys, action) + return tensor_dict + + +class DiscreteActionConvMockEnv(DiscreteActionVecMockEnv): + observation_spec = NdUnboundedContinuousTensorSpec(shape=torch.Size([1, 7, 7])) + action_spec = OneHotDiscreteTensorSpec(7) + reward_spec = UnboundedContinuousTensorSpec() + from_pixels = True + + out_key = "observation_pixels" + + def _get_out_obs(self, obs): + obs = torch.diag_embed(obs, 0, -2, -1).unsqueeze(0) + return obs + + def _get_in_obs(self, obs): + return obs.diagonal(0, -1, -2).squeeze() + + +class DiscreteActionConvMockEnvNumpy(DiscreteActionConvMockEnv): + observation_spec = NdUnboundedContinuousTensorSpec(shape=torch.Size([7, 7, 3])) + from_pixels = True + + def _get_out_obs(self, obs): + obs = torch.diag_embed(obs, 0, -2, -1).unsqueeze(-1) + obs = obs.expand(*obs.shape[:-1], 3) + return obs + + def _get_in_obs(self, obs): + return obs.diagonal(0, -2, -3)[..., 0] + + def _obs_step(self, obs, a): + return obs + a.unsqueeze(-1) / self.maxstep + + +class ContinuousActionConvMockEnv(ContinuousActionVecMockEnv): + observation_spec = NdUnboundedContinuousTensorSpec(shape=torch.Size([1, 7, 7])) + action_spec = NdBoundedTensorSpec(-1, 1, (7,)) + reward_spec = UnboundedContinuousTensorSpec() + from_pixels = True + + out_key = "observation_pixels" + + def _get_out_obs(self, obs): + obs = torch.diag_embed(obs, 0, -2, -1).unsqueeze(0) + return obs + + def _get_in_obs(self, obs): + return obs.diagonal(0, -1, -2).squeeze() + + +class ContinuousActionConvMockEnvNumpy(ContinuousActionConvMockEnv): + observation_spec = NdUnboundedContinuousTensorSpec(shape=torch.Size([7, 7, 3])) + from_pixels = True + + def _get_out_obs(self, obs): + obs = torch.diag_embed(obs, 0, -2, -1).unsqueeze(-1) + obs = obs.expand(*obs.shape[:-1], 3) + return obs + + def _get_in_obs(self, obs): + return obs.diagonal(0, -2, -3)[..., 0] + + def _obs_step(self, obs, a): + return obs + a.unsqueeze(-1) / self.maxstep + + +class DiscreteActionConvPolicy(DiscreteActionVecPolicy): + in_keys = ["observation_pixels"] + out_keys = ["action"] + + def _get_in_obs(self, tensor_dict): + obs = tensor_dict.get(*self.in_keys).diagonal(0, -1, -2).squeeze() + return obs diff --git a/test/test_agent.py b/test/test_agent.py new file mode 100644 index 00000000000..5387a23f503 --- /dev/null +++ b/test/test_agent.py @@ -0,0 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse + +import pytest + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_collector.py b/test/test_collector.py new file mode 100644 index 00000000000..b8ffbbbf3cc --- /dev/null +++ b/test/test_collector.py @@ -0,0 +1,511 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse + +import numpy as np +import pytest +import torch +from mocking_classes import ( + DiscreteActionConvMockEnv, + DiscreteActionVecMockEnv, + DiscreteActionVecPolicy, + DiscreteActionConvPolicy, +) +from torch import nn +from torchrl.agents.env_creator import EnvCreator +from torchrl.collectors import SyncDataCollector, aSyncDataCollector +from torchrl.collectors.collectors import ( + RandomPolicy, + MultiSyncDataCollector, + MultiaSyncDataCollector, +) +from torchrl.data.tensordict.tensordict import assert_allclose_td +from torchrl.envs import ParallelEnv +from torchrl.envs.libs.gym import _has_gym +from torchrl.envs.transforms import TransformedEnv, VecNorm + + +def make_make_env(env_name="conv"): + def make_transformed_env(): + if env_name == "conv": + return DiscreteActionConvMockEnv() + elif env_name == "vec": + return DiscreteActionVecMockEnv() + + return make_transformed_env + + +def dummypolicy_vec(): + policy = DiscreteActionVecPolicy() + return policy + + +def dummypolicy_conv(): + policy = DiscreteActionConvPolicy() + return policy + + +def make_policy(env): + if env == "conv": + return dummypolicy_conv() + elif env == "vec": + return dummypolicy_vec() + else: + raise NotImplementedError + + +@pytest.mark.parametrize("num_env", [1, 3]) +@pytest.mark.parametrize("env_name", ["conv", "vec"]) +def test_concurrent_collector_consistency(num_env, env_name, seed=100): + if num_env == 1: + + def env_fn(seed): + env = make_make_env(env_name)() + env.set_seed(seed) + return env + + else: + + def env_fn(seed): + env = ParallelEnv( + num_workers=num_env, create_env_fn=make_make_env(env_name) + ) + env.set_seed(seed) + return env + + policy = make_policy(env_name) + + collector = SyncDataCollector( + create_env_fn=env_fn, + create_env_kwargs={"seed": seed}, + policy=policy, + frames_per_batch=20, + max_frames_per_traj=2000, + total_frames=20000, + device="cpu", + pin_memory=False, + ) + for i, d in enumerate(collector): + if i == 0: + b1 = d + elif i == 1: + b2 = d + else: + break + with pytest.raises(AssertionError): + assert_allclose_td(b1, b2) + collector.shutdown() + + ccollector = aSyncDataCollector( + create_env_fn=env_fn, + create_env_kwargs={"seed": seed}, + policy=policy, + frames_per_batch=20, + max_frames_per_traj=2000, + total_frames=20000, + pin_memory=False, + ) + for i, d in enumerate(ccollector): + if i == 0: + b1c = d + elif i == 1: + b2c = d + else: + break + with pytest.raises(AssertionError): + assert_allclose_td(b1c, b2c) + + assert_allclose_td(b1c, b1) + assert_allclose_td(b2c, b2) + + ccollector.shutdown() + + +@pytest.mark.parametrize("num_env", [1, 3]) +@pytest.mark.parametrize("env_name", ["vec", "conv"]) +def test_collector_batch_size(num_env, env_name, seed=100): + if num_env == 1: + + def env_fn(): + env = make_make_env(env_name)() + return env + + else: + + def env_fn(): + env = ParallelEnv( + num_workers=num_env, create_env_fn=make_make_env(env_name) + ) + return env + + policy = make_policy(env_name) + + torch.manual_seed(0) + np.random.seed(0) + num_workers = 4 + frames_per_batch = 20 + ccollector = MultiaSyncDataCollector( + create_env_fn=[env_fn for _ in range(num_workers)], + policy=policy, + frames_per_batch=frames_per_batch, + max_frames_per_traj=1000, + total_frames=frames_per_batch * 100, + pin_memory=False, + ) + ccollector.set_seed(seed) + for i, b in enumerate(ccollector): + assert b.numel() == -(-frames_per_batch // num_env) * num_env + if i == 5: + break + ccollector.shutdown() + + ccollector = MultiSyncDataCollector( + create_env_fn=[env_fn for _ in range(num_workers)], + policy=policy, + frames_per_batch=frames_per_batch, + max_frames_per_traj=1000, + total_frames=frames_per_batch * 100, + pin_memory=False, + ) + ccollector.set_seed(seed) + for i, b in enumerate(ccollector): + assert ( + b.numel() + == -(-frames_per_batch // num_env // num_workers) * num_env * num_workers + ) + if i == 5: + break + ccollector.shutdown() + + +@pytest.mark.parametrize("num_env", [1, 3]) +@pytest.mark.parametrize("env_name", ["vec", "conv"]) +def test_concurrent_collector_seed(num_env, env_name, seed=100): + if num_env == 1: + + def env_fn(): + env = make_make_env(env_name)() + return env + + else: + + def env_fn(): + env = ParallelEnv( + num_workers=num_env, create_env_fn=make_make_env(env_name) + ) + return env + + policy = make_policy(env_name) + + torch.manual_seed(0) + np.random.seed(0) + ccollector = aSyncDataCollector( + create_env_fn=env_fn, + create_env_kwargs={}, + policy=policy, + frames_per_batch=20, + max_frames_per_traj=20, + total_frames=300, + pin_memory=False, + ) + ccollector.set_seed(seed) + for i, data in enumerate(ccollector): + if i == 0: + b1 = data + ccollector.set_seed(seed) + elif i == 1: + b2 = data + elif i == 2: + b3 = data + else: + break + assert_allclose_td(b1, b2) + with pytest.raises(AssertionError): + assert_allclose_td(b1, b3) + ccollector.shutdown() + + +@pytest.mark.parametrize("num_env", [1, 3]) +@pytest.mark.parametrize("env_name", ["conv", "vec"]) +def test_collector_consistency(num_env, env_name, seed=100): + if num_env == 1: + + def env_fn(seed): + env = make_make_env(env_name)() + env.set_seed(seed) + return env + + else: + + def env_fn(seed): + env = ParallelEnv( + num_workers=num_env, create_env_fn=make_make_env(env_name) + ) + env.set_seed(seed) + return env + + policy = make_policy(env_name) + + torch.manual_seed(0) + np.random.seed(0) + + # Get a single rollout with dummypolicy + env = env_fn(seed) + rollout1a = env.rollout(policy=policy, n_steps=20, auto_reset=True) + env.set_seed(seed) + rollout1b = env.rollout(policy=policy, n_steps=20, auto_reset=True) + rollout2 = env.rollout(policy=policy, n_steps=20, auto_reset=True) + assert assert_allclose_td(rollout1a, rollout1b) + with pytest.raises(AssertionError): + assert_allclose_td(rollout1a, rollout2) + env.close() + + collector = SyncDataCollector( + create_env_fn=env_fn, + create_env_kwargs={"seed": seed}, + policy=policy, + frames_per_batch=20 * num_env, + max_frames_per_traj=20, + total_frames=200, + device="cpu", + pin_memory=False, + ) + collector = iter(collector) + b1 = next(collector) + b2 = next(collector) + with pytest.raises(AssertionError): + assert_allclose_td(b1, b2) + + if num_env == 1: + # rollouts collected through DataCollector are padded using pad_sequence, which introduces a first dimension + rollout1a = rollout1a.unsqueeze(0) + assert ( + rollout1a.batch_size == b1.batch_size + ), f"got batch_size {rollout1a.batch_size} and {b1.batch_size}" + + assert_allclose_td(rollout1a, b1.select(*rollout1a.keys())) + + +@pytest.mark.parametrize("num_env", [1, 3]) +@pytest.mark.parametrize("collector_class", [SyncDataCollector, aSyncDataCollector]) +@pytest.mark.parametrize("env_name", ["conv", "vec"]) +def test_traj_len_consistency(num_env, env_name, collector_class, seed=100): + """ + Tests that various frames_per_batch lead to the same results + """ + if num_env == 1: + + def env_fn(seed): + env = make_make_env(env_name)() + env.set_seed(seed) + return env + + else: + + def env_fn(seed): + env = ParallelEnv( + num_workers=num_env, create_env_fn=make_make_env(env_name) + ) + env.set_seed(seed) + return env + + max_frames_per_traj = 20 + + policy = make_policy(env_name) + + def make_frames_per_batch(frames_per_batch): + return -(-frames_per_batch // num_env) * num_env + + collector1 = collector_class( + create_env_fn=env_fn, + create_env_kwargs={"seed": seed}, + policy=policy, + frames_per_batch=1 * num_env, + max_frames_per_traj=2000, + total_frames=2 * num_env * max_frames_per_traj, + device="cpu", + seed=seed, + pin_memory=False, + ) + count = 0 + data1 = [] + for d in collector1: + data1.append(d) + count += d.shape[1] + if count > max_frames_per_traj: + break + + data1 = torch.cat(data1, 1) + data1 = data1[:, :max_frames_per_traj] + + collector1.shutdown() + del collector1 + + collector10 = collector_class( + create_env_fn=env_fn, + create_env_kwargs={"seed": seed}, + policy=policy, + frames_per_batch=10 * num_env, + max_frames_per_traj=20, + total_frames=2 * num_env * max_frames_per_traj, + device="cpu", + seed=seed, + pin_memory=False, + ) + count = 0 + data10 = [] + for d in collector10: + data10.append(d) + count += d.shape[1] + if count > max_frames_per_traj: + break + + data10 = torch.cat(data10, 1) + data10 = data10[:, :max_frames_per_traj] + + collector10.shutdown() + del collector10 + + collector20 = collector_class( + create_env_fn=env_fn, + create_env_kwargs={"seed": seed}, + policy=policy, + frames_per_batch=20 * num_env, + max_frames_per_traj=2000, + total_frames=2 * num_env * max_frames_per_traj, + device="cpu", + seed=seed, + pin_memory=False, + ) + count = 0 + data20 = [] + for d in collector20: + data20.append(d) + count += d.shape[1] + if count > max_frames_per_traj: + break + + collector20.shutdown() + del collector20 + data20 = torch.cat(data20, 1) + data20 = data20[:, :max_frames_per_traj] + + assert_allclose_td(data1, data20) + assert_allclose_td(data10, data20) + + +@pytest.mark.skipif(not _has_gym, reason="test designed with GymEnv") +def test_collector_vecnorm_envcreator(): + """ + High level test of the following pipeline: + (1) Design a function that creates an environment with VecNorm + (2) Wrap that function in an EnvCreator to instantiate the shared tensordict + (3) Create a ParallelEnv that dispatches this env across workers + (4) Run several ParallelEnv synchronously + The function tests that the tensordict gathered from the workers match at certain moments in time, and that they + are modified after the collector is run for more steps. + + """ + from torchrl.envs import GymEnv + + env_make = EnvCreator(lambda: TransformedEnv(GymEnv("Pendulum-v1"), VecNorm())) + env_make = ParallelEnv(4, env_make) + + policy = RandomPolicy(env_make.action_spec) + c = MultiSyncDataCollector( + [env_make, env_make], policy=policy, total_frames=int(1e6) + ) + final_seed = c.set_seed(0) + assert final_seed == 7 + + c_iter = iter(c) + next(c_iter) + next(c_iter) + + s = c.state_dict() + + td1 = s["worker0"]["env_state_dict"]["worker3"]["_extra_state"].clone() + td2 = s["worker1"]["env_state_dict"]["worker0"]["_extra_state"].clone() + assert (td1 == td2).all() + + next(c_iter) + next(c_iter) + + s = c.state_dict() + + td3 = s["worker0"]["env_state_dict"]["worker3"]["_extra_state"].clone() + td4 = s["worker1"]["env_state_dict"]["worker0"]["_extra_state"].clone() + assert (td3 == td4).all() + assert (td1 != td4).any() + + del c + + +@pytest.mark.parametrize("use_async", [False, True]) +@pytest.mark.skipif(torch.cuda.device_count() <= 1, reason="no cuda device found") +def test_update_weights(use_async): + policy = torch.nn.Linear(3, 4).cuda(1) + policy.share_memory() + collector_class = ( + MultiSyncDataCollector if not use_async else MultiaSyncDataCollector + ) + collector = collector_class( + [lambda: DiscreteActionVecMockEnv()] * 3, + policy=policy, + devices=[torch.device("cuda:0")] * 3, + passing_devices=[torch.device("cuda:0")] * 3, + ) + # collect state_dict + state_dict = collector.state_dict() + policy_state_dict = policy.state_dict() + for worker in range(3): + for k in state_dict[f"worker{worker}"]["policy_state_dict"]: + torch.testing.assert_allclose( + state_dict[f"worker{worker}"]["policy_state_dict"][k], + policy_state_dict[k].cpu(), + ) + + # change policy weights + for p in policy.parameters(): + p.data += torch.randn_like(p) + + # collect state_dict + state_dict = collector.state_dict() + policy_state_dict = policy.state_dict() + # check they don't match + for worker in range(3): + for k in state_dict[f"worker{worker}"]["policy_state_dict"]: + with pytest.raises(AssertionError): + torch.testing.assert_allclose( + state_dict[f"worker{worker}"]["policy_state_dict"][k], + policy_state_dict[k].cpu(), + ) + + # update weights + collector.update_policy_weights_() + + # collect state_dict + state_dict = collector.state_dict() + policy_state_dict = policy.state_dict() + for worker in range(3): + for k in state_dict[f"worker{worker}"]["policy_state_dict"]: + torch.testing.assert_allclose( + state_dict[f"worker{worker}"]["policy_state_dict"][k], + policy_state_dict[k].cpu(), + ) + + collector.shutdown() + del collector + + +def weight_reset(m): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + m.reset_parameters() + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_cost.py b/test/test_cost.py new file mode 100644 index 00000000000..15bed4b79ee --- /dev/null +++ b/test/test_cost.py @@ -0,0 +1,1324 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +from copy import deepcopy + +import numpy as np +import pytest +import torch +from _utils_internal import get_available_devices +from torch import nn +from torchrl.data import TensorDict, NdBoundedTensorSpec, MultOneHotDiscreteTensorSpec +from torchrl.data.postprocs.postprocs import MultiStep + +# from torchrl.data.postprocs.utils import expand_as_right +from torchrl.data.tensordict.tensordict import assert_allclose_td +from torchrl.data.utils import expand_as_right +from torchrl.modules import DistributionalQValueActor, QValueActor +from torchrl.modules.distributions.continuous import TanhNormal, NormalParamWrapper +from torchrl.modules.models.models import MLP +from torchrl.modules.td_module.actors import ValueOperator, Actor, ProbabilisticActor +from torchrl.objectives import ( + DQNLoss, + DoubleDQNLoss, + DistributionalDQNLoss, + DistributionalDoubleDQNLoss, + DDPGLoss, + DoubleDDPGLoss, + SACLoss, + DoubleSACLoss, + PPOLoss, + ClipPPOLoss, + KLPENPPOLoss, + GAE, +) +from torchrl.objectives.costs.common import _LossModule +from torchrl.objectives.costs.redq import ( + REDQLoss, + DoubleREDQLoss, + REDQLoss_deprecated, + DoubleREDQLoss_deprecated, +) +from torchrl.objectives.costs.utils import hold_out_net, HardUpdate, SoftUpdate + + +class _check_td_steady: + def __init__(self, td): + self.td_clone = td.clone() + self.td = td + + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_val, exc_tb): + assert (self.td.select(*self.td_clone.keys()) == self.td_clone).all() + + +def get_devices(): + devices = [torch.device("cpu")] + for i in range(torch.cuda.device_count()): + devices += [torch.device(f"cuda:{i}")] + return devices + + +class TestDQN: + seed = 0 + + def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): + # Actor + action_spec = NdBoundedTensorSpec( + -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) + ) + module = nn.Linear(obs_dim, action_dim) + actor = QValueActor( + spec=action_spec, + module=module, + ).to(device) + return actor + + def _create_mock_distributional_actor( + self, batch=2, obs_dim=3, action_dim=4, atoms=5, vmin=1, vmax=5 + ): + # Actor + action_spec = MultOneHotDiscreteTensorSpec([atoms] * action_dim) + support = torch.linspace(vmin, vmax, atoms, dtype=torch.float) + module = MLP(obs_dim, (atoms, action_dim)) + actor = DistributionalQValueActor( + spec=action_spec, + module=module, + support=support, + ) + return actor + + def _create_mock_data_dqn( + self, batch=2, obs_dim=3, action_dim=4, atoms=None, device="cpu" + ): + # create a tensordict + obs = torch.randn(batch, obs_dim) + next_obs = torch.randn(batch, obs_dim) + if atoms: + action_value = torch.randn(batch, atoms, action_dim).softmax(-2) + action = ( + action_value[..., 0, :] == action_value[..., 0, :].max(-1, True)[0] + ).to(torch.long) + else: + action_value = torch.randn(batch, action_dim) + action = (action_value == action_value.max(-1, True)[0]).to(torch.long) + reward = torch.randn(batch, 1) + done = torch.zeros(batch, 1, dtype=torch.bool) + td = TensorDict( + batch_size=(batch,), + source={ + "observation": obs, + "next_observation": next_obs, + "done": done, + "reward": reward, + "action": action, + "action_value": action_value, + }, + ).to(device) + return td + + def _create_seq_mock_data_dqn( + self, batch=2, T=4, obs_dim=3, action_dim=4, atoms=None, device="cpu" + ): + # create a tensordict + total_obs = torch.randn(batch, T + 1, obs_dim, device=device) + obs = total_obs[:, :T] + next_obs = total_obs[:, 1:] + if atoms: + action_value = torch.randn( + batch, T, atoms, action_dim, device=device + ).softmax(-2) + action = ( + action_value[..., 0, :] == action_value[..., 0, :].max(-1, True)[0] + ).to(torch.long) + else: + action_value = torch.randn(batch, T, action_dim, device=device) + action = (action_value == action_value.max(-1, True)[0]).to(torch.long) + reward = torch.randn(batch, T, 1, device=device) + done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + td = TensorDict( + batch_size=(batch, T), + source={ + "observation": obs * mask.to(obs.dtype), + "next_observation": next_obs * mask.to(obs.dtype), + "done": done, + "mask": mask, + "reward": reward * mask.to(obs.dtype), + "action": action * mask.to(obs.dtype), + "action_value": action_value + * expand_as_right(mask.to(obs.dtype).squeeze(-1), action_value), + }, + ) + return td + + @pytest.mark.parametrize("loss_class", (DQNLoss, DoubleDQNLoss)) + @pytest.mark.parametrize("device", get_available_devices()) + def test_dqn(self, loss_class, device): + torch.manual_seed(self.seed) + actor = self._create_mock_actor(device=device) + td = self._create_mock_data_dqn(device=device) + loss_fn = loss_class(actor, gamma=0.9, loss_function="l2") + with _check_td_steady(td): + loss = loss_fn(td) + assert loss_fn.priority_key in td.keys() + + sum([item for _, item in loss.items()]).backward() + assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0 + + # Check param update effect on targets + target_value = [p.clone() for p in loss_fn.target_value_network_params] + for p in loss_fn.parameters(): + p.data += torch.randn_like(p) + target_value2 = [p.clone() for p in loss_fn.target_value_network_params] + if loss_fn.delay_value: + assert all((p1 == p2).all() for p1, p2 in zip(target_value, target_value2)) + else: + assert not any( + (p1 == p2).any() for p1, p2 in zip(target_value, target_value2) + ) + + # check that policy is updated after parameter update + parameters = [p.clone() for p in actor.parameters()] + for p in loss_fn.parameters(): + p.data += torch.randn_like(p) + assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())) + + @pytest.mark.parametrize("n", range(4)) + @pytest.mark.parametrize("loss_class", (DQNLoss, DoubleDQNLoss)) + @pytest.mark.parametrize("device", get_available_devices()) + def test_dqn_batcher(self, n, loss_class, device, gamma=0.9): + torch.manual_seed(self.seed) + actor = self._create_mock_actor(device=device) + + td = self._create_seq_mock_data_dqn(device=device) + loss_fn = loss_class(actor, gamma=gamma, loss_function="l2") + + ms = MultiStep(gamma=gamma, n_steps_max=n).to(device) + ms_td = ms(td.clone()) + + with _check_td_steady(ms_td): + loss_ms = loss_fn(ms_td) + assert loss_fn.priority_key in ms_td.keys() + + with torch.no_grad(): + loss = loss_fn(td) + if n == 0: + assert_allclose_td(td, ms_td.select(*list(td.keys()))) + _loss = sum([item for _, item in loss.items()]) + _loss_ms = sum([item for _, item in loss_ms.items()]) + assert ( + abs(_loss - _loss_ms) < 1e-3 + ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" + else: + with pytest.raises(AssertionError): + assert_allclose_td(loss, loss_ms) + sum([item for _, item in loss_ms.items()]).backward() + assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0 + + # Check param update effect on targets + target_value = [p.clone() for p in loss_fn.target_value_network_params] + for p in loss_fn.parameters(): + p.data += torch.randn_like(p) + target_value2 = [p.clone() for p in loss_fn.target_value_network_params] + if loss_fn.delay_value: + assert all((p1 == p2).all() for p1, p2 in zip(target_value, target_value2)) + else: + assert not any( + (p1 == p2).any() for p1, p2 in zip(target_value, target_value2) + ) + + # check that policy is updated after parameter update + parameters = [p.clone() for p in actor.parameters()] + for p in loss_fn.parameters(): + p.data += torch.randn_like(p) + assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())) + + @pytest.mark.parametrize("atoms", range(4, 10)) + @pytest.mark.parametrize( + "loss_class", (DistributionalDQNLoss, DistributionalDoubleDQNLoss) + ) + @pytest.mark.parametrize("device", get_devices()) + def test_distributional_dqn(self, atoms, loss_class, device, gamma=0.9): + torch.manual_seed(self.seed) + actor = self._create_mock_distributional_actor(atoms=atoms).to(device) + + td = self._create_mock_data_dqn(atoms=atoms).to(device) + loss_fn = loss_class(actor, gamma=gamma) + + with _check_td_steady(td): + loss = loss_fn(td) + assert loss_fn.priority_key in td.keys() + + sum([item for _, item in loss.items()]).backward() + assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0 + + # Check param update effect on targets + target_value = [p.clone() for p in loss_fn.target_value_network_params] + for p in loss_fn.parameters(): + p.data += torch.randn_like(p) + target_value2 = [p.clone() for p in loss_fn.target_value_network_params] + if loss_fn.delay_value: + assert all((p1 == p2).all() for p1, p2 in zip(target_value, target_value2)) + else: + assert not any( + (p1 == p2).any() for p1, p2 in zip(target_value, target_value2) + ) + + # check that policy is updated after parameter update + parameters = [p.clone() for p in actor.parameters()] + for p in loss_fn.parameters(): + p.data += torch.randn_like(p) + assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())) + + +class TestDDPG: + seed = 0 + + def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): + # Actor + action_spec = NdBoundedTensorSpec( + -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) + ) + module = nn.Linear(obs_dim, action_dim) + actor = Actor( + spec=action_spec, + module=module, + ) + return actor.to(device) + + def _create_mock_value(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): + # Actor + class ValueClass(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(obs_dim + action_dim, 1) + + def forward(self, obs, act): + return self.linear(torch.cat([obs, act], -1)) + + module = ValueClass() + value = ValueOperator( + module=module, + in_keys=["observation", "action"], + ) + return value.to(device) + + def _create_mock_distributional_actor( + self, batch=2, obs_dim=3, action_dim=4, atoms=5, vmin=1, vmax=5 + ): + raise NotImplementedError + + def _create_mock_data_ddpg( + self, batch=8, obs_dim=3, action_dim=4, atoms=None, device="cpu" + ): + # create a tensordict + obs = torch.randn(batch, obs_dim, device=device) + next_obs = torch.randn(batch, obs_dim, device=device) + if atoms: + raise NotImplementedError + else: + action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) + reward = torch.randn(batch, 1, device=device) + done = torch.zeros(batch, 1, dtype=torch.bool, device=device) + td = TensorDict( + batch_size=(batch,), + source={ + "observation": obs, + "next_observation": next_obs, + "done": done, + "reward": reward, + "action": action, + }, + ) + return td + + def _create_seq_mock_data_ddpg( + self, batch=8, T=4, obs_dim=3, action_dim=4, atoms=None, device="cpu" + ): + # create a tensordict + total_obs = torch.randn(batch, T + 1, obs_dim, device=device) + obs = total_obs[:, :T] + next_obs = total_obs[:, 1:] + if atoms: + action = torch.randn(batch, T, atoms, action_dim, device=device).clamp( + -1, 1 + ) + else: + action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) + reward = torch.randn(batch, T, 1, device=device) + done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + td = TensorDict( + batch_size=(batch, T), + source={ + "observation": obs * mask.to(obs.dtype), + "next_observation": next_obs * mask.to(obs.dtype), + "done": done, + "mask": mask, + "reward": reward * mask.to(obs.dtype), + "action": action * mask.to(obs.dtype), + }, + ) + return td + + @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize("loss_class", (DDPGLoss, DoubleDDPGLoss)) + def test_ddpg(self, loss_class, device): + torch.manual_seed(self.seed) + actor = self._create_mock_actor(device=device) + value = self._create_mock_value(device=device) + td = self._create_mock_data_ddpg(device=device) + loss_fn = loss_class(actor, value, gamma=0.9, loss_function="l2") + with _check_td_steady(td): + loss = loss_fn(td) + + # check that loss are independent + for k in loss.keys(): + if not k.startswith("loss"): + continue + loss[k].sum().backward(retain_graph=True) + if k == "loss_actor": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.value_network_params + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params + ) + elif k == "loss_value": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.value_network_params + ) + else: + raise NotImplementedError(k) + loss_fn.zero_grad() + + # check overall grad + sum([item for _, item in loss.items()]).backward() + parameters = list(actor.parameters()) + list(value.parameters()) + for p in parameters: + assert p.grad.norm() > 0.0 + + # Check param update effect on targets + target_actor = [p.clone() for p in loss_fn.target_actor_network_params] + target_value = [p.clone() for p in loss_fn.target_value_network_params] + for p in loss_fn.parameters(): + p.data += torch.randn_like(p) + target_actor2 = [p.clone() for p in loss_fn.target_actor_network_params] + target_value2 = [p.clone() for p in loss_fn.target_value_network_params] + if loss_fn.delay_actor: + assert all((p1 == p2).all() for p1, p2 in zip(target_actor, target_actor2)) + else: + assert not any( + (p1 == p2).any() for p1, p2 in zip(target_actor, target_actor2) + ) + if loss_fn.delay_value: + assert all((p1 == p2).all() for p1, p2 in zip(target_value, target_value2)) + else: + assert not any( + (p1 == p2).any() for p1, p2 in zip(target_value, target_value2) + ) + + # check that policy is updated after parameter update + parameters = [p.clone() for p in actor.parameters()] + for p in loss_fn.parameters(): + p.data += torch.randn_like(p) + assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())) + + @pytest.mark.parametrize("n", list(range(4))) + @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize("loss_class", (DDPGLoss, DoubleDDPGLoss)) + def test_ddpg_batcher(self, n, loss_class, device, gamma=0.9): + torch.manual_seed(self.seed) + actor = self._create_mock_actor(device=device) + value = self._create_mock_value(device=device) + td = self._create_seq_mock_data_ddpg(device=device) + loss_fn = loss_class(actor, value, gamma=gamma, loss_function="l2") + + ms = MultiStep(gamma=gamma, n_steps_max=n).to(device) + ms_td = ms(td.clone()) + with _check_td_steady(ms_td): + loss_ms = loss_fn(ms_td) + with torch.no_grad(): + loss = loss_fn(td) + if n == 0: + assert_allclose_td(td, ms_td.select(*list(td.keys()))) + _loss = sum([item for _, item in loss.items()]) + _loss_ms = sum([item for _, item in loss_ms.items()]) + assert ( + abs(_loss - _loss_ms) < 1e-3 + ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" + else: + with pytest.raises(AssertionError): + assert_allclose_td(loss, loss_ms) + sum([item for _, item in loss_ms.items()]).backward() + parameters = list(actor.parameters()) + list(value.parameters()) + for p in parameters: + assert p.grad.norm() > 0.0 + + +class TestSAC: + seed = 0 + + def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): + # Actor + action_spec = NdBoundedTensorSpec( + -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) + ) + module = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) + actor = ProbabilisticActor( + spec=action_spec, + module=module, + distribution_class=TanhNormal, + ) + return actor.to(device) + + def _create_mock_qvalue(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): + class ValueClass(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(obs_dim + action_dim, 1) + + def forward(self, obs, act): + return self.linear(torch.cat([obs, act], -1)) + + module = ValueClass() + qvalue = ValueOperator( + module=module, + in_keys=["observation", "action"], + ) + return qvalue.to(device) + + def _create_mock_value(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): + module = nn.Linear(obs_dim, 1) + value = ValueOperator( + module=module, + in_keys=["observation"], + ) + return value.to(device) + + def _create_mock_distributional_actor( + self, batch=2, obs_dim=3, action_dim=4, atoms=5, vmin=1, vmax=5 + ): + raise NotImplementedError + + def _create_mock_data_sac( + self, batch=16, obs_dim=3, action_dim=4, atoms=None, device="cpu" + ): + # create a tensordict + obs = torch.randn(batch, obs_dim, device=device) + next_obs = torch.randn(batch, obs_dim, device=device) + if atoms: + raise NotImplementedError + else: + action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) + reward = torch.randn(batch, 1, device=device) + done = torch.zeros(batch, 1, dtype=torch.bool, device=device) + td = TensorDict( + batch_size=(batch,), + source={ + "observation": obs, + "next_observation": next_obs, + "done": done, + "reward": reward, + "action": action, + }, + ) + return td + + def _create_seq_mock_data_sac( + self, batch=8, T=4, obs_dim=3, action_dim=4, atoms=None, device="cpu" + ): + # create a tensordict + total_obs = torch.randn(batch, T + 1, obs_dim, device=device) + obs = total_obs[:, :T] + next_obs = total_obs[:, 1:] + if atoms: + action = torch.randn(batch, T, atoms, action_dim, device=device).clamp( + -1, 1 + ) + else: + action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) + reward = torch.randn(batch, T, 1, device=device) + done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + td = TensorDict( + batch_size=(batch, T), + source={ + "observation": obs * mask.to(obs.dtype), + "next_observation": next_obs * mask.to(obs.dtype), + "done": done, + "mask": mask, + "reward": reward * mask.to(obs.dtype), + "action": action * mask.to(obs.dtype), + }, + ) + return td + + @pytest.mark.parametrize("loss_class", (SACLoss, DoubleSACLoss)) + @pytest.mark.parametrize("delay_actor", (True, False)) + @pytest.mark.parametrize("delay_qvalue", (True, False)) + @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8]) + @pytest.mark.parametrize("device", get_available_devices()) + def test_sac(self, loss_class, delay_actor, delay_qvalue, num_qvalue, device): + if (delay_actor or delay_qvalue) and loss_class is not DoubleSACLoss: + pytest.skip("incompatible config") + + torch.manual_seed(self.seed) + td = self._create_mock_data_sac(device=device) + + actor = self._create_mock_actor(device=device) + qvalue = self._create_mock_qvalue(device=device) + value = self._create_mock_value(device=device) + + kwargs = {} + if delay_actor: + kwargs["delay_actor"] = True + if delay_qvalue: + kwargs["delay_qvalue"] = True + + loss_fn = loss_class( + actor_network=actor, + qvalue_network=qvalue, + value_network=value, + num_qvalue_nets=num_qvalue, + gamma=0.9, + loss_function="l2", + **kwargs, + ) + + with _check_td_steady(td): + loss = loss_fn(td) + assert loss_fn.priority_key in td.keys() + + # check that loss are independent + for k in loss.keys(): + if not k.startswith("loss"): + continue + loss[k].sum().backward(retain_graph=True) + if k == "loss_actor": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.value_network_params + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params + ) + elif k == "loss_value": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.value_network_params + ) + elif k == "loss_qvalue": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.value_network_params + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params + ) + elif k == "loss_alpha": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.value_network_params + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params + ) + else: + raise NotImplementedError(k) + loss_fn.zero_grad() + + sum([item for _, item in loss.items()]).backward() + named_parameters = list(loss_fn.named_parameters()) + named_buffers = list(loss_fn.named_buffers()) + + assert len(set(p for n, p in named_parameters)) == len(list(named_parameters)) + assert len(set(p for n, p in named_buffers)) == len(list(named_buffers)) + + for name, p in named_parameters: + assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient" + + @pytest.mark.parametrize("n", list(range(4))) + @pytest.mark.parametrize("loss_class", (SACLoss, DoubleSACLoss)) + @pytest.mark.parametrize("delay_actor", (True, False)) + @pytest.mark.parametrize("delay_qvalue", (True, False)) + @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8]) + @pytest.mark.parametrize("device", get_available_devices()) + def test_sac_batcher( + self, n, loss_class, delay_actor, delay_qvalue, num_qvalue, device, gamma=0.9 + ): + if (delay_actor or delay_qvalue) and (loss_class is not DoubleSACLoss): + pytest.skip("incompatible config") + torch.manual_seed(self.seed) + td = self._create_seq_mock_data_sac(device=device) + + actor = self._create_mock_actor(device=device) + qvalue = self._create_mock_qvalue(device=device) + value = self._create_mock_value(device=device) + + kwargs = {} + if delay_actor: + kwargs["delay_actor"] = True + if delay_qvalue: + kwargs["delay_qvalue"] = True + + loss_fn = loss_class( + actor_network=actor, + qvalue_network=qvalue, + value_network=value, + num_qvalue_nets=num_qvalue, + gamma=0.9, + loss_function="l2", + **kwargs, + ) + + ms = MultiStep(gamma=gamma, n_steps_max=n).to(device) + + td_clone = td.clone() + ms_td = ms(td_clone) + + torch.manual_seed(0) + np.random.seed(0) + with _check_td_steady(ms_td): + loss_ms = loss_fn(ms_td) + assert loss_fn.priority_key in ms_td.keys() + + with torch.no_grad(): + torch.manual_seed(0) # log-prob is computed with a random action + np.random.seed(0) + loss = loss_fn(td) + if n == 0: + assert_allclose_td(td, ms_td.select(*list(td.keys()))) + _loss = sum([item for _, item in loss.items()]) + _loss_ms = sum([item for _, item in loss_ms.items()]) + assert ( + abs(_loss - _loss_ms) < 1e-3 + ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" + else: + with pytest.raises(AssertionError): + assert_allclose_td(loss, loss_ms) + sum([item for _, item in loss_ms.items()]).backward() + named_parameters = loss_fn.named_parameters() + for name, p in named_parameters: + assert p.grad.norm() > 0.0, f"parameter {name} has null gradient" + + # Check param update effect on targets + target_actor = [p.clone() for p in loss_fn.target_actor_network_params] + target_qvalue = [p.clone() for p in loss_fn.target_qvalue_network_params] + target_value = [p.clone() for p in loss_fn.target_value_network_params] + for p in loss_fn.parameters(): + p.data += torch.randn_like(p) + target_actor2 = [p.clone() for p in loss_fn.target_actor_network_params] + target_qvalue2 = [p.clone() for p in loss_fn.target_qvalue_network_params] + target_value2 = [p.clone() for p in loss_fn.target_value_network_params] + if loss_fn.delay_actor: + assert all((p1 == p2).all() for p1, p2 in zip(target_actor, target_actor2)) + else: + assert not any( + (p1 == p2).any() for p1, p2 in zip(target_actor, target_actor2) + ) + if loss_fn.delay_qvalue: + assert all( + (p1 == p2).all() for p1, p2 in zip(target_qvalue, target_qvalue2) + ) + else: + assert not any( + (p1 == p2).any() for p1, p2 in zip(target_qvalue, target_qvalue2) + ) + if loss_fn.delay_value: + assert all((p1 == p2).all() for p1, p2 in zip(target_value, target_value2)) + else: + assert not any( + (p1 == p2).any() for p1, p2 in zip(target_value, target_value2) + ) + + # check that policy is updated after parameter update + parameters = [p.clone() for p in actor.parameters()] + for p in loss_fn.parameters(): + p.data += torch.randn_like(p) + assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())) + + +class TestREDQ: + seed = 0 + + def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): + # Actor + action_spec = NdBoundedTensorSpec( + -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) + ) + module = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) + actor = ProbabilisticActor( + spec=action_spec, + module=module, + distribution_class=TanhNormal, + return_log_prob=True, + ) + return actor.to(device) + + def _create_mock_qvalue(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): + class ValueClass(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(obs_dim + action_dim, 1) + + def forward(self, obs, act): + return self.linear(torch.cat([obs, act], -1)) + + module = ValueClass() + qvalue = ValueOperator( + module=module, + in_keys=["observation", "action"], + ) + return qvalue.to(device) + + def _create_mock_data_redq( + self, batch=16, obs_dim=3, action_dim=4, atoms=None, device="cpu" + ): + # create a tensordict + obs = torch.randn(batch, obs_dim, device=device) + next_obs = torch.randn(batch, obs_dim, device=device) + if atoms: + raise NotImplementedError + else: + action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) + reward = torch.randn(batch, 1, device=device) + done = torch.zeros(batch, 1, dtype=torch.bool, device=device) + td = TensorDict( + batch_size=(batch,), + source={ + "observation": obs, + "next_observation": next_obs, + "done": done, + "reward": reward, + "action": action, + }, + ) + return td + + def _create_seq_mock_data_redq( + self, batch=8, T=4, obs_dim=3, action_dim=4, atoms=None, device="cpu" + ): + # create a tensordict + total_obs = torch.randn(batch, T + 1, obs_dim, device=device) + obs = total_obs[:, :T] + next_obs = total_obs[:, 1:] + if atoms: + action = torch.randn(batch, T, atoms, action_dim, device=device).clamp( + -1, 1 + ) + else: + action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) + reward = torch.randn(batch, T, 1, device=device) + done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + td = TensorDict( + batch_size=(batch, T), + source={ + "observation": obs * mask.to(obs.dtype), + "next_observation": next_obs * mask.to(obs.dtype), + "done": done, + "mask": mask, + "reward": reward * mask.to(obs.dtype), + "action": action * mask.to(obs.dtype), + }, + ) + return td + + @pytest.mark.parametrize("loss_class", (REDQLoss, DoubleREDQLoss)) + @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8]) + @pytest.mark.parametrize("device", get_available_devices()) + def test_redq(self, loss_class, num_qvalue, device): + + torch.manual_seed(self.seed) + td = self._create_mock_data_redq(device=device) + + actor = self._create_mock_actor(device=device) + qvalue = self._create_mock_qvalue(device=device) + + loss_fn = loss_class( + actor_network=actor, + qvalue_network=qvalue, + num_qvalue_nets=num_qvalue, + gamma=0.9, + loss_function="l2", + ) + + with _check_td_steady(td): + loss = loss_fn(td) + + # check td is left untouched + assert loss_fn.priority_key in td.keys() + + # check that loss are independent + for k in loss.keys(): + if not k.startswith("loss"): + continue + loss[k].sum().backward(retain_graph=True) + if k == "loss_actor": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params + ) + elif k == "loss_qvalue": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params + ) + elif k == "loss_alpha": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params + ) + else: + raise NotImplementedError(k) + loss_fn.zero_grad() + + sum([item for _, item in loss.items()]).backward() + named_parameters = list(loss_fn.named_parameters()) + named_buffers = list(loss_fn.named_buffers()) + + assert len(set(p for n, p in named_parameters)) == len(list(named_parameters)) + assert len(set(p for n, p in named_buffers)) == len(list(named_buffers)) + + for name, p in named_parameters: + assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient" + + @pytest.mark.parametrize("loss_class", (REDQLoss, DoubleREDQLoss)) + @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8]) + @pytest.mark.parametrize("device", get_available_devices()) + def test_redq_batched(self, loss_class, num_qvalue, device): + + torch.manual_seed(self.seed) + td = self._create_mock_data_redq(device=device) + + actor = self._create_mock_actor(device=device) + qvalue = self._create_mock_qvalue(device=device) + + loss_fn = loss_class( + actor_network=deepcopy(actor), + qvalue_network=deepcopy(qvalue), + num_qvalue_nets=num_qvalue, + gamma=0.9, + loss_function="l2", + ) + + loss_class_deprec = ( + REDQLoss_deprecated if loss_class is REDQLoss else DoubleREDQLoss_deprecated + ) + loss_fn_deprec = loss_class_deprec( + actor_network=deepcopy(actor), + qvalue_network=deepcopy(qvalue), + num_qvalue_nets=num_qvalue, + gamma=0.9, + loss_function="l2", + ) + + td_clone1 = td.clone() + td_clone2 = td.clone() + torch.manual_seed(0) + with _check_td_steady(td_clone1): + loss1 = loss_fn(td_clone1) + + torch.manual_seed(0) + with _check_td_steady(td_clone2): + loss2 = loss_fn_deprec(td_clone2) + + # TODO: find a way to compare the losses: problem is that we sample actions either sequentially or in batch, + # so setting seed has little impact + + @pytest.mark.parametrize("n", list(range(4))) + @pytest.mark.parametrize("loss_class", (REDQLoss, DoubleREDQLoss)) + @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8]) + @pytest.mark.parametrize("device", get_available_devices()) + def test_redq_batcher(self, n, loss_class, num_qvalue, device, gamma=0.9): + torch.manual_seed(self.seed) + td = self._create_seq_mock_data_redq(device=device) + + actor = self._create_mock_actor(device=device) + qvalue = self._create_mock_qvalue(device=device) + + loss_fn = loss_class( + actor_network=actor, + qvalue_network=qvalue, + num_qvalue_nets=num_qvalue, + gamma=0.9, + loss_function="l2", + ) + + ms = MultiStep(gamma=gamma, n_steps_max=n).to(device) + + td_clone = td.clone() + ms_td = ms(td_clone) + + torch.manual_seed(0) + np.random.seed(0) + + with _check_td_steady(ms_td): + loss_ms = loss_fn(ms_td) + assert loss_fn.priority_key in ms_td.keys() + + with torch.no_grad(): + torch.manual_seed(0) # log-prob is computed with a random action + np.random.seed(0) + loss = loss_fn(td) + if n == 0: + assert_allclose_td(td, ms_td.select(*list(td.keys()))) + _loss = sum([item for _, item in loss.items()]) + _loss_ms = sum([item for _, item in loss_ms.items()]) + assert ( + abs(_loss - _loss_ms) < 1e-3 + ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" + else: + with pytest.raises(AssertionError): + assert_allclose_td(loss, loss_ms) + sum([item for _, item in loss_ms.items()]).backward() + named_parameters = loss_fn.named_parameters() + for name, p in named_parameters: + assert p.grad.norm() > 0.0, f"parameter {name} has null gradient" + + # Check param update effect on targets + target_actor = [p.clone() for p in loss_fn.target_actor_network_params] + target_qvalue = [p.clone() for p in loss_fn.target_qvalue_network_params] + for p in loss_fn.parameters(): + p.data += torch.randn_like(p) + target_actor2 = [p.clone() for p in loss_fn.target_actor_network_params] + target_qvalue2 = [p.clone() for p in loss_fn.target_qvalue_network_params] + if loss_fn.delay_actor: + assert all((p1 == p2).all() for p1, p2 in zip(target_actor, target_actor2)) + else: + assert not any( + (p1 == p2).any() for p1, p2 in zip(target_actor, target_actor2) + ) + if loss_fn.delay_qvalue: + assert all( + (p1 == p2).all() for p1, p2 in zip(target_qvalue, target_qvalue2) + ) + else: + assert not any( + (p1 == p2).any() for p1, p2 in zip(target_qvalue, target_qvalue2) + ) + + # check that policy is updated after parameter update + actorp_set = set(actor.parameters()) + loss_fnp_set = set(loss_fn.parameters()) + assert len(actorp_set.intersection(loss_fnp_set)) == len(actorp_set) + parameters = [p.clone() for p in actor.parameters()] + for p in loss_fn.parameters(): + p.data += torch.randn_like(p) + assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())) + + +class TestPPO: + seed = 0 + + def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): + # Actor + action_spec = NdBoundedTensorSpec( + -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) + ) + module = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) + actor = ProbabilisticActor( + spec=action_spec, + module=module, + distribution_class=TanhNormal, + save_dist_params=True, + ) + return actor.to(device) + + def _create_mock_value(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): + module = nn.Linear(obs_dim, 1) + value = ValueOperator( + module=module, + in_keys=["observation"], + ) + return value.to(device) + + def _create_mock_distributional_actor( + self, batch=2, obs_dim=3, action_dim=4, atoms=0, vmin=1, vmax=5 + ): + raise NotImplementedError + + def _create_mock_data_ppo( + self, batch=2, obs_dim=3, action_dim=4, atoms=None, device="cpu" + ): + # create a tensordict + obs = torch.randn(batch, obs_dim, device=device) + next_obs = torch.randn(batch, obs_dim, device=device) + if atoms: + raise NotImplementedError + else: + action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) + reward = torch.randn(batch, 1, device=device) + done = torch.zeros(batch, 1, dtype=torch.bool, device=device) + td = TensorDict( + batch_size=(batch,), + source={ + "observation": obs, + "next_observation": next_obs, + "done": done, + "reward": reward, + "action": action, + "action_log_prob": torch.randn_like(action[..., :1]) / 10, + }, + ) + return td + + def _create_seq_mock_data_ppo( + self, batch=2, T=4, obs_dim=3, action_dim=4, atoms=None, device="cpu" + ): + # create a tensordict + total_obs = torch.randn(batch, T + 1, obs_dim, device=device) + obs = total_obs[:, :T] + next_obs = total_obs[:, 1:] + if atoms: + action = torch.randn(batch, T, atoms, action_dim, device=device).clamp( + -1, 1 + ) + else: + action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) + reward = torch.randn(batch, T, 1, device=device) + done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + params_mean = torch.randn_like(action.repeat(1, 1, 2)) / 10 + params_scale = torch.rand_like(action.repeat(1, 1, 2)) / 10 + td = TensorDict( + batch_size=(batch, T), + source={ + "observation": obs * mask.to(obs.dtype), + "next_observation": next_obs * mask.to(obs.dtype), + "done": done, + "mask": mask, + "reward": reward * mask.to(obs.dtype), + "action": action * mask.to(obs.dtype), + "action_log_prob": torch.randn_like(action[..., :1]) + / 10 + * mask.to(obs.dtype), + "action_dist_param_0": params_mean * mask.to(obs.dtype), + "action_dist_param_1": params_scale * mask.to(obs.dtype), + }, + ) + return td + + @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) + @pytest.mark.parametrize("device", get_available_devices()) + def test_ppo(self, loss_class, device): + torch.manual_seed(self.seed) + td = self._create_seq_mock_data_ppo(device=device) + + actor = self._create_mock_actor(device=device) + value = self._create_mock_value(device=device) + gae = GAE(gamma=0.9, lamda=0.9, critic=value) + loss_fn = loss_class( + actor, value, advantage_module=gae, gamma=0.9, loss_critic_type="l2" + ) + + loss = loss_fn(td) + sum([item for _, item in loss.items()]).backward() + named_parameters = loss_fn.named_parameters() + for name, p in named_parameters: + assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient" + + +def test_hold_out(): + net = torch.nn.Linear(3, 4) + x = torch.randn(1, 3) + x_rg = torch.randn(1, 3, requires_grad=True) + y = net(x) + assert y.requires_grad + with hold_out_net(net): + y = net(x) + assert not y.requires_grad + y = net(x_rg) + assert y.requires_grad + + y = net(x) + assert y.requires_grad + + # nested case + with hold_out_net(net): + y = net(x) + assert not y.requires_grad + with hold_out_net(net): + y = net(x) + assert not y.requires_grad + y = net(x_rg) + assert y.requires_grad + + y = net(x) + assert y.requires_grad + + # exception + with pytest.raises( + RuntimeError, + match="hold_out_net requires the network parameter set to be non-empty.", + ): + net = torch.nn.Sequential() + with hold_out_net(net): + pass + + +@pytest.mark.parametrize("mode", ["hard", "soft"]) +@pytest.mark.parametrize("value_network_update_interval", [100, 1000]) +@pytest.mark.parametrize("device", get_available_devices()) +def test_updater(mode, value_network_update_interval, device): + torch.manual_seed(100) + + class custom_module_error(nn.Module): + def __init__(self): + super().__init__() + self._target_params = [torch.randn(3, 4)] + self._target_error_params = [torch.randn(3, 4)] + self.params = nn.ParameterList( + [nn.Parameter(torch.randn(3, 4, requires_grad=True))] + ) + + module = custom_module_error().to(device) + with pytest.raises( + RuntimeError, match="Your module seems to have a _target tensor list " + ): + if mode == "hard": + upd = HardUpdate(module, value_network_update_interval) + elif mode == "soft": + upd = SoftUpdate(module, 1 - 1 / value_network_update_interval) + + class custom_module(_LossModule): + def __init__(self): + super().__init__() + module1 = torch.nn.BatchNorm2d(10).eval() + self.convert_to_functional(module1, "module1", create_target_params=True) + module2 = torch.nn.BatchNorm2d(10).eval() + self.module2 = module2 + for target in self.target_module1_params: + target.data.normal_() + for target in self.target_module1_buffers: + if target.dtype is not torch.int64: + target.data.normal_() + else: + target.data += 10 + + module = custom_module().to(device) + if mode == "hard": + upd = HardUpdate( + module, value_network_update_interval=value_network_update_interval + ) + elif mode == "soft": + upd = SoftUpdate(module, 1 - 1 / value_network_update_interval) + upd.init_() + for _, v in upd._targets.items(): + for _v in v: + if _v.dtype is not torch.int64: + _v.copy_(torch.randn_like(_v)) + else: + _v += 10 + + # total dist + d0 = sum( + [ + (target_val[0] - val[0]).norm().item() + for (_, target_val), (_, val) in zip( + upd._targets.items(), upd._sources.items() + ) + ] + ) + assert d0 > 0 + if mode == "hard": + for i in range(value_network_update_interval + 1): + d1 = sum( + [ + (target_val[0] - val[0]).norm().item() + for (_, target_val), (_, val) in zip( + upd._targets.items(), upd._sources.items() + ) + ] + ) + assert d1 == d0, i + assert upd.counter == i + upd.step() + assert upd.counter == 0 + d1 = sum( + [ + (target_val[0] - val[0]).norm().item() + for (_, target_val), (_, val) in zip( + upd._targets.items(), upd._sources.items() + ) + ] + ) + assert d1 < d0 + + elif mode == "soft": + upd.step() + d1 = sum( + [ + (target_val[0] - val[0]).norm().item() + for (_, target_val), (_, val) in zip( + upd._targets.items(), upd._sources.items() + ) + ] + ) + assert d1 < d0 + + upd.init_() + upd.step() + d2 = sum( + [ + (target_val[0] - val[0]).norm().item() + for (_, target_val), (_, val) in zip( + upd._targets.items(), upd._sources.items() + ) + ] + ) + assert d2 < 1e-6 + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_distributions.py b/test/test_distributions.py new file mode 100644 index 00000000000..55cf4af3c43 --- /dev/null +++ b/test/test_distributions.py @@ -0,0 +1,175 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse + +import pytest +import torch +from _utils_internal import get_available_devices +from torch import nn +from torchrl.data.tensordict.tensordict import _TensorDict +from torchrl.modules import ( + TanhNormal, + NormalParamWrapper, + TruncatedNormal, + OneHotCategorical, +) +from torchrl.modules.distributions import TanhDelta, Delta + + +@pytest.mark.parametrize("device", get_available_devices()) +def test_delta(device): + x = torch.randn(1000000, 4, device=device) + d = Delta(x) + assert d.log_prob(d.mode).shape == x.shape[:-1] + assert (d.log_prob(d.mode) == float("inf")).all() + + x = torch.randn(1000000, 4, device=device) + d = TanhDelta(x, -1, 1.0, atol=1e-4, rtol=1e-4) + xinv = d.transforms[0].inv(d.mode) + assert d.base_dist._is_equal(xinv).all() + assert d.log_prob(d.mode).shape == x.shape[:-1] + assert (d.log_prob(d.mode) == float("inf")).all() + + +def _map_all(*tensors_or_other, device): + for t in tensors_or_other: + if isinstance(t, (torch.Tensor, _TensorDict)): + yield t.to(device) + else: + yield t + + +@pytest.mark.parametrize( + "min", [-torch.ones(3), -1, 3 * torch.tensor([-1.0, -2.0, -0.5]), -0.1] +) +@pytest.mark.parametrize( + "max", [torch.ones(3), 1, 3 * torch.tensor([1.0, 2.0, 0.5]), 0.1] +) +@pytest.mark.parametrize( + "vecs", + [ + (torch.tensor([0.1, 10.0, 5.0]), torch.tensor([0.1, 10.0, 5.0])), + (torch.zeros(7, 3), torch.ones(7, 3)), + ], +) +@pytest.mark.parametrize( + "upscale", [torch.ones(3), 1, 3 * torch.tensor([1.0, 2.0, 0.5]), 3] +) +@pytest.mark.parametrize("shape", [torch.Size([]), torch.Size([3, 4])]) +@pytest.mark.parametrize("device", get_available_devices()) +def test_tanhnormal(min, max, vecs, upscale, shape, device): + min, max, vecs, upscale, shape = _map_all( + min, max, vecs, upscale, shape, device=device + ) + torch.manual_seed(0) + d = TanhNormal( + *vecs, + upscale=upscale, + min=min, + max=max, + ) + for _ in range(100): + a = d.rsample(shape) + assert a.shape[: len(shape)] == shape + assert (a >= d.min).all() + assert (a <= d.max).all() + lp = d.log_prob(a) + assert torch.isfinite(lp).all() + + +@pytest.mark.parametrize( + "min", [-torch.ones(3), -1, 3 * torch.tensor([-1.0, -2.0, -0.5]), -0.1] +) +@pytest.mark.parametrize( + "max", [torch.ones(3), 1, 3 * torch.tensor([1.0, 2.0, 0.5]), 0.1] +) +@pytest.mark.parametrize( + "vecs", + [ + (torch.tensor([0.1, 10.0, 5.0]), torch.tensor([0.1, 10.0, 5.0])), + (torch.zeros(7, 3), torch.ones(7, 3)), + ], +) +@pytest.mark.parametrize( + "upscale", [torch.ones(3), 1, 3 * torch.tensor([1.0, 2.0, 0.5]), 3] +) +@pytest.mark.parametrize("shape", [torch.Size([]), torch.Size([3, 4])]) +@pytest.mark.parametrize("device", get_available_devices()) +def test_truncnormal(min, max, vecs, upscale, shape, device): + torch.manual_seed(0) + min, max, vecs, upscale, shape = _map_all( + min, max, vecs, upscale, shape, device=device + ) + d = TruncatedNormal( + *vecs, + upscale=upscale, + min=min, + max=max, + ) + for _ in range(100): + a = d.rsample(shape) + assert a.shape[: len(shape)] == shape + assert (a >= d.min).all() + assert (a <= d.max).all() + lp = d.log_prob(a) + assert torch.isfinite(lp).all() + + +@pytest.mark.parametrize( + "batch_size", + [ + (3,), + ( + 5, + 7, + ), + ], +) +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize( + "scale_mapping", + [ + "exp", + "biased_softplus_1.0", + "biased_softplus_0.11", + "expln", + "relu", + "softplus", + "raise_error", + ], +) +def test_normal_mapping(batch_size, device, scale_mapping, action_dim=11, state_dim=3): + torch.manual_seed(0) + for _ in range(100): + module = nn.LazyLinear(2 * action_dim).to(device) + module = NormalParamWrapper(module, scale_mapping=scale_mapping).to(device) + if scale_mapping != "raise_error": + loc, scale = module(torch.randn(*batch_size, state_dim, device=device)) + assert (scale > 0).all() + else: + with pytest.raises( + NotImplementedError, match="Unknown mapping " "raise_error" + ): + loc, scale = module(torch.randn(*batch_size, state_dim, device=device)) + + +@pytest.mark.parametrize("shape", [torch.Size([]), torch.Size([3, 4])]) +@pytest.mark.parametrize("device", get_available_devices()) +def test_categorical(shape, device): + torch.manual_seed(0) + for i in range(100): + logits = i * torch.randn(10) + dist = OneHotCategorical(logits=logits) + s = dist.sample(shape) + assert s.shape[: len(shape)] == shape + assert s.shape[-1] == logits.shape[-1] + assert (s.sum(-1) == 1).all() + assert torch.isfinite(dist.log_prob(s)).all() + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_env.py b/test/test_env.py new file mode 100644 index 00000000000..2fab5e40d7b --- /dev/null +++ b/test/test_env.py @@ -0,0 +1,409 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os.path +from collections import defaultdict + +import numpy as np +import pytest +import torch +import yaml +from scipy.stats import chisquare +from torchrl.agents import EnvCreator +from torchrl.data.tensor_specs import ( + OneHotDiscreteTensorSpec, + MultOneHotDiscreteTensorSpec, + BoundedTensorSpec, + NdBoundedTensorSpec, +) +from torchrl.data.tensordict.tensordict import assert_allclose_td, TensorDict +from torchrl.envs import gym, GymEnv +from torchrl.envs.transforms import ( + TransformedEnv, + Compose, + ToTensorImage, + RewardClipping, +) +from torchrl.envs.utils import step_tensor_dict +from torchrl.envs.vec_env import ParallelEnv, SerialEnv + +try: + this_dir = os.path.dirname(os.path.realpath(__file__)) + with open(os.path.join(this_dir, "configs", "atari.yaml"), "r") as file: + atari_confs = yaml.load(file, Loader=yaml.FullLoader) + _atari_found = True +except FileNotFoundError: + _atari_found = False + atari_confs = defaultdict(lambda: "") + + +## TO BE FIXED: DiscreteActionProjection queries a randint on each worker, which leads to divergent results between +## the serial and parallel batched envs +# def _make_atari_env(atari_env): +# action_spec = GymEnv(atari_env + "-ram-v0").action_spec +# n_act = action_spec.shape[-1] +# return lambda **kwargs: TransformedEnv( +# GymEnv(atari_env + "-ram-v0", **kwargs), +# DiscreteActionProjection(max_N=18, M=n_act), +# ) +# +# +# @pytest.mark.skipif( +# "ALE/Pong-v5" not in _get_gym_envs(), reason="no Atari OpenAI Gym env available" +# ) +# def test_composite_env(): +# num_workers = 10 +# frameskip = 2 +# create_env_fn = [ +# _make_atari_env(atari_env) +# for atari_env in atari_confs["atari_envs"][:num_workers] +# ] +# kwargs = {"frame_skip": frameskip} +# +# random_policy = lambda td: td.set( +# "action", torch.nn.functional.one_hot(torch.randint(18, (*td.batch_size,)), 18) +# ) +# p = SerialEnv(num_workers, create_env_fn, create_env_kwargs=kwargs) +# seed = p.set_seed(0) +# p.reset() +# torch.manual_seed(seed) +# rollout1 = p.rollout(n_steps=100, policy=random_policy, auto_reset=False) +# p.close() +# del p +# +# p = ParallelEnv(num_workers, create_env_fn, create_env_kwargs=kwargs) +# seed = p.set_seed(0) +# p.reset() +# torch.manual_seed(seed) +# rollout0 = p.rollout(n_steps=100, policy=random_policy, auto_reset=False) +# p.close() +# del p +# +# assert_allclose_td(rollout1, rollout0) + + +@pytest.mark.parametrize("env_name", ["Pendulum-v1", "CartPole-v1"]) +@pytest.mark.parametrize("frame_skip", [1, 4]) +def test_env_seed(env_name, frame_skip, seed=0): + env = gym.GymEnv(env_name, frame_skip=frame_skip) + action = env.action_spec.rand() + + env.set_seed(seed) + td0a = env.reset() + td1a = env.step(td0a.clone().set("action", action)) + + env.set_seed(seed) + td0b = env.specs.build_tensor_dict() + td0b = env.reset(tensor_dict=td0b) + td1b = env.step(td0b.clone().set("action", action)) + + assert_allclose_td(td0a, td0b.select(*td0a.keys())) + assert_allclose_td(td1a, td1b) + + env.set_seed( + seed=seed + 10, + ) + td0c = env.reset() + td1c = env.step(td0c.clone().set("action", action)) + + with pytest.raises(AssertionError): + assert_allclose_td(td0a, td0c.select(*td0a.keys())) + with pytest.raises(AssertionError): + assert_allclose_td(td1a, td1c) + + +@pytest.mark.parametrize("env_name", ["Pendulum-v1", "ALE/Pong-v5"]) +@pytest.mark.parametrize("frame_skip", [1, 4]) +def test_rollout(env_name, frame_skip, seed=0): + env = gym.GymEnv(env_name, frame_skip=frame_skip) + + torch.manual_seed(seed) + np.random.seed(seed) + env.set_seed(seed) + env.reset() + rollout1 = env.rollout(n_steps=100) + + torch.manual_seed(seed) + np.random.seed(seed) + env.set_seed(seed) + env.reset() + rollout2 = env.rollout(n_steps=100) + + assert_allclose_td(rollout1, rollout2) + + torch.manual_seed(seed) + env.set_seed(seed + 10) + env.reset() + rollout3 = env.rollout(n_steps=100) + with pytest.raises(AssertionError): + assert_allclose_td(rollout1, rollout3) + + +def _make_envs(env_name, frame_skip, transformed, N): + torch.manual_seed(0) + if not transformed: + create_env_fn = lambda: GymEnv(env_name, frame_skip=frame_skip) + else: + if env_name == "ALE/Pong-v5": + create_env_fn = lambda: TransformedEnv( + GymEnv(env_name, frame_skip=frame_skip), + Compose(*[ToTensorImage(), RewardClipping(0, 0.1)]), + ) + else: + create_env_fn = lambda: TransformedEnv( + GymEnv(env_name, frame_skip=frame_skip), + Compose(*[RewardClipping(0, 0.1)]), + ) + env0 = create_env_fn() + env_parallel = ParallelEnv(N, create_env_fn) + env_serial = SerialEnv(N, create_env_fn) + return env_parallel, env_serial, env0 + + +@pytest.mark.parametrize("env_name", ["ALE/Pong-v5", "Pendulum-v1"]) +@pytest.mark.parametrize("frame_skip", [4, 1]) +@pytest.mark.parametrize("transformed", [True, False]) +def test_parallel_env(env_name, frame_skip, transformed, T=10, N=5): + env_parallel, env_serial, env0 = _make_envs(env_name, frame_skip, transformed, N) + + td = TensorDict( + source={"action": env0.action_spec.rand((N,))}, + batch_size=[ + N, + ], + ) + td1 = env_parallel.step(td) + assert not td1.is_shared() + assert "done" in td1.keys() + assert "reward" in td1.keys() + + with pytest.raises(RuntimeError): + # number of actions does not match number of workers + td = TensorDict( + source={"action": env0.action_spec.rand((N - 1,))}, batch_size=[N - 1] + ) + td1 = env_parallel.step(td) + + td_reset = TensorDict( + source={"reset_workers": torch.zeros(N, 1, dtype=torch.bool).bernoulli_()}, + batch_size=[ + N, + ], + ) + env_parallel.reset(tensor_dict=td_reset) + + td = env_parallel.rollout(policy=None, n_steps=T) + assert ( + td.shape == torch.Size([N, T]) or td.get("done").sum(1).all() + ), f"{td.shape}, {td.get('done').sum(1)}" + + +@pytest.mark.parametrize("env_name", ["ALE/Pong-v5", "Pendulum-v1"]) +@pytest.mark.parametrize("frame_skip", [4, 1]) +@pytest.mark.parametrize( + "transformed", + [ + False, + True, + ], +) +def test_parallel_env_seed(env_name, frame_skip, transformed): + env_parallel, env_serial, env0 = _make_envs(env_name, frame_skip, transformed, 5) + + out_seed_serial = env_serial.set_seed(0) + env_serial.reset() + td0_serial = env_serial.current_tensordict + torch.manual_seed(0) + + td_serial = env_serial.rollout(n_steps=10, auto_reset=False).contiguous() + key = "observation_pixels" if "observation_pixels" in td_serial else "observation" + torch.testing.assert_allclose( + td_serial[:, 0].get("next_" + key), td_serial[:, 1].get(key) + ) + + out_seed_parallel = env_parallel.set_seed(0) + env_parallel.reset() + td0_parallel = env_parallel.current_tensordict + + torch.manual_seed(0) + assert out_seed_parallel == out_seed_serial + td_parallel = env_parallel.rollout(n_steps=10, auto_reset=False).contiguous() + torch.testing.assert_allclose( + td_parallel[:, 0].get("next_" + key), td_parallel[:, 1].get(key) + ) + + assert_allclose_td(td0_serial, td0_parallel) + assert_allclose_td(td_serial[:, 0], td_parallel[:, 0]) # first step + assert_allclose_td(td_serial[:, 1], td_parallel[:, 1]) # second step + assert_allclose_td(td_serial, td_parallel) + + +def test_parallel_env_shutdown(): + env_make = EnvCreator(lambda: GymEnv("Pendulum-v1")) + env = ParallelEnv(4, env_make) + env.reset() + assert not env.is_closed + env.rand_step() + assert not env.is_closed + env.close() + assert env.is_closed + env.reset() + assert not env.is_closed + env.shutdown() + assert env.is_closed + + +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="no cuda device detected") +@pytest.mark.parametrize("env_name", ["ALE/Pong-v5", "Pendulum-v1"]) +@pytest.mark.parametrize("frame_skip", [4, 1]) +@pytest.mark.parametrize("transformed", [True, False]) +@pytest.mark.parametrize("device", [0, "cuda:0"]) +def test_parallel_env_device(env_name, frame_skip, transformed, device): + torch.manual_seed(0) + N = 5 + if not transformed: + create_env_fn = lambda: GymEnv("ALE/Pong-v5", frame_skip=frame_skip) + else: + if env_name == "ALE/Pong-v5": + create_env_fn = lambda: TransformedEnv( + GymEnv(env_name, frame_skip=frame_skip), + Compose(*[ToTensorImage(), RewardClipping(0, 0.1)]), + ) + else: + create_env_fn = lambda: TransformedEnv( + GymEnv(env_name, frame_skip=frame_skip), + Compose(*[RewardClipping(0, 0.1)]), + ) + env_parallel = ParallelEnv(N, create_env_fn, device=device) + out = env_parallel.rollout(n_steps=20) + + +class TestSpec: + def test_discrete_action_spec_reconstruct(self): + torch.manual_seed(0) + action_spec = OneHotDiscreteTensorSpec(10) + + actions_tensors = [action_spec.rand() for _ in range(10)] + actions_numpy = [action_spec.to_numpy(a) for a in actions_tensors] + actions_tensors_2 = [action_spec.encode(a) for a in actions_numpy] + assert all( + [(a1 == a2).all() for a1, a2 in zip(actions_tensors, actions_tensors_2)] + ) + + actions_numpy = [int(np.random.randint(0, 10, (1,))) for a in actions_tensors] + actions_tensors = [action_spec.encode(a) for a in actions_numpy] + actions_numpy_2 = [action_spec.to_numpy(a) for a in actions_tensors] + assert all([(a1 == a2) for a1, a2 in zip(actions_numpy, actions_numpy_2)]) + + def test_mult_discrete_action_spec_reconstruct(self): + torch.manual_seed(0) + action_spec = MultOneHotDiscreteTensorSpec((10, 5)) + + actions_tensors = [action_spec.rand() for _ in range(10)] + actions_numpy = [action_spec.to_numpy(a) for a in actions_tensors] + actions_tensors_2 = [action_spec.encode(a) for a in actions_numpy] + assert all( + [(a1 == a2).all() for a1, a2 in zip(actions_tensors, actions_tensors_2)] + ) + + actions_numpy = [ + np.concatenate( + [np.random.randint(0, 10, (1,)), np.random.randint(0, 5, (1,))], 0 + ) + for a in actions_tensors + ] + actions_tensors = [action_spec.encode(a) for a in actions_numpy] + actions_numpy_2 = [action_spec.to_numpy(a) for a in actions_tensors] + assert all([(a1 == a2).all() for a1, a2 in zip(actions_numpy, actions_numpy_2)]) + + def test_discrete_action_spec_rand(self): + torch.manual_seed(0) + action_spec = OneHotDiscreteTensorSpec(10) + + sample = torch.stack([action_spec.rand() for _ in range(10000)], 0) + + sample_list = sample.argmax(-1) + sample_list = list([sum(sample_list == i).item() for i in range(10)]) + assert chisquare(sample_list).pvalue > 0.1 + + sample = action_spec.to_numpy(sample) + sample = [sum(sample == i) for i in range(10)] + assert chisquare(sample).pvalue > 0.1 + + def test_mult_discrete_action_spec_rand(self): + torch.manual_seed(0) + ns = (10, 5) + N = 100000 + action_spec = MultOneHotDiscreteTensorSpec((10, 5)) + + actions_tensors = [action_spec.rand() for _ in range(10)] + actions_numpy = [action_spec.to_numpy(a) for a in actions_tensors] + actions_tensors_2 = [action_spec.encode(a) for a in actions_numpy] + assert all( + [(a1 == a2).all() for a1, a2 in zip(actions_tensors, actions_tensors_2)] + ) + + sample = np.stack( + [action_spec.to_numpy(action_spec.rand()) for _ in range(N)], 0 + ) + assert sample.shape[0] == N + assert sample.shape[1] == 2 + assert sample.ndim == 2, f"found shape: {sample.shape}" + + sample0 = sample[:, 0] + sample_list = list([sum(sample0 == i) for i in range(ns[0])]) + assert chisquare(sample_list).pvalue > 0.1 + + sample1 = sample[:, 1] + sample_list = list([sum(sample1 == i) for i in range(ns[1])]) + assert chisquare(sample_list).pvalue > 0.1 + + def test_bounded_rand(self): + spec = BoundedTensorSpec(-3, 3) + sample = torch.stack([spec.rand() for _ in range(100)]) + assert (-3 <= sample).all() and (3 >= sample).all() + + def test_ndbounded_shape(self): + spec = NdBoundedTensorSpec(-3, 3 * torch.ones(10, 5), shape=[10, 5]) + sample = torch.stack([spec.rand() for _ in range(100)], 0) + assert (-3 <= sample).all() and (3 >= sample).all() + assert sample.shape == torch.Size([100, 10, 5]) + + +def test_seed(): + torch.manual_seed(0) + env1 = GymEnv("Pendulum-v1") + env1.set_seed(0) + state0_1 = env1.reset() + state1_1 = env1.step(state0_1.set("action", env1.action_spec.rand())) + + torch.manual_seed(0) + env2 = GymEnv("Pendulum-v1") + env2.set_seed(0) + state0_2 = env2.reset() + state1_2 = env2.step(state0_2.set("action", env2.action_spec.rand())) + + assert_allclose_td(state0_1, state0_2) + assert_allclose_td(state1_1, state1_2) + + +def test_current_tensordict(): + torch.manual_seed(0) + env = GymEnv("Pendulum-v1") + env.set_seed(0) + tensor_dict = env.reset() + assert_allclose_td(tensor_dict, env.current_tensordict) + tensor_dict = env.step( + TensorDict(source={"action": env.action_spec.rand()}, batch_size=[]) + ) + assert_allclose_td(step_tensor_dict(tensor_dict), env.current_tensordict) + + +# TODO: test for frame-skip + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_exploration.py b/test/test_exploration.py new file mode 100644 index 00000000000..0decd7add48 --- /dev/null +++ b/test/test_exploration.py @@ -0,0 +1,143 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse + +import pytest +import torch +from _utils_internal import get_available_devices +from scipy.stats import ttest_1samp +from torch import nn +from torchrl.data import NdBoundedTensorSpec +from torchrl.data.tensordict.tensordict import TensorDict +from torchrl.envs.transforms.transforms import gSDENoise +from torchrl.modules import ProbabilisticActor +from torchrl.modules.distributions import TanhNormal +from torchrl.modules.distributions.continuous import ( + IndependentNormal, + NormalParamWrapper, +) +from torchrl.modules.models.exploration import gSDEWrapper +from torchrl.modules.td_module.exploration import ( + _OrnsteinUhlenbeckProcess, + OrnsteinUhlenbeckProcessWrapper, +) + + +@pytest.mark.parametrize("device", get_available_devices()) +def test_ou(device, seed=0): + torch.manual_seed(seed) + td = TensorDict({"action": torch.randn(3, device=device) / 10}, batch_size=[]) + ou = _OrnsteinUhlenbeckProcess(10.0, mu=2.0, x0=-4, sigma=0.1, sigma_min=0.01) + + tds = [] + for i in range(2000): + td = ou.add_sample(td) + tds.append(td.clone()) + td.set_("action", torch.randn(3) / 10) + if i % 1000 == 0: + td.zero_() + + tds = torch.stack(tds, 0) + + tset, pval_acc = ttest_1samp(tds.get("action")[950:1000, 0].cpu().numpy(), 2.0) + tset, pval_reg = ttest_1samp(tds.get("action")[:50, 0].cpu().numpy(), 2.0) + assert pval_acc > 0.05 + assert pval_reg < 0.1 + + tset, pval_acc = ttest_1samp(tds.get("action")[1950:2000, 0].cpu().numpy(), 2.0) + tset, pval_reg = ttest_1samp(tds.get("action")[1000:1050, 0].cpu().numpy(), 2.0) + assert pval_acc > 0.05 + assert pval_reg < 0.1 + + +@pytest.mark.parametrize("device", get_available_devices()) +def test_ou_wrapper(device, d_obs=4, d_act=6, batch=32, n_steps=100, seed=0): + torch.manual_seed(seed) + module = NormalParamWrapper(nn.Linear(d_obs, 2 * d_act)).to(device) + action_spec = NdBoundedTensorSpec(-torch.ones(d_act), torch.ones(d_act), (d_act,)) + policy = ProbabilisticActor( + spec=action_spec, + module=module, + distribution_class=TanhNormal, + default_interaction_mode="random", + ).to(device) + exploratory_policy = OrnsteinUhlenbeckProcessWrapper(policy) + + tensor_dict = TensorDict( + batch_size=[batch], + source={"observation": torch.randn(batch, d_obs, device=device)}, + device=device, + ) + out_noexp = [] + out = [] + for i in range(n_steps): + tensor_dict_noexp = policy(tensor_dict.select("observation")) + tensor_dict = exploratory_policy(tensor_dict) + out.append(tensor_dict.clone()) + out_noexp.append(tensor_dict_noexp.clone()) + tensor_dict.set_("observation", torch.randn(batch, d_obs, device=device)) + out = torch.stack(out, 0) + out_noexp = torch.stack(out_noexp, 0) + assert (out_noexp.get("action") != out.get("action")).all() + assert (out.get("action") <= 1.0).all(), out.get("action").min() + assert (out.get("action") >= -1.0).all(), out.get("action").max() + + +@pytest.mark.parametrize("state_dim", [7]) +@pytest.mark.parametrize("action_dim", [5, 11]) +@pytest.mark.parametrize("gSDE", [True, False]) +@pytest.mark.parametrize("safe", [True, False]) +@pytest.mark.parametrize("device", get_available_devices()) +def test_gsde(state_dim, action_dim, gSDE, device, safe, batch=16, bound=0.1): + torch.manual_seed(0) + if gSDE: + model = torch.nn.LazyLinear(action_dim) + wrapper = gSDEWrapper(model, action_dim, state_dim).to(device) + exploration_mode = "net_output" + distribution_class = IndependentNormal + distribution_kwargs = {} + in_keys = ["observation", "_eps_gSDE"] + else: + model = torch.nn.LazyLinear(action_dim * 2) + wrapper = NormalParamWrapper(model).to(device) + exploration_mode = "random" + distribution_class = TanhNormal + distribution_kwargs = {"min": -bound, "max": bound} + in_keys = ["observation"] + spec = NdBoundedTensorSpec( + -torch.ones(action_dim) * bound, torch.ones(action_dim) * bound, (action_dim,) + ).to(device) + actor = ProbabilisticActor( + module=wrapper, + spec=spec, + in_keys=in_keys, + distribution_class=distribution_class, + distribution_kwargs=distribution_kwargs, + default_interaction_mode=exploration_mode, + safe=safe, + ) + + td = TensorDict( + {"observation": torch.randn(batch, state_dim, device=device)}, + [ + batch, + ], + ) + if gSDE: + gSDENoise(action_dim, state_dim).reset(td) + assert "_eps_gSDE" in td.keys() + assert td.get("_eps_gSDE").device == device + actor(td) + assert "action" in td.keys() + if not safe and gSDE: + assert not spec.is_in(td.get("action")) + elif safe and gSDE: + assert spec.is_in(td.get("action")) + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_gym.py b/test/test_gym.py new file mode 100644 index 00000000000..f7fe3e9715e --- /dev/null +++ b/test/test_gym.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +from collections import defaultdict + +import pytest +import yaml +from torchrl.envs import GymEnv +from torchrl.envs.libs.gym import _has_gym + +try: + this_dir = os.path.dirname(os.path.realpath(__file__)) + with open(os.path.join(this_dir, "configs", "atari.yaml"), "r") as file: + atari_confs = yaml.load(file, Loader=yaml.FullLoader) + _atari_found = True +except FileNotFoundError: + _atari_found = False + atari_confs = defaultdict(lambda: "") + + +@pytest.mark.skipif(not _atari_found, reason="no _atari_found found") +@pytest.mark.skipif(not _has_gym, reason="no gym library found") +@pytest.mark.parametrize("env_name", atari_confs["atari_envs"]) +@pytest.mark.parametrize("env_suffix", atari_confs["version"]) +@pytest.mark.parametrize("frame_skip", [1, 2, 3, 4]) +def test_atari(env_name, env_suffix, frame_skip): + env = GymEnv("-".join([env_name, env_suffix]), frame_skip=frame_skip) + env.rollout(n_steps=50) + + +# TODO: check gym envs in a smart, efficient way +# @pytest.mark.skipif(not _has_gym, reason="no gym library found") +# @pytest.mark.parametrize("env_name", _get_envs_gym()) +# @pytest.mark.parametrize("from_pixels", [False, True]) +# def test_gym(env_name, from_pixels): +# print(f"testing {env_name} with from_pixels={from_pixels}") +# torch.manual_seed(0) +# env = GymEnv(env_name, frame_skip=4, from_pixels=from_pixels) +# env.set_seed(0) +# td1 = env.rollout(n_steps=10, auto_reset=True) +# tdb = env.rollout(n_steps=10, auto_reset=True) +# if not tdb.get("done").sum(): +# tdc = env.rollout(n_steps=10, auto_reset=False) +# torch.manual_seed(0) +# env = GymEnv(env_name, frame_skip=4, from_pixels=from_pixels) +# env.set_seed(0) +# td2 = env.rollout(n_steps=10, auto_reset=True) +# assert_allclose_td(td1, td2) + + +if __name__ == "__main__": + pytest.main([__file__, "--capture", "no"]) diff --git a/test/test_helpers.py b/test/test_helpers.py new file mode 100644 index 00000000000..a078658026c --- /dev/null +++ b/test/test_helpers.py @@ -0,0 +1,329 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse + +import pytest +import torch +from _utils_internal import get_available_devices +from mocking_classes import ( + ContinuousActionConvMockEnvNumpy, + ContinuousActionVecMockEnv, + DiscreteActionVecMockEnv, + DiscreteActionConvMockEnvNumpy, +) +from torchrl.agents.helpers import parser_env_args, transformed_env_constructor +from torchrl.agents.helpers.models import ( + make_dqn_actor, + parser_model_args_discrete, + parser_model_args_continuous, + make_ddpg_actor, + make_ppo_model, + make_sac_model, + make_redq_model, +) +from torchrl.envs.libs.gym import _has_gym + + +## these tests aren't truly unitary but setting up a fake env for the +# purpose of building a model with args is a lot of unstable scaffoldings +# with unclear benefits + + +def _assert_keys_match(td, expeceted_keys): + td_keys = list(td.keys()) + d = set(td_keys) - set(expeceted_keys) + assert len(d) == 0, f"{d} is in tensordict but unexpected" + d = set(expeceted_keys) - set(td_keys) + assert len(d) == 0, f"{d} is expecter but not in tensordict" + assert len(td_keys) == len(expeceted_keys) + + +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("noisy", [tuple(), ("--noisy",)]) +@pytest.mark.parametrize("distributional", [tuple(), ("--distributional",)]) +@pytest.mark.parametrize("from_pixels", [tuple(), ("--from_pixels",)]) +def test_dqn_maker(device, noisy, distributional, from_pixels): + flags = list(noisy + distributional + from_pixels) + ["--env_name=CartPole-v1"] + parser = argparse.ArgumentParser() + parser = parser_env_args(parser) + parser = parser_model_args_discrete(parser) + args = parser.parse_args(flags) + + env_maker = ( + DiscreteActionConvMockEnvNumpy if from_pixels else DiscreteActionVecMockEnv + ) + env_maker = transformed_env_constructor( + args, use_env_creator=False, custom_env_maker=env_maker + ) + proof_environment = env_maker() + + actor = make_dqn_actor(proof_environment, args, device) + td = proof_environment.reset().to(device) + actor(td) + + expected_keys = ["done", "action", "action_value"] + if from_pixels: + expected_keys += ["observation_pixels"] + else: + expected_keys += ["observation_vector"] + + if not distributional: + expected_keys += ["chosen_action_value"] + try: + _assert_keys_match(td, expected_keys) + except AssertionError: + proof_environment.close() + raise + proof_environment.close() + + +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("from_pixels", [tuple(), ("--from_pixels",)]) +def test_ddpg_maker(device, from_pixels): + device = torch.device("cpu") + flags = list(from_pixels) + parser = argparse.ArgumentParser() + parser = parser_env_args(parser) + parser = parser_model_args_continuous(parser, algorithm="DDPG") + args = parser.parse_args(flags) + + env_maker = ( + ContinuousActionConvMockEnvNumpy if from_pixels else ContinuousActionVecMockEnv + ) + env_maker = transformed_env_constructor( + args, use_env_creator=False, custom_env_maker=env_maker + ) + proof_environment = env_maker() + actor, value = make_ddpg_actor(proof_environment, device=device, args=args) + td = proof_environment.reset().to(device) + actor(td) + expected_keys = ["done", "action"] + if from_pixels: + expected_keys += ["observation_pixels"] + else: + expected_keys += ["observation_vector"] + + try: + _assert_keys_match(td, expected_keys) + except AssertionError: + proof_environment.close() + raise + + value(td) + expected_keys += ["state_action_value"] + try: + _assert_keys_match(td, expected_keys) + except AssertionError: + proof_environment.close() + raise + + proof_environment.close() + del proof_environment + + +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("from_pixels", [tuple(), ("--from_pixels",)]) +@pytest.mark.parametrize("gsde", [tuple(), ("--gSDE",)]) +@pytest.mark.parametrize("shared_mapping", [tuple(), ("--shared_mapping",)]) +def test_ppo_maker(device, from_pixels, shared_mapping, gsde): + flags = list(from_pixels + shared_mapping + gsde) + if gsde and from_pixels: + pytest.skip("gsde and from_pixels are incompatible") + parser = argparse.ArgumentParser() + parser = parser_env_args(parser) + parser = parser_model_args_continuous(parser, algorithm="PPO") + args = parser.parse_args(flags) + + env_maker = ( + ContinuousActionConvMockEnvNumpy if from_pixels else ContinuousActionVecMockEnv + ) + env_maker = transformed_env_constructor( + args, use_env_creator=False, custom_env_maker=env_maker + ) + proof_environment = env_maker() + + actor_value = make_ppo_model( + proof_environment, + device=device, + args=args, + ) + actor = actor_value.get_policy_operator() + expected_keys = [ + "done", + "observation_pixels" if len(from_pixels) else "observation_vector", + "action_dist_param_0", + "action_dist_param_1", + "action", + "action_log_prob", + ] + if shared_mapping: + expected_keys += ["hidden"] + if len(gsde): + expected_keys += ["_eps_gSDE", "_action_duplicate", "action_dist_param_2"] + td = proof_environment.reset().to(device) + td_clone = td.clone() + actor(td_clone) + try: + _assert_keys_match(td_clone, expected_keys) + except AssertionError: + proof_environment.close() + raise + + value = actor_value.get_value_operator() + expected_keys = [ + "done", + "observation_pixels" if len(from_pixels) else "observation_vector", + "state_value", + ] + if shared_mapping: + expected_keys += ["hidden"] + if len(gsde): + expected_keys += ["_eps_gSDE"] + + td_clone = td.clone() + value(td_clone) + try: + _assert_keys_match(td_clone, expected_keys) + except AssertionError: + proof_environment.close() + raise + proof_environment.close() + del proof_environment + + +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("gsde", [tuple(), ("--gSDE",)]) +@pytest.mark.parametrize("from_pixels", [tuple()]) +@pytest.mark.parametrize("tanh_loc", [tuple(), ("--tanh_loc",)]) +@pytest.mark.skipif(not _has_gym, reason="No gym library found") +def test_sac_make(device, gsde, tanh_loc, from_pixels): + flags = list(gsde + tanh_loc + from_pixels) + if gsde and from_pixels: + pytest.skip("gsde and from_pixels are incompatible") + + parser = argparse.ArgumentParser() + parser = parser_env_args(parser) + parser = parser_model_args_continuous(parser, algorithm="SAC") + args = parser.parse_args(flags) + + env_maker = ( + ContinuousActionConvMockEnvNumpy if from_pixels else ContinuousActionVecMockEnv + ) + env_maker = transformed_env_constructor( + args, use_env_creator=False, custom_env_maker=env_maker + ) + proof_environment = env_maker() + + model = make_sac_model( + proof_environment, + device=device, + args=args, + ) + + actor, qvalue, value = model + td = proof_environment.reset().to(device) + td_clone = td.clone() + actor(td_clone) + expected_keys = [ + "done", + "observation_pixels" if len(from_pixels) else "observation_vector", + "action", + ] + if len(gsde): + expected_keys += ["_eps_gSDE"] + + try: + _assert_keys_match(td_clone, expected_keys) + except AssertionError: + proof_environment.close() + raise + + qvalue(td_clone) + expected_keys = ["done", "observation_vector", "action", "state_action_value"] + if len(gsde): + expected_keys += ["_eps_gSDE"] + + try: + _assert_keys_match(td_clone, expected_keys) + except AssertionError: + proof_environment.close() + raise + + value(td) + expected_keys = ["done", "observation_vector", "state_value"] + if len(gsde): + expected_keys += ["_eps_gSDE"] + + try: + _assert_keys_match(td, expected_keys) + except AssertionError: + proof_environment.close() + raise + proof_environment.close() + del proof_environment + + +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("from_pixels", [tuple()]) +@pytest.mark.parametrize("gsde", [tuple(), ("--gSDE",)]) +@pytest.mark.skipif(not _has_gym, reason="No gym library found") +def test_redq_make(device, from_pixels, gsde): + flags = list(from_pixels + gsde) + if gsde and from_pixels: + pytest.skip("gsde and from_pixels are incompatible") + + parser = argparse.ArgumentParser() + parser = parser_env_args(parser) + parser = parser_model_args_continuous(parser, algorithm="REDQ") + args = parser.parse_args(flags) + + env_maker = ( + ContinuousActionConvMockEnvNumpy if from_pixels else ContinuousActionVecMockEnv + ) + env_maker = transformed_env_constructor( + args, use_env_creator=False, custom_env_maker=env_maker + ) + proof_environment = env_maker() + + model = make_redq_model( + proof_environment, + device=device, + args=args, + ) + actor, qvalue = model + td = proof_environment.reset().to(device) + actor(td) + expected_keys = ["done", "observation_vector", "action", "action_log_prob"] + if len(gsde): + expected_keys += ["_eps_gSDE"] + try: + _assert_keys_match(td, expected_keys) + except AssertionError: + proof_environment.close() + raise + + qvalue(td) + expected_keys = [ + "done", + "observation_vector", + "action", + "action_log_prob", + "state_action_value", + ] + if len(gsde): + expected_keys += ["_eps_gSDE"] + try: + _assert_keys_match(td, expected_keys) + except AssertionError: + proof_environment.close() + raise + proof_environment.close() + del proof_environment + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_memmap.py b/test/test_memmap.py new file mode 100644 index 00000000000..f919ddb5932 --- /dev/null +++ b/test/test_memmap.py @@ -0,0 +1,140 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os.path +import pickle +import tempfile + +import numpy as np +import pytest +import torch +from torchrl.data.tensordict.memmap import MemmapTensor + + +def test_memmap_type(): + array = np.random.rand(1) + with pytest.raises( + TypeError, match="convert input to torch.Tensor before calling MemmapTensor" + ): + MemmapTensor(array) + + +def test_grad(): + t = torch.tensor([1.0]) + MemmapTensor(t) + t = t.requires_grad_() + with pytest.raises( + RuntimeError, match="MemmapTensor is incompatible with tensor.requires_grad" + ): + MemmapTensor(t) + with pytest.raises( + RuntimeError, match="MemmapTensor is incompatible with tensor.requires_grad" + ): + MemmapTensor(t + 1) + + +@pytest.mark.parametrize("dtype", [torch.float, torch.int, torch.double, torch.bool]) +@pytest.mark.parametrize( + "shape", + [ + [ + 2, + ], + [1, 2], + ], +) +def test_memmap_metadata(dtype, shape): + t = torch.tensor([1, 0]).reshape(shape) + m = MemmapTensor(t) + assert m.dtype == t.dtype + assert (m == t).all() + assert m.shape == t.shape + + assert m.contiguous().dtype == t.dtype + assert (m.contiguous() == t).all() + assert m.contiguous().shape == t.shape + + assert m.clone().dtype == t.dtype + assert (m.clone() == t).all() + assert m.clone().shape == t.shape + + +def test_memmap_del(): + t = torch.tensor([1]) + m = MemmapTensor(t) + filename = m.filename + assert os.path.isfile(filename) + del m + with pytest.raises(AssertionError): + assert os.path.isfile(filename) + + +@pytest.mark.parametrize("value", [True, False]) +def test_memmap_ownership(value): + t = torch.tensor([1]) + m = MemmapTensor(t, transfer_ownership=value) + assert m.file.delete + with tempfile.NamedTemporaryFile(suffix=".pkl") as tmp: + pickle.dump(m, tmp) + m2 = pickle.load(open(tmp.name, "rb")) + assert m2._memmap_array is None # assert data is not actually loaded + assert isinstance(m2, MemmapTensor) + assert m2.filename == m.filename + assert m2.file.name == m2.filename + assert m2.file._closer.name == m2.filename + assert ( + m.file.delete is not m2.file.delete + ) # delete attributes must have changed + assert ( + m.file._closer.delete is not m2.file._closer.delete + ) # delete attributes must have changed + del m + if value: + assert os.path.isfile(m2.filename) + else: + # m2 should point to a non-existing file + assert not os.path.isfile(m2.filename) + with pytest.raises(FileNotFoundError): + m2.contiguous() + + +@pytest.mark.parametrize("value", [True, False]) +def test_memmap_ownership_2pass(value): + t = torch.tensor([1]) + m1 = MemmapTensor(t, transfer_ownership=value) + with tempfile.NamedTemporaryFile(suffix=".pkl") as tmp2: + pickle.dump(m1, tmp2) + m2 = pickle.load(open(tmp2.name, "rb")) + with tempfile.NamedTemporaryFile(suffix=".pkl") as tmp3: + pickle.dump(m2, tmp3) + m3 = pickle.load(open(tmp3.name, "rb")) + assert m1._has_ownership + m2._has_ownership + m3._has_ownership == 1 + + del m1, m2, m3 + m1 = MemmapTensor(t, transfer_ownership=value) + with tempfile.NamedTemporaryFile(suffix=".pkl") as tmp2: + pickle.dump(m1, tmp2) + m2 = pickle.load(open(tmp2.name, "rb")) + with tempfile.NamedTemporaryFile(suffix=".pkl") as tmp3: + pickle.dump(m1, tmp3) + m3 = pickle.load(open(tmp3.name, "rb")) + assert m1._has_ownership + m2._has_ownership + m3._has_ownership == 1 + + +def test_memmap_clone(): + t = torch.tensor([1]) + m1 = MemmapTensor(t) + m2 = m1.clone() + assert isinstance(m2, MemmapTensor) + assert m2.filename != m1.filename + assert m2.filename == m2.file.name + assert m2.filename == m2.file._closer.name + m2c = m2.contiguous() + assert isinstance(m2c, torch.Tensor) + assert m2c == m1 + + +if __name__ == "__main__": + pytest.main([__file__, "--capture", "no"]) diff --git a/test/test_modules.py b/test/test_modules.py new file mode 100644 index 00000000000..791654e1bd3 --- /dev/null +++ b/test/test_modules.py @@ -0,0 +1,179 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from numbers import Number + +import pytest +import torch +from _utils_internal import get_available_devices +from torch import nn +from torchrl.data import TensorDict +from torchrl.data.tensor_specs import OneHotDiscreteTensorSpec +from torchrl.modules import ( + QValueActor, + ActorValueOperator, + TDModule, + ValueOperator, + ProbabilisticActor, +) +from torchrl.modules.models import NoisyLinear, MLP, NoisyLazyLinear + + +@pytest.mark.parametrize("in_features", [3, 10, None]) +@pytest.mark.parametrize("out_features", [3, (3, 10)]) +@pytest.mark.parametrize("depth, num_cells", [(3, 32), (None, (32, 32, 32))]) +@pytest.mark.parametrize("activation_kwargs", [{"inplace": True}, {}]) +@pytest.mark.parametrize( + "norm_class, norm_kwargs", + [(nn.LazyBatchNorm1d, {}), (nn.BatchNorm1d, {"num_features": 32})], +) +@pytest.mark.parametrize("bias_last_layer", [True, False]) +@pytest.mark.parametrize("single_bias_last_layer", [True, False]) +@pytest.mark.parametrize("layer_class", [nn.Linear, NoisyLinear]) +@pytest.mark.parametrize("device", get_available_devices()) +def test_mlp( + in_features, + out_features, + depth, + num_cells, + activation_kwargs, + bias_last_layer, + norm_class, + norm_kwargs, + single_bias_last_layer, + layer_class, + device, + seed=0, +): + torch.manual_seed(seed) + batch = 2 + mlp = MLP( + in_features=in_features, + out_features=out_features, + depth=depth, + num_cells=num_cells, + activation_class=nn.ReLU, + activation_kwargs=activation_kwargs, + norm_class=norm_class, + norm_kwargs=norm_kwargs, + bias_last_layer=bias_last_layer, + single_bias_last_layer=False, + layer_class=layer_class, + ).to(device) + if in_features is None: + in_features = 5 + x = torch.randn(batch, in_features, device=device) + y = mlp(x) + out_features = [out_features] if isinstance(out_features, Number) else out_features + assert y.shape == torch.Size([batch, *out_features]) + + +@pytest.mark.parametrize( + "layer_class", + [ + NoisyLinear, + NoisyLazyLinear, + ], +) +@pytest.mark.parametrize("device", get_available_devices()) +def test_noisy(layer_class, device, seed=0): + torch.manual_seed(seed) + layer = layer_class(3, 4).to(device) + x = torch.randn(10, 3, device=device) + y1 = layer(x) + layer.reset_noise() + y2 = layer(x) + y3 = layer(x) + torch.testing.assert_allclose(y2, y3) + with pytest.raises(AssertionError): + torch.testing.assert_allclose(y1, y2) + + +@pytest.mark.parametrize("device", get_available_devices()) +def test_value_based_policy(device): + torch.manual_seed(0) + obs_dim = 4 + action_dim = 5 + action_spec = OneHotDiscreteTensorSpec(action_dim) + + def make_net(): + net = MLP(in_features=obs_dim, out_features=action_dim, depth=2) + for mod in net.modules(): + if hasattr(mod, "bias") and mod.bias is not None: + mod.bias.data.zero_() + return net + + actor = QValueActor(spec=action_spec, module=make_net(), safe=True).to(device) + obs = torch.zeros(2, obs_dim, device=device) + td = TensorDict(batch_size=[2], source={"observation": obs}) + action = actor(td).get("action") + assert (action.sum(-1) == 1).all() + + actor = QValueActor(spec=action_spec, module=make_net(), safe=False).to(device) + obs = torch.randn(2, obs_dim, device=device) + td = TensorDict(batch_size=[2], source={"observation": obs}) + action = actor(td).get("action") + assert (action.sum(-1) == 1).all() + + actor = QValueActor(spec=action_spec, module=make_net(), safe=False).to(device) + obs = torch.zeros(2, obs_dim, device=device) + td = TensorDict(batch_size=[2], source={"observation": obs}) + action = actor(td).get("action") + with pytest.raises(AssertionError): + assert (action.sum(-1) == 1).all() + + +@pytest.mark.parametrize("device", get_available_devices()) +def test_actorcritic(device): + common_module = TDModule( + spec=None, module=nn.Linear(3, 4), in_keys=["obs"], out_keys=["hidden"] + ).to(device) + policy_operator = ProbabilisticActor( + spec=None, module=nn.Linear(4, 5), in_keys=["hidden"], return_log_prob=True + ).to(device) + value_operator = ValueOperator(nn.Linear(4, 1), in_keys=["hidden"]).to(device) + op = ActorValueOperator( + common_operator=common_module, + policy_operator=policy_operator, + value_operator=value_operator, + ).to(device) + td = TensorDict( + source={"obs": torch.randn(4, 3)}, + batch_size=[ + 4, + ], + ).to(device) + td_total = op(td.clone()) + policy_op = op.get_policy_operator() + td_policy = policy_op(td.clone()) + value_op = op.get_value_operator() + td_value = value_op(td) + torch.testing.assert_allclose(td_total.get("action"), td_policy.get("action")) + torch.testing.assert_allclose( + td_total.get("action_log_prob"), td_policy.get("action_log_prob") + ) + torch.testing.assert_allclose( + td_total.get("state_value"), td_value.get("state_value") + ) + + value_params = set( + list(op.get_value_operator().parameters()) + list(op.module[0].parameters()) + ) + value_params2 = set(value_op.parameters()) + assert len(value_params.difference(value_params2)) == 0 and len( + value_params.intersection(value_params2) + ) == len(value_params) + + policy_params = set( + list(op.get_policy_operator().parameters()) + list(op.module[0].parameters()) + ) + policy_params2 = set(policy_op.parameters()) + assert len(policy_params.difference(policy_params2)) == 0 and len( + policy_params.intersection(policy_params2) + ) == len(policy_params) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/test/test_postprocs.py b/test/test_postprocs.py new file mode 100644 index 00000000000..a5889677385 --- /dev/null +++ b/test/test_postprocs.py @@ -0,0 +1,171 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from torchrl.collectors.utils import split_trajectories +from torchrl.data.postprocs.postprocs import MultiStep +from torchrl.data.tensordict.tensordict import TensorDict, assert_allclose_td + + +@pytest.mark.parametrize("n", range(13)) +@pytest.mark.parametrize( + "key", ["observation", "observation_pixels", "observation_whatever"] +) +def test_multistep(n, key, T=11): + torch.manual_seed(0) + + # mock data + b = 5 + + done = torch.zeros(b, T, 1, dtype=torch.bool) + done[0, -1] = True + done[1, -2] = True + done[2, -3] = True + done[3, -4] = True + + terminal = done.clone() + terminal[:, -1] = done.sum(1) != 1 + + mask = done.clone().cumsum(1).cumsum(1) >= 2 + mask = ~mask + + total_obs = torch.randn(1, T + 1, 1).expand(b, T + 1, 1) + tensor_dict = TensorDict( + source={ + key: total_obs[:, :T] * mask.to(torch.float), + "next_" + key: total_obs[:, 1:] * mask.to(torch.float), + "done": done, + "reward": torch.randn(1, T, 1).expand(b, T, 1) * mask.to(torch.float), + "mask": mask, + }, + batch_size=(b, T), + ) + + ms = MultiStep( + 0.9, + n, + ) + ms_tensor_dict = ms(tensor_dict.clone()) + + assert ms_tensor_dict.get("done").max() == 1 + + if n == 0: + assert_allclose_td( + tensor_dict, ms_tensor_dict.select(*list(tensor_dict.keys())) + ) + + # assert that done at last step is similar to unterminated traj + assert (ms_tensor_dict.get("gamma")[4] == ms_tensor_dict.get("gamma")[0]).all() + assert ( + ms_tensor_dict.get("next_" + key)[4] == ms_tensor_dict.get("next_" + key)[0] + ).all() + assert ( + ms_tensor_dict.get("steps_to_next_obs")[4] + == ms_tensor_dict.get("steps_to_next_obs")[0] + ).all() + + # check that next obs is properly replaced, or that it is terminated + next_obs = ms_tensor_dict.get(key)[:, (1 + ms.n_steps_max) :] + true_next_obs = ms_tensor_dict.get("next_" + key)[:, : -(1 + ms.n_steps_max)] + terminated = ~ms_tensor_dict.get("nonterminal") + assert ((next_obs == true_next_obs) | terminated[:, (1 + ms.n_steps_max) :]).all() + + # test gamma computation + torch.testing.assert_allclose( + ms_tensor_dict.get("gamma"), ms.gamma ** ms_tensor_dict.get("steps_to_next_obs") + ) + + # test reward + if n > 0: + assert ( + ms_tensor_dict.get("reward") != ms_tensor_dict.get("original_reward") + ).any() + else: + assert ( + ms_tensor_dict.get("reward") == ms_tensor_dict.get("original_reward") + ).all() + + +class TestSplits: + @staticmethod + def create_fake_trajs( + num_workers=32, + traj_len=200, + ): + traj_ids = torch.arange(num_workers).unsqueeze(-1) + steps_count = torch.zeros(num_workers).unsqueeze(-1) + workers = torch.arange(num_workers) + + out = [] + for i in range(traj_len): + done = steps_count == traj_ids # traj_id 0 has 0 steps, 1 has 1 step etc. + + td = TensorDict( + source={ + "traj_ids": traj_ids, + "a": traj_ids.clone(), + "steps_count": steps_count, + "workers": workers, + "done": done, + }, + batch_size=[num_workers], + ) + out.append(td.clone()) + steps_count += 1 + + traj_ids[done] = traj_ids.max() + torch.arange(1, done.sum() + 1) + steps_count[done] = 0 + + out = torch.stack(out, 1) + return out + + @pytest.mark.parametrize("num_workers", range(4, 35)) + @pytest.mark.parametrize("traj_len", [10, 17, 50, 97, 200]) + def test_splits(self, num_workers, traj_len): + + trajs = TestSplits.create_fake_trajs(num_workers, traj_len) + assert trajs.shape[0] == num_workers + assert trajs.shape[1] == traj_len + split_trajs = split_trajectories(trajs) + + assert split_trajs.shape[0] == split_trajs.get("traj_ids").max() + 1 + assert split_trajs.shape[1] == split_trajs.get("steps_count").max() + 1 + + split_trajs.get("mask").sum() == num_workers * traj_len + + assert split_trajs.get("done").sum(1).max() == 1 + out_mask = split_trajs[split_trajs.get("mask")] + for i in range(split_trajs.shape[0]): + traj_id_split = split_trajs[i].get("traj_ids")[split_trajs[i].get("mask")] + assert 1 == len(traj_id_split.unique()) + + for w in range(num_workers): + assert (out_mask.get("workers") == w).sum() == traj_len + + # Assert that either the chain is not done XOR if it is it must have the desired length (equal to traj id by design) + for i in range(split_trajs.get("traj_ids").max()): + idx_traj_id = out_mask.get("traj_ids") == i + # (!=) == (xor) + c1 = (idx_traj_id.sum() - 1 == i) and ( + out_mask.get("done")[idx_traj_id].sum() == 1 + ) # option 1: trajectory is complete + c2 = out_mask.get("done")[idx_traj_id].sum() == 0 + assert c1 != c2, ( + f"traj_len={traj_len}, " + f"num_workers={num_workers}, " + f"traj_id={i}, " + f"idx_traj_id.sum()={idx_traj_id.sum()}, " + f"done={out_mask.get('done')[idx_traj_id].sum()}" + ) + + assert ( + split_trajs.get("traj_ids").unique().numel() + == split_trajs.get("traj_ids").max() + 1 + ) + + +if __name__ == "__main__": + pytest.main([__file__, "--capture", "no"]) diff --git a/test/test_rb.py b/test/test_rb.py new file mode 100644 index 00000000000..8735f8e2c74 --- /dev/null +++ b/test/test_rb.py @@ -0,0 +1,101 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse + +import numpy as np +import pytest +import torch +from _utils_internal import get_available_devices +from torchrl.data import TensorDict +from torchrl.data.replay_buffers import TensorDictPrioritizedReplayBuffer +from torchrl.data.tensordict.tensordict import assert_allclose_td + + +@pytest.mark.parametrize("priority_key", ["pk", "td_error"]) +@pytest.mark.parametrize("contiguous", [True, False]) +@pytest.mark.parametrize("device", get_available_devices()) +def test_prb(priority_key, contiguous, device): + torch.manual_seed(0) + np.random.seed(0) + rb = TensorDictPrioritizedReplayBuffer( + 5, + alpha=0.7, + beta=0.9, + collate_fn=None if contiguous else lambda x: torch.stack(x, 0), + priority_key=priority_key, + ) + td1 = TensorDict( + source={ + "a": torch.randn(3, 1), + priority_key: torch.rand(3, 1) / 10, + "_idx": torch.arange(3).view(3, 1), + }, + batch_size=[3], + ).to(device) + rb.extend(td1) + s = rb.sample(2) + assert s.batch_size == torch.Size( + [ + 2, + ] + ) + assert (td1[s.get("_idx").squeeze()].get("a") == s.get("a")).all() + assert_allclose_td(td1[s.get("_idx").squeeze()].select("a"), s.select("a")) + + # test replacement + td2 = TensorDict( + source={ + "a": torch.randn(5, 1), + priority_key: torch.rand(5, 1) / 10, + "_idx": torch.arange(5).view(5, 1), + }, + batch_size=[5], + ).to(device) + rb.extend(td2) + s = rb.sample(5) + assert s.batch_size == torch.Size( + [ + 5, + ] + ) + assert (td2[s.get("_idx").squeeze()].get("a") == s.get("a")).all() + assert_allclose_td(td2[s.get("_idx").squeeze()].select("a"), s.select("a")) + + # test strong update + # get all indices that match first item + idx = s.get("_idx") + idx_match = (idx == idx[0]).nonzero()[:, 0] + s.set_at_( + priority_key, + torch.ones( + idx_match.numel(), + 1, + device=device, + ) + * 100000000, + idx_match, + ) + val = s.get("a")[0] + + idx0 = s.get("_idx")[0] + rb.update_priority(s) + s = rb.sample(5) + assert (val == s.get("a")).sum() >= 1 + torch.testing.assert_allclose( + td2[idx0].get("a").view(1), s.get("a").unique().view(1) + ) + + # test updating values of original td + td2.set_("a", torch.ones_like(td2.get("a"))) + s = rb.sample(5) + torch.testing.assert_allclose( + td2[idx0].get("a").view(1), s.get("a").unique().view(1) + ) + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_recipe.py b/test/test_recipe.py new file mode 100644 index 00000000000..5387a23f503 --- /dev/null +++ b/test/test_recipe.py @@ -0,0 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse + +import pytest + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_shared.py b/test/test_shared.py new file mode 100644 index 00000000000..c52f1ea7100 --- /dev/null +++ b/test/test_shared.py @@ -0,0 +1,226 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import time +import warnings + +import pytest +import torch +from torch import multiprocessing as mp +from torchrl.data import SavedTensorDict +from torchrl.data import TensorDict + + +class TestShared: + @staticmethod + def remote_process(command_pipe_child, command_pipe_parent, tensordict): + command_pipe_parent.close() + assert tensordict.is_shared() + t0 = time.time() + tensordict.zero_() + print(f"zeroing time: {time.time() - t0}") + command_pipe_child.send("done") + + @staticmethod + def driver_func(subtd, td): + assert subtd.is_shared() + command_pipe_parent, command_pipe_child = mp.Pipe() + proc = mp.Process( + target=TestShared.remote_process, + args=(command_pipe_child, command_pipe_parent, subtd), + ) + proc.start() + command_pipe_child.close() + command_pipe_parent.recv() + for key, item in subtd.items(): + assert (item == 0).all() + + for key, item in td[0].items(): + assert (item == 0).all() + proc.join() + command_pipe_parent.close() + + def test_shared(self): + torch.manual_seed(0) + tensordict = TensorDict( + source={ + "a": torch.randn(1000, 200), + "b": torch.randn(1000, 100), + "done": torch.zeros(1000, 100, dtype=torch.bool).bernoulli_(), + }, + batch_size=[1000], + ) + + td1 = tensordict.clone().share_memory_() + td2 = tensordict.clone().share_memory_() + td3 = tensordict.clone().share_memory_() + subtd2 = TensorDict( + source={key: item[0] for key, item in td2.items()}, batch_size=[] + ) + assert subtd2.is_shared() + print("sub td2 is shared: ", subtd2.is_shared()) + + subtd1 = td1.get_sub_tensor_dict(0) + t0 = time.time() + self.driver_func(subtd1, td1) + t_elapsed = time.time() - t0 + print(f"execution on subtd: {t_elapsed}") + + t0 = time.time() + self.driver_func(subtd2, td2) + t_elapsed = time.time() - t0 + print(f"execution on plain td: {t_elapsed}") + + subtd3 = td3[0] + t0 = time.time() + self.driver_func(subtd3, td3) + t_elapsed = time.time() - t0 + print(f"execution on regular indexed td: {t_elapsed}") + + +class TestStack: + @staticmethod + def remote_process(command_pipe_child, command_pipe_parent, tensordict): + command_pipe_parent.close() + assert isinstance(tensordict, TensorDict), f"td is of type {type(tensordict)}" + assert tensordict.is_shared() or tensordict.is_memmap() + new_tensor_dict = torch.stack( + [ + tensordict[i].contiguous().clone().zero_() + for i in range(tensordict.shape[0]) + ], + 0, + ) + cmd = command_pipe_child.recv() + t0 = time.time() + if cmd == "stack": + tensordict.copy_(new_tensor_dict) + elif cmd == "serial": + for i, td in enumerate(new_tensor_dict.tensor_dicts): + tensordict.update_at_(td, i) + time_spent = time.time() - t0 + command_pipe_child.send(time_spent) + + @staticmethod + def driver_func(td, stack): + + command_pipe_parent, command_pipe_child = mp.Pipe() + proc = mp.Process( + target=TestStack.remote_process, + args=(command_pipe_child, command_pipe_parent, td), + ) + proc.start() + command_pipe_child.close() + command_pipe_parent.send("stack" if stack else "serial") + time_spent = command_pipe_parent.recv() + print(f"stack {stack}: time={time_spent}") + for key, item in td.items(): + assert (item == 0).all() + proc.join() + command_pipe_parent.close() + return time_spent + + @pytest.mark.parametrize("shared", ["shared", "memmap"]) + def test_shared(self, shared): + print(f"test_shared: shared={shared}") + torch.manual_seed(0) + tensordict = TensorDict( + source={ + "a": torch.randn(100, 2), + "b": torch.randn(100, 1), + "done": torch.zeros(100, 1, dtype=torch.bool).bernoulli_(), + }, + batch_size=[100], + ) + if shared == "shared": + tensordict.share_memory_() + else: + tensordict.memmap_() + t_true = self.driver_func(tensordict, True) + t_false = self.driver_func(tensordict, False) + if t_true > t_false: + warnings.warn( + "Updating each element of the tensordict did " + "not take longer than updating the stack." + ) + + +@pytest.mark.parametrize( + "idx", + [ + torch.tensor( + [ + 3, + 5, + 7, + 8, + ] + ), + slice(200), + ], +) +@pytest.mark.parametrize("dtype", [torch.float, torch.bool]) +def test_memmap(idx, dtype, large_scale=False): + N = 5000 if large_scale else 10 + H = 128 if large_scale else 8 + td = TensorDict( + source={ + "a": torch.zeros(N, 3, H, H, dtype=dtype), + "b": torch.zeros(N, 3, H, H, dtype=dtype), + "c": torch.zeros(N, 3, H, H, dtype=dtype), + }, + batch_size=[ + N, + ], + ) + + td_sm = td.clone().share_memory_() + td_memmap = td.clone().memmap_() + td_saved = td.to(SavedTensorDict) + + print("\nTesting reading from TD") + for i in range(2): + t0 = time.time() + td_sm[idx].clone() + if i == 1: + print(f"sm: {time.time() - t0:4.4f} sec") + + t0 = time.time() + td_memmap[idx].clone() + if i == 1: + print(f"memmap: {time.time() - t0:4.4f} sec") + + t0 = time.time() + td_saved[idx].clone() + if i == 1: + print(f"saved td: {time.time() - t0:4.4f} sec") + + td_to_copy = td[idx].contiguous() + for k in td_to_copy.keys(): + td_to_copy.set(k, torch.ones_like(td_to_copy.get(k))) + + print("\nTesting writing to TD") + for i in range(2): + t0 = time.time() + td_sm[idx].update_(td_to_copy) + if i == 1: + print(f"sm td: {time.time() - t0:4.4f} sec") + torch.testing.assert_allclose(td_sm[idx].get("a"), td_to_copy.get("a")) + + t0 = time.time() + td_memmap[idx].update_(td_to_copy) + if i == 1: + print(f"memmap td: {time.time() - t0:4.4f} sec") + torch.testing.assert_allclose(td_memmap[idx].get("a"), td_to_copy.get("a")) + + t0 = time.time() + td_saved[idx].update_(td_to_copy) + if i == 1: + print(f"saved td: {time.time() - t0:4.4f} sec") + torch.testing.assert_allclose(td_saved[idx].get("a"), td_to_copy.get("a")) + + +if __name__ == "__main__": + pytest.main([__file__, "--capture", "no"]) diff --git a/test/test_tdmodules.py b/test/test_tdmodules.py new file mode 100644 index 00000000000..de546776462 --- /dev/null +++ b/test/test_tdmodules.py @@ -0,0 +1,569 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse + +import pytest +import torch +from functorch import make_functional, make_functional_with_buffers +from torch import nn +from torchrl.data import TensorDict +from torchrl.data.tensor_specs import ( + NdUnboundedContinuousTensorSpec, + NdBoundedTensorSpec, +) +from torchrl.modules import ( + TDModule, + ProbabilisticTDModule, + TanhNormal, + TDSequence, + NormalParamWrapper, +) + + +class TestTDModule: + @pytest.mark.parametrize("safe", [True, False]) + @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) + @pytest.mark.parametrize("probabilistic", [True, False]) + @pytest.mark.parametrize("lazy", [True, False]) + def test_stateful(self, safe, spec_type, probabilistic, lazy): + torch.manual_seed(0) + param_multiplier = 2 if probabilistic else 1 + if lazy: + net = nn.LazyLinear(4 * param_multiplier) + else: + net = nn.Linear(3, 4 * param_multiplier) + + if probabilistic: + net = NormalParamWrapper(net) + + if spec_type is None: + spec = None + elif spec_type == "bounded": + spec = NdBoundedTensorSpec(-0.1, 0.1, 4) + elif spec_type == "unbounded": + spec = NdUnboundedContinuousTensorSpec(4) + + if probabilistic: + tdclass = ProbabilisticTDModule + kwargs = {"distribution_class": TanhNormal} + else: + tdclass = TDModule + kwargs = {} + + if safe and spec is None: + with pytest.raises( + RuntimeError, + match="is not a valid configuration as the tensor specs are not " + "specified", + ): + tdmodule = tdclass( + module=net, + spec=spec, + in_keys=["in"], + out_keys=["out"], + safe=safe, + **kwargs + ) + return + else: + tdmodule = tdclass( + module=net, + spec=spec, + in_keys=["in"], + out_keys=["out"], + safe=safe, + **kwargs + ) + + td = TensorDict({"in": torch.randn(3, 3)}, [3]) + tdmodule(td) + assert td.shape == torch.Size([3]) + assert td.get("out").shape == torch.Size([3, 4]) + + # test bounds + if not safe and spec_type == "bounded": + assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() + elif safe and spec_type == "bounded": + assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() + + @pytest.mark.parametrize("safe", [True, False]) + @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) + @pytest.mark.parametrize("probabilistic", [True, False]) + def test_functional(self, safe, spec_type, probabilistic): + torch.manual_seed(0) + param_multiplier = 2 if probabilistic else 1 + + net = nn.Linear(3, 4 * param_multiplier) + if probabilistic: + net = NormalParamWrapper(net) + + fnet, params = make_functional(net) + + if spec_type is None: + spec = None + elif spec_type == "bounded": + spec = NdBoundedTensorSpec(-0.1, 0.1, 4) + elif spec_type == "unbounded": + spec = NdUnboundedContinuousTensorSpec(4) + + if probabilistic: + tdclass = ProbabilisticTDModule + kwargs = {"distribution_class": TanhNormal} + else: + tdclass = TDModule + kwargs = {} + + if safe and spec is None: + with pytest.raises( + RuntimeError, + match="is not a valid configuration as the tensor specs are not " + "specified", + ): + tdmodule = tdclass( + spec=spec, + module=fnet, + in_keys=["in"], + out_keys=["out"], + safe=safe, + **kwargs + ) + return + else: + tdmodule = tdclass( + spec=spec, + module=fnet, + in_keys=["in"], + out_keys=["out"], + safe=safe, + **kwargs + ) + + td = TensorDict({"in": torch.randn(3, 3)}, [3]) + tdmodule(td, params=params) + assert td.shape == torch.Size([3]) + assert td.get("out").shape == torch.Size([3, 4]) + + # test bounds + if not safe and spec_type == "bounded": + assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() + elif safe and spec_type == "bounded": + assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() + + @pytest.mark.parametrize("safe", [True, False]) + @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) + @pytest.mark.parametrize("probabilistic", [True, False]) + def test_functional_with_buffer(self, safe, spec_type, probabilistic): + torch.manual_seed(0) + param_multiplier = 2 if probabilistic else 1 + + net = nn.BatchNorm1d(32 * param_multiplier) + if probabilistic: + net = NormalParamWrapper(net) + + fnet, params, buffers = make_functional_with_buffers(net) + + if spec_type is None: + spec = None + elif spec_type == "bounded": + spec = NdBoundedTensorSpec(-0.1, 0.1, 32) + elif spec_type == "unbounded": + spec = NdUnboundedContinuousTensorSpec(32) + + if probabilistic: + tdclass = ProbabilisticTDModule + kwargs = {"distribution_class": TanhNormal} + else: + tdclass = TDModule + kwargs = {} + + if safe and spec is None: + with pytest.raises( + RuntimeError, + match="is not a valid configuration as the tensor specs are not " + "specified", + ): + tdmodule = tdclass( + spec=spec, + module=fnet, + in_keys=["in"], + out_keys=["out"], + safe=safe, + **kwargs + ) + return + else: + tdmodule = tdclass( + spec=spec, + module=fnet, + in_keys=["in"], + out_keys=["out"], + safe=safe, + **kwargs + ) + + td = TensorDict({"in": torch.randn(3, 32 * param_multiplier)}, [3]) + tdmodule(td, params=params, buffers=buffers) + assert td.shape == torch.Size([3]) + assert td.get("out").shape == torch.Size([3, 32]) + + # test bounds + if not safe and spec_type == "bounded": + assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() + elif safe and spec_type == "bounded": + assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() + + @pytest.mark.parametrize("safe", [True, False]) + @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) + @pytest.mark.parametrize("probabilistic", [True, False]) + def test_vmap(self, safe, spec_type, probabilistic): + torch.manual_seed(0) + param_multiplier = 2 if probabilistic else 1 + + net = nn.Linear(3, 4 * param_multiplier) + if probabilistic: + net = NormalParamWrapper(net) + + fnet, params = make_functional(net) + + if spec_type is None: + spec = None + elif spec_type == "bounded": + spec = NdBoundedTensorSpec(-0.1, 0.1, 4) + elif spec_type == "unbounded": + spec = NdUnboundedContinuousTensorSpec(4) + + if probabilistic: + tdclass = ProbabilisticTDModule + kwargs = {"distribution_class": TanhNormal} + else: + tdclass = TDModule + kwargs = {} + + if safe and spec is None: + with pytest.raises( + RuntimeError, + match="is not a valid configuration as the tensor specs are not " + "specified", + ): + tdmodule = tdclass( + spec=spec, + module=fnet, + in_keys=["in"], + out_keys=["out"], + safe=safe, + **kwargs + ) + return + else: + tdmodule = tdclass( + spec=spec, + module=fnet, + in_keys=["in"], + out_keys=["out"], + safe=safe, + **kwargs + ) + + # vmap = True + params = [p.repeat(10, *[1 for _ in p.shape]) for p in params] + td = TensorDict({"in": torch.randn(3, 3)}, [3]) + td_out = tdmodule(td, params=params, vmap=True) + assert td_out is not td + assert td_out.shape == torch.Size([10, 3]) + assert td_out.get("out").shape == torch.Size([10, 3, 4]) + # test bounds + if not safe and spec_type == "bounded": + assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() + elif safe and spec_type == "bounded": + assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() + + # vmap = (0, None) + td_out = tdmodule(td, params=params, vmap=(0, None)) + assert td_out is not td + assert td_out.shape == torch.Size([10, 3]) + assert td_out.get("out").shape == torch.Size([10, 3, 4]) + # test bounds + if not safe and spec_type == "bounded": + assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() + elif safe and spec_type == "bounded": + assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() + + # vmap = (0, 0) + td_repeat = td.expand(10).clone() + td_out = tdmodule(td_repeat, params=params, vmap=(0, 0)) + assert td_out is not td + assert td_out.shape == torch.Size([10, 3]) + assert td_out.get("out").shape == torch.Size([10, 3, 4]) + # test bounds + if not safe and spec_type == "bounded": + assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() + elif safe and spec_type == "bounded": + assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() + + +class TestTDSequence: + @pytest.mark.parametrize("safe", [True, False]) + @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) + @pytest.mark.parametrize("probabilistic", [True, False]) + @pytest.mark.parametrize("lazy", [True, False]) + def test_stateful(self, safe, spec_type, probabilistic, lazy): + torch.manual_seed(0) + param_multiplier = 2 if probabilistic else 1 + if lazy: + net1 = nn.LazyLinear(4) + net2 = nn.LazyLinear(4 * param_multiplier) + else: + net1 = nn.Linear(3, 4) + net2 = nn.Linear(4, 4 * param_multiplier) + if probabilistic: + net2 = NormalParamWrapper(net2) + + if spec_type is None: + spec = None + elif spec_type == "bounded": + spec = NdBoundedTensorSpec(-0.1, 0.1, 4) + elif spec_type == "unbounded": + spec = NdUnboundedContinuousTensorSpec(4) + + if probabilistic: + tdclass = ProbabilisticTDModule + kwargs = {"distribution_class": TanhNormal} + else: + tdclass = TDModule + kwargs = {} + + if safe and spec is None: + pytest.skip("safe and spec is None is checked elsewhere") + else: + tdmodule1 = TDModule( + net1, + None, + in_keys=["in"], + out_keys=["hidden"], + safe=False, + ) + tdmodule2 = tdclass( + spec=spec, + module=net2, + in_keys=["hidden"], + out_keys=["out"], + safe=False, + **kwargs + ) + tdmodule = TDSequence(tdmodule1, tdmodule2) + + td = TensorDict({"in": torch.randn(3, 3)}, [3]) + tdmodule(td) + assert td.shape == torch.Size([3]) + assert td.get("out").shape == torch.Size([3, 4]) + + # test bounds + if not safe and spec_type == "bounded": + assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() + elif safe and spec_type == "bounded": + assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() + + @pytest.mark.parametrize("safe", [True, False]) + @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) + @pytest.mark.parametrize("probabilistic", [True, False]) + def test_functional(self, safe, spec_type, probabilistic): + torch.manual_seed(0) + param_multiplier = 2 if probabilistic else 1 + + net1 = nn.Linear(3, 4) + net2 = nn.Linear(4, 4 * param_multiplier) + if probabilistic: + net2 = NormalParamWrapper(net2) + + fnet1, params1 = make_functional(net1) + fnet2, params2 = make_functional(net2) + params = list(params1) + list(params2) + + if spec_type is None: + spec = None + elif spec_type == "bounded": + spec = NdBoundedTensorSpec(-0.1, 0.1, 4) + elif spec_type == "unbounded": + spec = NdUnboundedContinuousTensorSpec(4) + + if probabilistic: + tdclass = ProbabilisticTDModule + kwargs = {"distribution_class": TanhNormal} + else: + tdclass = TDModule + kwargs = {} + + if safe and spec is None: + pytest.skip("safe and spec is None is checked elsewhere") + else: + tdmodule1 = TDModule( + fnet1, + None, + in_keys=["in"], + out_keys=["hidden"], + safe=False, + ) + tdmodule2 = tdclass( + fnet2, spec, in_keys=["hidden"], out_keys=["out"], safe=safe, **kwargs + ) + tdmodule = TDSequence(tdmodule1, tdmodule2) + + td = TensorDict({"in": torch.randn(3, 3)}, [3]) + tdmodule(td, params=params) + assert td.shape == torch.Size([3]) + assert td.get("out").shape == torch.Size([3, 4]) + + # test bounds + if not safe and spec_type == "bounded": + assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() + elif safe and spec_type == "bounded": + assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() + + @pytest.mark.parametrize("safe", [True, False]) + @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) + @pytest.mark.parametrize("probabilistic", [True, False]) + def test_functional_with_buffer(self, safe, spec_type, probabilistic): + torch.manual_seed(0) + param_multiplier = 2 if probabilistic else 1 + + net1 = nn.Sequential(nn.Linear(7, 7), nn.BatchNorm1d(7)) + net2 = nn.Sequential( + nn.Linear(7, 7 * param_multiplier), nn.BatchNorm1d(7 * param_multiplier) + ) + if probabilistic: + net2 = NormalParamWrapper(net2) + + fnet1, params1, buffers1 = make_functional_with_buffers(net1) + fnet2, params2, buffers2 = make_functional_with_buffers(net2) + + params = list(params1) + list(params2) + buffers = list(buffers1) + list(buffers2) + + if spec_type is None: + spec = None + elif spec_type == "bounded": + spec = NdBoundedTensorSpec(-0.1, 0.1, 7) + elif spec_type == "unbounded": + spec = NdUnboundedContinuousTensorSpec(7) + + if probabilistic: + tdclass = ProbabilisticTDModule + kwargs = {"distribution_class": TanhNormal} + else: + tdclass = TDModule + kwargs = {} + + if safe and spec is None: + pytest.skip("safe and spec is None is checked elsewhere") + else: + tdmodule1 = TDModule( + fnet1, + None, + in_keys=["in"], + out_keys=["hidden"], + safe=False, + ) + tdmodule2 = tdclass( + fnet2, spec, in_keys=["hidden"], out_keys=["out"], safe=safe, **kwargs + ) + tdmodule = TDSequence(tdmodule1, tdmodule2) + + td = TensorDict({"in": torch.randn(3, 7)}, [3]) + tdmodule(td, params=params, buffers=buffers) + assert td.shape == torch.Size([3]) + assert td.get("out").shape == torch.Size([3, 7]) + + # test bounds + if not safe and spec_type == "bounded": + assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() + elif safe and spec_type == "bounded": + assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() + + @pytest.mark.parametrize("safe", [True, False]) + @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) + @pytest.mark.parametrize("probabilistic", [True, False]) + def test_vmap(self, safe, spec_type, probabilistic): + torch.manual_seed(0) + param_multiplier = 2 if probabilistic else 1 + + net1 = nn.Linear(3, 4) + net2 = nn.Linear(4, 4 * param_multiplier) + if probabilistic: + net2 = NormalParamWrapper(net2) + + fnet1, params1 = make_functional(net1) + fnet2, params2 = make_functional(net2) + params = params1 + params2 + + if spec_type is None: + spec = None + elif spec_type == "bounded": + spec = NdBoundedTensorSpec(-0.1, 0.1, 4) + elif spec_type == "unbounded": + spec = NdUnboundedContinuousTensorSpec(4) + + if probabilistic: + tdclass = ProbabilisticTDModule + kwargs = {"distribution_class": TanhNormal} + else: + tdclass = TDModule + kwargs = {} + + if safe and spec is None: + pytest.skip("safe and spec is None is checked elsewhere") + else: + tdmodule1 = TDModule( + fnet1, + None, + in_keys=["in"], + out_keys=["hidden"], + safe=False, + ) + tdmodule2 = tdclass( + fnet2, spec, in_keys=["hidden"], out_keys=["out"], safe=safe, **kwargs + ) + tdmodule = TDSequence(tdmodule1, tdmodule2) + + # vmap = True + params = [p.repeat(10, *[1 for _ in p.shape]) for p in params] + td = TensorDict({"in": torch.randn(3, 3)}, [3]) + td_out = tdmodule(td, params=params, vmap=True) + assert td_out is not td + assert td_out.shape == torch.Size([10, 3]) + assert td_out.get("out").shape == torch.Size([10, 3, 4]) + # test bounds + if not safe and spec_type == "bounded": + assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() + elif safe and spec_type == "bounded": + assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() + + # vmap = (0, None) + td_out = tdmodule(td, params=params, vmap=(0, None)) + assert td_out is not td + assert td_out.shape == torch.Size([10, 3]) + assert td_out.get("out").shape == torch.Size([10, 3, 4]) + # test bounds + if not safe and spec_type == "bounded": + assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() + elif safe and spec_type == "bounded": + assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() + + # vmap = (0, 0) + td_repeat = td.expand(10).clone() + td_out = tdmodule(td_repeat, params=params, vmap=(0, 0)) + assert td_out is not td + assert td_out.shape == torch.Size([10, 3]) + assert td_out.get("out").shape == torch.Size([10, 3, 4]) + # test bounds + if not safe and spec_type == "bounded": + assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() + elif safe and spec_type == "bounded": + assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_tensor_spec.py b/test/test_tensor_spec.py new file mode 100644 index 00000000000..a7ce40fe045 --- /dev/null +++ b/test/test_tensor_spec.py @@ -0,0 +1,263 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import pytest +import torch +from torchrl.data.tensor_specs import ( + NdUnboundedContinuousTensorSpec, + NdBoundedTensorSpec, + CompositeSpec, + MultOneHotDiscreteTensorSpec, + BinaryDiscreteTensorSpec, + BoundedTensorSpec, + UnboundedContinuousTensorSpec, + OneHotDiscreteTensorSpec, +) + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.float64, None]) +def test_bounded(dtype): + torch.manual_seed(0) + np.random.seed(0) + for _ in range(100): + bounds = torch.randn(2).sort()[0] + ts = BoundedTensorSpec(bounds[0].item(), bounds[1].item(), dtype=dtype) + _dtype = dtype + if dtype is None: + _dtype = torch.get_default_dtype() + + r = ts.rand() + assert ts.is_in(r) + assert r.dtype is _dtype + ts.is_in(ts.encode(bounds.mean())) + ts.is_in(ts.encode(bounds.mean().item())) + assert (ts.encode(ts.to_numpy(r)) == r).all() + + +def test_onehot(): + torch.manual_seed(0) + np.random.seed(0) + + ts = OneHotDiscreteTensorSpec(10) + for _ in range(100): + r = ts.rand() + ts.to_numpy(r) + ts.encode(torch.tensor([5])) + ts.encode(torch.tensor([5]).numpy()) + ts.encode(9) + with pytest.raises(RuntimeError): + ts.encode(torch.tensor([11])) # out of bounds + assert ts.is_in(r) + assert (ts.encode(ts.to_numpy(r)) == r).all() + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.float64, None]) +def test_unbounded(dtype): + torch.manual_seed(0) + np.random.seed(0) + ts = UnboundedContinuousTensorSpec(dtype=dtype) + + if dtype is None: + dtype = torch.get_default_dtype() + for _ in range(100): + r = ts.rand() + ts.to_numpy(r) + assert ts.is_in(r) + assert r.dtype is dtype + assert (ts.encode(ts.to_numpy(r)) == r).all() + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.float64, None]) +@pytest.mark.parametrize( + "shape", + [ + [], + torch.Size( + [ + 3, + ] + ), + ], +) +def test_ndbounded(dtype, shape): + torch.manual_seed(0) + np.random.seed(0) + + for _ in range(100): + lb = torch.rand(10) - 1 + ub = torch.rand(10) + 1 + ts = NdBoundedTensorSpec(lb, ub, dtype=dtype) + _dtype = dtype + if dtype is None: + _dtype = torch.get_default_dtype() + + r = ts.rand(shape) + assert r.dtype is _dtype + assert r.shape == torch.Size([*shape, 10]) + assert (r >= lb.to(dtype)).all() and ( + r <= ub.to(dtype) + ).all(), f"{r[r <= lb] - lb.expand_as(r)[r <= lb]} -- {r[r >= ub] - ub.expand_as(r)[r >= ub]} " + ts.to_numpy(r) + assert ts.is_in(r) + ts.encode(lb + torch.rand(10) * (ub - lb)) + ts.encode((lb + torch.rand(10) * (ub - lb)).numpy()) + assert (ts.encode(ts.to_numpy(r)) == r).all() + with pytest.raises(AssertionError): + ts.encode(torch.rand(10) + 3) # out of bounds + with pytest.raises(AssertionError): + ts.to_numpy(torch.rand(10) + 3) # out of bounds + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.float64, None]) +@pytest.mark.parametrize("n", range(3, 10)) +@pytest.mark.parametrize( + "shape", + [ + [], + torch.Size( + [ + 3, + ] + ), + ], +) +def test_ndunbounded(dtype, n, shape): + torch.manual_seed(0) + np.random.seed(0) + + ts = NdUnboundedContinuousTensorSpec( + shape=[ + n, + ], + dtype=dtype, + ) + + if dtype is None: + dtype = torch.get_default_dtype() + + for _ in range(100): + r = ts.rand(shape) + assert r.shape == torch.Size( + [ + *shape, + n, + ] + ) + ts.to_numpy(r) + assert ts.is_in(r) + assert r.dtype is dtype + assert (ts.encode(ts.to_numpy(r)) == r).all() + + +@pytest.mark.parametrize("n", range(3, 10)) +@pytest.mark.parametrize( + "shape", + [ + [], + torch.Size( + [ + 3, + ] + ), + ], +) +def test_binary(n, shape): + torch.manual_seed(0) + np.random.seed(0) + + ts = BinaryDiscreteTensorSpec(n) + for _ in range(100): + r = ts.rand(shape) + assert r.shape == torch.Size( + [ + *shape, + n, + ] + ) + assert ts.is_in(r) + assert ((r == 0) | (r == 1)).all() + assert (ts.encode(r.numpy()) == r).all() + assert (ts.encode(ts.to_numpy(r)) == r).all() + + +@pytest.mark.parametrize( + "ns", + [ + [ + 5, + ], + [5, 2, 3], + [4, 4, 1], + ], +) +@pytest.mark.parametrize( + "shape", + [ + [], + torch.Size( + [ + 3, + ] + ), + ], +) +def test_mult_onehot(shape, ns): + torch.manual_seed(0) + np.random.seed(0) + ts = MultOneHotDiscreteTensorSpec(nvec=ns) + for _ in range(100): + r = ts.rand(shape) + assert r.shape == torch.Size( + [ + *shape, + sum(ns), + ] + ) + assert ts.is_in(r) + assert ((r == 0) | (r == 1)).all() + rsplit = r.split(ns, dim=-1) + for _r, _n in zip(rsplit, ns): + assert (_r.sum(-1) == 1).all() + assert _r.shape[-1] == _n + np_r = ts.to_numpy(r) + assert (ts.encode(np_r) == r).all() + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.float64, None]) +@pytest.mark.parametrize( + "shape", + [ + [], + torch.Size( + [ + 3, + ] + ), + ], +) +def test_composite(shape, dtype): + torch.manual_seed(0) + np.random.seed(0) + + ts = CompositeSpec( + obs=NdBoundedTensorSpec( + torch.zeros(3, 32, 32), torch.ones(3, 32, 32), dtype=dtype + ), + act=NdUnboundedContinuousTensorSpec((7,), dtype=dtype), + ) + if dtype is None: + dtype = torch.get_default_dtype() + + rand_td = ts.rand(shape) + assert rand_td.shape == torch.Size(shape) + assert rand_td.get("obs").shape == torch.Size([*shape, 3, 32, 32]) + assert rand_td.get("obs").dtype == dtype + assert rand_td.get("act").shape == torch.Size([*shape, 7]) + assert rand_td.get("act").dtype == dtype + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/test/test_tensordict.py b/test/test_tensordict.py new file mode 100644 index 00000000000..5c28f1fef71 --- /dev/null +++ b/test/test_tensordict.py @@ -0,0 +1,912 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os.path + +import numpy as np +import pytest +import torch +from _utils_internal import get_available_devices +from torch import multiprocessing as mp +from torchrl.data import TensorDict, SavedTensorDict +from torchrl.data.tensordict.tensordict import LazyStackedTensorDict, assert_allclose_td +from torchrl.data.tensordict.utils import _getitem_batch_size + + +@pytest.mark.parametrize("device", get_available_devices()) +def test_tensor_dict_set(device): + torch.manual_seed(1) + td = TensorDict({}, batch_size=(4, 5)) + td.set("key1", torch.randn(4, 5, device=device)) + assert td.device == torch.device(device) + # by default inplace: + with pytest.raises(RuntimeError): + td.set("key1", torch.randn(5, 5, device=device)) + + # robust to dtype casting + td.set_("key1", torch.ones(4, 5, device=device, dtype=torch.double)) + assert (td.get("key1") == 1).all() + + # robust to device casting + td.set("key_device", torch.ones(4, 5, device="cpu", dtype=torch.double)) + assert td.get("key_device").device == torch.device(device) + + with pytest.raises( + AttributeError, match="for populating tensordict with new key-value pair" + ): + td.set_("smartypants", torch.ones(4, 5, device="cpu", dtype=torch.double)) + # test set_at_ + td.set("key2", torch.randn(4, 5, 6, device=device)) + x = torch.randn(6, device=device) + td.set_at_("key2", x, (2, 2)) + assert (td.get("key2")[2, 2] == x).all() + + # test set_at_ with dtype casting + x = torch.randn(6, dtype=torch.double, device=device) + td.set_at_("key2", x, (2, 2)) # robust to dtype casting + torch.testing.assert_allclose(td.get("key2")[2, 2], x.to(torch.float)) + + td.set("key1", torch.zeros(4, 5, dtype=torch.double, device=device), inplace=True) + assert (td.get("key1") == 0).all() + td.set( + "key1", + torch.randn(4, 5, 1, 2, dtype=torch.double, device=device), + inplace=False, + ) + assert td._tensor_dict_meta["key1"].shape == td._tensor_dict["key1"].shape + + +@pytest.mark.parametrize("device", get_available_devices()) +def test_stack(device): + torch.manual_seed(1) + tds_list = [TensorDict(source={}, batch_size=(4, 5)) for _ in range(3)] + tds = torch.stack(tds_list, 0) + assert tds[0] is tds_list[0] + + td = TensorDict( + source={"a": torch.randn(4, 5, 3, device=device)}, batch_size=(4, 5) + ) + td_list = list(td) + td_reconstruct = torch.stack(td_list, 0) + assert td_reconstruct.batch_size == td.batch_size + assert (td_reconstruct == td).all() + + +@pytest.mark.parametrize("device", get_available_devices()) +def test_tensor_dict_indexing(device): + torch.manual_seed(1) + td = TensorDict({}, batch_size=(4, 5)) + td.set("key1", torch.randn(4, 5, 1, device=device)) + td.set("key2", torch.randn(4, 5, 6, device=device, dtype=torch.double)) + + td_select = td[2, 2] + td_select._check_batch_size() + + td_select = td[2, :2] + td_select._check_batch_size() + + td_select = td[None, :2] + td_select._check_batch_size() + + td_reconstruct = torch.stack([_td for _td in td], 0) + assert ( + td_reconstruct == td + ).all(), f"td and td_reconstruct differ, got {td} and {td_reconstruct}" + + superlist = [torch.stack([__td for __td in _td], 0) for _td in td] + td_reconstruct = torch.stack(superlist, 0) + assert ( + td_reconstruct == td + ).all(), f"td and td_reconstruct differ, got {td == td_reconstruct}" + + x = torch.randn(4, 5, device=device) + td = TensorDict( + source={"key1": torch.zeros(3, 4, 5, device=device)}, + batch_size=[3, 4], + ) + td[0].set_("key1", x) + torch.testing.assert_allclose(td.get("key1")[0], x) + torch.testing.assert_allclose(td.get("key1")[0], td[0].get("key1")) + + y = torch.randn(3, 5, device=device) + td[:, 0].set_("key1", y) + torch.testing.assert_allclose(td.get("key1")[:, 0], y) + torch.testing.assert_allclose(td.get("key1")[:, 0], td[:, 0].get("key1")) + + +@pytest.mark.parametrize("device", get_available_devices()) +def test_subtensor_dict_construction(device): + torch.manual_seed(1) + td = TensorDict({}, batch_size=(4, 5)) + td.set("key1", torch.randn(4, 5, 1, device=device)) + td.set("key2", torch.randn(4, 5, 6, dtype=torch.double, device=device)) + std1 = td.get_sub_tensor_dict(2) + std2 = std1.get_sub_tensor_dict(2) + std_control = td.get_sub_tensor_dict((2, 2)) + assert (std_control.get("key1") == std2.get("key1")).all() + assert (std_control.get("key2") == std2.get("key2")).all() + + # write values + std_control.set("key1", torch.randn(1, device=device)) + std_control.set("key2", torch.randn(6, device=device, dtype=torch.double)) + + assert (std_control.get("key1") == std2.get("key1")).all() + assert (std_control.get("key2") == std2.get("key2")).all() + + assert std_control.get_parent_tensor_dict() is td + assert ( + std_control.get_parent_tensor_dict() + is std2.get_parent_tensor_dict().get_parent_tensor_dict() + ) + + +@pytest.mark.parametrize("device", get_available_devices()) +def test_mask_td(device): + torch.manual_seed(1) + d = { + "key1": torch.randn(4, 5, 6, device=device), + "key2": torch.randn(4, 5, 10, device=device), + } + mask = torch.zeros(4, 5, dtype=torch.bool, device=device).bernoulli_() + td = TensorDict(batch_size=(4, 5), source=d) + td_masked = td.masked_select(mask) + assert len(td_masked.get("key1")) == td_masked.shape[0] + + +@pytest.mark.parametrize("device", get_available_devices()) +def test_unbind_td(device): + torch.manual_seed(1) + d = { + "key1": torch.randn(4, 5, 6, device=device), + "key2": torch.randn(4, 5, 10, device=device), + } + td = TensorDict(batch_size=(4, 5), source=d) + td_unbind = td.unbind(1) + assert ( + td_unbind[0].batch_size == td[:, 0].batch_size + ), f"got {td_unbind[0].batch_size} and {td[:, 0].batch_size}" + + +@pytest.mark.parametrize("device", get_available_devices()) +def test_cat_td(device): + torch.manual_seed(1) + d = { + "key1": torch.randn(4, 5, 6, device=device), + "key2": torch.randn(4, 5, 10, device=device), + } + td1 = TensorDict(batch_size=(4, 5), source=d) + d = { + "key1": torch.randn(4, 10, 6, device=device), + "key2": torch.randn(4, 10, 10, device=device), + } + td2 = TensorDict(batch_size=(4, 10), source=d) + + td_cat = torch.cat([td1, td2], 1) + assert td_cat.batch_size == torch.Size([4, 15]) + d = {"key1": torch.randn(4, 15, 6), "key2": torch.randn(4, 15, 10)} + td_out = TensorDict(batch_size=(4, 15), source=d) + torch.cat([td1, td2], 1, out=td_out) + assert td_out.batch_size == torch.Size([4, 15]) + + +@pytest.mark.parametrize("device", get_available_devices()) +def test_expand(device): + torch.manual_seed(1) + d = { + "key1": torch.randn(4, 5, 6, device=device), + "key2": torch.randn(4, 5, 10, device=device), + } + td1 = TensorDict(batch_size=(4, 5), source=d) + td2 = td1.expand(3, 7) + assert td2.batch_size == torch.Size([3, 7, 4, 5]) + assert td2.get("key1").shape == torch.Size([3, 7, 4, 5, 6]) + assert td2.get("key2").shape == torch.Size([3, 7, 4, 5, 10]) + + +@pytest.mark.parametrize("device", get_available_devices()) +def test_squeeze(device): + torch.manual_seed(1) + d = { + "key1": torch.randn(4, 5, 6, device=device), + "key2": torch.randn(4, 5, 10, device=device), + } + td1 = TensorDict(batch_size=(4, 5), source=d) + td2 = td1.unsqueeze(1) + assert td2.batch_size == torch.Size([4, 1, 5]) + + td1b = td2.squeeze(1) + assert td1b.batch_size == td1.batch_size + + +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("stack_dim", [0, 1]) +def test_stacked_td(stack_dim, device): + tensor_dicts = [ + TensorDict( + batch_size=[11, 12], + source={ + "key1": torch.randn(11, 12, 5, device=device), + "key2": torch.zeros( + 11, 12, 50, device=device, dtype=torch.bool + ).bernoulli_(), + }, + ) + for _ in range(10) + ] + + tensor_dicts0 = tensor_dicts[0] + tensor_dicts1 = tensor_dicts[1] + tensor_dicts2 = tensor_dicts[2] + tensor_dicts3 = tensor_dicts[3] + sub_td = LazyStackedTensorDict(*tensor_dicts, stack_dim=stack_dim) + + std_bis = torch.stack(tensor_dicts, dim=stack_dim) + assert (sub_td == std_bis).all() + + item = tuple([*[slice(None) for _ in range(stack_dim)], 0]) + tensor_dicts0.zero_() + assert (sub_td[item].get("key1") == sub_td.get("key1")[item]).all() + assert ( + sub_td.contiguous()[item].get("key1") == sub_td.contiguous().get("key1")[item] + ).all() + assert (sub_td.contiguous().get("key1")[item] == 0).all() + + item = tuple([*[slice(None) for _ in range(stack_dim)], 1]) + std2 = sub_td[:5] + tensor_dicts1.zero_() + assert (std2[item].get("key1") == std2.get("key1")[item]).all() + assert ( + std2.contiguous()[item].get("key1") == std2.contiguous().get("key1")[item] + ).all() + assert (std2.contiguous().get("key1")[item] == 0).all() + + std3 = sub_td[:5, :, :5] + tensor_dicts2.zero_() + item = tuple([*[slice(None) for _ in range(stack_dim)], 2]) + assert (std3[item].get("key1") == std3.get("key1")[item]).all() + assert ( + std3.contiguous()[item].get("key1") == std3.contiguous().get("key1")[item] + ).all() + assert (std3.contiguous().get("key1")[item] == 0).all() + + std4 = sub_td.select("key1") + tensor_dicts3.zero_() + item = tuple([*[slice(None) for _ in range(stack_dim)], 3]) + assert (std4[item].get("key1") == std4.get("key1")[item]).all() + assert ( + std4.contiguous()[item].get("key1") == std4.contiguous().get("key1")[item] + ).all() + assert (std4.contiguous().get("key1")[item] == 0).all() + + std5 = sub_td.unbind(1)[0] + assert (std5.contiguous() == sub_td.contiguous().unbind(1)[0]).all() + + +@pytest.mark.parametrize("device", get_available_devices()) +def test_savedtensordict(device): + vals = [torch.randn(3, 1, device=device) for _ in range(4)] + ss_list = [ + SavedTensorDict( + source=TensorDict( + source={"a": vals[i]}, + batch_size=[ + 3, + ], + ) + ) + for i in range(4) + ] + ss = torch.stack(ss_list, 0) + assert ss_list[1] is ss[1] + torch.testing.assert_allclose(ss_list[1].get("a"), vals[1]) + torch.testing.assert_allclose(ss_list[1].get("a"), ss[1].get("a")) + torch.testing.assert_allclose(ss[1].get("a"), ss.get("a")[1]) + assert ss.get("a").device == device + + +class TestTensorDicts: + @property + def td(self): + return TensorDict( + source={ + "a": torch.randn(3, 1, 5), + "b": torch.randn(3, 1, 10), + "c": torch.randint(10, (3, 1, 3)), + }, + batch_size=[3, 1], + ) + + @property + def stacked_td(self): + return torch.stack([self.td for _ in range(2)], 0) + + @property + def idx_td(self): + return self.td[0] + + @property + def sub_td(self): + return self.td.get_sub_tensor_dict(0) + + @property + def saved_td(self): + return SavedTensorDict(source=self.td) + + @property + def unsqueezed_td(self): + return self.td.unsqueeze(0) + + @pytest.mark.parametrize( + "td_name", ["td", "stacked_td", "sub_td", "idx_td", "saved_td", "unsqueezed_td"] + ) + def test_select(self, td_name): + torch.manual_seed(1) + td = getattr(self, td_name) + td2 = td.select("a") + assert td2 is not td + assert len(list(td2.keys())) == 1 and "a" in td2.keys() + assert len(list(td2.clone().keys())) == 1 and "a" in td2.clone().keys() + + td2 = td.select("a", inplace=True) + assert td2 is td + + @pytest.mark.parametrize( + "td_name", ["td", "stacked_td", "sub_td", "idx_td", "saved_td", "unsqueezed_td"] + ) + def test_expand(self, td_name): + torch.manual_seed(1) + td = getattr(self, td_name) + batch_size = td.batch_size + new_td = td.expand(3) + assert new_td.batch_size == torch.Size([3, *batch_size]) + assert all((_new_td == td).all() for _new_td in new_td) + + @pytest.mark.parametrize( + "td_name", ["td", "stacked_td", "sub_td", "idx_td", "saved_td", "unsqueezed_td"] + ) + def test_cast(self, td_name): + torch.manual_seed(1) + td = getattr(self, td_name) + td_td = td.to(TensorDict) + assert (td == td_td).all() + + td = getattr(self, td_name) + td_saved = td.to(SavedTensorDict) + assert (td == td_saved).all() + + @pytest.mark.parametrize( + "td_name", ["td", "stacked_td", "sub_td", "idx_td", "saved_td", "unsqueezed_td"] + ) + def test_remove(self, td_name): + torch.manual_seed(1) + td = getattr(self, td_name) + td = td.del_("a") + assert td is not None + assert "a" not in td.keys() + + @pytest.mark.parametrize( + "td_name", ["td", "stacked_td", "sub_td", "idx_td", "saved_td", "unsqueezed_td"] + ) + def test_set_unexisting(self, td_name): + torch.manual_seed(1) + td = getattr(self, td_name) + td.set("z", torch.ones_like(td.get("a"))) + assert (td.get("z") == 1).all() + + @pytest.mark.parametrize( + "td_name", ["td", "stacked_td", "sub_td", "idx_td", "saved_td", "unsqueezed_td"] + ) + def test_fill_(self, td_name): + torch.manual_seed(1) + td = getattr(self, td_name) + new_td = td.fill_("a", 0.1) + assert (td.get("a") == 0.1).all() + assert new_td is td + + @pytest.mark.parametrize( + "td_name", ["td", "stacked_td", "sub_td", "idx_td", "saved_td", "unsqueezed_td"] + ) + def test_masked_fill_(self, td_name): + torch.manual_seed(1) + td = getattr(self, td_name) + mask = torch.zeros(td.shape, dtype=torch.bool).bernoulli_() + new_td = td.masked_fill_(mask, -10.0) + assert new_td is td + for k, item in td.items(): + assert (item[mask] == -10).all() + + @pytest.mark.parametrize( + "td_name", ["td", "stacked_td", "sub_td", "idx_td", "saved_td", "unsqueezed_td"] + ) + def test_zero_(self, td_name): + torch.manual_seed(1) + td = getattr(self, td_name) + new_td = td.zero_() + assert new_td is td + for k in td.keys(): + assert (td.get(k) == 0).all() + + @pytest.mark.parametrize( + "td_name", ["td", "stacked_td", "sub_td", "idx_td", "saved_td", "unsqueezed_td"] + ) + def test_from_empty(self, td_name): + torch.manual_seed(1) + td = getattr(self, td_name) + new_td = TensorDict({}, batch_size=td.batch_size) + for key, item in td.items(): + new_td.set(key, item) + assert_allclose_td(td, new_td) + assert td.device == new_td.device + assert td.shape == new_td.shape + + @pytest.mark.parametrize( + "td_name", ["td", "stacked_td", "sub_td", "idx_td", "saved_td", "unsqueezed_td"] + ) + def test_masking(self, td_name): + torch.manual_seed(1) + td = getattr(self, td_name) + mask = torch.zeros(td.batch_size, dtype=torch.bool).bernoulli_(0.8) + td_masked = td[mask] + td_masked2 = td.masked_select(mask) + assert_allclose_td(td_masked, td_masked2) + assert td_masked.batch_size[0] == mask.sum() + assert td_masked.batch_dims == 1 + + @pytest.mark.skipif( + torch.cuda.device_count() == 0, reason="No cuda device detected" + ) + @pytest.mark.parametrize( + "td_name", ["td", "stacked_td", "sub_td", "idx_td", "saved_td", "unsqueezed_td"] + ) + @pytest.mark.parametrize("device", [0, "cuda:0", "cuda", torch.device("cuda:0")]) + def test_pin_memory(self, td_name, device): + torch.manual_seed(1) + td = getattr(self, td_name) + if td_name != "saved_td": + td.pin_memory() + td_device = td.to(device) + _device = torch.device("cuda:0") + assert td_device.device == _device + assert td_device.clone().device == _device + assert td_device is not td + for k, item in td_device.items(): + assert item.device == _device + for k, item in td_device.clone().items(): + assert item.device == _device + # assert type(td_device) is type(td) + assert_allclose_td(td, td_device.to("cpu")) + else: + with pytest.raises( + RuntimeError, + match="pin_memory requires tensordicts that live in memory", + ): + td.pin_memory() + + @pytest.mark.skipif( + torch.cuda.device_count() == 0, reason="No cuda device detected" + ) + @pytest.mark.parametrize( + "td_name", ["td", "stacked_td", "sub_td", "idx_td", "saved_td", "unsqueezed_td"] + ) + @pytest.mark.parametrize("device", get_available_devices()) + def test_cast_device(self, td_name, device): + torch.manual_seed(1) + td = getattr(self, td_name) + td_device = td.to(device) + + for k, item in td_device.items_meta(): + assert item.device == device + for k, item in td_device.items(): + assert item.device == device + for k, item in td_device.clone().items(): + assert item.device == device + + assert td_device.device == device, ( + f"td_device first tensor device is " f"{next(td_device.items())[1].device}" + ) + assert td_device.clone().device == device + if device != td.device: + assert td_device is not td + assert td_device.to(device) is td_device + assert td.to("cpu") is td + # assert type(td_device) is type(td) + assert_allclose_td(td, td_device.to("cpu")) + + @pytest.mark.skipif( + torch.cuda.device_count() == 0, reason="No cuda device detected" + ) + @pytest.mark.parametrize( + "td_name", ["td", "stacked_td", "sub_td", "idx_td", "saved_td", "unsqueezed_td"] + ) + def test_cpu_cuda(self, td_name): + torch.manual_seed(1) + td = getattr(self, td_name) + td_device = td.cuda() + td_back = td_device.cpu() + assert td_device.device == torch.device("cuda:0") + assert td_back.device == torch.device("cpu") + + @pytest.mark.parametrize( + "td_name", ["td", "stacked_td", "saved_td", "unsqueezed_td"] + ) + def test_unbind(self, td_name): + torch.manual_seed(1) + td = getattr(self, td_name) + td_unbind = td.unbind(0) + assert (td == torch.stack(td_unbind, 0)).all() + assert (td[0] == td_unbind[0]).all() + + @pytest.mark.parametrize( + "td_name", ["td", "stacked_td", "sub_td", "idx_td", "saved_td", "unsqueezed_td"] + ) + @pytest.mark.parametrize("squeeze_dim", [0, 1]) + def test_unsqueeze(self, td_name, squeeze_dim): + torch.manual_seed(1) + td = getattr(self, td_name) + td_unsqueeze = td.unsqueeze(squeeze_dim) + tensor = torch.ones_like(td.get("a").unsqueeze(squeeze_dim)) + td_unsqueeze.set("a", tensor) + assert (td_unsqueeze.get("a") == tensor).all() + assert (td.get("a") == tensor.squeeze(squeeze_dim)).all() + assert td_unsqueeze.squeeze(squeeze_dim) is td + assert (td_unsqueeze.get("a") == 1).all() + assert (td.get("a") == 1).all() + + @pytest.mark.parametrize( + "td_name", ["td", "stacked_td", "sub_td", "idx_td", "saved_td", "unsqueezed_td"] + ) + def test_squeeze(self, td_name, squeeze_dim=-1): + torch.manual_seed(1) + td = getattr(self, td_name) + td_squeeze = td.squeeze(-1) + tensor_squeeze_dim = td.batch_dims + squeeze_dim + tensor = torch.ones_like(td.get("a").squeeze(tensor_squeeze_dim)) + td_squeeze.set("a", tensor) + assert (td_squeeze.get("a") == tensor).all() + assert (td.get("a") == tensor.unsqueeze(tensor_squeeze_dim)).all() + assert td_squeeze.unsqueeze(squeeze_dim) is td + assert (td_squeeze.get("a") == 1).all() + assert (td.get("a") == 1).all() + + @pytest.mark.parametrize( + "td_name", ["td", "stacked_td", "sub_td", "idx_td", "saved_td", "unsqueezed_td"] + ) + def test_view(self, td_name): + torch.manual_seed(1) + td = getattr(self, td_name) + td_view = td.view(-1) + tensor = td.get("a") + tensor = tensor.view(-1, tensor.numel() // np.prod(td.batch_size)) + tensor = torch.ones_like(tensor) + td_view.set("a", tensor) + assert (td_view.get("a") == tensor).all() + assert (td.get("a") == tensor.view(td.get("a").shape)).all() + assert td_view.view(td.shape) is td + assert td_view.view(*td.shape) is td + assert (td_view.get("a") == 1).all() + assert (td.get("a") == 1).all() + + @pytest.mark.parametrize( + "td_name", ["td", "stacked_td", "sub_td", "idx_td", "saved_td", "unsqueezed_td"] + ) + def test_clone_td(self, td_name): + torch.manual_seed(1) + td = getattr(self, td_name) + assert (td.clone() == td).all() + assert td.batch_size == td.clone().batch_size + if td_name in ("stacked_td", "saved_td", "unsqueezed_td", "sub_td"): + with pytest.raises(AssertionError): + assert td.clone(recursive=False).get("a") is td.get("a") + else: + assert td.clone(recursive=False).get("a") is td.get("a") + + @pytest.mark.parametrize( + "td_name", ["td", "stacked_td", "sub_td", "idx_td", "saved_td", "unsqueezed_td"] + ) + def test_rename_key(self, td_name): + torch.manual_seed(1) + td = getattr(self, td_name) + with pytest.raises(KeyError, match="already present in TensorDict"): + td.rename_key("a", "b", safe=True) + a = td.get("a") + td.rename_key("a", "z") + with pytest.raises(KeyError): + td.get("a") + assert "a" not in td.keys() + + z = td.get("z") + torch.testing.assert_allclose(a, z) + + new_z = torch.randn_like(z) + td.set("z", new_z) + torch.testing.assert_allclose(new_z, td.get("z")) + + new_z = torch.randn_like(z) + td.set_("z", new_z) + torch.testing.assert_allclose(new_z, td.get("z")) + + @pytest.mark.parametrize( + "td_name", ["td", "stacked_td", "sub_td", "idx_td", "saved_td", "unsqueezed_td"] + ) + def test_set_nontensor(self, td_name): + torch.manual_seed(1) + td = getattr(self, td_name) + r = torch.randn_like(td.get("a")) + td.set("numpy", r.numpy()) + torch.testing.assert_allclose(td.get("numpy"), r) + + @pytest.mark.parametrize( + "td_name", ["td", "stacked_td", "sub_td", "idx_td", "saved_td", "unsqueezed_td"] + ) + @pytest.mark.parametrize("idx", [slice(1), torch.tensor([0]), torch.tensor([0, 1])]) + def test_setitem(self, td_name, idx): + torch.manual_seed(1) + td = getattr(self, td_name) + if isinstance(idx, torch.Tensor) and idx.numel() > 1 and td.shape[0] == 1: + pytest.mark.skip("cannot index tensor with desired index") + return + + td_clone = td[idx].clone().zero_() + td[idx] = td_clone + assert (td[idx].get("a") == 0).all() + + td_clone = torch.cat([td_clone, td_clone], 0) + with pytest.raises(RuntimeError, match="differs from the source batch size"): + td[idx] = td_clone + + @pytest.mark.parametrize( + "td_name", ["td", "stacked_td", "sub_td", "idx_td", "saved_td", "unsqueezed_td"] + ) + @pytest.mark.parametrize("dim", [0, 1]) + @pytest.mark.parametrize("chunks", [1, 2]) + def test_chunk(self, td_name, dim, chunks): + torch.manual_seed(1) + td = getattr(self, td_name) + if len(td.shape) - 1 < dim: + pytest.mark.skip(f"no dim {dim} in td") + return + + chunks = min(td.shape[dim], chunks) + td_chunks = td.chunk(chunks, dim) + assert len(td_chunks) == chunks + assert sum([_td.shape[dim] for _td in td_chunks]) == td.shape[dim] + assert (torch.cat(td_chunks, dim) == td).all() + + +@pytest.mark.parametrize("index0", [None, slice(None)]) +def test_set_sub_key(index0): + # tests that parent tensordict is affected when subtensordict is set with a new key + batch_size = [10, 10] + source = {"a": torch.randn(10, 10, 10), "b": torch.ones(10, 10, 2)} + td = TensorDict(source, batch_size=batch_size) + idx0 = (index0, 0) if index0 is not None else 0 + td0 = td.get_sub_tensor_dict(idx0) + idx = (index0, slice(2, 4)) if index0 is not None else slice(2, 4) + sub_td = td.get_sub_tensor_dict(idx) + if index0 is None: + c = torch.randn(2, 10, 10) + else: + c = torch.randn(10, 2, 10) + sub_td.set("c", c) + assert (td.get("c")[idx] == sub_td.get("c")).all() + assert (sub_td.get("c") == c).all() + assert (td.get("c")[idx0] == 0).all() + assert (td.get_sub_tensor_dict(idx0).get("c") == 0).all() + assert (td0.get("c") == 0).all() + + +def _remote_process(worker_id, command_pipe_child, command_pipe_parent, tensordict): + command_pipe_parent.close() + while True: + cmd, val = command_pipe_child.recv() + if cmd == "recv": + b = tensordict.get("b") + assert (b == val).all() + command_pipe_child.send("done") + elif cmd == "send": + a = torch.ones(2) * val + tensordict.set_("a", a) + assert ( + tensordict.get("a") == a + ).all(), f'found {a} and {tensordict.get("a")}' + command_pipe_child.send("done") + elif cmd == "set_done": + tensordict.set_("done", torch.ones(1, dtype=torch.bool)) + command_pipe_child.send("done") + elif cmd == "set_undone_": + tensordict.set_("done", torch.zeros(1, dtype=torch.bool)) + command_pipe_child.send("done") + elif cmd == "update": + tensordict.update_( + TensorDict( + source={"a": tensordict.get("a").clone() + 1}, + batch_size=tensordict.batch_size, + ) + ) + command_pipe_child.send("done") + elif cmd == "update_": + tensordict.update_( + TensorDict( + source={"a": tensordict.get("a").clone() - 1}, + batch_size=tensordict.batch_size, + ) + ) + command_pipe_child.send("done") + + elif cmd == "close": + command_pipe_child.close() + break + + +def _driver_func(tensordict, tensor_dict_unbind): + procs = [] + children = [] + parents = [] + + for i in range(2): + command_pipe_parent, command_pipe_child = mp.Pipe() + proc = mp.Process( + target=_remote_process, + args=(i, command_pipe_child, command_pipe_parent, tensor_dict_unbind[i]), + ) + proc.start() + command_pipe_child.close() + parents.append(command_pipe_parent) + children.append(command_pipe_child) + procs.append(proc) + + b = torch.ones(2, 1) * 10 + tensordict.set_("b", b) + for i in range(2): + parents[i].send(("recv", 10)) + is_done = parents[i].recv() + assert is_done == "done" + + for i in range(2): + parents[i].send(("send", i)) + is_done = parents[i].recv() + assert is_done == "done" + a = tensordict.get("a").clone() + assert (a[0] == 0).all() + assert (a[1] == 1).all() + + assert not tensordict.get("done").any() + for i in range(2): + parents[i].send(("set_done", i)) + is_done = parents[i].recv() + assert is_done == "done" + assert tensordict.get("done").all() + + for i in range(2): + parents[i].send(("set_undone_", i)) + is_done = parents[i].recv() + assert is_done == "done" + assert not tensordict.get("done").any() + + a_prev = tensordict.get("a").clone().contiguous() + for i in range(2): + parents[i].send(("update_", i)) + is_done = parents[i].recv() + assert is_done == "done" + new_a = tensordict.get("a").clone().contiguous() + torch.testing.assert_allclose(a_prev - 1, new_a) + + a_prev = tensordict.get("a").clone().contiguous() + for i in range(2): + parents[i].send(("update", i)) + is_done = parents[i].recv() + assert is_done == "done" + new_a = tensordict.get("a").clone().contiguous() + torch.testing.assert_allclose(a_prev + 1, new_a) + + for i in range(2): + parents[i].send(("close", None)) + procs[i].join() + + +@pytest.mark.parametrize( + "td_type", ["contiguous", "stack", "saved", "memmap", "memmap_stack"] +) +def test_mp(td_type): + tensordict = TensorDict( + source={ + "a": torch.randn(2, 2), + "b": torch.randn(2, 1), + "done": torch.zeros(2, 1, dtype=torch.bool), + }, + batch_size=[2], + ) + if td_type == "contiguous": + tensordict = tensordict.share_memory_() + elif td_type == "stack": + tensordict = torch.stack( + [ + tensordict[0].clone().share_memory_(), + tensordict[1].clone().share_memory_(), + ], + 0, + ) + elif td_type == "saved": + tensordict = tensordict.clone().to(SavedTensorDict) + elif td_type == "memmap": + tensordict = tensordict.memmap_() + elif td_type == "memmap_stack": + tensordict = torch.stack( + [tensordict[0].clone().memmap_(), tensordict[1].clone().memmap_()], 0 + ) + else: + raise NotImplementedError + _driver_func(tensordict, tensordict.unbind(0)) + + +def test_saved_delete(): + td = TensorDict(source={"a": torch.randn(3)}, batch_size=[]) + td = td.to(SavedTensorDict) + file = td.file.name + assert os.path.isfile(file) + del td + assert not os.path.isfile(file) + + +def test_stack_keys(): + td1 = TensorDict(source={"a": torch.randn(3)}, batch_size=[]) + td2 = TensorDict( + source={ + "a": torch.randn(3), + "b": torch.randn(3), + "c": torch.randn(4), + "d": torch.randn(5), + }, + batch_size=[], + ) + td = torch.stack([td1, td2], 0) + assert "a" in td.keys() + assert "b" not in td.keys() + assert "b" in td[1].keys() + td.set("b", torch.randn(2, 10), inplace=False) # overwrites + with pytest.raises(KeyError): + td.set_("c", torch.randn(2, 10)) # overwrites + td.set_("b", torch.randn(2, 10)) # b has been set before + + td1.set("c", torch.randn(4)) + assert "c" in td.keys() # now all tds have the key c + td.get("c") + + td1.set("d", torch.randn(6)) + with pytest.raises(RuntimeError): + td.get("d") + + +def test_getitem_batch_size(): + shape = [ + 10, + 7, + 11, + 5, + ] + mocking_tensor = torch.zeros(*shape) + for idx in [ + (slice(None),), + slice(None), + (3, 4), + (3, slice(None), slice(2, 2, 2)), + (torch.tensor([1, 2, 3]),), + ([1, 2, 3]), + ( + torch.tensor([1, 2, 3]), + torch.tensor([2, 3, 4]), + torch.tensor([0, 10, 2]), + torch.tensor([2, 4, 1]), + ), + ]: + expected_shape = mocking_tensor[idx].shape + resulting_shape = _getitem_batch_size(shape, idx) + assert expected_shape == resulting_shape, idx + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_transforms.py b/test/test_transforms.py new file mode 100644 index 00000000000..ef75e86fdb9 --- /dev/null +++ b/test/test_transforms.py @@ -0,0 +1,261 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from torch import multiprocessing as mp +from torchrl.agents.env_creator import EnvCreator +from torchrl.data import TensorDict +from torchrl.envs import GymEnv, ParallelEnv +from torchrl.envs.transforms import VecNorm, TransformedEnv + +TIMEOUT = 10.0 + + +def _test_vecnorm_subproc(idx, queue_out: mp.Queue, queue_in: mp.Queue): + td = queue_in.get(timeout=TIMEOUT) + env = GymEnv("Pendulum-v1") + env.set_seed(idx) + t = VecNorm(shared_td=td) + env = TransformedEnv(env, t) + for _ in range(10): + env.rand_step() + queue_out.put(True) + msg = queue_in.get(timeout=TIMEOUT) + assert msg == "all_done" + obs_sum = t._td.get("next_observation_sum").clone() + obs_ssq = t._td.get("next_observation_ssq").clone() + obs_count = t._td.get("next_observation_ssq").clone() + reward_sum = t._td.get("reward_sum").clone() + reward_ssq = t._td.get("reward_ssq").clone() + reward_count = t._td.get("reward_ssq").clone() + + td_out = TensorDict( + { + "obs_sum": obs_sum, + "obs_ssq": obs_ssq, + "obs_count": obs_count, + "reward_sum": reward_sum, + "reward_ssq": reward_ssq, + "reward_count": reward_count, + }, + [], + ).share_memory_() + queue_out.put(td_out) + msg = queue_in.get(timeout=TIMEOUT) + assert msg == "all_done" + + +@pytest.mark.parametrize("nprc", [2, 5]) +def test_vecnorm_parallel(nprc): + queues = [] + prcs = [] + td = VecNorm.build_td_for_shared_vecnorm(GymEnv("Pendulum-v1")) + for idx in range(nprc): + prc_queue_in = mp.Queue(1) + prc_queue_out = mp.Queue(1) + p = mp.Process( + target=_test_vecnorm_subproc, + args=( + idx, + prc_queue_in, + prc_queue_out, + ), + ) + p.start() + prc_queue_out.put(td) + prcs.append(p) + queues.append((prc_queue_in, prc_queue_out)) + + dones = [queue[0].get(timeout=TIMEOUT) for queue in queues] + assert all(dones) + msg = "all_done" + for idx in range(nprc): + queues[idx][1].put(msg) + + obs_sum = td.get("next_observation_sum").clone() + obs_ssq = td.get("next_observation_ssq").clone() + obs_count = td.get("next_observation_ssq").clone() + reward_sum = td.get("reward_sum").clone() + reward_ssq = td.get("reward_ssq").clone() + reward_count = td.get("reward_ssq").clone() + + for idx in range(nprc): + td_out = queues[idx][0].get(timeout=TIMEOUT) + _obs_sum = td_out.get("obs_sum") + _obs_ssq = td_out.get("obs_ssq") + _obs_count = td_out.get("obs_count") + _reward_sum = td_out.get("reward_sum") + _reward_ssq = td_out.get("reward_ssq") + _reward_count = td_out.get("reward_count") + assert (obs_sum == _obs_sum).all() + assert (obs_ssq == _obs_ssq).all() + assert (obs_count == _obs_count).all() + assert (reward_sum == _reward_sum).all() + assert (reward_ssq == _reward_ssq).all() + assert (reward_count == _reward_count).all() + + obs_sum, obs_ssq, obs_count, reward_sum, reward_ssq, reward_count = ( + _obs_sum, + _obs_ssq, + _obs_count, + _reward_sum, + _reward_ssq, + _reward_count, + ) + + msg = "all_done" + for idx in range(nprc): + queues[idx][1].put(msg) + + +def _test_vecnorm_subproc_auto(idx, make_env, queue_out: mp.Queue, queue_in: mp.Queue): + env = make_env() + env.set_seed(idx) + for _ in range(10): + env.rand_step() + queue_out.put(True) + msg = queue_in.get(timeout=TIMEOUT) + assert msg == "all_done" + t = env.transform + obs_sum = t._td.get("next_observation_sum").clone() + obs_ssq = t._td.get("next_observation_ssq").clone() + obs_count = t._td.get("next_observation_ssq").clone() + reward_sum = t._td.get("reward_sum").clone() + reward_ssq = t._td.get("reward_ssq").clone() + reward_count = t._td.get("reward_ssq").clone() + + queue_out.put((obs_sum, obs_ssq, obs_count, reward_sum, reward_ssq, reward_count)) + msg = queue_in.get(timeout=TIMEOUT) + assert msg == "all_done" + + +@pytest.mark.parametrize("nprc", [2, 5]) +def test_vecnorm_parallel_auto(nprc): + queues = [] + prcs = [] + make_env = EnvCreator(lambda: TransformedEnv(GymEnv("Pendulum-v1"), VecNorm())) + for idx in range(nprc): + prc_queue_in = mp.Queue(1) + prc_queue_out = mp.Queue(1) + p = mp.Process( + target=_test_vecnorm_subproc_auto, + args=( + idx, + make_env, + prc_queue_in, + prc_queue_out, + ), + ) + p.start() + prcs.append(p) + queues.append((prc_queue_in, prc_queue_out)) + + td = list(make_env.state_dict().values())[0] + dones = [queue[0].get() for queue in queues] + assert all(dones) + msg = "all_done" + for idx in range(nprc): + queues[idx][1].put(msg) + + obs_sum = td.get("next_observation_sum").clone() + obs_ssq = td.get("next_observation_ssq").clone() + obs_count = td.get("next_observation_ssq").clone() + reward_sum = td.get("reward_sum").clone() + reward_ssq = td.get("reward_ssq").clone() + reward_count = td.get("reward_ssq").clone() + + for idx in range(nprc): + tup = queues[idx][0].get(timeout=TIMEOUT) + _obs_sum, _obs_ssq, _obs_count, _reward_sum, _reward_ssq, _reward_count = tup + assert (obs_sum == _obs_sum).all(), (_obs_sum, obs_sum) + assert (obs_ssq == _obs_ssq).all() + assert (obs_count == _obs_count).all() + assert (reward_sum == _reward_sum).all() + assert (reward_ssq == _reward_ssq).all() + assert (reward_count == _reward_count).all() + + obs_sum, obs_ssq, obs_count, reward_sum, reward_ssq, reward_count = ( + _obs_sum, + _obs_ssq, + _obs_count, + _reward_sum, + _reward_ssq, + _reward_count, + ) + msg = "all_done" + for idx in range(nprc): + queues[idx][1].put(msg) + + +def _run_parallelenv(parallel_env, queue_in, queue_out): + parallel_env.reset() + + msg = queue_in.get(timeout=TIMEOUT) + assert msg == "start" + for _ in range(10): + parallel_env.rand_step() + queue_out.put("first round") + msg = queue_in.get(timeout=TIMEOUT) + assert msg == "start" + for _ in range(10): + parallel_env.rand_step() + queue_out.put("second round") + del parallel_env + + +def test_parallelenv_vecnorm(): + make_env = EnvCreator(lambda: TransformedEnv(GymEnv("Pendulum-v1"), VecNorm())) + parallel_env = ParallelEnv(3, make_env) + queue_out = mp.Queue(1) + queue_in = mp.Queue(1) + proc = mp.Process(target=_run_parallelenv, args=(parallel_env, queue_out, queue_in)) + proc.start() + parallel_sd = parallel_env.state_dict() + assert "worker0" in parallel_sd + worker_sd = parallel_sd["worker0"] + td = list(worker_sd.values())[0] + queue_out.put("start") + msg = queue_in.get(timeout=TIMEOUT) + assert msg == "first round" + values = td.clone() + queue_out.put("start") + msg = queue_in.get(timeout=TIMEOUT) + assert msg == "second round" + new_values = td.clone() + for k, item in values.items(): + assert (item != new_values.get(k)).any(), k + proc.join() + + +@pytest.mark.parametrize("parallel", [False, True]) +def test_vecnorm(parallel, thr=0.2, N=200): # 10000): + torch.manual_seed(0) + + if parallel: + env = ParallelEnv(num_workers=5, create_env_fn=lambda: GymEnv("Pendulum-v1")) + else: + env = GymEnv("Pendulum-v1") + env.set_seed(0) + t = VecNorm() + env = TransformedEnv(env, t) + env.reset() + tds = [] + for _ in range(N): + td = env.rand_step() + if td.get("done").any(): + env.reset() + tds.append(td) + tds = torch.stack(tds, 0) + obs = tds.get("next_observation") + obs = obs.view(-1, obs.shape[-1]) + mean = obs.mean(0) + assert (abs(mean) < thr).all() + std = obs.std(0) + assert (abs(std - 1) < thr).all() + + +if __name__ == "__main__": + pytest.main([__file__, "--capture", "no"]) diff --git a/torchrl/__init__.py b/torchrl/__init__.py new file mode 100644 index 00000000000..842ab35a82c --- /dev/null +++ b/torchrl/__init__.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import abc +import time +from warnings import warn + +from torch import multiprocessing as mp + +from ._extension import _init_extension + +__version__ = "0.1" + +_init_extension() + +# if not HAS_OPS: +# print("could not load C++ libraries") + +try: + mp.set_start_method("spawn") +except RuntimeError as err: + if str(err).startswith("context has already been set"): + mp_start_method = mp.get_start_method() + if mp_start_method != "spawn": + warn( + f"failed to set start method to spawn, " + f"and current start method for mp is {mp_start_method}." + ) + + +class timeit: + """ + A dirty but easy to use decorator for profiling code + """ + + _REG = {} + + def __init__(self, name): + self.name = name + + def __call__(self, fn): + def decorated_fn(*args, **kwargs): + with self: + out = fn(*args, **kwargs) + return out + + return decorated_fn + + def __enter__(self): + self.t0 = time.time() + + def __exit__(self, exc_type, exc_val, exc_tb): + t = time.time() - self.t0 + self._REG.setdefault(self.name, [0.0, 0]) + + count = self._REG[self.name][1] + self._REG[self.name][0] = (self._REG[self.name][0] * count + t) / (count + 1) + self._REG[self.name][1] = count + 1 + + @staticmethod + def print(): + keys = list(timeit._REG) + keys.sort() + for name in keys: + print(f"{name} took {timeit._REG[name][0] * 1000:4.4} msec") diff --git a/torchrl/_extension.py b/torchrl/_extension.py new file mode 100644 index 00000000000..0047a23b6b1 --- /dev/null +++ b/torchrl/_extension.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +import warnings + + +def is_module_available(*modules: str) -> bool: + r"""Returns if a top-level module with :attr:`name` exists *without** + importing it. + + This is generally safer than try-catch block around a + `import X`. It avoids third party libraries breaking assumptions of some of + our tests, e.g., setting multiprocessing start method when imported + (see librosa/#747, torchvision/#544). + """ + return all(importlib.util.find_spec(m) is not None for m in modules) + + +def _init_extension(): + if not is_module_available("torchrl._torchrl"): + warnings.warn("torchrl C++ extension is not available.") + return diff --git a/torchrl/agents/__init__.py b/torchrl/agents/__init__.py new file mode 100644 index 00000000000..3a260afa4b1 --- /dev/null +++ b/torchrl/agents/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .agents import * +from .env_creator import * diff --git a/torchrl/agents/agents.py b/torchrl/agents/agents.py new file mode 100644 index 00000000000..6a412880664 --- /dev/null +++ b/torchrl/agents/agents.py @@ -0,0 +1,591 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import pathlib +import warnings +from collections import OrderedDict +from textwrap import indent +from typing import Callable, Dict, Optional, Union, Sequence + +import numpy as np +import torch.nn +from torch import nn, optim + +try: + from tqdm import tqdm + + _has_tqdm = True +except ImportError: + _has_tqdm = False + +from torchrl.collectors.collectors import _DataCollector +from torchrl.data import ( + ReplayBuffer, + TensorDictPrioritizedReplayBuffer, + TensorDictReplayBuffer, +) +from torchrl.data.tensordict.tensordict import _TensorDict +from torchrl.data.utils import expand_right +from torchrl.envs.common import _EnvClass +from torchrl.envs.transforms import TransformedEnv +from torchrl.envs.utils import set_exploration_mode +from torchrl.modules import reset_noise, TDModuleWrapper +from torchrl.objectives.costs.common import _LossModule +from torchrl.objectives.costs.utils import _TargetNetUpdate + +REPLAY_BUFFER_CLASS = { + "prioritized": TensorDictPrioritizedReplayBuffer, + "circular": TensorDictReplayBuffer, +} + +WRITER_METHODS = { + "grad_norm": "add_scalar", + "loss": "add_scalar", +} + +__all__ = ["Agent"] + + +class Agent: + """A generic Agent class. + + An agent is responsible of collecting data and training the model. + To keep the class as versatile as possible, Agent does not construct any + of its components: they all must be provided as argument when + initializing the object. + To build an Agent, one needs a iterable data source (a `collector`), a + loss module, an optimizer. Optionally, a recorder (i.e. an environment + instance used for testing purposes) and a policy can be provided for + evaluating the training progress. + + Args: + collector (Sequence[_TensorDict]): An iterable returning batches of + data in a TensorDict form of shape [batch x time steps]. + total_frames (int): Total number of frames to be collected during + training. + loss_module (_LossModule): A module that reads TensorDict batches + (possibly sampled from a replay buffer) and return a loss + TensorDict where every key points to a different loss component. + optimizer (optim.Optimizer): An optimizer that trains the parameters + of the model. + recorder (_EnvClass, optional): An environment instance to be used + for testing. + optim_scheduler (optim.lr_scheduler._LRScheduler, optional): + learning rate scheduler. + target_net_updater (_TargetNetUpdate, optional): + a target network updater. + policy_exploration (ProbabilisticTDModule, optional): a policy + instance used for + + (1) updating the exploration noise schedule; + + (2) testing the policy on the recorder. + + Given that this instance is supposed to both explore and render + the performance of the policy, it should be possible to turn off + the explorative behaviour by calling the + `set_exploration_mode('mode')` context manager. + replay_buffer (ReplayBuffer, optional): a replay buffer for offline + learning. + writer (SummaryWriter, optional): a Tensorboard summary writer for + logging purposes. + update_weights_interval (int, optional): interval between two updates + of the weights of a model living on another device. By default, + the weights will be updated after every collection of data. + record_interval (int, optional): total number of optimisation steps + between two calls to the recorder for testing. Default is 10000. + record_frames (int, optional): number of frames to be recorded during + testing. Default is 1000. + frame_skip (int, optional): frame_skip used in the environment. It is + important to let the agent know the number of frames skipped at + each iteration, otherwise the frame count can be underestimated. + For logging, this parameter is important to normalize the reward. + Finally, to compare different runs with different frame_skip, + one must normalize the frame count and rewards. Default is 1. + optim_steps_per_batch (int, optional): number of optimization steps + per collection of data. An agent works as follows: a main loop + collects batches of data (epoch loop), and a sub-loop (training + loop) performs model updates in between two collections of data. + Default is 500 + batch_size (int, optional): batch size when sampling data from the + latest collection or from the replay buffer, if it is present. + If no replay buffer is present, the sub-sampling will be + achieved over the latest collection with a resulting batch of + size (batch_size x sub_traj_len). + Default is 256 + clip_grad_norm (bool, optional): If True, the gradients will be clipped + based on the total norm of the model parameters. If False, + all the partial derivatives will be clamped to + (-clip_norm, clip_norm). Default is `True`. + clip_norm (Number, optional): value to be used for clipping gradients. + Default is 100.0. + progress_bar (bool, optional): If True, a progress bar will be + displayed using tqdm. If tqdm is not installed, this option + won't have any effect. Default is `True` + seed (int, optional): Seed to be used for the collector, pytorch and + numpy. Default is 42. + save_agent_interval (int, optional): How often the agent should be + saved to disk. Default is 10000. + save_agent_file (path, optional): path where to save the agent. + Default is None (no saving) + normalize_rewards_online (bool, optional): if True, the running + statistics of the rewards are computed and the rewards used for + training will be normalized based on these. + Default is `False` + sub_traj_len (int, optional): length of the trajectories that + sub-samples must have in online settings. Default is -1 (i.e. + takes the full length of the trajectory) + min_sub_traj_len (int, optional): minimum value of `sub_traj_len`, in + case some elements of the batch contain few steps. + Default is -1 (i.e. no minimum value) + selected_keys (iterable of str, optional): a list of strings that + indicate the data that should be kept from the data collector. + Since storing and retrieving information from the replay buffer + does not come for free, limiting the amount of data passed to + it can improve the algorithm performance. Default is None, + i.e. all keys are kept. + + """ + + # trackers + _optim_count: int = 0 + _collected_frames: int = 0 + _last_log: dict = {} + _last_save: int = 0 + _log_interval: int = 10000 + _reward_stats: dict = {"decay": 0.999} + + def __init__( + self, + collector: _DataCollector, + total_frames: int, + loss_module: Union[_LossModule, Callable[[_TensorDict], _TensorDict]], + optimizer: optim.Optimizer, + recorder: Optional[_EnvClass] = None, + optim_scheduler: Optional[optim.lr_scheduler._LRScheduler] = None, + target_net_updater: Optional[_TargetNetUpdate] = None, + policy_exploration: Optional[TDModuleWrapper] = None, + replay_buffer: Optional[ReplayBuffer] = None, + writer: Optional["SummaryWriter"] = None, + update_weights_interval: int = -1, + record_interval: int = 10000, + record_frames: int = 1000, + frame_skip: int = 1, + optim_steps_per_batch: int = 500, + batch_size: int = 256, + clip_grad_norm: bool = True, + clip_norm: float = 100.0, + progress_bar: bool = True, + seed: int = 42, + save_agent_interval: int = 10000, + save_agent_file: Optional[Union[str, pathlib.Path]] = None, + normalize_rewards_online: bool = False, + sub_traj_len: int = -1, + min_sub_traj_len: int = -1, + selected_keys: Optional[Sequence[str]] = None, + ) -> None: + + # objects + self.collector = collector + self.loss_module = loss_module + self.recorder = recorder + self.optimizer = optimizer + self.optim_scheduler = optim_scheduler + self.replay_buffer = replay_buffer + self.policy_exploration = policy_exploration + self.target_net_updater = target_net_updater + self.writer = writer + self._params = [] + for p in self.optimizer.param_groups: + self._params += p["params"] + + # seeding + self.seed = seed + self.set_seed() + + # constants + self.update_weights_interval = update_weights_interval + self.optim_steps_per_batch = optim_steps_per_batch + self.batch_size = batch_size + self.total_frames = total_frames + self.frame_skip = frame_skip + self.clip_grad_norm = clip_grad_norm + self.clip_norm = clip_norm + if progress_bar and not _has_tqdm: + warnings.warn( + "tqdm library not found. Consider installing tqdm to use the Agent progress bar." + ) + self.progress_bar = progress_bar and _has_tqdm + self.record_interval = record_interval + self.record_frames = record_frames + self.save_agent_interval = save_agent_interval + self.save_agent_file = save_agent_file + self.normalize_rewards_online = normalize_rewards_online + self.sub_traj_len = sub_traj_len + self.min_sub_traj_len = min_sub_traj_len + self.selected_keys = selected_keys + + def save_agent(self) -> None: + _save = False + if self.save_agent_file is not None: + if (self._collected_frames - self._last_save) > self.save_agent_interval: + self._last_save = self._collected_frames + _save = True + if _save: + torch.save(self.state_dict(), self.save_agent_file) + + def load_from_file(self, file: Union[str, pathlib.Path]) -> Agent: + loaded_dict: OrderedDict = torch.load(file) + + # checks that keys match + expected_keys = { + "env", + "loss_module", + "_collected_frames", + "_last_log", + "_last_save", + "_optim_count", + } + actual_keys = set(loaded_dict.keys()) + if len(actual_keys.difference(expected_keys)) or len( + expected_keys.difference(actual_keys) + ): + raise RuntimeError( + f"Expected keys {expected_keys} in the loaded file but got" + f" {actual_keys}" + ) + self.collector.load_state_dict(loaded_dict["env"]) + self.model.load_state_dict(loaded_dict["model"]) + for key in [ + "_collected_frames", + "_last_log", + "_last_save", + "_optim_count", + ]: + setattr(self, key, loaded_dict[key]) + return self + + def set_seed(self): + seed = self.collector.set_seed(self.seed) + torch.manual_seed(seed) + np.random.seed(seed) + + def state_dict(self) -> Dict: + state_dict = OrderedDict( + env=self.collector.state_dict(), + loss_module=self.loss_module.state_dict(), + _collected_frames=self._collected_frames, + _last_log=self._last_log, + _last_save=self._last_save, + _optim_count=self._optim_count, + ) + return state_dict + + def load_state_dict(self, state_dict: Dict) -> None: + model_state_dict = state_dict["loss_module"] + env_state_dict = state_dict["env"] + self.loss_module.load_state_dict(model_state_dict) + self.collector.load_state_dict(env_state_dict) + + @property + def collector(self) -> _DataCollector: + return self._collector + + @collector.setter + def collector(self, collector: _DataCollector) -> None: + self._collector = collector + + def train(self): + if self.progress_bar: + self._pbar = tqdm(total=self.total_frames) + self._pbar_str = OrderedDict() + + collected_frames = 0 + for i, batch in enumerate(self.collector): + if self.selected_keys: + batch = batch.select(*self.selected_keys, "mask") + + if "mask" in batch.keys(): + current_frames = batch.get("mask").sum().item() * self.frame_skip + else: + current_frames = batch.numel() * self.frame_skip + collected_frames += current_frames + self._collected_frames = collected_frames + + if self.replay_buffer is not None: + if "mask" in batch.keys(): + batch = batch[batch.get("mask").squeeze(-1)] + else: + batch = batch.reshape(-1) + reward_training = batch.get("reward").mean().item() + batch = batch.cpu() + self.replay_buffer.extend(batch) + else: + if "mask" in batch.keys(): + reward_training = batch.get("reward") + mask = batch.get("mask").squeeze(-1) + reward_training = reward_training[mask].mean().item() + else: + reward_training = batch.get("reward").mean().item() + + if self.normalize_rewards_online: + reward = batch.get("reward") + self._update_reward_stats(reward) + + if collected_frames > self.collector.init_random_frames: + self.steps(batch) + self._collector_scheduler_step(i, current_frames) + + self._log(reward_training=reward_training) + if self.progress_bar: + self._pbar.update(current_frames) + self._pbar_description() + + if collected_frames > self.total_frames: + break + + self.collector.shutdown() + + @torch.no_grad() + def _update_reward_stats(self, reward: torch.Tensor) -> None: + decay = self._reward_stats.get("decay", 0.999) + sum = self._reward_stats["sum"] = ( + decay * self._reward_stats.get("sum", 0.0) + reward.sum() + ) + ssq = self._reward_stats["ssq"] = ( + decay * self._reward_stats.get("ssq", 0.0) + reward.pow(2).sum() + ) + count = self._reward_stats["count"] = ( + decay * self._reward_stats.get("count", 0.0) + reward.numel() + ) + + mean = self._reward_stats["mean"] = sum / count + var = self._reward_stats["var"] = ssq / count - mean.pow(2) + self._reward_stats["std"] = var.clamp_min(1e-6).sqrt() + + def _normalize_reward(self, tensordict: _TensorDict) -> None: + reward = tensordict.get("reward") + reward = reward - self._reward_stats["mean"] + reward = reward / self._reward_stats["std"] + tensordict.set_("reward", reward) + + def _collector_scheduler_step(self, step: int, current_frames: int): + """Runs entropy annealing steps for exploration, policy weights update + across workers etc. + + """ + + if self.policy_exploration is not None and hasattr( + self.policy_exploration, "step" + ): + self.policy_exploration.step(current_frames) + + if step % self.update_weights_interval == 0: + self.collector.update_policy_weights_() + + def steps(self, batch: _TensorDict) -> None: + average_grad_norm = 0.0 + average_losses = None + + self.loss_module.apply(reset_noise) # TODO: group in loss_module.reset? + self.loss_module.reset() + + for j in range(self.optim_steps_per_batch): + self._optim_count += 1 + if self.replay_buffer is not None: + sub_batch = self.replay_buffer.sample(self.batch_size) + else: + sub_batch = self._sub_sample_batch(batch) + + if self.normalize_rewards_online: + self._normalize_reward(sub_batch) + + sub_batch_device = sub_batch.to(self.loss_module.device) + losses_td = self.loss_module(sub_batch_device) + if isinstance(self.replay_buffer, TensorDictPrioritizedReplayBuffer): + self.replay_buffer.update_priority(sub_batch_device) + + # sum all keys that start with 'loss_' + loss = sum( + [item for key, item in losses_td.items() if key.startswith("loss")] + ) + loss.backward() + if average_losses is None: + average_losses: _TensorDict = losses_td.detach() + else: + for key, item in losses_td.items(): + val = average_losses.get(key) + average_losses.set(key, val * j / (j + 1) + item / (j + 1)) + + grad_norm = self._grad_clip() + average_grad_norm = average_grad_norm * j / (j + 1) + grad_norm / (j + 1) + self.optimizer.step() + self.optimizer.zero_grad() + + self._optim_schedule_step() + + if self._optim_count % self.record_interval == 0: + self.record() + + if self.optim_steps_per_batch > 0: + self._log( + grad_norm=average_grad_norm, + optim_steps=self._optim_count, + **average_losses, + ) + + def _optim_schedule_step(self) -> None: + """Runs scheduler steps, target network update steps etc. + Returns: + """ + if self.optim_scheduler is not None: + self.optim_scheduler.step() + if self.target_net_updater is not None: + self.target_net_updater.step() + + def _sub_sample_batch(self, batch: _TensorDict) -> _TensorDict: + """Sub-sampled part of a batch randomly. + + If the batch has one dimension, a random subsample of length + self.bach_size will be returned. If the batch has two or more + dimensions, it is assumed that the first dimension represents the + batch, and the second the time. If so, the resulting subsample will + contain consecutive samples across time. + """ + + if batch.ndimension() == 1: + return batch[torch.randperm(batch.shape[0])[: self.batch_size]] + + sub_traj_len = self.sub_traj_len if self.sub_traj_len > 0 else batch.shape[1] + if "mask" in batch.keys(): + # if a valid mask is present, it's important to sample only + # valid steps + traj_len = batch.get("mask").sum(1).squeeze() + sub_traj_len = max( + self.min_sub_traj_len, + min(sub_traj_len, traj_len.min().int().item()), + ) + else: + traj_len = ( + torch.ones(batch.shape[0], device=batch.device, dtype=torch.bool) + * batch.shape[1] + ) + len_mask = traj_len >= sub_traj_len + valid_trajectories = torch.arange(batch.shape[0])[len_mask] + + batch_size = self.batch_size // sub_traj_len + traj_idx = valid_trajectories[ + torch.randint( + valid_trajectories.numel(), (batch_size,), device=batch.device + ) + ] + + if sub_traj_len < batch.shape[1]: + _traj_len = traj_len[traj_idx] + seq_idx = ( + torch.rand_like(_traj_len, dtype=torch.float) + * (_traj_len - sub_traj_len) + ).int() + seq_idx = seq_idx.unsqueeze(-1).expand(-1, sub_traj_len) + elif sub_traj_len == batch.shape[1]: + seq_idx = torch.zeros( + batch_size, sub_traj_len, device=batch.device, dtype=torch.long + ) + else: + raise ValueError( + f"sub_traj_len={sub_traj_len} is not allowed. Accepted values " + f"are in the range [1, {batch.shape[1]}]." + ) + + seq_idx = seq_idx + torch.arange(sub_traj_len, device=seq_idx.device) + td = batch[traj_idx].clone() + td = td.apply( + lambda t: t.gather( + dim=1, + index=expand_right(seq_idx, (batch_size, sub_traj_len, *t.shape[2:])), + ), + batch_size=(batch_size, sub_traj_len), + ) + if "mask" in batch.keys() and not td.get("mask").all(): + raise RuntimeError("Sampled invalid steps") + return td + + def _grad_clip(self) -> float: + if self.clip_grad_norm: + gn = nn.utils.clip_grad_norm_(self._params, self.clip_norm) + else: + gn = sum([p.grad.pow(2).sum() for p in self._params]).sqrt() + nn.utils.clip_grad_value_(self._params, self.clip_norm) + return float(gn) + + def _log(self, **kwargs) -> None: + collected_frames = self._collected_frames + for key, item in kwargs.items(): + if (collected_frames - self._last_log.get(key, 0)) > self._log_interval: + self._last_log[key] = collected_frames + _log = True + else: + _log = False + method = WRITER_METHODS.get(key, "add_scalar") + if _log and self.writer is not None: + getattr(self.writer, method)(key, item, global_step=collected_frames) + if method == "add_scalar" and self.progress_bar: + self._pbar_str[key] = float(item) + + def _pbar_description(self) -> None: + if self.progress_bar: + self._pbar.set_description( + ", ".join( + [ + f"{key}: {float(item):4.4f}" + for key, item in self._pbar_str.items() + ] + ) + ) + + @torch.no_grad() + @set_exploration_mode("mode") + def record(self) -> None: + if self.recorder is not None: + self.policy_exploration.eval() + self.recorder.eval() + if isinstance(self.recorder, TransformedEnv): + self.recorder.transform.eval() + td_record = self.recorder.rollout( + policy=self.policy_exploration, + n_steps=self.record_frames, + ) + self.policy_exploration.train() + self.recorder.train() + reward = td_record.get("reward").mean() / self.frame_skip + self._log(reward_evaluation=reward) + self.recorder.transform.dump() + + def __repr__(self) -> str: + loss_str = indent(f"loss={self.loss_module}", 4 * " ") + policy_str = indent(f"policy_exploration={self.policy_exploration}", 4 * " ") + collector_str = indent(f"collector={self.collector}", 4 * " ") + buffer_str = indent(f"buffer={self.replay_buffer}", 4 * " ") + optimizer_str = indent(f"optimizer={self.optimizer}", 4 * " ") + target_net_updater = indent( + f"target_net_updater={self.target_net_updater}", 4 * " " + ) + writer = indent(f"writer={self.writer}", 4 * " ") + + string = "\n".join( + [ + loss_str, + policy_str, + collector_str, + buffer_str, + optimizer_str, + target_net_updater, + writer, + ] + ) + string = f"Agent(\n{string})" + return string diff --git a/torchrl/agents/env_creator.py b/torchrl/agents/env_creator.py new file mode 100644 index 00000000000..75c50fc1af0 --- /dev/null +++ b/torchrl/agents/env_creator.py @@ -0,0 +1,150 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from collections import OrderedDict +from typing import Callable, Dict, Optional + +import torch + +from torchrl.data.tensordict.tensordict import _TensorDict +from torchrl.data.utils import CloudpickleWrapper +from torchrl.envs.common import _EnvClass + +__all__ = ["EnvCreator"] + + +class EnvCreator: + """Environment creator class. + + EnvCreator is a generic environment creator class that can substitute + lambda functions when creating environments in multiprocessing contexts. + If the environment created on a subprocess must share information with the + main process (e.g. for the VecNorm transform), EnvCreator will pass the + pointers to the tensordicts in shared memory to each process such that + all of them are synchronised. + + Args: + create_env_fn (callable): a callable that returns an _EnvClass + instance. + create_env_kwargs (dict, optional): the kwargs of the env creator. + share_memory (bool, optional): if False, the resulting tensordict + from the environment won't be placed in shared memory. + + Examples: + >>> # We create the same environment on 2 processes using VecNorm + >>> # and check that the discounted count of observations match on + >>> # both workers, even if one has not executed any step + >>> import time + >>> from torchrl.envs import GymEnv + >>> from torchrl.data import VecNorm, TransformedEnv + >>> from torchrl.agents import EnvCreator + >>> from torch import multiprocessing as mp + >>> env_fn = lambda: TransformedEnv(GymEnv("Pendulum-v1"), VecNorm()) + >>> env_creator = EnvCreator(env_fn) + >>> + >>> def test_env1(env_creator): + >>> env = env_creator() + >>> for _ in range(10): + >>> env.rand_step() + >>> if env.is_done: + >>> env.reset() + >>> print("env 1: ", env.transform._td.get("next_observation_count")) + >>> + >>> def test_env2(env_creator): + >>> env = env_creator() + >>> time.sleep(5) + >>> print("env 2: ", env.transform._td.get("next_observation_count")) + >>> + >>> if __name__ == "__main__": + >>> ps = [] + >>> p1 = mp.Process(target=test_env1, args=(env_creator,)) + >>> p1.start() + >>> ps.append(p1) + >>> p2 = mp.Process(target=test_env2, args=(env_creator,)) + >>> p2.start() + >>> ps.append(p1) + >>> for p in ps: + >>> p.join() + env 1: tensor([11.9934]) + env 2: tensor([11.9934]) + """ + + def __init__( + self, + create_env_fn: Callable[..., _EnvClass], + create_env_kwargs: Optional[Dict] = None, + share_memory: bool = True, + ) -> None: + if not isinstance(create_env_fn, EnvCreator): + self.create_env_fn = CloudpickleWrapper(create_env_fn) + else: + self.create_env_fn = create_env_fn + + self.create_env_kwargs = ( + create_env_kwargs if isinstance(create_env_kwargs, dict) else dict() + ) + self.initialized = False + self._share_memory = share_memory + self.init_() + + def share_memory(self, state_dict: OrderedDict) -> None: + for key, item in list(state_dict.items()): + if isinstance(item, (_TensorDict,)): + if not item.is_shared(): + print(f"{self.env_type}: sharing mem of {item}") + item.share_memory_() + else: + print( + f"{self.env_type}: {item} is already shared" + ) # , deleting key') + del state_dict[key] + elif isinstance(item, OrderedDict): + self.share_memory(item) + elif isinstance(item, torch.Tensor): + del state_dict[key] + + def init_(self) -> EnvCreator: + shadow_env = self.create_env_fn(**self.create_env_kwargs) + shadow_env.reset() + shadow_env.rand_step() + self.env_type = type(shadow_env) + self._transform_state_dict = shadow_env.state_dict() + if self._share_memory: + self.share_memory(self._transform_state_dict) + self.initialized = True + return self + + def __call__(self) -> _EnvClass: + if not self.initialized: + raise RuntimeError("EnvCreator must be initialized before being called.") + env = self.create_env_fn(**self.create_env_kwargs) + env.load_state_dict(self._transform_state_dict, strict=False) + return env + + def state_dict(self, destination: Optional[OrderedDict] = None) -> OrderedDict: + if self._transform_state_dict is None: + return destination if destination is not None else OrderedDict() + if destination is not None: + destination.update(self._transform_state_dict) + return destination + return self._transform_state_dict + + def load_state_dict(self, state_dict: OrderedDict) -> None: + if self._transform_state_dict is not None: + for key, item in state_dict.items(): + item_to_update = self._transform_state_dict[key] + item_to_update.copy_(item) + + def __repr__(self) -> str: + substr = ", ".join( + [f"{key}: {type(item)}" for key, item in self.create_env_kwargs] + ) + return f"EnvCreator({self.create_env_fn}({substr}))" + + +def env_creator(fun: Callable) -> EnvCreator: + return EnvCreator(fun) diff --git a/torchrl/agents/helpers/__init__.py b/torchrl/agents/helpers/__init__.py new file mode 100644 index 00000000000..82572feebb1 --- /dev/null +++ b/torchrl/agents/helpers/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .agents import * +from .collectors import * +from .envs import * +from .losses import * +from .models import * +from .recorder import * +from .replay_buffer import * diff --git a/torchrl/agents/helpers/agents.py b/torchrl/agents/helpers/agents.py new file mode 100644 index 00000000000..a832cb4fbdd --- /dev/null +++ b/torchrl/agents/helpers/agents.py @@ -0,0 +1,230 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +from argparse import ArgumentParser, Namespace +from typing import Optional, Union +from warnings import warn + +from torch import optim + +from torchrl.agents.agents import Agent +from torchrl.collectors.collectors import _DataCollector +from torchrl.data import ReplayBuffer +from torchrl.envs.common import _EnvClass +from torchrl.modules import TDModule, TDModuleWrapper +from torchrl.objectives.costs.common import _LossModule +from torchrl.objectives.costs.utils import _TargetNetUpdate + +OPTIMIZERS = { + "adam": optim.Adam, + "sgd": optim.SGD, + "adamax": optim.Adamax, +} + +__all__ = [ + "make_agent", + "parser_agent_args", +] + + +def make_agent( + collector: _DataCollector, + loss_module: _LossModule, + recorder: Optional[_EnvClass] = None, + target_net_updater: Optional[_TargetNetUpdate] = None, + policy_exploration: Optional[Union[TDModuleWrapper, TDModule]] = None, + replay_buffer: Optional[ReplayBuffer] = None, + writer: Optional["SummaryWriter"] = None, + args: Optional[Namespace] = None, +) -> Agent: + """Creates an Agent instance given its constituents. + + Args: + collector (_DataCollector): A data collector to be used to collect data. + loss_module (_LossModule): A TorchRL loss module + recorder (_EnvClass, optional): a recorder environment. If None, the agent will train the policy without + testing it. + target_net_updater (_TargetNetUpdate, optional): A target network update object. + policy_exploration (TDModule or TDModuleWrapper, optional): a policy to be used for recording and exploration + updates (should be synced with the learnt policy). + replay_buffer (ReplayBuffer, optional): a replay buffer to be used to collect data. + writer (SummaryWriter, optional): a tensorboard SummaryWriter to be used for logging. + args (argparse.Namespace, optional): a Namespace containing the arguments of the script. If None, the default + arguments are used. + + Returns: + An agent built with the input objects. The optimizer is built by this helper function using the args provided. + + Examples: + >>> import torch + >>> import tempfile + >>> from torch.utils.tensorboard import SummaryWriter + >>> from torchrl.agents import Agent, EnvCreator + >>> from torchrl.collectors.collectors import SyncDataCollector + >>> from torchrl.data import TensorDictReplayBuffer + >>> from torchrl.envs import GymEnv + >>> from torchrl.modules import TDModuleWrapper, TDModule, ValueOperator, EGreedyWrapper + >>> from torchrl.objectives.costs.common import _LossModule + >>> from torchrl.objectives.costs.utils import _TargetNetUpdate + >>> from torchrl.objectives import DDPGLoss + >>> env_maker = EnvCreator(lambda: GymEnv("Pendulum-v0")) + >>> env_proof = env_maker() + >>> obs_spec = env_proof.observation_spec + >>> action_spec = env_proof.action_spec + >>> net = torch.nn.Linear(env_proof.observation_spec.shape[-1], action_spec.shape[-1]) + >>> net_value = torch.nn.Linear(env_proof.observation_spec.shape[-1], 1) # for the purpose of testing + >>> policy = TDModule(action_spec, net, in_keys=["observation"], out_keys=["action"]) + >>> value = ValueOperator(net_value, in_keys=["observation"], out_keys=["state_action_value"]) + >>> collector = SyncDataCollector(env_maker, policy, total_frames=100) + >>> loss_module = DDPGLoss(policy, value, gamma=0.99) + >>> recorder = env_proof + >>> target_net_updater = None + >>> policy_exploration = EGreedyWrapper(policy) + >>> replay_buffer = TensorDictReplayBuffer(1000) + >>> dir = tempfile.gettempdir() + >>> writer = SummaryWriter(log_dir=dir) + >>> agent = make_agent(collector, loss_module, recorder, target_net_updater, policy_exploration, + ... replay_buffer, writer) + >>> print(agent) + + """ + if args is None: + warn( + "Getting default args for the agent. This should be only used for debugging." + ) + parser = parser_agent_args(argparse.ArgumentParser()) + parser.add_argument("--frame_skip", default=1) + parser.add_argument("--total_frames", default=1000) + parser.add_argument("--record_frames", default=10) + parser.add_argument("--record_interval", default=10) + args = parser.parse_args([]) + + optimizer = OPTIMIZERS[args.optimizer]( + loss_module.parameters(), lr=args.lr, weight_decay=args.weight_decay + ) + optim_scheduler = None + + print( + f"collector = {collector}; \n" + f"loss_module = {loss_module}; \n" + f"recorder = {recorder}; \n" + f"target_net_updater = {target_net_updater}; \n" + f"policy_exploration = {policy_exploration}; \n" + f"replay_buffer = {replay_buffer}; \n" + f"writer = {writer}; \n" + f"args = {args}; \n" + ) + + if writer is not None: + # log hyperparams + txt = "\n\t".join([f"{k}: {val}" for k, val in sorted(vars(args).items())]) + writer.add_text("hparams", txt) + + return Agent( + collector=collector, + total_frames=args.total_frames * args.frame_skip, + loss_module=loss_module, + optimizer=optimizer, + recorder=recorder, + optim_scheduler=optim_scheduler, + target_net_updater=target_net_updater, + policy_exploration=policy_exploration, + replay_buffer=replay_buffer, + writer=writer, + update_weights_interval=1, + frame_skip=args.frame_skip, + optim_steps_per_batch=args.optim_steps_per_collection, + batch_size=args.batch_size, + clip_grad_norm=args.clip_grad_norm, + clip_norm=args.clip_norm, + record_interval=args.record_interval, + record_frames=args.record_frames, + normalize_rewards_online=args.normalize_rewards_online, + sub_traj_len=args.sub_traj_len, + selected_keys=args.selected_keys, + ) + + +def parser_agent_args(parser: ArgumentParser) -> ArgumentParser: + """ + Populates the argument parser to build the agent. + + Args: + parser (ArgumentParser): parser to be populated. + + """ + parser.add_argument( + "--optim_steps_per_collection", + type=int, + default=500, + help="Number of optimization steps in between two collection of data. See frames_per_batch " + "below. " + "Default=500", + ) + parser.add_argument( + "--optimizer", type=str, default="adam", help="Optimizer to be used." + ) + parser.add_argument( + "--selected_keys", + nargs="+", + default=None, + help="a list of strings that indicate the data that should be kept from the data collector. Since storing and " + "retrieving information from the replay buffer does not come for free, limiting the amount of data " + "passed to it can improve the algorithm performance." + "Default is None, i.e. all keys are kept.", + ) + + parser.add_argument( + "--batch_size", + type=int, + default=256, + help="batch size of the TensorDict retrieved from the replay buffer. Default=64.", + ) + parser.add_argument( + "--log_interval", + type=int, + default=10000, + help="logging interval, in terms of optimization steps. Default=1000.", + ) + parser.add_argument( + "--lr", + type=float, + default=3e-4, + help="Learning rate used for the optimizer. Default=2e-4.", + ) + parser.add_argument( + "--weight_decay", + type=float, + default=2e-5, + help="Weight-decay to be used with the optimizer. Default=0.0.", + ) + parser.add_argument( + "--clip_norm", + type=float, + default=1.0, + help="value at which the total gradient norm should be clipped. Default=1.0", + ) + parser.add_argument( + "--clip_grad_norm", + action="store_true", + help="if called, the gradient will be clipped based on its L2 norm. Otherwise, single gradient " + "values will be clipped to the desired threshold.", + ) + parser.add_argument( + "--normalize_rewards_online", + "--normalize-rewards-online", + action="store_true", + help="Computes the running statistics of the rewards and normalizes them before they are " + "passed to the loss module.", + ) + parser.add_argument( + "--sub_traj_len", + "--sub-traj-len", + type=int, + default=-1, + help="length of the trajectories that sub-samples must have in online settings.", + ) + return parser diff --git a/torchrl/agents/helpers/collectors.py b/torchrl/agents/helpers/collectors.py new file mode 100644 index 00000000000..e4fc0b1a294 --- /dev/null +++ b/torchrl/agents/helpers/collectors.py @@ -0,0 +1,478 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from argparse import ArgumentParser, Namespace +from typing import Callable, List, Optional, Type, Union + +from torchrl.collectors.collectors import ( + _DataCollector, + _MultiDataCollector, + MultiaSyncDataCollector, + MultiSyncDataCollector, +) +from torchrl.data import MultiStep +from torchrl.data.tensordict.tensordict import _TensorDict +from torchrl.envs import ParallelEnv + +__all__ = [ + "sync_sync_collector", + "sync_async_collector", + "make_collector_offpolicy", + "make_collector_onpolicy", + "parser_collector_args_offpolicy", + "parser_collector_args_onpolicy", +] + +from torchrl.envs.common import _EnvClass +from torchrl.modules import ProbabilisticTDModule, TDModuleWrapper + + +def sync_async_collector( + env_fns: Union[Callable, List[Callable]], + env_kwargs: Optional[Union[dict, List[dict]]], + num_env_per_collector: Optional[int] = None, + num_collectors: Optional[int] = None, + **kwargs, +) -> MultiaSyncDataCollector: + """ + Runs asynchronous collectors, each running synchronous environments. + + .. aafig:: + + + +----------------------------------------------------------------------+ + | "MultiConcurrentCollector" | | + |~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~| | + | "Collector 1" | "Collector 2" | "Collector 3" | "Main" | + |~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~| + | "env1" | "env2" | "env3" | "env4" | "env5" | "env6" | | + |~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~~~~~~~~~| + |"reset" |"reset" |"reset" |"reset" |"reset" |"reset" | | + | | | | | | | | + | "actor" | | | "actor" | | + | | | | | | + | "step" | "step" | "actor" | | | + | | | | | | + | | | | "step" | "step" | | + | | | | | | | + | "actor | "step" | "step" | "actor" | | + | | | | | | + | "yield batch 1" | "actor" | |"collect, train"| + | | | | | + | "step" | "step" | | "yield batch 2" |"collect, train"| + | | | | | | + | | | "yield batch 3" | |"collect, train"| + | | | | | | + +----------------------------------------------------------------------+ + + Environment types can be identical or different. In the latter case, env_fns should be a list with all the creator + fns for the various envs, + and the policy should handle those envs in batch. + + Args: + env_fns: Callable (or list of Callables) returning an instance of _EnvClass class. + env_kwargs: Optional. Dictionary (or list of dictionaries) containing the kwargs for the environment being created. + num_env_per_collector: Number of environments per data collector. The product + num_env_per_collector * num_collectors should be less or equal to the number of workers available. + num_collectors: Number of data collectors to be run in parallel. + **kwargs: Other kwargs passed to the data collectors + + """ + + return _make_collector( + MultiaSyncDataCollector, + env_fns=env_fns, + env_kwargs=env_kwargs, + num_env_per_collector=num_env_per_collector, + num_collectors=num_collectors, + **kwargs, + ) + + +def sync_sync_collector( + env_fns: Union[Callable, List[Callable]], + env_kwargs: Optional[Union[dict, List[dict]]], + num_env_per_collector: Optional[int] = None, + num_collectors: Optional[int] = None, + **kwargs, +) -> MultiSyncDataCollector: + """ + Runs synchronous collectors, each running synchronous environments. + + E.g. + + .. aafig:: + + +----------------------------------------------------------------------+ + | "MultiConcurrentCollector" | | + |~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~| | + | "Collector 1" | "Collector 2" | "Collector 3" | Main | + |~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~| + | "env1" | "env2" | "env3" | "env4" | "env5" | "env6" | | + |~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~~~~~~~~~| + |"reset" |"reset" |"reset" |"reset" |"reset" |"reset" | | + | | | | | | | | + | "actor" | | | "actor" | | + | | | | | | + | "step" | "step" | "actor" | | | + | | | | | | + | | | | "step" | "step" | | + | | | | | | | + | "actor" | "step" | "step" | "actor" | | + | | | | | | + | | "actor" | | | + | | | | | + | "yield batch of traj 1"------->"collect, train"| + | | | + | "step" | "step" | "step" | "step" | "step" | "step" | | + | | | | | | | | + | "actor" | "actor" | | | | + | | "step" | "step" | "actor" | | + | | | | | | + | "step" | "step" | "actor" | "step" | "step" | | + | | | | | | | + | "actor" | | "actor" | | + | "yield batch of traj 2"------->"collect, train"| + | | | + +----------------------------------------------------------------------+ + + Envs can be identical or different. In the latter case, env_fns should be a list with all the creator fns + for the various envs, + and the policy should handle those envs in batch. + + Args: + env_fns: Callable (or list of Callables) returning an instance of _EnvClass class. + env_kwargs: Optional. Dictionary (or list of dictionaries) containing the kwargs for the environment being created. + num_env_per_collector: Number of environments per data collector. The product + num_env_per_collector * num_collectors should be less or equal to the number of workers available. + num_collectors: Number of data collectors to be run in parallel. + **kwargs: Other kwargs passed to the data collectors + + """ + return _make_collector( + MultiSyncDataCollector, + env_fns=env_fns, + env_kwargs=env_kwargs, + num_env_per_collector=num_env_per_collector, + num_collectors=num_collectors, + **kwargs, + ) + + +def _make_collector( + collector_class: Type, + env_fns: Union[Callable, List[Callable]], + env_kwargs: Optional[Union[dict, List[dict]]], + policy: Callable[[_TensorDict], _TensorDict], + max_frames_per_traj: int = -1, + frames_per_batch: int = 200, + total_frames: Optional[int] = None, + postproc: Optional[Callable] = None, + num_env_per_collector: Optional[int] = None, + num_collectors: Optional[int] = None, + **kwargs, +) -> _MultiDataCollector: + if env_kwargs is None: + env_kwargs = dict() + if isinstance(env_fns, list): + num_env = len(env_fns) + if num_env_per_collector is None: + num_env_per_collector = -(num_env // -num_collectors) + elif num_collectors is None: + num_collectors = -(num_env // -num_env_per_collector) + else: + if num_env_per_collector * num_collectors < num_env: + raise ValueError( + f"num_env_per_collector * num_collectors={num_env_per_collector * num_collectors} " + f"has been found to be less than num_env={num_env}" + ) + else: + try: + num_env = num_env_per_collector * num_collectors + env_fns = [env_fns for _ in range(num_env)] + except (TypeError): + raise Exception( + "num_env was not a list but num_env_per_collector and num_collectors were not both specified," + f"got num_env_per_collector={num_env_per_collector} and num_collectors={num_collectors}" + ) + if not isinstance(env_kwargs, list): + env_kwargs = [env_kwargs for _ in range(num_env)] + + env_fns_split = [ + env_fns[i : i + num_env_per_collector] + for i in range(0, num_env, num_env_per_collector) + ] + env_kwargs_split = [ + env_kwargs[i : i + num_env_per_collector] + for i in range(0, num_env, num_env_per_collector) + ] + if len(env_fns_split) != num_collectors: + raise RuntimeError( + f"num_collectors={num_collectors} differs from len(env_fns_split)={len(env_fns_split)}" + ) + + if num_env_per_collector == 1: + env_fns = [_env_fn[0] for _env_fn in env_fns_split] + env_kwargs = [_env_kwargs[0] for _env_kwargs in env_kwargs_split] + else: + env_fns = [ + lambda: ParallelEnv( + num_workers=len(_env_fn), + create_env_fn=_env_fn, + create_env_kwargs=_env_kwargs, + ) + for _env_fn, _env_kwargs in zip(env_fns_split, env_kwargs_split) + ] + env_kwargs = None + return collector_class( + create_env_fn=env_fns, + create_env_kwargs=env_kwargs, + policy=policy, + total_frames=total_frames, + max_frames_per_traj=max_frames_per_traj, + frames_per_batch=frames_per_batch, + postproc=postproc, + **kwargs, + ) + + +def make_collector_offpolicy( + make_env: Callable[[], _EnvClass], + actor_model_explore: Union[TDModuleWrapper, ProbabilisticTDModule], + args: Namespace, + make_env_kwargs=None, +) -> _DataCollector: + """ + Returns a data collector for off-policy algorithms. + + Args: + make_env (Callable): environment creator + actor_model_explore (TDModule): Model instance used for evaluation and exploration update + args (Namespace): argument namespace built from the parser constructor + make_env_kwargs (dict): kwargs for the env creator + + """ + if args.async_collection: + collector_helper = sync_async_collector + else: + collector_helper = sync_sync_collector + + if args.multi_step: + ms = MultiStep( + gamma=args.gamma, + n_steps_max=args.n_steps_return, + ) + else: + ms = None + + env_kwargs = {} + if make_env_kwargs is not None: + env_kwargs.update(make_env_kwargs) + args.collector_devices = ( + args.collector_devices + if len(args.collector_devices) > 1 + else args.collector_devices[0] + ) + collector_helper_kwargs = { + "env_fns": make_env, + "env_kwargs": env_kwargs, + "policy": actor_model_explore, + "max_frames_per_traj": args.max_frames_per_traj, + "frames_per_batch": args.frames_per_batch, + "total_frames": args.total_frames, + "postproc": ms, + "num_env_per_collector": 1, + # we already took care of building the make_parallel_env function + "num_collectors": -args.num_workers // -args.env_per_collector, + "devices": args.collector_devices, + "passing_devices": args.collector_devices, + "init_random_frames": args.init_random_frames, + "pin_memory": args.pin_memory, + "split_trajs": ms is not None, + # trajectories must be separated if multi-step is used + "init_with_lag": args.init_with_lag, + "exploration_mode": args.exploration_mode, + } + + collector = collector_helper(**collector_helper_kwargs) + collector.set_seed(args.seed) + return collector + + +def make_collector_onpolicy( + make_env: Callable[[], _EnvClass], + actor_model_explore: Union[TDModuleWrapper, ProbabilisticTDModule], + args: Namespace, + make_env_kwargs=None, +) -> _DataCollector: + collector_helper = sync_sync_collector + + ms = None + + env_kwargs = {} + if make_env_kwargs is not None: + env_kwargs.update(make_env_kwargs) + args.collector_devices = ( + args.collector_devices + if len(args.collector_devices) > 1 + else args.collector_devices[0] + ) + collector_helper_kwargs = { + "env_fns": make_env, + "env_kwargs": env_kwargs, + "policy": actor_model_explore, + "max_frames_per_traj": args.max_frames_per_traj, + "frames_per_batch": args.frames_per_batch, + "total_frames": args.total_frames, + "postproc": ms, + "num_env_per_collector": 1, + # we already took care of building the make_parallel_env function + "num_collectors": -args.num_workers // -args.env_per_collector, + "devices": args.collector_devices, + "passing_devices": args.collector_devices, + "pin_memory": args.pin_memory, + "split_trajs": True, + # trajectories must be separated in online settings + "init_with_lag": args.init_with_lag, + "exploration_mode": args.exploration_mode, + } + + collector = collector_helper(**collector_helper_kwargs) + collector.set_seed(args.seed) + return collector + + +def _parser_collector_args(parser: ArgumentParser) -> ArgumentParser: + parser.add_argument( + "--collector_devices", + "--collector-devices", + nargs="+", + default=["cpu"], + help="device on which the data collector should store the trajectories to be passed to this script." + "If the collector device differs from the policy device (cuda:0 if available), then the " + "weights of the collector policy are synchronized with collector.update_policy_weights_().", + ) + parser.add_argument( + "--pin_memory", + "--pin-memory", + action="store_true", + help="if True, the data collector will call pin_memory before dispatching tensordicts onto the " + "passing device.", + ) + parser.add_argument( + "--init_with_lag", + "--init-with-lag", + action="store_true", + help="if True, the first trajectory will be truncated earlier at a random step. This is helpful" + " to desynchronize the environments, such that steps do no match in all collected " + "rollouts. Especially useful for online training, to prevent cyclic sample indices.", + ) + parser.add_argument( + "--frames_per_batch", + "--frames-per-batch", + type=int, + default=1000, + help="number of steps executed in the environment per collection." + "This value represents how many steps will the data collector execute and return in *each*" + "environment that has been created in between two rounds of optimization " + "(see the optim_steps_per_collection above). " + "On the one hand, a low value will enhance the data throughput between processes in async " + "settings, which can make the accessing of data a computational bottleneck. " + "High values will on the other hand lead to greater tensor sizes in memory and disk to be " + "written and read at each global iteration. One should look at the number of frames per second" + "in the log to assess the efficiency of the configuration.", + ) + parser.add_argument( + "--total_frames", + "--total-frames", + type=int, + default=50000000, + help="total number of frames collected for training. Does account for frame_skip (i.e. will be " + "divided by the frame_skip). Default=50e6.", + ) + parser.add_argument( + "--num_workers", + "--num-workers", + type=int, + default=32, + help="Number of workers used for data collection. ", + ) + parser.add_argument( + "--env_per_collector", + "--env-per-collector", + default=8, + type=int, + help="Number of environments per collector. If the env_per_collector is in the range: " + "1 ArgumentParser: + """ + Populates the argument parser to build a data collector for on-policy algorithms (DQN, DDPG, SAC, REDQ). + + Args: + parser (ArgumentParser): parser to be populated. + + """ + parser = _parser_collector_args(parser) + parser.add_argument( + "--multi_step", + "--multi-step", + dest="multi_step", + action="store_true", + help="whether or not multi-step rewards should be used.", + ) + parser.add_argument( + "--n_steps_return", + "--n-steps-return", + type=int, + default=3, + help="If multi_step is set to True, this value defines the number of steps to look ahead for the " + "reward computation.", + ) + parser.add_argument( + "--init_random_frames", + "--init-random-frames", + type=int, + default=50000, + help="Initial number of random frames used before the policy is being used. Default=5000.", + ) + return parser + + +def parser_collector_args_onpolicy(parser: ArgumentParser) -> ArgumentParser: + """ + Populates the argument parser to build a data collector for on-policy algorithms (PPO). + + Args: + parser (ArgumentParser): parser to be populated. + """ + parser = _parser_collector_args(parser) + return parser diff --git a/torchrl/agents/helpers/envs.py b/torchrl/agents/helpers/envs.py new file mode 100644 index 00000000000..2ade8b3d83a --- /dev/null +++ b/torchrl/agents/helpers/envs.py @@ -0,0 +1,334 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from argparse import ArgumentParser, Namespace +from typing import Callable, Optional, Union + +import torch + +from torchrl.agents.env_creator import env_creator, EnvCreator +from torchrl.envs import DMControlEnv, GymEnv, ParallelEnv, RetroEnv +from torchrl.envs.common import _EnvClass +from torchrl.envs.transforms import ( + CatFrames, + CatTensors, + Compose, + DoubleToFloat, + FiniteTensorDictCheck, + GrayScale, + NoopResetEnv, + ObservationNorm, + Resize, + RewardScaling, + ToTensorImage, + TransformedEnv, + VecNorm, +) +from torchrl.envs.transforms.transforms import gSDENoise +from torchrl.record.recorder import VideoRecorder + +__all__ = [ + "correct_for_frame_skip", + "transformed_env_constructor", + "parallel_env_constructor", + "get_stats_random_rollout", + "parser_env_args", +] + +LIBS = { + "gym": GymEnv, + "retro": RetroEnv, + "dm_control": DMControlEnv, +} + + +def correct_for_frame_skip(args: Namespace) -> Namespace: + """ + Correct the arguments for the input frame_skip, by dividing all the arguments that reflect a count of frames by the + frame_skip. + This is aimed at avoiding unknowingly over-sampling from the environment, i.e. targetting a total number of frames + of 1M but actually collecting frame_skip * 1M frames. + + Args: + args (argparse.Namespace): Namespace containing some frame-counting argument, including: + "max_frames_per_traj", "total_frames", "frames_per_batch", "record_frames", "annealing_frames", + "init_random_frames", "init_env_steps" + + Returns: + the input Namespace, modified in-place. + + """ + # Adapt all frame counts wrt frame_skip + if args.frame_skip != 1: + fields = [ + "max_frames_per_traj", + "total_frames", + "frames_per_batch", + "record_frames", + "annealing_frames", + "init_random_frames", + "init_env_steps", + "noops", + ] + for field in fields: + if hasattr(args, field): + setattr(args, field, getattr(args, field) // args.frame_skip) + return args + + +def transformed_env_constructor( + args: Namespace, + video_tag: str = "", + writer: Optional["SummaryWriter"] = None, + stats: Optional[dict] = None, + norm_obs_only: bool = False, + use_env_creator: bool = True, + custom_env_maker: Optional[Callable] = None, +) -> Union[Callable, EnvCreator]: + """ + Returns an environment creator from an argparse.Namespace built with the appropriate parser constructor. + + Args: + args (argparse.Namespace): script arguments originating from the parser built with parser_env_args + video_tag (str, optional): video tag to be passed to the SummaryWriter object + writer (SummaryWriter, optional): tensorboard writer associated with the script + stats (dict, optional): a dictionary containing the `loc` and `scale` for the `ObservationNorm` transform + norm_obs_only (bool, optional): If `True` and `VecNorm` is used, the reward won't be normalized online. + Default is `False`. + use_env_creator (bool, optional): wheter the `EnvCreator` class should be used. By using `EnvCreator`, + one can make sure that running statistics will be put in shared memory and accessible for all workers + when using a `VecNorm` transform. Default is `True`. + custom_env_maker (callable, optional): if your env maker is not part + of torchrl env wrappers, a custom callable + can be passed instead. In this case it will override the + constructor retrieved from `args`. + """ + + def make_transformed_env() -> TransformedEnv: + env_name = args.env_name + env_task = args.env_task + env_library = LIBS[args.env_library] + frame_skip = args.frame_skip + from_pixels = args.from_pixels + vecnorm = args.vecnorm + norm_rewards = vecnorm and args.norm_rewards + _norm_obs_only = norm_obs_only or not norm_rewards + reward_scaling = args.reward_scaling + + if custom_env_maker is None: + env_kwargs = { + "envname": env_name, + "device": "cpu", + "frame_skip": frame_skip, + "from_pixels": from_pixels or len(video_tag), + "pixels_only": from_pixels, + } + if env_library is DMControlEnv: + env_kwargs.update({"taskname": env_task}) + env = env_library(**env_kwargs) + else: + env = custom_env_maker() + + keys = env.reset().keys() + transforms = [] + + if args.noops: + transforms += [NoopResetEnv(env, args.noops)] + if from_pixels: + transforms += [ + ToTensorImage(), + Resize(84, 84), + GrayScale(), + CatFrames(keys=["next_observation_pixels"]), + ObservationNorm(loc=-1.0, scale=2.0, keys=["next_observation_pixels"]), + ] + if norm_rewards: + reward_scaling = 1.0 + if norm_obs_only: + reward_scaling = 1.0 + if reward_scaling is not None: + transforms.append(RewardScaling(0.0, reward_scaling)) + + double_to_float_list = [] + if env_library is DMControlEnv: + double_to_float_list += [ + "reward", + "action", + ] # DMControl requires double-precision + if not from_pixels: + selected_keys = [ + "next_" + key + for key in keys + if key.startswith("observation") and "pixels" not in key + ] + + # even if there is a single tensor, it'll be renamed in "next_observation_vector" + out_key = "next_observation_vector" + transforms.append(CatTensors(keys=selected_keys, out_key=out_key)) + + if not vecnorm: + if stats is None: + _stats = {"loc": 0.0, "scale": 1.0} + else: + _stats = stats + transforms.append( + ObservationNorm(**_stats, keys=[out_key], standard_normal=True) + ) + else: + transforms.append( + VecNorm( + keys=[out_key, "reward"] if not _norm_obs_only else [out_key], + decay=0.9999, + ) + ) + + double_to_float_list.append(out_key) + transforms.append(DoubleToFloat(keys=double_to_float_list)) + + if hasattr(args, "gSDE") and args.gSDE: + transforms.append( + gSDENoise( + action_dim=env.action_spec.shape[-1], + ) + ) + + else: + transforms.append(DoubleToFloat(keys=double_to_float_list)) + if hasattr(args, "gSDE") and args.gSDE: + raise RuntimeError("gSDE not compatible with from_pixels=True") + + if len(video_tag): + transforms = [ + VideoRecorder( + writer=writer, + tag=f"{video_tag}_{env_name}_video", + ), + *transforms, + ] + transforms.append(FiniteTensorDictCheck()) + env = TransformedEnv( + env, + Compose(*transforms), + ) + return env + + if use_env_creator: + return env_creator(make_transformed_env) + return make_transformed_env + + +def parallel_env_constructor(args: Namespace, **kwargs) -> EnvCreator: + """Returns a parallel environment from an argparse.Namespace built with the appropriate parser constructor. + + Args: + args (argparse.Namespace): script arguments originating from the parser built with parser_env_args + kwargs: keyword arguments for the `transformed_env_constructor` method. + """ + kwargs.update({"args": args, "use_env_creator": True}) + make_transformed_env = transformed_env_constructor(**kwargs) + env = ParallelEnv( + num_workers=args.env_per_collector, + create_env_fn=make_transformed_env, + create_env_kwargs=None, + pin_memory=args.pin_memory, + ) + return env + + +def get_stats_random_rollout(args: Namespace, proof_environment: _EnvClass): + if not hasattr(args, "init_env_steps"): + raise AttributeError("init_env_steps missing from arguments.") + + td_stats = proof_environment.rollout(n_steps=args.init_env_steps) + if args.from_pixels: + m = td_stats.get("observation_pixels").mean(dim=0) + s = td_stats.get("observation_pixels").std(dim=0).clamp_min(1e-5) + else: + m = td_stats.get("observation_vector").mean(dim=0) + s = td_stats.get("observation_vector").std(dim=0).clamp_min(1e-5) + if not torch.isfinite(m).all(): + raise RuntimeError("non-finite values found in mean") + if not torch.isfinite(s).all(): + raise RuntimeError("non-finite values found in sd") + stats = {"loc": m, "scale": s} + return stats + + +def parser_env_args(parser: ArgumentParser) -> ArgumentParser: + """ + Populates the argument parser to build an environment constructor. + + Args: + parser (ArgumentParser): parser to be populated. + + """ + + parser.add_argument( + "--env_library", + type=str, + default="gym", + choices=["dm_control", "gym"], + help="env_library used for the simulated environment. Default=gym", + ) + parser.add_argument( + "--env_name", + type=str, + default="Humanoid-v2", + help="name of the environment to be created. Default=Humanoid-v2", + ) + parser.add_argument( + "--env_task", + type=str, + default="", + help="task (if any) for the environment. Default=run", + ) + parser.add_argument( + "--from_pixels", + action="store_true", + help="whether the environment output should be state vector(s) (default) or the pixels.", + ) + parser.add_argument( + "--frame_skip", + type=int, + default=1, + help="frame_skip for the environment. Note that this value does NOT impact the buffer size," + "maximum steps per trajectory, frames per batch or any other factor in the algorithm," + "e.g. if the total number of frames that has to be computed is 50e6 and the frame skip is 4," + "the actual number of frames retrieved will be 200e6. Default=1.", + ) + parser.add_argument("--reward_scaling", type=float, help="scale of the reward.") + parser.add_argument( + "--init_env_steps", + type=int, + default=1000, + help="number of random steps to compute normalizing constants", + ) + parser.add_argument( + "--vecnorm", + action="store_true", + help="Normalizes the environment observation and reward outputs with the running statistics " + "obtained across processes.", + ) + parser.add_argument( + "--norm_rewards", + action="store_true", + help="If True, rewards will be normalized on the fly. This may interfere with SAC update rule and " + "should be used cautiously.", + ) + parser.add_argument( + "--noops", + type=int, + default=0, + help="number of random steps to do after reset. Default is 0", + ) + parser.add_argument( + "--max_frames_per_traj", + type=int, + default=1000, + help="Number of steps before a reset of the environment is called (if it has not been flagged as " + "done before). ", + ) + + return parser diff --git a/torchrl/agents/helpers/losses.py b/torchrl/agents/helpers/losses.py new file mode 100644 index 00000000000..b4f6b5d26d9 --- /dev/null +++ b/torchrl/agents/helpers/losses.py @@ -0,0 +1,285 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from argparse import ArgumentParser, Namespace + +__all__ = [ + "make_sac_loss", + "make_dqn_loss", + "make_ddpg_loss", + "make_target_updater", + "make_ppo_loss", + "make_redq_loss", + "parser_loss_args", + "parser_loss_args_ppo", +] + +from typing import Optional, Tuple + +from torchrl.objectives import ( + ClipPPOLoss, + DDPGLoss, + DistributionalDoubleDQNLoss, + DistributionalDQNLoss, + DoubleDDPGLoss, + DoubleDQNLoss, + DoubleSACLoss, + DQNLoss, + GAE, + HardUpdate, + KLPENPPOLoss, + PPOLoss, + SACLoss, + SoftUpdate, +) +from torchrl.objectives.costs.common import _LossModule +from torchrl.objectives.costs.redq import DoubleREDQLoss, REDQLoss + +# from torchrl.objectives.costs.redq import REDQLoss, DoubleREDQLoss +from torchrl.objectives.costs.utils import _TargetNetUpdate + + +def make_target_updater( + args: Namespace, loss_module: _LossModule +) -> Optional[_TargetNetUpdate]: + """Builds a target network weight update object.""" + if args.loss == "double": + if not args.hard_update: + target_net_updater = SoftUpdate( + loss_module, 1 - 1 / args.value_network_update_interval + ) + else: + target_net_updater = HardUpdate( + loss_module, args.value_network_update_interval + ) + # assert len(target_net_updater.net_pairs) == 3, "length of target_net_updater nets should be 3" + target_net_updater.init_() + else: + assert not args.hard_update, ( + "hard/soft-update are supposed to be used with double SAC loss. " + "Consider using --loss=double or discarding the hard_update flag." + ) + target_net_updater = None + return target_net_updater + + +def make_sac_loss(model, args) -> Tuple[SACLoss, Optional[_TargetNetUpdate]]: + """Builds the SAC loss module.""" + loss_kwargs = {} + if hasattr(args, "distributional") and args.distributional: + raise NotImplementedError + else: + loss_kwargs.update({"loss_function": args.loss_function}) + if args.loss == "double": + loss_class = DoubleSACLoss + loss_kwargs.update( + { + "delay_actor": False, + "delay_qvalue": False, + } + ) + else: + loss_class = SACLoss + actor_model, qvalue_model, value_model = model + + loss_module = loss_class( + actor_network=actor_model, + qvalue_network=qvalue_model, + value_network=value_model, + num_qvalue_nets=args.num_q_values, + gamma=args.gamma, + **loss_kwargs + ) + target_net_updater = make_target_updater(args, loss_module) + return loss_module, target_net_updater + + +def make_redq_loss(model, args) -> Tuple[REDQLoss, Optional[_TargetNetUpdate]]: + """Builds the REDQ loss module.""" + loss_kwargs = {} + if hasattr(args, "distributional") and args.distributional: + raise NotImplementedError + else: + loss_kwargs.update({"loss_function": args.loss_function}) + if args.loss == "double": + loss_class = DoubleREDQLoss + else: + loss_class = REDQLoss + actor_model, qvalue_model = model + + loss_module = loss_class( + actor_network=actor_model, + qvalue_network=qvalue_model, + num_qvalue_nets=args.num_q_values, + gamma=args.gamma, + **loss_kwargs + ) + target_net_updater = make_target_updater(args, loss_module) + return loss_module, target_net_updater + + +def make_ddpg_loss(model, args) -> Tuple[DDPGLoss, Optional[_TargetNetUpdate]]: + """Builds the DDPG loss module.""" + actor, value_net = model + loss_kwargs = {} + if args.distributional: + raise NotImplementedError + else: + loss_kwargs.update({"loss_function": args.loss_function}) + if args.loss == "single": + loss_class = DDPGLoss + elif args.loss == "double": + loss_class = DoubleDDPGLoss + else: + raise NotImplementedError + loss_module = loss_class(actor, value_net, gamma=args.gamma, **loss_kwargs) + target_net_updater = make_target_updater(args, loss_module) + return loss_module, target_net_updater + + +def make_dqn_loss(model, args) -> Tuple[DQNLoss, Optional[_TargetNetUpdate]]: + """Builds the DQN loss module.""" + loss_kwargs = {} + if args.distributional: + if args.loss == "single": + loss_class = DistributionalDQNLoss + elif args.loss == "double": + loss_class = DistributionalDoubleDQNLoss + else: + raise NotImplementedError + else: + loss_kwargs.update({"loss_function": args.loss_function}) + if args.loss == "single": + loss_class = DQNLoss + elif args.loss == "double": + loss_class = DoubleDQNLoss + else: + raise NotImplementedError + loss_module = loss_class(model, gamma=args.gamma, **loss_kwargs) + target_net_updater = make_target_updater(args, loss_module) + return loss_module, target_net_updater + + +def make_ppo_loss(model, args) -> PPOLoss: + """Builds the PPO loss module.""" + loss_dict = { + "clip": ClipPPOLoss, + "kl": KLPENPPOLoss, + "base": PPOLoss, + "": PPOLoss, + } + actor_model = model.get_policy_operator() + critic_model = model.get_value_operator() + + advantage = GAE(args.gamma, args.lamda, critic=critic_model, average_rewards=True) + loss_module = loss_dict[args.loss]( + actor=actor_model, + critic=critic_model, + advantage_module=advantage, + loss_critic_type=args.loss_function, + entropy_factor=args.entropy_factor, + ) + return loss_module + + +def parser_loss_args(parser: ArgumentParser, algorithm: str) -> ArgumentParser: + """ + Populates the argument parser to build the off-policy loss function (REDQ, SAC, DDPG, DQN). + + Args: + parser (ArgumentParser): parser to be populated. + algorithm (str): one of `"DDPG"`, `"SAC"`, `"REDQ"`, `"DQN"` + + """ + parser.add_argument( + "--loss", + type=str, + default="double", + choices=["double", "single"], + help="whether double or single SAC loss should be used. Default=double", + ) + parser.add_argument( + "--hard_update", + action="store_true", + help="whether soft-update should be used with double SAC loss (default) or hard updates.", + ) + parser.add_argument( + "--loss_function", + type=str, + default="smooth_l1", + choices=["l1", "l2", "smooth_l1"], + help="loss function for the value network. Either one of l1, l2 or smooth_l1 (default).", + ) + parser.add_argument( + "--value_network_update_interval", + type=int, + default=1000, + help="how often the target value network weights are updated (in number of updates)." + "If soft-updates are used, the value is translated into a moving average decay by using " + "the formula decay=1-1/args.value_network_update_interval. Default=1000", + ) + parser.add_argument( + "--gamma", + type=float, + default=0.99, + help="Decay factor for return computation. Default=0.99.", + ) + if algorithm in ("SAC", "REDQ"): + parser.add_argument( + "--num_q_values", + default=2, + type=int, + help="As suggested in the original SAC paper and in https://arxiv.org/abs/1802.09477, we can " + "use two (or more!) different qvalue networks trained independently and choose the lowest value " + "predicted to predict the state action value. This can be disabled by using this flag." + "REDQ uses an arbitrary number of Q-value functions to speed up learning in MF contexts.", + ) + + return parser + + +def parser_loss_args_ppo(parser: ArgumentParser) -> ArgumentParser: + """ + Populates the argument parser to build the PPO loss function. + + Args: + parser (ArgumentParser): parser to be populated. + + """ + parser.add_argument( + "--loss", + type=str, + default="clip", + choices=["clip", "kl", "base", ""], + help="PPO loss class, either clip or kl or base/. Default=clip", + ) + parser.add_argument( + "--gamma", + type=float, + default=0.99, + help="Decay factor for return computation. Default=0.99.", + ) + parser.add_argument( + "--lamda", + default=0.95, + type=float, + help="lambda factor in GAE (using 'lambda' as attribute is prohibited in python, " + "hence the misspelling)", + ) + parser.add_argument( + "--entropy_factor", + type=float, + default=1e-3, + help="Entropy factor for the PPO loss", + ) + parser.add_argument( + "--loss_function", + type=str, + default="smooth_l1", + choices=["l1", "l2", "smooth_l1"], + help="loss function for the value network. Either one of l1, l2 or smooth_l1 (default).", + ) + + return parser diff --git a/torchrl/agents/helpers/models.py b/torchrl/agents/helpers/models.py new file mode 100644 index 00000000000..62dadf7c8d7 --- /dev/null +++ b/torchrl/agents/helpers/models.py @@ -0,0 +1,1025 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from argparse import ArgumentParser, Namespace +from typing import Optional, Sequence + +import torch +from torch import nn + +from torchrl.data import DEVICE_TYPING, CompositeSpec +from torchrl.envs.common import _EnvClass +from torchrl.modules import ( + ActorValueOperator, + NoisyLinear, + TDModule, + NormalParamWrapper, +) +from torchrl.modules.distributions import ( + Delta, + OneHotCategorical, + TanhDelta, + TanhNormal, + TruncatedNormal, +) +from torchrl.modules.distributions.continuous import IndependentNormal +from torchrl.modules.models.exploration import gSDEWrapper +from torchrl.modules.models.models import ( + ConvNet, + DdpgCnnActor, + DdpgCnnQNet, + DdpgMlpActor, + DdpgMlpQNet, + DuelingCnnDQNet, + LSTMNet, + MLP, + DuelingMlpDQNet, +) +from torchrl.modules.td_module import ( + Actor, + DistributionalQValueActor, + QValueActor, +) +from torchrl.modules.td_module.actors import ( + ActorCriticWrapper, + ProbabilisticActor, + ValueOperator, +) + +DISTRIBUTIONS = { + "delta": Delta, + "tanh-normal": TanhNormal, + "categorical": OneHotCategorical, + "tanh-delta": TanhDelta, +} + +__all__ = [ + "make_dqn_actor", + "make_ddpg_actor", + "make_ppo_model", + "make_sac_model", + "make_redq_model", + "parser_model_args_continuous", + "parser_model_args_discrete", +] + + +def make_dqn_actor( + proof_environment: _EnvClass, args: Namespace, device: torch.device +) -> Actor: + """ + DQN constructor helper function. + + Args: + proof_environment (_EnvClass): a dummy environment to retrieve the observation and action spec. + args (argparse.Namespace): arguments of the DQN script + device (torch.device): device on which the model must be cast + + Returns: + A DQN policy operator. + + Examples: + >>> from torchrl.agents.helpers.models import make_dqn_actor, parser_model_args_discrete + >>> from torchrl.envs import GymEnv + >>> from torchrl.envs.transforms import ToTensorImage, TransformedEnv + >>> import argparse + >>> proof_environment = TransformedEnv(GymEnv("ALE/Pong-v5", + ... pixels_only=True), ToTensorImage()) + >>> device = torch.device("cpu") + >>> args = parser_model_args_discrete(argparse.ArgumentParser()).parse_args([]) + >>> actor = make_dqn_actor(proof_environment, args, device) + >>> td = proof_environment.reset() + >>> print(actor(td)) + TensorDict( + fields={ + done: Tensor(torch.Size([1]), dtype=torch.bool), + observation_pixels: Tensor(torch.Size([3, 210, 160]), dtype=torch.float32), + action: Tensor(torch.Size([6]), dtype=torch.int64), + action_value: Tensor(torch.Size([6]), dtype=torch.float32), + chosen_action_value: Tensor(torch.Size([1]), dtype=torch.float32)}, + batch_size=torch.Size([]), + device=cpu, + is_shared=False) + + + """ + env_specs = proof_environment.specs + + atoms = args.atoms if args.distributional else None + linear_layer_class = torch.nn.Linear if not args.noisy else NoisyLinear + + action_spec = env_specs["action_spec"] + if action_spec.domain != "discrete": + raise ValueError( + f"env {proof_environment} has an action domain " + f"{action_spec.domain} which is incompatible with " + f"DQN. Make sure your environment has a discrete " + f"domain." + ) + + if args.from_pixels: + net_class = DuelingCnnDQNet + default_net_kwargs = { + "cnn_kwargs": { + "bias_last_layer": True, + "depth": None, + "num_cells": [32, 64, 64], + "kernel_sizes": [8, 4, 3], + "strides": [4, 2, 1], + }, + "mlp_kwargs": {"num_cells": 512, "layer_class": linear_layer_class}, + } + in_key = "observation_pixels" + + else: + net_class = DuelingMlpDQNet + default_net_kwargs = { + "mlp_kwargs_feature": {}, # see class for details + "mlp_kwargs_output": {"num_cells": 512, "layer_class": linear_layer_class}, + } + in_key = "observation_vector" + + out_features = env_specs["action_spec"].shape[0] + actor_class = QValueActor + actor_kwargs = {} + if args.distributional: + if not atoms: + raise RuntimeError( + "Expected atoms to be a positive integer, " f"got {atoms}" + ) + vmin = -3 + vmax = 3 + + out_features = (atoms, out_features) + support = torch.linspace(vmin, vmax, atoms) + actor_class = DistributionalQValueActor + actor_kwargs.update({"support": support}) + default_net_kwargs.update({"out_features_value": (atoms, 1)}) + + net = net_class( + out_features=out_features, + **default_net_kwargs, + ) + + model = actor_class( + module=net, + spec=action_spec, + in_keys=[in_key], + safe=True, + **actor_kwargs, + ).to(device) + + # init + with torch.no_grad(): + td = proof_environment.reset() + model(td.to(device)) + return model + + +def make_ddpg_actor( + proof_environment: _EnvClass, + args: Namespace, + actor_net_kwargs: Optional[dict] = None, + value_net_kwargs: Optional[dict] = None, + device: DEVICE_TYPING = "cpu", +) -> torch.nn.ModuleList: + """ + DDPG constructor helper function. + + Args: + proof_environment (_EnvClass): a dummy environment to retrieve the observation and action spec + args (argparse.Namespace): arguments of the DDPG script + actor_net_kwargs (dict, optional): kwargs to be used for the policy network (either DdpgCnnActor or + DdpgMlpActor). + value_net_kwargs (dict, optional): kwargs to be used for the policy network (either DdpgCnnQNet or + DdpgMlpQNet). + device (torch.device, optional): device on which the model must be cast. Default is "cpu". + + Returns: + An actor and a value operators for DDPG. + + For more details on DDPG, refer to "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING", + https://arxiv.org/pdf/1509.02971.pdf. + + Examples: + >>> from torchrl.agents.helpers.envs import parser_env_args + >>> from torchrl.agents.helpers.models import make_ddpg_actor, parser_model_args_continuous + >>> from torchrl.envs import GymEnv + >>> from torchrl.envs.transforms import CatTensors, TransformedEnv, DoubleToFloat, Compose + >>> import argparse + >>> proof_environment = TransformedEnv(GymEnv("HalfCheetah-v2"), Compose(DoubleToFloat(["next_observation"]), + ... CatTensors(["next_observation"], "next_observation_vector"))) + >>> device = torch.device("cpu") + >>> parser = argparse.ArgumentParser() + >>> parser = parser_env_args(parser) + >>> args = parser_model_args_continuous(parser, algorithm="DDPG").parse_args([]) + >>> actor, value = make_ddpg_actor( + ... proof_environment, + ... device=device, + ... args=args) + >>> td = proof_environment.reset() + >>> print(actor(td)) + TensorDict( + fields={ + done: Tensor(torch.Size([1]), dtype=torch.bool), + observation_vector: Tensor(torch.Size([17]), dtype=torch.float32), + action: Tensor(torch.Size([6]), dtype=torch.float32)}, + batch_size=torch.Size([]), + device=cpu, + is_shared=False) + >>> print(value(td)) + TensorDict( + fields={ + done: Tensor(torch.Size([1]), dtype=torch.bool), + observation_vector: Tensor(torch.Size([17]), dtype=torch.float32), + action: Tensor(torch.Size([6]), dtype=torch.float32), + state_action_value: Tensor(torch.Size([1]), dtype=torch.float32)}, + batch_size=torch.Size([]), + device=cpu, + is_shared=False) + """ + + # TODO: https://arxiv.org/pdf/1804.08617.pdf + + from_pixels = args.from_pixels + noisy = args.noisy + + actor_net_kwargs = actor_net_kwargs if actor_net_kwargs is not None else dict() + value_net_kwargs = value_net_kwargs if value_net_kwargs is not None else dict() + + linear_layer_class = torch.nn.Linear if not noisy else NoisyLinear + + env_specs = proof_environment.specs + out_features = env_specs["action_spec"].shape[0] + + # We use a ProbabilisticActor to make sure that we map the network output to the right space using a TanhDelta + # distribution. + actor_class = ProbabilisticActor + + actor_net_default_kwargs = { + "action_dim": out_features, + "mlp_net_kwargs": {"layer_class": linear_layer_class}, + } + actor_net_default_kwargs.update(actor_net_kwargs) + if from_pixels: + in_keys = ["observation_pixels"] + actor_net = DdpgCnnActor(**actor_net_default_kwargs) + + else: + in_keys = ["observation_vector"] + actor_net = DdpgMlpActor(**actor_net_default_kwargs) + + actor = actor_class( + in_keys=in_keys, + spec=env_specs["action_spec"], + module=actor_net, + safe=True, + distribution_class=TanhDelta, + distribution_kwargs={ + "min": env_specs["action_spec"].space.minimum, + "max": env_specs["action_spec"].space.maximum, + }, + ) + + state_class = ValueOperator + if from_pixels: + value_net_default_kwargs = { + "mlp_net_kwargs": {"layer_class": linear_layer_class} + } + value_net_default_kwargs.update(value_net_kwargs) + + in_keys = ["observation_pixels", "action"] + out_keys = ["state_action_value"] + q_net = DdpgCnnQNet(**value_net_default_kwargs) + else: + value_net_default_kwargs1 = {"activation_class": torch.nn.ELU} + value_net_default_kwargs1.update( + value_net_kwargs.get( + "mlp_net_kwargs_net1", {"layer_class": linear_layer_class} + ) + ) + value_net_default_kwargs2 = { + "num_cells": [400, 300], + "depth": 2, + "activation_class": torch.nn.ELU, + } + value_net_default_kwargs2.update( + value_net_kwargs.get( + "mlp_net_kwargs_net2", {"layer_class": linear_layer_class} + ) + ) + in_keys = ["observation_vector", "action"] + out_keys = ["state_action_value"] + q_net = DdpgMlpQNet( + mlp_net_kwargs_net1=value_net_default_kwargs1, + mlp_net_kwargs_net2=value_net_default_kwargs2, + ) + + value = state_class( + in_keys=in_keys, + out_keys=out_keys, + module=q_net, + ) + + module = torch.nn.ModuleList([actor, value]).to(device) + + # init + with torch.no_grad(): + td = proof_environment.reset().to(device) + module[0](td) + module[1](td) + + return module + + +def make_ppo_model( + proof_environment: _EnvClass, + args: Namespace, + device: DEVICE_TYPING, + in_keys_actor: Optional[Sequence[str]] = None, + **kwargs, +) -> ActorValueOperator: + """ + Actor-value model constructor helper function. + Currently constructs MLP networks with immutable default arguments as described in "Proximal Policy Optimization + Algorithms", https://arxiv.org/abs/1707.06347 + Other configurations can easily be implemented by modifying this function at will. + + Args: + proof_environment (_EnvClass): a dummy environment to retrieve the observation and action spec + args (argparse.Namespace): arguments of the PPO script + device (torch.device): device on which the model must be cast. + in_keys_actor (iterable of strings, optional): observation key to be read by the actor, usually one of + `'observation_vector'` or `'observation_pixels'`. If none is provided, one of these two keys is chosen based on + the `args.from_pixels` argument. + + Returns: + A joined ActorCriticOperator. + + Examples: + >>> from torchrl.agents.helpers.envs import parser_env_args + >>> from torchrl.agents.helpers.models import make_ppo_model, parser_model_args_continuous + >>> from torchrl.envs import GymEnv + >>> from torchrl.envs.transforms import CatTensors, TransformedEnv, DoubleToFloat, Compose + >>> import argparse + >>> proof_environment = TransformedEnv(GymEnv("HalfCheetah-v2"), Compose(DoubleToFloat(["next_observation"]), + ... CatTensors(["next_observation"], "next_observation_vector"))) + >>> device = torch.device("cpu") + >>> parser = argparse.ArgumentParser() + >>> parser = parser_env_args(parser) + >>> args = parser_model_args_continuous(parser, algorithm="PPO").parse_args(["--shared_mapping"]) + >>> actor_value = make_ppo_model( + ... proof_environment, + ... device=device, + ... args=args, + ... ) + >>> actor = actor_value.get_policy_operator() + >>> value = actor_value.get_value_operator() + >>> td = proof_environment.reset() + >>> print(actor(td.clone())) + TensorDict( + fields={ + done: Tensor(torch.Size([1]), dtype=torch.bool), + observation_vector: Tensor(torch.Size([17]), dtype=torch.float32), + hidden: Tensor(torch.Size([300]), dtype=torch.float32), + action_dist_param_0: Tensor(torch.Size([6]), dtype=torch.float32), + action_dist_param_1: Tensor(torch.Size([6]), dtype=torch.float32), + action: Tensor(torch.Size([6]), dtype=torch.float32), + action_log_prob: Tensor(torch.Size([1]), dtype=torch.float32)}, + batch_size=torch.Size([]), + device=cpu, + is_shared=False) + >>> print(value(td.clone())) + TensorDict( + fields={ + done: Tensor(torch.Size([1]), dtype=torch.bool), + observation_vector: Tensor(torch.Size([17]), dtype=torch.float32), + hidden: Tensor(torch.Size([300]), dtype=torch.float32), + state_value: Tensor(torch.Size([1]), dtype=torch.float32)}, + batch_size=torch.Size([]), + device=cpu, + is_shared=False) + + """ + # proof_environment.set_seed(args.seed) + specs = proof_environment.specs # TODO: use env.sepcs + action_spec = specs["action_spec"] + obs_spec = specs["observation_spec"] + + if in_keys_actor is None and proof_environment.from_pixels: + in_keys_actor = ["observation_pixels"] + in_keys_critic = ["observation_pixels"] + elif in_keys_actor is None: + in_keys_actor = ["observation_vector"] + in_keys_critic = ["observation_vector"] + out_keys = ["action"] + + if action_spec.domain == "continuous": + out_features = (2 - args.gSDE) * action_spec.shape[-1] + if args.gSDE: + policy_distribution_kwargs = { + "tanh_loc": args.tanh_loc, + } + policy_distribution_class = IndependentNormal + elif args.distribution == "tanh_normal": + policy_distribution_kwargs = { + "min": action_spec.space.minimum, + "max": action_spec.space.maximum, + "tanh_loc": args.tanh_loc, + } + policy_distribution_class = TanhNormal + elif args.distribution == "truncated_normal": + policy_distribution_kwargs = { + "min": action_spec.space.minimum, + "max": action_spec.space.maximum, + "tanh_loc": args.tanh_loc, + } + policy_distribution_class = TruncatedNormal + elif action_spec.domain == "discrete": + out_features = action_spec.shape[-1] + policy_distribution_kwargs = {} + policy_distribution_class = OneHotCategorical + else: + raise NotImplementedError( + f"actions with domain {action_spec.domain} are not supported" + ) + + if args.shared_mapping: + hidden_features = 300 + if proof_environment.from_pixels: + if in_keys_actor is None: + in_keys_actor = ["observation_pixels"] + common_module = ConvNet( + bias_last_layer=True, + depth=None, + num_cells=[32, 64, 64], + kernel_sizes=[8, 4, 3], + strides=[4, 2, 1], + ) + if args.gSDE: + raise NotImplementedError("must define the hidden_features accordingly") + else: + if args.lstm: + raise NotImplementedError( + "lstm not yet compatible with shared mapping for PPO" + ) + common_module = MLP( + num_cells=[ + 400, + ], + out_features=hidden_features, + activate_last_layer=True, + ) + common_operator = TDModule( + spec=None, + module=common_module, + in_keys=in_keys_actor, + out_keys=["hidden"], + ) + + policy_net = MLP( + num_cells=[200], + out_features=out_features, + ) + if not args.gSDE: + actor_net = NormalParamWrapper( + policy_net, scale_mapping=f"biased_softplus_{args.default_policy_scale}" + ) + in_keys = ["hidden"] + else: + actor_net = gSDEWrapper( + policy_net, action_dim=action_spec.shape[0], state_dim=hidden_features + ) + in_keys = ["hidden", "gSDE_noise"] + out_keys += ["_action_duplicate"] + + policy_operator = ProbabilisticActor( + spec=action_spec, + module=actor_net, + in_keys=in_keys, + out_keys=out_keys, + default_interaction_mode="random" if not args.gSDE else "net_output", + distribution_class=policy_distribution_class, + distribution_kwargs=policy_distribution_kwargs, + return_log_prob=True, + save_dist_params=True, + ) + value_net = MLP( + num_cells=[200], + out_features=1, + ) + value_operator = ValueOperator(value_net, in_keys=["hidden"]) + actor_value = ActorValueOperator( + common_operator=common_operator, + policy_operator=policy_operator, + value_operator=value_operator, + ).to(device) + else: + if args.lstm: + policy_net = LSTMNet( + out_features=out_features, + lstm_kwargs={"input_size": 256, "hidden_size": 256}, + mlp_kwargs={"num_cells": [256, 256], "out_features": 256}, + ) + in_keys_actor += ["hidden0", "hidden1"] + out_keys += ["hidden0", "hidden1", "next_hidden0", "next_hidden1"] + else: + policy_net = MLP( + num_cells=[400, 300], + out_features=out_features, + ) + + if not args.gSDE: + actor_net = NormalParamWrapper( + policy_net, scale_mapping=f"biased_softplus_{args.default_policy_scale}" + ) + else: + actor_net = gSDEWrapper( + policy_net, action_dim=action_spec.shape[0], state_dim=obs_spec.shape[0] + ) + in_keys_actor += ["_eps_gSDE"] + out_keys += ["_action_duplicate"] + + policy_po = ProbabilisticActor( + actor_net, + action_spec, + distribution_class=policy_distribution_class, + distribution_kwargs=policy_distribution_kwargs, + in_keys=in_keys_actor, + out_keys=out_keys, + return_log_prob=True, + save_dist_params=True, + default_interaction_mode="random" if not args.gSDE else "net_output", + ) + + value_net = MLP( + num_cells=[400, 300], + out_features=1, + ) + value_po = ValueOperator( + value_net, + in_keys=in_keys_critic, + ) + actor_value = ActorCriticWrapper(policy_po, value_po).to(device) + + with torch.no_grad(): + td = proof_environment.reset() + td_device = td.to(device) + td_device = td_device.unsqueeze(0) + td_device = actor_value(td_device) # for init + return actor_value + + +def make_sac_model( + proof_environment: _EnvClass, + args: Namespace, + device: DEVICE_TYPING = "cpu", + in_keys: Optional[Sequence[str]] = None, + actor_net_kwargs=None, + qvalue_net_kwargs=None, + value_net_kwargs=None, + **kwargs, +) -> nn.ModuleList: + """ + Actor, Q-value and value model constructor helper function for SAC. + + Follows default parameters proposed in SAC original paper: https://arxiv.org/pdf/1801.01290.pdf. + Other configurations can easily be implemented by modifying this function at will. + + Args: + proof_environment (_EnvClass): a dummy environment to retrieve the observation and action spec + args (argparse.Namespace): arguments of the SAC script + device (torch.device, optional): device on which the model must be cast. Default is "cpu". + in_keys (iterable of strings, optional): observation key to be read by the actor, usually one of + `'observation_vector'` or `'observation_pixels'`. If none is provided, one of these two keys is chosen + based on the `args.from_pixels` argument. + actor_net_kwargs (dict, optional): kwargs of the actor MLP. + qvalue_net_kwargs (dict, optional): kwargs of the qvalue MLP. + value_net_kwargs (dict, optional): kwargs of the value MLP. + + Returns: + A nn.ModuleList containing the actor, qvalue operator(s) and the value operator. + + Examples: + >>> from torchrl.agents.helpers.envs import parser_env_args + >>> from torchrl.agents.helpers.models import make_sac_model, parser_model_args_continuous + >>> from torchrl.envs import GymEnv + >>> from torchrl.envs.transforms import CatTensors, TransformedEnv, DoubleToFloat, Compose + >>> import argparse + >>> proof_environment = TransformedEnv(GymEnv("HalfCheetah-v2"), Compose(DoubleToFloat(["next_observation"]), + ... CatTensors(["next_observation"], "next_observation_vector"))) + >>> device = torch.device("cpu") + >>> parser = argparse.ArgumentParser() + >>> parser = parser_env_args(parser) + >>> args = parser_model_args_continuous(parser, + ... algorithm="SAC").parse_args([]) + >>> model = make_sac_model( + ... proof_environment, + ... device=device, + ... args=args, + ... ) + >>> actor, qvalue, value = model + >>> td = proof_environment.reset() + >>> print(actor(td)) + TensorDict( + fields={ + done: Tensor(torch.Size([1]), dtype=torch.bool), + observation_vector: Tensor(torch.Size([17]), dtype=torch.float32), + action: Tensor(torch.Size([6]), dtype=torch.float32)}, + batch_size=torch.Size([]), + device=cpu, + is_shared=False) + >>> print(qvalue(td.clone())) + TensorDict( + fields={ + done: Tensor(torch.Size([1]), dtype=torch.bool), + observation_vector: Tensor(torch.Size([17]), dtype=torch.float32), + action: Tensor(torch.Size([6]), dtype=torch.float32), + state_action_value: Tensor(torch.Size([1]), dtype=torch.float32)}, + batch_size=torch.Size([]), + device=cpu, + is_shared=False) + >>> print(value(td.clone())) + TensorDict( + fields={ + done: Tensor(torch.Size([1]), dtype=torch.bool), + observation_vector: Tensor(torch.Size([17]), dtype=torch.float32), + action: Tensor(torch.Size([6]), dtype=torch.float32), + state_value: Tensor(torch.Size([1]), dtype=torch.float32)}, + batch_size=torch.Size([]), + device=cpu, + is_shared=False) + + """ + tanh_loc = args.tanh_loc + default_policy_scale = args.default_policy_scale + gSDE = args.gSDE + + td = proof_environment.reset() + action_spec = proof_environment.action_spec + obs_spec = proof_environment.observation_spec + if actor_net_kwargs is None: + actor_net_kwargs = {} + if value_net_kwargs is None: + value_net_kwargs = {} + if qvalue_net_kwargs is None: + qvalue_net_kwargs = {} + + if in_keys is None: + in_keys = ["observation_vector"] + + actor_net_kwargs_default = { + "num_cells": [256, 256], + "out_features": (2 - gSDE) * action_spec.shape[-1], + "activation_class": nn.ELU, + } + actor_net_kwargs_default.update(actor_net_kwargs) + actor_net = MLP(**actor_net_kwargs_default) + + qvalue_net_kwargs_default = { + "num_cells": [256, 256], + "out_features": 1, + "activation_class": nn.ELU, + } + qvalue_net_kwargs_default.update(qvalue_net_kwargs) + qvalue_net = MLP( + **qvalue_net_kwargs_default, + ) + + value_net_kwargs_default = { + "num_cells": [256, 256], + "out_features": 1, + "activation_class": nn.ELU, + } + value_net_kwargs_default.update(value_net_kwargs) + value_net = MLP( + **value_net_kwargs_default, + ) + + if not gSDE: + actor_net = NormalParamWrapper( + actor_net, scale_mapping=f"biased_softplus_{default_policy_scale}" + ) + in_keys_actor = in_keys + dist_class = TanhNormal + dist_kwargs = { + "min": action_spec.space.minimum, + "max": action_spec.space.maximum, + "tanh_loc": tanh_loc, + } + else: + if isinstance(obs_spec, CompositeSpec): + obs_spec = obs_spec["vector"] + obs_spec_len = obs_spec.shape[0] + actor_net = gSDEWrapper( + actor_net, action_dim=action_spec.shape[0], state_dim=obs_spec_len + ) + in_keys_actor = in_keys + ["_eps_gSDE"] + dist_class = IndependentNormal + dist_kwargs = { + "tanh_loc": tanh_loc, + } + + actor = ProbabilisticActor( + spec=action_spec, + in_keys=in_keys_actor, + module=actor_net, + distribution_class=dist_class, + distribution_kwargs=dist_kwargs, + default_interaction_mode="random" if not gSDE else "net_output", + safe=True, + ) + qvalue = ValueOperator( + in_keys=["action"] + in_keys, + module=qvalue_net, + ) + value = ValueOperator( + in_keys=in_keys, + module=value_net, + ) + model = nn.ModuleList([actor, qvalue, value]).to(device) + + # init nets + td = td.to(device) + for net in model: + net(td) + del td + + return model + + +def make_redq_model( + proof_environment: _EnvClass, + args: Namespace, + device: DEVICE_TYPING = "cpu", + in_keys: Optional[Sequence[str]] = None, + actor_net_kwargs=None, + qvalue_net_kwargs=None, + **kwargs, +) -> nn.ModuleList: + """ + Actor and Q-value model constructor helper function for REDQ. + Follows default parameters proposed in REDQ original paper: https://openreview.net/pdf?id=AY8zfZm0tDd. + Other configurations can easily be implemented by modifying this function at will. + A single instance of the Q-value model is returned. It will be multiplicated by the loss function. + + Args: + proof_environment (_EnvClass): a dummy environment to retrieve the observation and action spec + args (argparse.Namespace): arguments of the REDQ script + device (torch.device, optional): device on which the model must be cast. Default is "cpu". + in_keys (iterable of strings, optional): observation key to be read by the actor, usually one of + `'observation_vector'` or `'observation_pixels'`. If none is provided, one of these two keys is chosen + based on the `args.from_pixels` argument. + actor_net_kwargs (dict, optional): kwargs of the actor MLP. + qvalue_net_kwargs (dict, optional): kwargs of the qvalue MLP. + + Returns: + A nn.ModuleList containing the actor, qvalue operator(s) and the value operator. + + Examples: + >>> from torchrl.agents.helpers.envs import parser_env_args + >>> from torchrl.agents.helpers.models import make_redq_model, parser_model_args_continuous + >>> from torchrl.envs import GymEnv + >>> from torchrl.envs.transforms import CatTensors, TransformedEnv, DoubleToFloat, Compose + >>> import argparse + >>> proof_environment = TransformedEnv(GymEnv("HalfCheetah-v2"), Compose(DoubleToFloat(["next_observation"]), + ... CatTensors(["next_observation"], "next_observation_vector"))) + >>> device = torch.device("cpu") + >>> parser = argparse.ArgumentParser() + >>> parser = parser_env_args(parser) + >>> args = parser_model_args_continuous(parser, + ... algorithm="REDQ").parse_args([]) + >>> model = make_redq_model( + ... proof_environment, + ... device=device, + ... args=args, + ... ) + >>> actor, qvalue = model + >>> td = proof_environment.reset() + >>> print(actor(td)) + TensorDict( + fields={ + done: Tensor(torch.Size([1]), dtype=torch.bool), + observation_vector: Tensor(torch.Size([17]), dtype=torch.float32), + action: Tensor(torch.Size([6]), dtype=torch.float32), + action_log_prob: Tensor(torch.Size([1]), dtype=torch.float32)}, + batch_size=torch.Size([]), + device=cpu, + is_shared=False) + >>> print(qvalue(td.clone())) + TensorDict( + fields={ + done: Tensor(torch.Size([1]), dtype=torch.bool), + observation_vector: Tensor(torch.Size([17]), dtype=torch.float32), + action: Tensor(torch.Size([6]), dtype=torch.float32), + action_log_prob: Tensor(torch.Size([1]), dtype=torch.float32), + state_action_value: Tensor(torch.Size([1]), dtype=torch.float32)}, + batch_size=torch.Size([]), + device=cpu, + is_shared=False) + + """ + + tanh_loc = args.tanh_loc + default_policy_scale = args.default_policy_scale + gSDE = args.gSDE + + td = proof_environment.reset() + action_spec = proof_environment.action_spec + obs_spec = proof_environment.observation_spec + + if actor_net_kwargs is None: + actor_net_kwargs = {} + if qvalue_net_kwargs is None: + qvalue_net_kwargs = {} + + if in_keys is None: + in_keys = ["observation_vector"] + + actor_net_kwargs_default = { + "num_cells": [256, 256], + "out_features": (2 - gSDE) * action_spec.shape[-1], + "activation_class": nn.ELU, + } + actor_net_kwargs_default.update(actor_net_kwargs) + actor_net = MLP(**actor_net_kwargs_default) + + qvalue_net_kwargs_default = { + "num_cells": [256, 256], + "out_features": 1, + "activation_class": nn.ELU, + } + qvalue_net_kwargs_default.update(qvalue_net_kwargs) + qvalue_net = MLP( + **qvalue_net_kwargs_default, + ) + + if not gSDE: + actor_net = NormalParamWrapper( + actor_net, scale_mapping=f"biased_softplus_{default_policy_scale}" + ) + in_keys_actor = in_keys + dist_class = TanhNormal + dist_kwargs = { + "min": action_spec.space.minimum, + "max": action_spec.space.maximum, + "tanh_loc": tanh_loc, + } + else: + if isinstance(obs_spec, CompositeSpec): + obs_spec = obs_spec["vector"] + obs_spec_len = obs_spec.shape[0] + actor_net = gSDEWrapper( + actor_net, action_dim=action_spec.shape[0], state_dim=obs_spec_len + ) + in_keys_actor = in_keys + ["_eps_gSDE"] + dist_class = IndependentNormal + dist_kwargs = { + "tanh_loc": tanh_loc, + } + + actor = ProbabilisticActor( + spec=action_spec, + in_keys=in_keys_actor, + module=actor_net, + distribution_class=dist_class, + distribution_kwargs=dist_kwargs, + default_interaction_mode="random" if not args.gSDE else "net_output", + return_log_prob=True, + ) + qvalue = ValueOperator( + in_keys=["action"] + in_keys, + module=qvalue_net, + ) + model = nn.ModuleList([actor, qvalue]).to(device) + + # init nets + td = td.to(device) + for net in model: + net(td) + del td + return model + + +def parser_model_args_continuous( + parser: ArgumentParser, algorithm: str +) -> ArgumentParser: + """ + Populates the argument parser to build a model for continuous actions. + + Args: + parser (ArgumentParser): parser to be populated. + algorithm (str): one of `"DDPG"`, `"SAC"`, `"REDQ"`, `"PPO"` + + """ + + if algorithm not in ("SAC", "DDPG", "PPO", "REDQ"): + raise NotImplementedError(f"Unknown algorithm {algorithm}") + + if algorithm in ("SAC", "DDPG", "REDQ"): + parser.add_argument( + "--annealing_frames", + type=int, + default=1000000, + help="float of frames used for annealing of the OrnsteinUhlenbeckProcess. Default=1e6.", + ) + parser.add_argument( + "--noisy", + action="store_true", + help="whether to use NoisyLinearLayers in the value network.", + ) + parser.add_argument( + "--ou_exploration", + action="store_true", + help="wraps the policy in an OU exploration wrapper, similar to DDPG. SAC being designed for " + "efficient entropy-based exploration, this should be left for experimentation only.", + ) + parser.add_argument( + "--distributional", + action="store_true", + help="whether a distributional loss should be used (TODO: not implemented yet).", + ) + parser.add_argument( + "--atoms", + type=int, + default=51, + help="number of atoms used for the distributional loss (TODO)", + ) + + if algorithm in ("SAC", "PPO", "REDQ"): + parser.add_argument( + "--tanh_loc", + "--tanh-loc", + action="store_true", + help="if True, uses a Tanh-Normal transform for the policy " + "location of the form " + "`upscale * tanh(loc/upscale)` (only available with " + "TanhTransform and TruncatedGaussian distributions)", + ) + parser.add_argument( + "--gSDE", + action="store_true", + help="if True, exploration is achieved using the gSDE technique.", + ) + parser.add_argument( + "--default_policy_scale", + default=1.0, + help="Default policy scale parameter", + ) + parser.add_argument( + "--distribution", + type=str, + default="tanh_normal", + help="if True, uses a Tanh-Normal-Tanh distribution for the policy", + ) + if algorithm == "PPO": + parser.add_argument( + "--lstm", + action="store_true", + help="if True, uses an LSTM for the policy.", + ) + parser.add_argument( + "--shared_mapping", + "--shared-mapping", + action="store_true", + help="if True, the first layers of the actor-critic are shared.", + ) + + return parser + + +def parser_model_args_discrete(parser: ArgumentParser) -> ArgumentParser: + """ + Populates the argument parser to build a model for discrete actions. + + Args: + parser (ArgumentParser): parser to be populated. + + """ + parser.add_argument( + "--annealing_frames", + type=int, + default=1000000, + help="Number of frames used for annealing of the EGreedy exploration. Default=1e6.", + ) + + parser.add_argument( + "--noisy", + action="store_true", + help="whether to use NoisyLinearLayers in the value network.", + ) + parser.add_argument( + "--distributional", + action="store_true", + help="whether a distributional loss should be used.", + ) + parser.add_argument( + "--atoms", + type=int, + default=51, + help="number of atoms used for the distributional loss", + ) + + return parser diff --git a/torchrl/agents/helpers/recorder.py b/torchrl/agents/helpers/recorder.py new file mode 100644 index 00000000000..34bbf1fd737 --- /dev/null +++ b/torchrl/agents/helpers/recorder.py @@ -0,0 +1,46 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from argparse import ArgumentParser + +__all__ = ["parser_recorder_args"] + + +def parser_recorder_args(parser: ArgumentParser) -> ArgumentParser: + """ + Populates the argument parser to build a recorder. + + Args: + parser (ArgumentParser): parser to be populated. + + """ + + parser.add_argument( + "--record_video", + action="store_true", + help="whether a video of the task should be rendered during logging.", + ) + parser.add_argument( + "--exp_name", + type=str, + default="", + help="experiment name. Used for logging directory. " + "A date and uuid will be joined to account for multiple experiments with the same name.", + ) + parser.add_argument( + "--record_interval", + type=int, + default=10000, + help="number of optimization steps in between two collections of validation rollouts. " + "Default=10000.", + ) + parser.add_argument( + "--record_frames", + type=int, + default=1000, + help="number of steps in validation rollouts. " "Default=1000.", + ) + + return parser diff --git a/torchrl/agents/helpers/replay_buffer.py b/torchrl/agents/helpers/replay_buffer.py new file mode 100644 index 00000000000..c43a2b9c95c --- /dev/null +++ b/torchrl/agents/helpers/replay_buffer.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from argparse import ArgumentParser, Namespace + +import torch + +from torchrl.data import ( + DEVICE_TYPING, + ReplayBuffer, + TensorDictPrioritizedReplayBuffer, +) + +__all__ = ["make_replay_buffer", "parser_replay_args"] + + +def make_replay_buffer(device: DEVICE_TYPING, args: Namespace) -> ReplayBuffer: + """Builds a replay buffer using the arguments build from the parser returned by parser_replay_args.""" + device = torch.device(device) + if not args.prb: + buffer = ReplayBuffer( + args.buffer_size, + # collate_fn=InPlaceSampler(device), + pin_memory=device != torch.device("cpu"), + prefetch=3, + ) + else: + buffer = TensorDictPrioritizedReplayBuffer( + args.buffer_size, + alpha=0.7, + beta=0.5, + # collate_fn=InPlaceSampler(device), + pin_memory=device != torch.device("cpu"), + prefetch=3, + ) + return buffer + + +def parser_replay_args(parser: ArgumentParser) -> ArgumentParser: + """ + Populates the argument parser to build a replay buffer. + + Args: + parser (ArgumentParser): parser to be populated. + + """ + + parser.add_argument( + "--buffer_size", + type=int, + default=1000000, + help="buffer size, in number of frames stored. Default=1e6", + ) + parser.add_argument( + "--prb", + action="store_true", + help="whether a Prioritized replay buffer should be used instead of a more basic circular one.", + ) + return parser diff --git a/torchrl/collectors/__init__.py b/torchrl/collectors/__init__.py new file mode 100644 index 00000000000..5de5a86f148 --- /dev/null +++ b/torchrl/collectors/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .collectors import * diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py new file mode 100644 index 00000000000..771e1f4f295 --- /dev/null +++ b/torchrl/collectors/collectors.py @@ -0,0 +1,1305 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import abc +import math +import queue +import time +from collections import OrderedDict +from copy import deepcopy +from multiprocessing import connection, queues +from textwrap import indent +from typing import Callable, Iterator, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +from torch import multiprocessing as mp +from torch.utils.data import IterableDataset + +from torchrl.envs.utils import set_exploration_mode, step_tensor_dict +from torchrl.modules import ProbabilisticTDModule +from .utils import split_trajectories + +__all__ = [ + "SyncDataCollector", + "aSyncDataCollector", + "MultiaSyncDataCollector", + "MultiSyncDataCollector", +] + +from torchrl.envs.transforms import TransformedEnv +from ..data import TensorSpec +from ..data.tensordict.tensordict import _TensorDict, TensorDict +from ..data.utils import CloudpickleWrapper, DEVICE_TYPING +from ..envs.common import _EnvClass +from ..envs.vec_env import _BatchedEnv + +_TIMEOUT = 1.0 +_MIN_TIMEOUT = 1e-3 # should be several orders of magnitude inferior wrt time spent collecting a trajectory + + +class RandomPolicy: + def __init__(self, action_spec: TensorSpec): + """Random policy for a given action_spec. + This is a wrapper around the action_spec.rand method. + + + $ python example_google.py + + Args: + action_spec: TensorSpec object describing the action specs + + Examples: + >>> from torchrl.data.tensor_specs import NdBoundedTensorSpec + >>> from torchrl.data.tensordict import TensorDict + >>> action_spec = NdBoundedTensorSpec(-torch.ones(3), torch.ones(3)) + >>> actor = RandomPolicy(spec=action_spec) + >>> td = actor(TensorDict(batch_size=[])) # selects a random action in the cube [-1; 1] + """ + self.action_spec = action_spec + + def __call__(self, td: _TensorDict) -> _TensorDict: + return td.set("action", self.action_spec.rand(td.batch_size)) + + +def recursive_map_to_cpu(dictionary: OrderedDict) -> OrderedDict: + return OrderedDict( + **{ + k: recursive_map_to_cpu(item) + if isinstance(item, OrderedDict) + else item.cpu() + for k, item in dictionary.items() + } + ) + + +class _DataCollector(IterableDataset, metaclass=abc.ABCMeta): + def _get_policy_and_device( + self, + create_env_fn: Optional[ + Union[_EnvClass, "EnvCreator", Sequence[Callable[[], _EnvClass]]] + ] = None, + create_env_kwargs: Optional[dict] = None, + policy: Optional[ + Union[ProbabilisticTDModule, Callable[[_TensorDict], _TensorDict]] + ] = None, + device: Optional[DEVICE_TYPING] = None, + ) -> Tuple[ProbabilisticTDModule, torch.device, Union[None, Callable[[], dict]]]: + """From a policy and a device, assigns the self.device attribute to + the desired device and maps the policy onto it or (if the device is + ommitted) assigns the self.device attribute to the policy device. + + Args: + create_env_fn (Callable or list of callables): an env creator + function (or a list of creators) + create_env_kwargs (dictionary): kwargs for the env creator + policy (ProbabilisticTDModule, optional): a policy to be used + device (int, str or torch.device, optional): device where to place + the policy + + """ + if create_env_fn is not None: + if create_env_kwargs is None: + create_env_kwargs = dict() + self.create_env_fn = create_env_fn + if isinstance(create_env_fn, _EnvClass): + env = create_env_fn + else: + env = self.create_env_fn(**create_env_kwargs) + else: + env = None + + if policy is None: + if env is None: + raise ValueError( + "env must be provided to _get_policy_and_device if policy is None" + ) + policy = RandomPolicy(env.action_spec) + try: + policy_device = next(policy.parameters()).device + except: # noqa + policy_device = ( + torch.device(device) if device is not None else torch.device("cpu") + ) + + device = torch.device(device) if device is not None else policy_device + if device is None: + # if device cannot be found in policy and is not specified, set cpu + device = torch.device("cpu") + get_weights_fn = None + if policy_device != device: + get_weights_fn = policy.state_dict + policy = deepcopy(policy).requires_grad_(False).to(device) + policy.share_memory() + # if not (len(list(policy.parameters())) == 0 or next(policy.parameters()).is_shared()): + # raise RuntimeError("Provided policy parameters must be shared.") + return policy, device, get_weights_fn + + def update_policy_weights_(self) -> None: + """Update the policy weights if the policy of the data collector and the trained policy live on different devices.""" + if self.get_weights_fn is not None: + self.policy.load_state_dict(self.get_weights_fn()) + + def __iter__(self) -> Iterator[_TensorDict]: + return self.iterator() + + @abc.abstractmethod + def iterator(self) -> Iterator[_TensorDict]: + raise NotImplementedError + + @abc.abstractmethod + def set_seed(self, seed: int) -> int: + raise NotImplementedError + + @abc.abstractmethod + def state_dict(self, destination: Optional[OrderedDict] = None) -> OrderedDict: + raise NotImplementedError + + @abc.abstractmethod + def load_state_dict(self, state_dict: OrderedDict) -> None: + raise NotImplementedError + + def __repr__(self) -> str: + string = f"{self.__class__.__name__}()" + return string + + +class SyncDataCollector(_DataCollector): + """ + Generic data collector for RL problems. Requires and environment constructor and a policy. + + Args: + create_env_fn (Callable), returns an instance of _EnvClass class. + policy (Callable, optional): Policy to be executed in the environment. + Must accept _TensorDict object as input. + total_frames (int): lower bound of the total number of frames returned by the collector. The iterator will + stop once the total number of frames equates or exceeds the total number of frames passed to the + collector. + create_env_kwargs (dict, optional): Dictionary of kwargs for create_env_fn. + max_frames_per_traj (int, optional): Maximum steps per trajectory. Note that a trajectory can span over multiple batches + (unless reset_at_each_iter is set to True, see below). Once a trajectory reaches n_steps_max, + the environment is reset. If the environment wraps multiple environments together, the number of steps + is tracked for each environment independently. Negative values are allowed, in which case this argument + is ignored. + default: -1 (i.e. no maximum number of steps) + frames_per_batch (int): Time-length of a batch. + reset_at_each_iter and frames_per_batch == n_steps_max are equivalent configurations. + default: 200 + init_random_frames (int, optional): Number of frames for which the policy is ignored before it is called. + This feature is mainly intended to be used in offline/model-based settings, where a batch of random + trajectories can be used to initialize training. + default=-1 (i.e. no random frames) + reset_at_each_iter (bool): Whether or not environments should be reset for each batch. + default=False. + postproc (Callable, optional): A Batcher is an object that will read a batch of data and return it in a useful format for training. + default: None. + split_trajs (bool): Boolean indicating whether the resulting TensorDict should be split according to the trajectories. + See utils.split_trajectories for more information. + device (int, str or torch.device, optional): The device on which the policy will be placed. + If it differs from the input policy device, the update_policy_weights_() method should be queried + at appropriate times during the training loop to accommodate for the lag between parameter configuration + at various times. + default = None (i.e. policy is kept on its original device) + seed (int, optional): seed to be used for torch and numpy. + pin_memory (bool): whether pin_memory() should be called on the outputs. + passing_device (int, str or torch.device, optional): The device on which the output TensorDict will be stored. + For long trajectories, it may be necessary to store the data on a different device than the one where + the policy is stored. + default = "cpu" + return_in_place (bool): if True, the collector will yield the same tensordict container with updated values + at each iteration. + default = False + exploration_mode (str, optional): interaction mode to be used when collecting data. Must be one of "random", + "mode", "mean" or "net_output". + default = "random" + init_with_lag (bool, optional): if True, the first trajectory will be truncated earlier at a random step. + This is helpful to desynchronize the environments, such that steps do no match in all collected rollouts. + default = True + + """ + + def __init__( + self, + create_env_fn: Union[ + _EnvClass, "EnvCreator", Sequence[Callable[[], _EnvClass]] + ], + policy: Optional[ + Union[ProbabilisticTDModule, Callable[[_TensorDict], _TensorDict]] + ] = None, + total_frames: Optional[int] = -1, + create_env_kwargs: Optional[dict] = None, + max_frames_per_traj: int = -1, + frames_per_batch: int = 200, + init_random_frames: int = -1, + reset_at_each_iter: bool = False, + postproc: Optional[Callable[[_TensorDict], _TensorDict]] = None, + split_trajs: bool = True, + device: DEVICE_TYPING = None, + passing_device: DEVICE_TYPING = "cpu", + seed: Optional[int] = None, + pin_memory: bool = False, + return_in_place: bool = False, + exploration_mode: str = "random", + init_with_lag: bool = False, + ): + if seed is not None: + torch.manual_seed(seed) + np.random.seed(seed) + + if create_env_kwargs is None: + create_env_kwargs = {} + if not isinstance(create_env_fn, _EnvClass): + env = create_env_fn(**create_env_kwargs) + else: + env = create_env_fn + + self.env: _EnvClass = env + self.n_env = self.env.numel() + + (self.policy, self.device, self.get_weights_fn,) = self._get_policy_and_device( + create_env_fn=create_env_fn, + create_env_kwargs=create_env_kwargs, + policy=policy, + device=device, + ) + + self.env_device = env.device + if not total_frames > 0: + total_frames = float("inf") + self.total_frames = total_frames + self.reset_at_each_iter = reset_at_each_iter + self.init_random_frames = init_random_frames + self.postproc = postproc + if self.postproc is not None: + self.postproc.to(self.passing_device) + self.max_frames_per_traj = max_frames_per_traj + self.frames_per_batch = -(-frames_per_batch // self.n_env) + self.pin_memory = pin_memory + self.exploration_mode = exploration_mode + self.init_with_lag = init_with_lag and max_frames_per_traj > 0 + + self.passing_device = torch.device(passing_device) + + self._tensor_dict = env.reset().to(self.passing_device) + self._tensor_dict.set( + "step_count", torch.zeros(*self.env.batch_size, 1, dtype=torch.int) + ) + self._tensor_dict_out = TensorDict( + {}, + batch_size=[*self.env.batch_size, self.frames_per_batch], + device=self.passing_device, + ) + + self.return_in_place = return_in_place + self.split_trajs = split_trajs + if self.return_in_place and self.split_trajs: + raise RuntimeError( + "the 'return_in_place' and 'split_trajs' argument are incompatible, but found to be both " + "True. split_trajs=True will cause the output tensordict to have an unpredictable output " + "shape, which prevents caching and overwriting the tensors." + ) + self._td_env = None + self._td_policy = None + self._has_been_done = None + self.closed = False + + def set_seed(self, seed: int) -> int: + """Sets the seeds of the environments stored in the DataCollector. + + Args: + seed (int): integer representing the seed to be used for the environment. + + Returns: + Output seed. This is useful when more than one environment is contained in the DataCollector, as the + seed will be incremented for each of these. The resulting seed is the seed of the last environment. + + Examples: + >>> env_fn = lambda: GymEnv("Pendulum-v1") + >>> env_fn_parallel = lambda: ParallelEnv(6, env_fn) + >>> collector = SyncDataCollector(env_fn_parallel) + >>> out_seed = collector.set_seed(1) # out_seed = 6 + + """ + return self.env.set_seed(seed) + + def iterator(self) -> Iterator[_TensorDict]: + """Iterates through the DataCollector. + + Yields: _TensorDict objects containing (chunks of) trajectories + + """ + total_frames = self.total_frames + i = -1 + self._frames = 0 + while True: + i += 1 + self._iter = i + tensor_dict_out = self.rollout() + self._frames += tensor_dict_out.numel() + if self._frames >= total_frames: + self.env.close() + + if self.split_trajs: + tensor_dict_out = split_trajectories(tensor_dict_out) + if self.postproc is not None: + tensor_dict_out = self.postproc(tensor_dict_out) + yield tensor_dict_out + del tensor_dict_out + if self._frames >= self.total_frames: + break + + def _cast_to_policy(self, td: _TensorDict) -> _TensorDict: + policy_device = self.device + if hasattr(self.policy, "in_keys"): + td = td.select(*self.policy.in_keys) + if self._td_policy is None: + self._td_policy = td.to(policy_device) + else: + if td.device == torch.device("cpu") and self.pin_memory: + td.pin_memory() + self._td_policy.update(td, inplace=True) + return self._td_policy + + def _cast_to_env( + self, td: _TensorDict, dest: Optional[_TensorDict] = None + ) -> _TensorDict: + env_device = self.env_device + if dest is None: + if self._td_env is None: + self._td_env = td.to(env_device) + else: + self._td_env.update(td, inplace=True) + return self._td_env + else: + return dest.update(td, inplace=True) + + def _reset_if_necessary(self) -> None: + done = self._tensor_dict.get("done") + steps = self._tensor_dict.get("step_count") + done_or_terminated = done | (steps == self.max_frames_per_traj) + if self._has_been_done is None: + self._has_been_done = done_or_terminated + else: + self._has_been_done = self._has_been_done | done_or_terminated + if not self._has_been_done.all() and self.init_with_lag: + _reset = torch.zeros_like(done_or_terminated).bernoulli_( + 1 / self.max_frames_per_traj + ) + _reset[self._has_been_done] = False + done_or_terminated = done_or_terminated | _reset + if done_or_terminated.any(): + traj_ids = self._tensor_dict.get("traj_ids").clone() + steps = steps.clone() + if len(self.env.batch_size): + self._tensor_dict.masked_fill_(done_or_terminated.squeeze(-1), 0) + self._tensor_dict.set("reset_workers", done_or_terminated) + else: + self._tensor_dict.zero_() + self.env.reset(tensor_dict=self._tensor_dict) + if self._tensor_dict.get("done").any(): + raise RuntimeError( + f"Got {sum(self._tensor_dict.get('done'))} done envs after reset." + ) + if len(self.env.batch_size): + self._tensor_dict.del_("reset_workers") + traj_ids[done_or_terminated] = traj_ids.max() + torch.arange( + 1, done_or_terminated.sum() + 1, device=traj_ids.device + ) + steps[done_or_terminated] = 0 + self._tensor_dict.set("traj_ids", traj_ids) # no ops if they already match + self._tensor_dict.set("step_count", steps) + + @torch.no_grad() + def rollout(self) -> _TensorDict: + """Computes a rollout in the environment using the provided policy. + + Returns: + _TensorDict containing the computed rollout. + + """ + if self.reset_at_each_iter: + self._tensor_dict.update(self.env.reset()) + self._tensor_dict.fill_("step_count", 0) + + n = self.env.batch_size[0] if len(self.env.batch_size) else 1 + self._tensor_dict.set("traj_ids", torch.arange(n).unsqueeze(-1)) + + tensor_dict_out = [] + with set_exploration_mode(self.exploration_mode): + for t in range(self.frames_per_batch): + if self._frames < self.init_random_frames: + self.env.rand_step(self._tensor_dict) + else: + td_cast = self._cast_to_policy(self._tensor_dict) + td_cast = self.policy(td_cast) + self._cast_to_env(td_cast, self._tensor_dict) + self.env.step(self._tensor_dict) + + step_count = self._tensor_dict.get("step_count") + step_count += 1 + tensor_dict_out.append(self._tensor_dict.clone()) + + self._reset_if_necessary() + self._tensor_dict.update(step_tensor_dict(self._tensor_dict)) + if self.return_in_place and len(self._tensor_dict_out.keys()) > 0: + tensor_dict_out = torch.stack(tensor_dict_out, len(self.env.batch_size)) + tensor_dict_out = tensor_dict_out.select(*self._tensor_dict_out.keys()) + return self._tensor_dict_out.update_(tensor_dict_out) + return torch.stack( + tensor_dict_out, + len(self.env.batch_size), + out=self._tensor_dict_out, + ) # dim 0 for single env, dim 1 for batch + + def reset(self, index=None, **kwargs) -> None: + """Resets the environments to a new initial state.""" + if index is not None: + # check that the env supports partial reset + if np.prod(self.env.batch_size) == 0: + raise RuntimeError("resetting unique env with index is not permitted.") + reset_workers = torch.zeros( + *self.env.batch_size, + 1, + dtype=torch.bool, + device=self.env.device, + ) + reset_workers[index] = 1 + td_in = TensorDict({"reset_workers": reset_workers}, self.env.batch_size) + self._tensor_dict[index].zero_() + else: + td_in = None + self._tensor_dict.zero_() + + self._tensor_dict.update(self.env.reset(td_in, **kwargs)) + self._tensor_dict.fill_("step_count", 0) + + def shutdown(self) -> None: + """Shuts down all workers and/or closes the local environment.""" + if not self.closed: + self.closed = True + del self._tensor_dict, self._tensor_dict_out + self.env.close() + del self.env + + def __del__(self): + self.shutdown() # make sure env is closed + + def state_dict(self, destination: Optional[OrderedDict] = None) -> OrderedDict: + """Returns the local state_dict of the data collector (environment + and policy). + + Args: + destination (optional): ordered dictionary to be updated. + + Returns: + an ordered dictionary with fields `"policy_state_dict"` and + `"env_state_dict"`. + + """ + + if isinstance(self.env, TransformedEnv): + env_state_dict = self.env.transform.state_dict() + elif isinstance(self.env, _BatchedEnv): + env_state_dict = self.env.state_dict() + else: + env_state_dict = OrderedDict() + + if hasattr(self.policy, "state_dict"): + policy_state_dict = self.policy.state_dict() + state_dict = OrderedDict( + policy_state_dict=policy_state_dict, + env_state_dict=env_state_dict, + ) + else: + state_dict = OrderedDict(env_state_dict=env_state_dict) + + if destination is not None: + destination.update(state_dict) + return destination + return state_dict + + def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None: + """Loads a state_dict on the environment and policy. + + Args: + state_dict (OrderedDict): ordered dictionary containing the fields + `"policy_state_dict"` and `"env_state_dict"`. + + """ + strict = kwargs.get("strict", True) + if strict or "env_state_dict" in state_dict: + self.env.load_state_dict(state_dict["env_state_dict"], **kwargs) + if strict or "policy_state_dict" in state_dict: + self.policy.load_state_dict(state_dict["policy_state_dict"], **kwargs) + + def __repr__(self) -> str: + env_str = indent(f"env={self.env}", 4 * " ") + policy_str = indent(f"policy={self.policy}", 4 * " ") + td_out_str = indent(f"td_out={self._tensor_dict_out}", 4 * " ") + string = f"{self.__class__.__name__}(\n{env_str},\n{policy_str},\n{td_out_str})" + return string + + +class _MultiDataCollector(_DataCollector): + """Runs a given number of DataCollectors on separate processes. + + Args: + create_env_fn (list of Callabled): list of Callables, each returning an instance of _EnvClass + policy (Callable, optional): Instance of ProbabilisticTDModule class. Must accept _TensorDict object as input. + total_frames (int): lower bound of the total number of frames returned by the collector. In parallel settings, + the actual number of frames may well be greater than this as the closing signals are sent to the + workers only once the total number of frames has been collected on the server. + create_env_kwargs (dict, optional): A (list of) dictionaries with the arguments used to create an environment + max_frames_per_traj: Maximum steps per trajectory. Note that a trajectory can span over multiple batches + (unless reset_at_each_iter is set to True, see below). Once a trajectory reaches n_steps_max, + the environment is reset. If the environment wraps multiple environments together, the number of steps + is tracked for each environment independently. Negative values are allowed, in which case this argument + is ignored. + default: -1 (i.e. no maximum number of steps) + frames_per_batch (int): Time-length of a batch. + reset_at_each_iter and frames_per_batch == n_steps_max are equivalent configurations. + default: 200 + init_random_frames (int): Number of frames for which the policy is ignored before it is called. + This feature is mainly intended to be used in offline/model-based settings, where a batch of random + trajectories can be used to initialize training. + default=-1 (i.e. no random frames) + reset_at_each_iter (bool): Whether or not environments should be reset for each batch. + default=False. + postproc (callable, optional): A PostProcessor is an object that will read a batch of data and process it in a + useful format for training. + default: None. + split_trajs (bool): Boolean indicating whether the resulting TensorDict should be split according to the trajectories. + See utils.split_trajectories for more information. + devices (int, str, torch.device or sequence of such, optional): The devices on which the policy will be placed. + If it differs from the input policy device, the update_policy_weights_() method should be queried + at appropriate times during the training loop to accommodate for the lag between parameter configuration + at various times. + default = None (i.e. policy is kept on its original device) + passing_devices (int, str, torch.device or sequence of such, optional): The devices on which the output + TensorDict will be stored. For long trajectories, it may be necessary to store the data on a different + device than the one where the policy is stored. + default = "cpu" + update_at_each_batch (bool): if True, the policy weights will be updated every time a batch of trajectories + is collected. + default=False + init_with_lag (bool, optional): if True, the first trajectory will be truncated earlier at a random step. + This is helpful to desynchronize the environments, such that steps do no match in all collected rollouts. + default = True + exploration_mode (str, optional): interaction mode to be used when collecting data. Must be one of "random", + "mode", "mean" or "net_output". + default = "random" + + """ + + def __init__( + self, + create_env_fn: Sequence[Callable[[], _EnvClass]], + policy: Optional[ + Union[ProbabilisticTDModule, Callable[[_TensorDict], _TensorDict]] + ] = None, + total_frames: Optional[int] = -1, + create_env_kwargs: Optional[Sequence[dict]] = None, + max_frames_per_traj: int = -1, + frames_per_batch: int = 200, + init_random_frames: int = -1, + reset_at_each_iter: bool = False, + postproc: Optional[Callable[[_TensorDict], _TensorDict]] = None, + split_trajs: bool = True, + devices: DEVICE_TYPING = None, + seed: Optional[int] = None, + pin_memory: bool = False, + passing_devices: Union[DEVICE_TYPING, Sequence[DEVICE_TYPING]] = "cpu", + update_at_each_batch: bool = False, + init_with_lag: bool = False, + exploration_mode: str = "random", + ): + self.closed = True + self.create_env_fn = create_env_fn + self.num_workers = len(create_env_fn) + self.create_env_kwargs = ( + create_env_kwargs + if create_env_kwargs is not None + else [dict() for _ in range(self.num_workers)] + ) + # Preparing devices: + # We want the user to be able to choose, for each worker, on which + # device will the policy live and which device will be used to store + # data. Those devices may or may not match. + # One caveat is that, if there is only one device for the policy, and + # if there are multiple workers, sending the same device and policy + # to be copied to each worker will result in multiple copies of the + # same policy on the same device. + # To go around this, we do the copies of the policy in the server + # (this object) to each possible device, and send to all the + # processes their copy of the policy. + + def device_err_msg(device_name, devices_list): + return ( + f"The length of the {device_name} argument should match the " + f"number of workers of the collector. Got len(" + f"create_env_fn)={self.num_workers} and len(" + f"passing_devices)={len(devices_list)}" + ) + + if isinstance(devices, (str, int, torch.device)): + devices = [torch.device(devices) for _ in range(self.num_workers)] + elif devices is None: + devices = [None for _ in range(self.num_workers)] + elif isinstance(devices, Sequence): + if len(devices) != self.num_workers: + raise RuntimeError(device_err_msg("devices", devices)) + devices = [torch.device(_device) for _device in devices] + else: + raise ValueError( + "devices should be either None, a torch.device or equivalent " + "or an iterable of devices. " + f"Found {type(devices)} instead." + ) + self._policy_dict = {} + self._get_weights_fn_dict = {} + for i, _device in enumerate(devices): + _policy, _device, _get_weight_fn = self._get_policy_and_device( + policy=policy, + device=_device, + ) + if _device not in self._policy_dict: + self._policy_dict[_device] = _policy + self._get_weights_fn_dict[_device] = _get_weight_fn + devices[i] = _device + self.devices = devices + + if isinstance(passing_devices, (str, int, torch.device)): + self.passing_devices = [ + torch.device(passing_devices) for _ in range(self.num_workers) + ] + elif isinstance(passing_devices, Sequence): + if len(passing_devices) != self.num_workers: + raise RuntimeError(device_err_msg("passing_devices", passing_devices)) + self.passing_devices = [ + torch.device(_passing_device) for _passing_device in passing_devices + ] + else: + raise ValueError( + "passing_devices should be either a torch.device or equivalent or an iterable of devices. " + f"Found {type(passing_devices)} instead." + ) + + self.total_frames = total_frames if total_frames > 0 else float("inf") + self.reset_at_each_iter = reset_at_each_iter + self.postprocs = dict() + if postproc is not None: + for _device in self.passing_devices: + self.postprocs[_device] = deepcopy(postproc).to(_device) + self.max_frames_per_traj = max_frames_per_traj + self.frames_per_batch = frames_per_batch + self.seed = seed + self.split_trajs = split_trajs + self.pin_memory = pin_memory + self.init_random_frames = init_random_frames + self.update_at_each_batch = update_at_each_batch + self.init_with_lag = init_with_lag + self.exploration_mode = exploration_mode + self.frames_per_worker = ( + -(self.total_frames // -self.num_workers) if total_frames > 0 else np.inf + ) # ceil(total_frames/num_workers) + self._run_processes() + + @property + def frames_per_batch_worker(self): + raise NotImplementedError + + def update_policy_weights_(self) -> None: + for _device in self._policy_dict: + if self._get_weights_fn_dict[_device] is not None: + self._policy_dict[_device].load_state_dict( + self._get_weights_fn_dict[_device]() + ) + + @property + def _queue_len(self) -> int: + raise NotImplementedError + + def _run_processes(self) -> None: + queue_out = mp.Queue(self._queue_len) # sends data from proc to main + self.procs = [] + self.pipes = [] + for i, (env_fun, env_fun_kwargs) in enumerate( + zip(self.create_env_fn, self.create_env_kwargs) + ): + _device = self.devices[i] + _passing_device = self.passing_devices[i] + pipe_parent, pipe_child = mp.Pipe() # send messages to procs + if env_fun.__class__.__name__ != "EnvCreator" and not isinstance( + env_fun, _EnvClass + ): # to avoid circular imports + env_fun = CloudpickleWrapper(env_fun) + + kwargs = { + "pipe_parent": pipe_parent, + "pipe_child": pipe_child, + "queue_out": queue_out, + "create_env_fn": env_fun, + "create_env_kwargs": env_fun_kwargs, + "policy": self._policy_dict[_device], + "frames_per_worker": self.frames_per_worker, + "max_frames_per_traj": self.max_frames_per_traj, + "frames_per_batch": self.frames_per_batch_worker, + "reset_at_each_iter": self.reset_at_each_iter, + "device": _device, + "passing_device": _passing_device, + "seed": self.seed, + "pin_memory": self.pin_memory, + "init_with_lag": self.init_with_lag, + "exploration_mode": self.exploration_mode, + "idx": i, + } + proc = mp.Process(target=_main_async_collector, kwargs=kwargs) + # proc.daemon can't be set as daemonic processes may be launched by the process itself + proc.start() + pipe_child.close() + self.procs.append(proc) + self.pipes.append(pipe_parent) + self.queue_out = queue_out + self.closed = False + + def __del__(self): + self.shutdown() + + def shutdown(self) -> None: + """Shuts down all processes. This operation is irreversible.""" + self._shutdown_main() + + def _shutdown_main(self) -> None: + if self.closed: + return + self.closed = True + for idx in range(self.num_workers): + self.pipes[idx].send((None, "close")) + + for idx in range(self.num_workers): + msg = self.pipes[idx].recv() + if msg != "closed": + raise RuntimeError(f"got {msg} but expected 'close'") + + for proc in self.procs: + proc.join() + + self.queue_out.close() + for pipe in self.pipes: + pipe.close() + + def set_seed(self, seed: int) -> int: + """Sets the seeds of the environments stored in the DataCollector. + + Args: + seed: integer representing the seed to be used for the environment. + + Returns: + Output seed. This is useful when more than one environment is + contained in the DataCollector, as the seed will be incremented for + each of these. The resulting seed is the seed of the last + environment. + + Examples: + >>> env_fn = lambda: GymEnv("Pendulum-v0") + >>> env_fn_parallel = lambda: ParallelEnv(6, env_fn) + >>> collector = SyncDataCollector(env_fn_parallel) + >>> out_seed = collector.set_seed(1) # out_seed = 6 + + """ + + for idx in range(self.num_workers): + self.pipes[idx].send((seed, "seed")) + new_seed, msg = self.pipes[idx].recv() + if msg != "seeded": + raise RuntimeError(f"Expected msg='seeded', got {msg}") + seed = new_seed + if idx < self.num_workers - 1: + seed = seed + 1 + self.reset() + return seed + + def reset(self, reset_idx: Optional[Sequence[bool]] = None) -> None: + """Resets the environments to a new initial state. + + Args: + reset_idx: Optional. Sequence indicating which environments have + to be reset. If None, all environments are reset. + + """ + + if reset_idx is None: + reset_idx = [True for _ in range(self.num_workers)] + for idx in range(self.num_workers): + if reset_idx[idx]: + self.pipes[idx].send((None, "reset")) + for idx in range(self.num_workers): + if reset_idx[idx]: + j, msg = self.pipes[idx].recv() + if msg != "reset": + raise RuntimeError(f"Expected msg='reset', got {msg}") + + def state_dict(self, destination: Optional[OrderedDict] = None) -> OrderedDict: + """ + Returns the state_dict of the data collector. + Each field represents a worker containing its own state_dict. + + Args: + destination (optional): A destination ordered dictionary where + to place the fetched data. + + """ + for idx in range(self.num_workers): + self.pipes[idx].send((None, "state_dict")) + state_dict = OrderedDict() + for idx in range(self.num_workers): + _state_dict, msg = self.pipes[idx].recv() + if msg != "state_dict": + raise RuntimeError(f"Expected msg='state_dict', got {msg}") + state_dict[f"worker{idx}"] = _state_dict + + if destination is not None: + destination.update(state_dict) + return destination + return state_dict + + def load_state_dict(self, state_dict: OrderedDict) -> None: + """ + Loads the state_dict on the workers. + + Args: + state_dict (OrderedDict): state_dict of the form + ``{"worker0": state_dict0, "worker1": state_dict1}``. + + """ + + for idx in range(self.num_workers): + self.pipes[idx].send((state_dict[f"worker{idx}"], "load_state_dict")) + for idx in range(self.num_workers): + _, msg = self.pipes[idx].recv() + if msg != "loaded": + raise RuntimeError(f"Expected msg='loaded', got {msg}") + + +class MultiSyncDataCollector(_MultiDataCollector): + """Runs a given number of DataCollectors on separate processes + synchronously. + + The collection starts when the next item of the collector is queried, + and no environment step is computed in between the reception of a batch of + trajectory and the start of the next collection. + This class can be safely used with online RL algorithms. + """ + + __doc__ += _MultiDataCollector.__doc__ + + @property + def frames_per_batch_worker(self): + return -(-self.frames_per_batch // self.num_workers) + + @property + def _queue_len(self) -> int: + return self.num_workers + + def iterator(self) -> Iterator[_TensorDict]: + i = -1 + frames = 0 + out_tensordicts_shared = OrderedDict() + dones = [False for _ in range(self.num_workers)] + workers_frames = [0 for _ in range(self.num_workers)] + while not all(dones) and frames < self.total_frames: + if self.update_at_each_batch: + self.update_policy_weights_() + + for idx in range(self.num_workers): + if frames < self.init_random_frames: + msg = "continue_random" + else: + msg = "continue" + self.pipes[idx].send((None, msg)) + + i += 1 + max_traj_idx = None + for k in range(self.num_workers): + new_data, j = self.queue_out.get() + if j == 0: + data, idx = new_data + out_tensordicts_shared[idx] = data + else: + idx = new_data + workers_frames[idx] = ( + workers_frames[idx] + out_tensordicts_shared[idx].numel() + ) + + if workers_frames[idx] >= self.total_frames: + print(f"{idx} is done!") + dones[idx] = True + for idx in range(self.num_workers): + traj_ids = out_tensordicts_shared[idx].get("traj_ids") + if max_traj_idx is not None: + traj_ids += max_traj_idx + # out_tensordicts_shared[idx].set("traj_ids", traj_ids) + max_traj_idx = traj_ids.max() + 1 + # out = out_tensordicts_shared[idx] + out = torch.cat([item for key, item in out_tensordicts_shared.items()], 0) + if self.split_trajs: + out = split_trajectories(out) + frames += out.get("mask").sum() + else: + frames += math.prod(out.shape) + if self.postprocs: + out = self.postprocs[out.device](out) + yield out + + del out_tensordicts_shared + self._shutdown_main() + + +class MultiaSyncDataCollector(_MultiDataCollector): + """Runs a given number of DataCollectors on separate processes + asynchronously. + + The collection keeps on occuring on all processes even between the time + the batch of rollouts is collected and the next call to the iterator. + This class can be safely used with offline RL algorithms. + """ + + __doc__ += _MultiDataCollector.__doc__ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.out_tensordicts = dict() + self.running = False + + @property + def frames_per_batch_worker(self): + return self.frames_per_batch + + def _get_from_queue(self, timeout=None) -> Tuple[int, int, _TensorDict]: + new_data, j = self.queue_out.get(timeout=timeout) + if j == 0: + data, idx = new_data + self.out_tensordicts[idx] = data + else: + idx = new_data + out = self.out_tensordicts[idx] + return idx, j, out + + @property + def _queue_len(self) -> int: + return 1 + + def iterator(self) -> Iterator[_TensorDict]: + if self.update_at_each_batch: + self.update_policy_weights_() + + for i in range(self.num_workers): + if self.init_random_frames > 0: + self.pipes[i].send((None, "continue_random")) + else: + self.pipes[i].send((None, "continue")) + self.running = True + i = -1 + self._frames = 0 + + dones = [False for _ in range(self.num_workers)] + workers_frames = [0 for _ in range(self.num_workers)] + while self._frames < self.total_frames: + i += 1 + idx, j, out = self._get_from_queue() + + worker_frames = out.numel() + if self.split_trajs: + out = split_trajectories(out) + self._frames += worker_frames + workers_frames[idx] = workers_frames[idx] + worker_frames + if self.postprocs: + out = self.postprocs[out.device](out) + + # the function blocks here until the next item is asked, hence we send the message to the + # worker to keep on working in the meantime before the yield statement + if workers_frames[idx] < self.frames_per_worker: + if self._frames < self.init_random_frames: + msg = "continue_random" + else: + msg = "continue" + self.pipes[idx].send((idx, msg)) + else: + print(f"{idx} is done!") + dones[idx] = True + + yield out + + self._shutdown_main() + self.running = False + + def _shutdown_main(self) -> None: + if hasattr(self, "out_tensordicts"): + del self.out_tensordicts + return super()._shutdown_main() + + def reset(self, reset_idx: Optional[Sequence[bool]] = None) -> None: + super().reset(reset_idx) + if self.queue_out.full(): + print("waiting") + time.sleep(_TIMEOUT) # wait until queue is empty + if self.queue_out.full(): + raise Exception("self.queue_out is full") + if self.running: + for idx in range(self.num_workers): + if self._frames < self.init_random_frames: + self.pipes[idx].send((idx, "continue_random")) + else: + self.pipes[idx].send((idx, "continue")) + + +class aSyncDataCollector(MultiaSyncDataCollector): + """Runs a single DataCollector on a separate process. + + This is mostly useful for offline RL paradigms where the policy being + trained can differ from the policy used to collect data. In online + settings, a regular DataCollector should be preferred. This class is + merely a wrapper around a MultiaSyncDataCollector where a single process + is being created. + + Args: + create_env_fn (Callabled): Callable returning an instance of _EnvClass + policy (Callable, optional): Instance of ProbabilisticTDModule class. + Must accept _TensorDict object as input. + total_frames (int): lower bound of the total number of frames returned + by the collector. In parallel settings, the actual number of + frames may well be greater than this as the closing signals are + sent to the workers only once the total number of frames has + been collected on the server. + create_env_kwargs (dict, optional): A dictionary with the arguments + used to create an environment + max_frames_per_traj: Maximum steps per trajectory. Note that a + trajectory can span over multiple batches (unless + reset_at_each_iter is set to True, see below). Once a trajectory + reaches n_steps_max, the environment is reset. If the + environment wraps multiple environments together, the number of + steps is tracked for each environment independently. Negative + values are allowed, in which case this argument is ignored. + Default is -1 (i.e. no maximum number of steps) + frames_per_batch (int): Time-length of a batch. + reset_at_each_iter and frames_per_batch == n_steps_max are equivalent configurations. + default: 200 + init_random_frames (int): Number of frames for which the policy is ignored before it is called. + This feature is mainly intended to be used in offline/model-based settings, where a batch of random + trajectories can be used to initialize training. + default=-1 (i.e. no random frames) + reset_at_each_iter (bool): Whether or not environments should be reset for each batch. + default=False. + postproc (callable, optional): A PostProcessor is an object that will read a batch of data and process it in a + useful format for training. + default: None. + split_trajs (bool): Boolean indicating whether the resulting TensorDict should be split according to the trajectories. + See utils.split_trajectories for more information. + device (int, str, torch.device, optional): The device on which the + policy will be placed. If it differs from the input policy + device, the update_policy_weights_() method should be queried + at appropriate times during the training loop to accommodate for + the lag between parameter configuration at various times. + Default is `None` (i.e. policy is kept on its original device) + passing_device (int, str, torch.device, optional): The device on which + the output TensorDict will be stored. For long trajectories, + it may be necessary to store the data on a different. + device than the one where the policy is stored. Default is `"cpu"`. + update_at_each_batch (bool): if True, the policy weights will be updated every time a batch of trajectories + is collected. + default=False + init_with_lag (bool, optional): if True, the first trajectory will be truncated earlier at a random step. + This is helpful to desynchronize the environments, such that steps do no match in all collected rollouts. + default = True + + """ + + def __init__( + self, + create_env_fn: Callable[[], _EnvClass], + policy: Optional[ + Union[ProbabilisticTDModule, Callable[[_TensorDict], _TensorDict]] + ] = None, + total_frames: Optional[int] = -1, + create_env_kwargs: Optional[dict] = None, + max_frames_per_traj: int = -1, + frames_per_batch: int = 200, + init_random_frames: int = -1, + reset_at_each_iter: bool = False, + postproc: Optional[Callable[[_TensorDict], _TensorDict]] = None, + split_trajs: bool = True, + device: Optional[Union[int, str, torch.device]] = None, + passing_device: Union[int, str, torch.device] = "cpu", + seed: Optional[int] = None, + pin_memory: bool = False, + ): + super().__init__( + create_env_fn=[create_env_fn], + policy=policy, + total_frames=total_frames, + create_env_kwargs=[create_env_kwargs], + max_frames_per_traj=max_frames_per_traj, + frames_per_batch=frames_per_batch, + reset_at_each_iter=reset_at_each_iter, + init_random_frames=init_random_frames, + postproc=postproc, + split_trajs=split_trajs, + devices=[device] if device is not None else None, + passing_devices=[passing_device], + seed=seed, + pin_memory=pin_memory, + ) + + +def _main_async_collector( + pipe_parent: connection.Connection, + pipe_child: connection.Connection, + queue_out: queues.Queue, + create_env_fn: Union[_EnvClass, "EnvCreator", Callable[[], _EnvClass]], + create_env_kwargs: dict, + policy: Callable[[_TensorDict], _TensorDict], + frames_per_worker: int, + max_frames_per_traj: int, + frames_per_batch: int, + reset_at_each_iter: bool, + device: Optional[Union[torch.device, str, int]], + passing_device: Optional[Union[torch.device, str, int]], + seed: Union[int, Sequence], + pin_memory: bool, + idx: int = 0, + init_with_lag: bool = False, + exploration_mode: str = "random", + verbose: bool = False, +) -> None: + pipe_parent.close() + #  init variables that will be cleared when closing + tensor_dict = data = d = data_in = dc = dc_iter = None + + dc = SyncDataCollector( + create_env_fn, + create_env_kwargs=create_env_kwargs, + policy=policy, + total_frames=-1, + max_frames_per_traj=max_frames_per_traj, + frames_per_batch=frames_per_batch, + reset_at_each_iter=reset_at_each_iter, + postproc=None, + split_trajs=False, + device=device, + seed=seed, + pin_memory=pin_memory, + passing_device=passing_device, + return_in_place=True, + init_with_lag=init_with_lag, + exploration_mode=exploration_mode, + ) + if verbose: + print("Sync data collector created") + dc_iter = iter(dc) + j = 0 + + has_timed_out = False + while True: + _timeout = _TIMEOUT if not has_timed_out else 1e-3 + if pipe_child.poll(_timeout): + data_in, msg = pipe_child.recv() + if verbose: + print(f"worker {idx} received {msg}") + else: + if verbose: + print(f"poll failed, j={j}") + # default is "continue" (after first iteration) + # this is expected to happen if queue_out reached the timeout, but no new msg was waiting in the pipe + # in that case, the main process probably expects the worker to continue collect data + if has_timed_out: + if msg not in ("continue", "continue_random"): + raise RuntimeError(f"Unexpected message after time out: msg={msg}") + else: + continue + if msg in ("continue", "continue_random"): + if msg == "continue_random": + dc.init_random_frames = float("inf") + else: + dc.init_random_frames = -1 + + d = next(dc_iter) + if pipe_child.poll(_MIN_TIMEOUT): + # in this case, main send a message to the worker while it was busy collecting trajectories. + # In that case, we skip the collected trajectory and get the message from main. This is faster than + # sending the trajectory in the queue until timeout when it's never going to be received. + continue + if j == 0: + tensor_dict = d + if passing_device is not None and tensor_dict.device != passing_device: + raise RuntimeError( + f"expected device to be {passing_device} but got {tensor_dict.device}" + ) + tensor_dict.share_memory_() + data = (tensor_dict, idx) + else: + if d is not tensor_dict: + raise RuntimeError( + "SyncDataCollector should return the same tensordict modified in-place." + ) + data = idx # flag the worker that has sent its data + try: + queue_out.put((data, j), timeout=_TIMEOUT) + if verbose: + print(f"worker {idx} successfully sent data") + j += 1 + has_timed_out = False + continue + except queue.Full: + if verbose: + print(f"worker {idx} has timed out") + has_timed_out = True + continue + # pipe_child.send("done") + + elif msg == "update": + dc.update_policy_weights_() + pipe_child.send((j, "updated")) + has_timed_out = False + continue + + elif msg == "seed": + new_seed = dc.set_seed(data_in) + torch.manual_seed(data_in) + np.random.seed(data_in) + pipe_child.send((new_seed, "seeded")) + has_timed_out = False + continue + + elif msg == "reset": + dc.reset() + pipe_child.send((j, "reset")) + continue + + elif msg == "state_dict": + state_dict = dc.state_dict() + # send state_dict to cpu first + state_dict = recursive_map_to_cpu(state_dict) + pipe_child.send((state_dict, "state_dict")) + has_timed_out = False + continue + + elif msg == "load_state_dict": + state_dict = data_in + dc.load_state_dict(state_dict) + pipe_child.send((j, "loaded")) + has_timed_out = False + continue + + elif msg == "close": + del tensor_dict, data, d, data_in + dc.shutdown() + del dc, dc_iter + pipe_child.send("closed") + if verbose: + print(f"collector {idx} closed") + break + + else: + raise Exception(f"Unrecognized message {msg}") diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py new file mode 100644 index 00000000000..cdb8abbf198 --- /dev/null +++ b/torchrl/collectors/utils.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable + +import torch + +from torchrl.data import TensorDict +from torchrl.data.tensordict.tensordict import _TensorDict + + +def _stack_output(fun) -> Callable: + def stacked_output_fun(*args, **kwargs): + out = fun(*args, **kwargs) + return tuple(torch.stack(_o, 0) for _o in out) + + return stacked_output_fun + + +def _stack_output_zip(fun) -> Callable: + def stacked_output_fun(*args, **kwargs): + out = fun(*args, **kwargs) + return tuple(torch.stack(_o, 0) for _o in zip(*out)) + + return stacked_output_fun + + +def split_trajectories(rollout_tensor_dict: _TensorDict) -> _TensorDict: + """Takes a tensordict with a key traj_ids that indicates the id of each trajectory. + From there, builds a B x T x ... zero-padded tensordict with B batches on max duration T + """ + traj_ids = rollout_tensor_dict.get("traj_ids") + ndim = len(rollout_tensor_dict.batch_size) + splits = traj_ids.view(-1) + splits = [(splits == i).sum().item() for i in splits.unique_consecutive()] + out_splits = { + key: _d.contiguous().view(-1, *_d.shape[ndim:]).split(splits, 0) + for key, _d in rollout_tensor_dict.items() + # if key not in ("step_count", "traj_ids") + } + # select complete rollouts + dones = out_splits["done"] + valid_ids = list(range(len(dones))) + out_splits = {key: [_out[i] for i in valid_ids] for key, _out in out_splits.items()} + mask = [torch.ones_like(_out, dtype=torch.bool) for _out in out_splits["done"]] + out_splits["mask"] = mask + out_dict = { + key: torch.nn.utils.rnn.pad_sequence(_o, batch_first=True) + for key, _o in out_splits.items() + } + td = TensorDict( + source=out_dict, + device=rollout_tensor_dict.device, + batch_size=out_dict["mask"].shape[:-1], + ) + if (out_dict["done"].sum(1) > 1).any(): + raise RuntimeError("Got more than one done per trajectory") + return td diff --git a/torchrl/csrc/CMakeLists.txt b/torchrl/csrc/CMakeLists.txt new file mode 100644 index 00000000000..6e5c8ef1804 --- /dev/null +++ b/torchrl/csrc/CMakeLists.txt @@ -0,0 +1,64 @@ +get_property(TORCHRL_THIRD_PARTIES GLOBAL PROPERTY TORCHRL_THIRD_PARTIES) + +################################################################################ +# _torchrl.so +################################################################################ +if (BUILD_TORCHRL_PYTHON_EXTENSION) + set( + EXTENSION_SOURCES + pybind.cpp + ) + add_library( + _torchrl + SHARED + ${EXTENSION_SOURCES} + ) + + set_target_properties(_torchrl PROPERTIES PREFIX "") + if (MSVC) + set_target_properties(_torchrl PROPERTIES SUFFIX ".pyd") + endif(MSVC) + + if (APPLE) + # https://github.com/facebookarchive/caffe2/issues/854#issuecomment-364538485 + # https://github.com/pytorch/pytorch/commit/73f6715f4725a0723d8171d3131e09ac7abf0666 + set_target_properties(_torchrl PROPERTIES LINK_FLAGS "-undefined dynamic_lookup") + endif() + + target_include_directories( + _torchrl + PRIVATE + ${PROJECT_SOURCE_DIR} + ${Python_INCLUDE_DIR} + ) + + if (APPLE) + # See https://github.com/pytorch/pytorch/issues/38122 + find_library(TORCH_PYTHON_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib") + else() + # PyTorch dependency + execute_process( + COMMAND + python -c + "import os; import torch; print(os.path.dirname(torch.__file__), end='')" + OUTPUT_VARIABLE + TORCH_PATH + ) + list(APPEND CMAKE_PREFIX_PATH ${TORCH_PATH}) + find_package(Torch REQUIRED) + set(TORCH_PYTHON_LIBRARIES "${TORCH_PATH}/lib/libtorch_python.so") + endif() + + if (WIN32) + find_package(Python3 ${PYTHON_VERSION} EXACT COMPONENTS Development) + set(ADDITIONAL_ITEMS Python3::Python) + endif() + + target_link_libraries(_torchrl PUBLIC torch ${TORCH_PYTHON_LIBRARIES}) + + install( + TARGETS _torchrl + LIBRARY DESTINATION . + RUNTIME DESTINATION . # For Windows + ) +endif() diff --git a/torchrl/csrc/pybind.cpp b/torchrl/csrc/pybind.cpp new file mode 100644 index 00000000000..622e3c95323 --- /dev/null +++ b/torchrl/csrc/pybind.cpp @@ -0,0 +1,140 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include + +#include + +#include "segment_tree.h" + +namespace py = pybind11; + +PYBIND11_MODULE(_torchrl, m) { + py::class_, + std::shared_ptr>>(m, + "SumSegmentTree") + .def(py::init()) + .def("__len__", &torchrl::SumSegmentTree::size) + .def("size", &torchrl::SumSegmentTree::size) + .def("capacity", &torchrl::SumSegmentTree::capacity) + .def("identity_element", + &torchrl::SumSegmentTree::identity_element) + .def("__getitem__", py::overload_cast( + &torchrl::SumSegmentTree::At, py::const_)) + .def("__getitem__", py::overload_cast &>( + &torchrl::SumSegmentTree::At, py::const_)) + .def("__getitem__", py::overload_cast( + &torchrl::SumSegmentTree::At, py::const_)) + .def("at", py::overload_cast(&torchrl::SumSegmentTree::At, + py::const_)) + .def("at", py::overload_cast &>( + &torchrl::SumSegmentTree::At, py::const_)) + .def("at", py::overload_cast( + &torchrl::SumSegmentTree::At, py::const_)) + .def("__setitem__", py::overload_cast( + &torchrl::SumSegmentTree::Update)) + .def("__setitem__", + py::overload_cast &, const float &>( + &torchrl::SumSegmentTree::Update)) + .def("__setitem__", py::overload_cast &, + const py::array_t &>( + &torchrl::SumSegmentTree::Update)) + .def("__setitem__", + py::overload_cast( + &torchrl::SumSegmentTree::Update)) + .def("__setitem__", + py::overload_cast( + &torchrl::SumSegmentTree::Update)) + .def("update", py::overload_cast( + &torchrl::SumSegmentTree::Update)) + .def("update", + py::overload_cast &, const float &>( + &torchrl::SumSegmentTree::Update)) + .def("update", py::overload_cast &, + const py::array_t &>( + &torchrl::SumSegmentTree::Update)) + .def("update", py::overload_cast( + &torchrl::SumSegmentTree::Update)) + .def("update", + py::overload_cast( + &torchrl::SumSegmentTree::Update)) + .def("query", py::overload_cast( + &torchrl::SumSegmentTree::Query, py::const_)) + .def("query", py::overload_cast &, + const py::array_t &>( + &torchrl::SumSegmentTree::Query, py::const_)) + .def("query", + py::overload_cast( + &torchrl::SumSegmentTree::Query, py::const_)) + .def("scan_lower_bound", + py::overload_cast( + &torchrl::SumSegmentTree::ScanLowerBound, py::const_)) + .def("scan_lower_bound", + py::overload_cast &>( + &torchrl::SumSegmentTree::ScanLowerBound, py::const_)) + .def("scan_lower_bound", + py::overload_cast( + &torchrl::SumSegmentTree::ScanLowerBound, py::const_)); + + py::class_, + std::shared_ptr>>(m, + "MinSegmentTree") + .def(py::init()) + .def("__len__", &torchrl::MinSegmentTree::size) + .def("size", &torchrl::MinSegmentTree::size) + .def("capacity", &torchrl::MinSegmentTree::capacity) + .def("identity_element", + &torchrl::MinSegmentTree::identity_element) + .def("__getitem__", py::overload_cast( + &torchrl::MinSegmentTree::At, py::const_)) + .def("__getitem__", py::overload_cast &>( + &torchrl::MinSegmentTree::At, py::const_)) + .def("__getitem__", py::overload_cast( + &torchrl::MinSegmentTree::At, py::const_)) + .def("at", py::overload_cast(&torchrl::MinSegmentTree::At, + py::const_)) + .def("at", py::overload_cast &>( + &torchrl::MinSegmentTree::At, py::const_)) + .def("at", py::overload_cast( + &torchrl::MinSegmentTree::At, py::const_)) + .def("__setitem__", py::overload_cast( + &torchrl::MinSegmentTree::Update)) + .def("__setitem__", + py::overload_cast &, const float &>( + &torchrl::MinSegmentTree::Update)) + .def("__setitem__", py::overload_cast &, + const py::array_t &>( + &torchrl::MinSegmentTree::Update)) + .def("__setitem__", + py::overload_cast( + &torchrl::MinSegmentTree::Update)) + .def("__setitem__", + py::overload_cast( + &torchrl::MinSegmentTree::Update)) + .def("update", py::overload_cast( + &torchrl::MinSegmentTree::Update)) + .def("update", + py::overload_cast &, const float &>( + &torchrl::MinSegmentTree::Update)) + .def("update", py::overload_cast &, + const py::array_t &>( + &torchrl::MinSegmentTree::Update)) + .def("update", py::overload_cast( + &torchrl::MinSegmentTree::Update)) + .def("update", + py::overload_cast( + &torchrl::MinSegmentTree::Update)) + .def("query", py::overload_cast( + &torchrl::MinSegmentTree::Query, py::const_)) + .def("query", py::overload_cast &, + const py::array_t &>( + &torchrl::MinSegmentTree::Query, py::const_)) + .def("query", + py::overload_cast( + &torchrl::MinSegmentTree::Query, py::const_)); +} diff --git a/torchrl/csrc/segment_tree.h b/torchrl/csrc/segment_tree.h new file mode 100644 index 00000000000..44d5bc63cae --- /dev/null +++ b/torchrl/csrc/segment_tree.h @@ -0,0 +1,302 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "torch_data_type.h" + +namespace py = pybind11; + +namespace torchrl { + +// SegmentTree is a tree data structure to maintain statistics of intervals. +// https://en.wikipedia.org/wiki/Segment_tree +// Here is the implementaion of non-recursive SegmentTree for single point +// update and interval query. The time complexities of both Update and Query are +// O(logN). +// One example of a SegmentTree is shown below. +// +// 1: [0, 8) +// / \ +// 2 [0, 4) 3 [4, 8) +// / \ / \ +// 4: [0, 2) 5: [2, 4) 6: [4, 6) 7: [6, 8) +// / \ / \ / \ / \ +// 8: 0 9: 1 10: 2 11: 3 12: 4 13: 5 14: 6 15: 7 + +template class SegmentTree { +public: + SegmentTree(int64_t size, const T &identity_element) + : size_(size), identity_element_(identity_element) { + for (capacity_ = 1; capacity_ < size; capacity_ <<= 1) + ; + values_.assign(2 * capacity_, identity_element_); + } + + int64_t size() const { return size_; } + + int64_t capacity() const { return capacity_; } + + const T &identity_element() const { return identity_element_; } + + const T &At(int64_t index) const { return values_[index | capacity_]; } + + std::vector At(const std::vector &index) const { + const int64_t n = index.size(); + std::vector value(n); + BatchAtImpl(n, index.data(), value.data()); + return value; + } + + py::array_t At(const py::array_t &index) const { + assert(index.ndim() == 1); + const int64_t n = index.size(); + py::array_t value(n); + BatchAtImpl(n, index.data(), value.mutable_data()); + return value; + } + + torch::Tensor At(const torch::Tensor &index) const { + assert(index.dtype() == torch::kInt64); + const torch::Tensor index_contiguous = index.contiguous(); + const int64_t n = index_contiguous.numel(); + torch::Tensor value = + torch::empty_like(index_contiguous, utils::TorchDataType::value); + BatchAtImpl(n, index_contiguous.data_ptr(), value.data_ptr()); + return value; + } + + // Update the item at index to value. + // Time complexity: O(logN). + void Update(int64_t index, const T &value) { + index |= capacity_; + for (values_[index] = value; index > 1; index >>= 1) { + values_[index >> 1] = op_(values_[index], values_[index ^ 1]); + } + } + + void Update(const std::vector &index, const T &value) { + BatchUpdateImpl(index.size(), index.data(), value); + } + + void Update(const std::vector &index, const std::vector &value) { + assert(value.size() == 1 || index.size() == value.size()); + const int64_t n = index.size(); + if (value.size() == 1) { + BatchUpdateImpl(n, index.data(), value[0]); + } else { + BatchUpdateImpl(n, index.data(), value.data()); + } + } + + void Update(const py::array_t &index, const T &value) { + assert(index.ndim() == 1); + BatchUpdateImpl(index.size(), index.data(), value); + } + + void Update(const py::array_t &index, const py::array_t &value) { + assert(index.ndim() == 1); + assert(value.ndim() == 1); + assert(value.size() == 1 || index.size() == value.size()); + const int64_t n = index.size(); + if (value.size() == 1) { + BatchUpdateImpl(n, index.data(), *(value.data())); + } else { + BatchUpdateImpl(n, index.data(), value.data()); + } + } + + void Update(const torch::Tensor &index, const T &value) { + assert(index.dtype() == torch::kInt64); + const torch::Tensor index_contiguous = index.contiguous(); + const int64_t n = index_contiguous.numel(); + BatchUpdateImpl(n, index_contiguous.data_ptr(), value); + } + + void Update(const torch::Tensor &index, const torch::Tensor &value) { + assert(index.dtype() == torch::kInt64); + assert(value.dtype() == utils::TorchDataType::value); + assert(value.numel() == 1 || index.sizes() == value.sizes()); + const torch::Tensor index_contiguous = index.contiguous(); + const torch::Tensor value_contiguous = value.contiguous(); + const int64_t n = index_contiguous.numel(); + if (value_contiguous.numel() == 1) { + BatchUpdateImpl(n, index_contiguous.data_ptr(), + *(value_contiguous.data_ptr())); + } else { + BatchUpdateImpl(n, index_contiguous.data_ptr(), + value_contiguous.data_ptr()); + } + } + + // Reduce the range of [l, r) by Operator. + // Time complexity: O(logN) + T Query(int64_t l, int64_t r) const { + assert(l < r); + if (l <= 0 && r >= size_) { + return values_[1]; + } + T ret = identity_element_; + l |= capacity_; + r |= capacity_; + while (l < r) { + if (l & 1) { + ret = op_(ret, values_[l++]); + } + if (r & 1) { + ret = op_(ret, values_[--r]); + } + l >>= 1; + r >>= 1; + } + return ret; + } + + std::vector Query(const std::vector &l, + const std::vector &r) const { + assert(l.size() == r.size()); + std::vector ret(l.size()); + const int64_t n = l.size(); + BatchQueryImpl(n, l.data(), r.data(), ret.data()); + return ret; + } + + py::array_t Query(const py::array_t &l, + const py::array_t &r) const { + assert(l.ndim() == 1); + assert(r.ndim() == 1); + assert(l.size() == r.size()); + const int64_t n = l.size(); + py::array_t ret(n); + BatchQueryImpl(n, l.data(), r.data(), ret.mutable_data()); + return ret; + } + + torch::Tensor Query(const torch::Tensor &l, const torch::Tensor &r) const { + assert(l.dtype() == torch::kInt64); + assert(r.dtype() == torch::kInt64); + assert(l.sizes() == r.sizes()); + const torch::Tensor l_contiguous = l.contiguous(); + const torch::Tensor r_contiguous = r.contiguous(); + torch::Tensor ret = + torch::empty_like(l_contiguous, utils::TorchDataType::value); + const int64_t n = l_contiguous.numel(); + BatchQueryImpl(n, l_contiguous.data_ptr(), + r_contiguous.data_ptr(), ret.data_ptr()); + return ret; + } + +protected: + void BatchAtImpl(int64_t n, const int64_t *index, T *value) const { + for (int64_t i = 0; i < n; ++i) { + value[i] = values_[index[i] | capacity_]; + } + } + + void BatchUpdateImpl(int64_t n, const int64_t *index, const T &value) { + for (int64_t i = 0; i < n; ++i) { + Update(index[i], value); + } + } + + void BatchUpdateImpl(int64_t n, const int64_t *index, const T *value) { + for (int64_t i = 0; i < n; ++i) { + Update(index[i], value[i]); + } + } + + void BatchQueryImpl(int64_t n, const int64_t *l, const int64_t *r, + T *result) const { + for (int64_t i = 0; i < n; ++i) { + result[i] = Query(l[i], r[i]); + } + } + + const Operator op_{}; + const int64_t size_; + int64_t capacity_; + const T identity_element_; + std::vector values_; +}; + +template +class SumSegmentTree final : public SegmentTree> { +public: + SumSegmentTree(int64_t size) : SegmentTree>(size, T(0)) {} + + // Get the 1st index where the scan (prefix sum) is not less than value. + // Time complexity: O(logN) + int64_t ScanLowerBound(const T &value) const { + if (value > this->values_[1]) { + return this->size_; + } + int64_t index = 1; + T current_value = value; + while (index < this->capacity_) { + index <<= 1; + const T &lvalue = this->values_[index]; + if (current_value > lvalue) { + current_value -= lvalue; + index |= 1; + } + } + return index ^ this->capacity_; + } + + std::vector ScanLowerBound(const std::vector &value) const { + std::vector index(value.size()); + BatchScanLowerBoundImpl(value.size(), value.data(), index.data()); + return index; + } + + py::array_t ScanLowerBound(const py::array_t &value) const { + assert(value.ndim() == 1); + const int64_t n = value.size(); + py::array_t index(n); + BatchScanLowerBoundImpl(n, value.data(), index.mutable_data()); + return index; + } + + torch::Tensor ScanLowerBound(const torch::Tensor &value) const { + assert(value.dtype() == utils::TorchDataType::value); + const torch::Tensor value_contiguous = value.contiguous(); + torch::Tensor index = torch::empty_like(value_contiguous, torch::kInt64); + const int64_t n = value_contiguous.numel(); + BatchScanLowerBoundImpl(n, value_contiguous.data_ptr(), + index.data_ptr()); + return index; + } + +protected: + void BatchScanLowerBoundImpl(int64_t n, const T *value, + int64_t *index) const { + for (int64_t i = 0; i < n; ++i) { + index[i] = ScanLowerBound(value[i]); + } + } +}; + +template struct MinOp { + T operator()(const T &lhs, const T &rhs) const { return std::min(lhs, rhs); } +}; + +template +class MinSegmentTree final : public SegmentTree> { +public: + MinSegmentTree(int64_t size) + : SegmentTree>(size, std::numeric_limits::max()) {} +}; + +} // namespace torchrl diff --git a/torchrl/csrc/torch_data_type.h b/torchrl/csrc/torch_data_type.h new file mode 100644 index 00000000000..c6a70a40fd8 --- /dev/null +++ b/torchrl/csrc/torch_data_type.h @@ -0,0 +1,30 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +#include + +namespace torchrl { +namespace utils { + +template struct TorchDataType; + +template <> struct TorchDataType { + static constexpr torch::ScalarType value = torch::kInt64; +}; + +template <> struct TorchDataType { + static constexpr torch::ScalarType value = torch::kFloat; +}; + +template <> struct TorchDataType { + static constexpr torch::ScalarType value = torch::kDouble; +}; + +} // namespace utils +} // namespace torchrl diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py new file mode 100644 index 00000000000..9394c3a4aee --- /dev/null +++ b/torchrl/data/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .postprocs import * +from .replay_buffers import * +from .tensor_specs import * +from .tensordict import * diff --git a/torchrl/data/postprocs/__init__.py b/torchrl/data/postprocs/__init__.py new file mode 100644 index 00000000000..8c557abcb6f --- /dev/null +++ b/torchrl/data/postprocs/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .postprocs import * diff --git a/torchrl/data/postprocs/postprocs.py b/torchrl/data/postprocs/postprocs.py new file mode 100644 index 00000000000..33e65bdb7af --- /dev/null +++ b/torchrl/data/postprocs/postprocs.py @@ -0,0 +1,216 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from typing import Tuple + +import torch +from torch import nn + +from torchrl.data.tensordict.tensordict import _TensorDict +from torchrl.data.utils import expand_as_right + +__all__ = ["MultiStep"] + + +def _conv1d( + reward: torch.Tensor, gammas: torch.Tensor, n_steps_max: int +) -> torch.Tensor: + if not (reward.ndimension() == 3 and reward.shape[-1] == 1): + raise RuntimeError( + f"Expected a B x T x 1 reward tensor, got reward.shape = {reward.shape}" + ) + reward_pad = torch.nn.functional.pad(reward, [0, 0, 0, n_steps_max]).transpose( + -1, -2 + ) + reward_pad = torch.conv1d(reward_pad, gammas).transpose(-1, -2) + return reward_pad + + +def _get_terminal( + done: torch.Tensor, n_steps_max: int +) -> Tuple[torch.Tensor, torch.Tensor]: + # terminal states (done or last) + terminal = done.clone() + terminal[:, -1] = done[:, -1] | (done.sum(1) != 1) + if not (terminal.sum(1) == 1).all(): + raise RuntimeError("Got more or less than one terminal state per episode.") + post_terminal = terminal.cumsum(1).cumsum(1) >= 2 + post_terminal = torch.cat( + [ + post_terminal, + torch.ones( + post_terminal.shape[0], + n_steps_max, + *post_terminal.shape[2:], + device=post_terminal.device, + dtype=torch.bool, + ), + ], + 1, + ) + return terminal, post_terminal + + +def _get_gamma( + gamma: float, reward: torch.Tensor, mask: torch.Tensor, n_steps_max: int +) -> torch.Tensor: + # Compute gamma for n-step value function + gamma_masked = gamma * torch.ones_like(reward) + gamma_masked = gamma_masked.masked_fill_(~mask, 1.0) + gamma_masked = torch.nn.functional.pad( + gamma_masked, [0, 0, 0, n_steps_max], value=1.0 + ) + gamma_masked = gamma_masked.unfold(1, n_steps_max + 1, 1) + gamma_masked = gamma_masked.flip(1).cumprod(-1).flip(1) + return gamma_masked[..., -1] + + +def _get_steps_to_next_obs(nonterminal: torch.Tensor, n_steps_max: int) -> torch.Tensor: + steps_to_next_obs = nonterminal.flip(1).cumsum(1).flip(1) + steps_to_next_obs.clamp_max_(n_steps_max + 1) + return steps_to_next_obs + + +def select_and_repeat( + tensor: torch.Tensor, + terminal: torch.Tensor, + post_terminal: torch.Tensor, + mask: torch.Tensor, + n_steps_max: int, +) -> torch.Tensor: + T = tensor.shape[1] + terminal = expand_as_right(terminal.squeeze(-1), tensor) + last_tensor = (terminal * tensor).sum(1, True) + + last_tensor = last_tensor.expand( + last_tensor.shape[0], post_terminal.shape[1], *last_tensor.shape[2:] + ) + post_terminal = expand_as_right(post_terminal.squeeze(-1), last_tensor) + post_terminal_tensor = last_tensor * post_terminal + + tensor_repeat = torch.zeros( + tensor.shape[0], + n_steps_max, + *tensor.shape[2:], + device=tensor.device, + dtype=tensor.dtype, + ) + tensor_cat = torch.cat([tensor, tensor_repeat], 1) + post_terminal_tensor + tensor_cat = tensor_cat[:, -T:] + mask = expand_as_right(mask.squeeze(-1), tensor_cat) + return tensor_cat.masked_fill(~mask, 0.0) + + +class MultiStep(nn.Module): + """ + Multistep reward, as presented in 'Sutton, R. S. 1988. Learning to + predict by the methods of temporal differences. Machine learning 3( + 1):9–44.' + + Args: + gamma (float): Discount factor for return computation + n_steps_max (integer): maximum look-ahead steps. + + """ + + def __init__( + self, + gamma: float, + n_steps_max: int, + ): + super().__init__() + if n_steps_max < 0: + raise ValueError("n_steps_max must be a null or positive integer") + if not (gamma > 0 and gamma <= 1): + raise ValueError(f"got out-of-bounds gamma decay: gamma={gamma}") + + self.gamma = gamma + self.n_steps_max = n_steps_max + self.register_buffer( + "gammas", + torch.tensor( + [gamma ** i for i in range(n_steps_max + 1)], + dtype=torch.float, + ).reshape(1, 1, -1), + ) + + def forward(self, tensor_dict: _TensorDict) -> _TensorDict: + """Args: + tensor_dict: TennsorDict instance with Batch x Time-steps x ... + dimensions. + The TensorDict must contain a "reward" and "done" key. All + keys that start with the "next_" prefix will be shifted by ( + at most) self.n_steps_max frames. The TensorDict will also + be updated with new key-value pairs: + + - gamma: indicating the discount to be used for the next + reward; + + - nonterminal: boolean value indicating whether a step is + non-terminal (not done or not last of trajectory); + + - original_reward: previous reward collected in the + environment (i.e. before multi-step); + + - The "reward" values will be replaced by the newly computed + rewards. + + Returns: + in-place transformation of the input tensordict. + + """ + if tensor_dict.batch_dims != 2: + raise RuntimeError("Expected a tensordict with B x T x ... dimensions") + + done = tensor_dict.get("done") + try: + mask = tensor_dict.get("mask") + except KeyError: + mask = done.clone().flip(1).cumsum(1).flip(1).to(torch.bool) + reward = tensor_dict.get("reward") + b, T, *_ = mask.shape + + terminal, post_terminal = _get_terminal(done, self.n_steps_max) + + # Compute gamma for n-step value function + gamma_masked = _get_gamma(self.gamma, reward, mask, self.n_steps_max) + + # step_to_next_state + nonterminal = ~post_terminal[:, :T] + steps_to_next_obs = _get_steps_to_next_obs(nonterminal, self.n_steps_max) + + # Discounted summed reward + partial_return = _conv1d(reward, self.gammas, self.n_steps_max) + + selected_td = tensor_dict.select( + *[ + key + for key in tensor_dict.keys() + if (key.startswith("next_") or key == "done") + ] + ) + + for key, item in selected_td.items(): + tensor_dict.set_( + key, + select_and_repeat( + item, + terminal, + post_terminal, + mask, + self.n_steps_max, + ), + ) + + tensor_dict.set("gamma", gamma_masked) + tensor_dict.set("steps_to_next_obs", steps_to_next_obs) + tensor_dict.set("nonterminal", nonterminal) + tensor_dict.rename_key("reward", "original_reward") + tensor_dict.set("reward", partial_return) + + tensor_dict.set_("done", done) + return tensor_dict diff --git a/torchrl/data/replay_buffers/__init__.py b/torchrl/data/replay_buffers/__init__.py new file mode 100644 index 00000000000..d428f5a74de --- /dev/null +++ b/torchrl/data/replay_buffers/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .replay_buffers import * diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py new file mode 100644 index 00000000000..dc82678b1e3 --- /dev/null +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -0,0 +1,766 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import collections +import concurrent.futures +import functools +import threading +from typing import Any, Callable, List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +from torch import Tensor + +from torchrl._torchrl import MinSegmentTree, SumSegmentTree +from torchrl.data.replay_buffers.utils import ( + cat_fields_to_device, + to_numpy, + to_torch, +) + +__all__ = [ + "ReplayBuffer", + "PrioritizedReplayBuffer", + "TensorDictReplayBuffer", + "TensorDictPrioritizedReplayBuffer", + "create_replay_buffer", + "create_prioritized_replay_buffer", +] + +from torchrl.data.tensordict.tensordict import _TensorDict, stack as stack_td +from torchrl.data.utils import DEVICE_TYPING + + +def stack_tensors(list_of_tensor_iterators: List) -> Tuple[torch.Tensor]: + """Zips a list of iterables containing tensor-like objects and stacks the + resulting lists of tensors together. + + Args: + list_of_tensor_iterators (list): Sequence containing similar iterators, + where each element of the nested iterator is a tensor whose + shape match the tensor of other iterators that have the same index. + + Returns: + Tuple of stacked tensors. + + Examples: + >>> list_of_tensor_iterators = [[torch.ones(3), torch.zeros(1,2)] + ... for _ in range(4)] + >>> stack_tensors(list_of_tensor_iterators) + (tensor([[1., 1., 1.], + [1., 1., 1.], + [1., 1., 1.], + [1., 1., 1.]]), tensor([[[0., 0.]], + + [[0., 0.]], + + [[0., 0.]], + + [[0., 0.]]])) + + """ + return tuple(torch.stack(tensors, 0) for tensors in zip(*list_of_tensor_iterators)) + + +def _pin_memory(output: Any) -> Any: + if hasattr(output, "pin_memory") and output.device == torch.device("cpu"): + return output.pin_memory() + else: + return output + + +def pin_memory_output(fun) -> Callable: + """Calls pin_memory on outputs of decorated function if they have such + method.""" + + def decorated_fun(self, *args, **kwargs): + output = fun(self, *args, **kwargs) + if self._pin_memory: + _tuple_out = True + if not isinstance(output, tuple): + _tuple_out = False + output = (output,) + output = tuple(_pin_memory(_output) for _output in output) + if _tuple_out: + return output + return output[0] + return output + + return decorated_fun + + +class ReplayBuffer: + """ + Circular replay buffer. + + Args: + size (int): integer indicating the maximum size of the replay buffer. + collate_fn (callable, optional): merges a list of samples to form a + mini-batch of Tensor(s)/outputs. Used when using batched + loading from a map-style dataset. + pin_memory (bool): whether pin_memory() should be called on the rb + samples. + prefetch (int, optional): number of next batches to be prefetched + using multithreading. + """ + + def __init__( + self, + size: int, + collate_fn: Optional[Callable] = None, + pin_memory: bool = False, + prefetch: Optional[int] = None, + ): + self._storage = [] + self._capacity = size + self._cursor = 0 + if collate_fn is not None: + self._collate_fn = collate_fn + else: + self._collate_fn = stack_tensors + self._pin_memory = pin_memory + + self._prefetch = prefetch is not None and prefetch > 0 + self._prefetch_cap = prefetch if prefetch is not None else 0 + self._prefetch_fut = collections.deque() + if self._prefetch_cap > 0: + self._prefetch_executor = concurrent.futures.ThreadPoolExecutor( + max_workers=self._prefetch_cap + ) + + self._replay_lock = threading.RLock() + self._future_lock = threading.RLock() + + def __len__(self) -> int: + with self._replay_lock: + return len(self._storage) + + @pin_memory_output + def __getitem__(self, index: Union[int, Tensor]) -> Any: + index = to_numpy(index) + + with self._replay_lock: + if isinstance(index, int): + data = self._storage[index] + else: + data = [self._storage[i] for i in index] + + if isinstance(data, list): + data = self._collate_fn(data) + return data + + @property + def capacity(self) -> int: + return self._capacity + + @property + def cursor(self) -> int: + with self._replay_lock: + return self._cursor + + def add(self, data: Any) -> int: + """Add a single element to the replay buffer. + + Args: + data (Any): data to be added to the replay buffer + + Returns: + index where the data lives in the replay buffer. + """ + with self._replay_lock: + ret = self._cursor + if self._cursor >= len(self._storage): + self._storage.append(data) + else: + self._storage[self._cursor] = data + self._cursor = (self._cursor + 1) % self._capacity + return ret + + def extend(self, data: Sequence[Any]): + """Extends the replay buffer with one or more elements contained in + an iterable. + + Args: + data (iterable): collection of data to be added to the replay + buffer. + + Returns: + Indices of the data aded to the replay buffer. + + """ + if not len(data): + raise Exception("extending with empty data is not supported") + if not isinstance(data, list): + data = list(data) + with self._replay_lock: + cur_size = len(self._storage) + batch_size = len(data) + storage = self._storage + cursor = self._cursor + if cur_size + batch_size <= self._capacity: + index = np.arange(cur_size, cur_size + batch_size) + self._storage += data + self._cursor = (self._cursor + batch_size) % self._capacity + elif cur_size < self._capacity: + d = self._capacity - cur_size + index = np.empty(batch_size, dtype=np.int64) + index[:d] = np.arange(cur_size, self._capacity) + index[d:] = np.arange(batch_size - d) + storage += data[:d] + for i, v in enumerate(data[d:]): + storage[i] = v + self._cursor = batch_size - d + elif self._cursor + batch_size <= self._capacity: + index = np.arange(self._cursor, self._cursor + batch_size) + for i, v in enumerate(data): + storage[cursor + i] = v + self._cursor = (self._cursor + batch_size) % self._capacity + else: + d = self._capacity - self._cursor + index = np.empty(batch_size, dtype=np.int64) + index[:d] = np.arange(self._cursor, self._capacity) + index[d:] = np.arange(batch_size - d) + for i, v in enumerate(data[:d]): + storage[cursor + i] = v + for i, v in enumerate(data[d:]): + storage[i] = v + self._cursor = batch_size - d + + return index + + @pin_memory_output + def _sample(self, batch_size: int) -> Any: + index = np.random.randint(0, len(self._storage), size=batch_size) + + with self._replay_lock: + data = [self._storage[i] for i in index] + + data = self._collate_fn(data) + return data + + def sample(self, batch_size: int) -> Any: + """Samples a batch of data from the replay buffer. + + Args: + batch_size (int): float of data to be collected. + + Returns: + A batch of data randomly selected in the replay buffer. + + """ + if not self._prefetch: + return self._sample(batch_size) + + with self._future_lock: + if len(self._prefetch_fut) == 0: + ret = self._sample(batch_size) + else: + ret = self._prefetch_fut.popleft().result() + + while len(self._prefetch_fut) < self._prefetch_cap: + fut = self._prefetch_executor.submit(self._sample, batch_size) + self._prefetch_fut.append(fut) + + return ret + + def __repr__(self) -> str: + string = ( + f"{self.__class__.__name__}(size={len(self)}, " + f"pin_memory={self._pin_memory})" + ) + return string + + +class PrioritizedReplayBuffer(ReplayBuffer): + """ + Prioritized replay buffer as presented in + "Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. + Prioritized experience replay." + (https://arxiv.org/abs/1511.05952) + + Args: + size (int): integer indicating the maximum size of the replay buffer. + alpha (float): exponent α determines how much prioritization is used, + with α = 0 corresponding to the uniform case. + beta (float): importance sampling negative exponent. + eps (float): delta added to the priorities to ensure that the buffer + does not contain null priorities. + collate_fn (callable, optional): merges a list of samples to form a + mini-batch of Tensor(s)/outputs. Used when using batched + loading from a map-style dataset. + pin_memory (bool): whether pin_memory() should be called on the rb + samples. + prefetch (int, optional): number of next batches to be prefetched + using multithreading. + """ + + def __init__( + self, + size: int, + alpha: float, + beta: float, + eps: float = 1e-8, + collate_fn=None, + pin_memory: bool = False, + prefetch: Optional[int] = None, + ) -> None: + super(PrioritizedReplayBuffer, self).__init__( + size, collate_fn, pin_memory, prefetch + ) + if alpha <= 0: + raise ValueError( + f"alpha must be strictly greater than 0, got alpha={alpha}" + ) + if beta < 0: + raise ValueError(f"beta must be greater or equal to 0, got beta={beta}") + + self._alpha = alpha + self._beta = beta + self._eps = eps + self._sum_tree = SumSegmentTree(size) + self._min_tree = MinSegmentTree(size) + self._max_priority = 1.0 + + @pin_memory_output + def __getitem__(self, index: Union[int, Tensor]) -> Any: + index = to_numpy(index) + + with self._replay_lock: + p_min = self._min_tree.query(0, self._capacity) + if p_min <= 0: + raise ValueError(f"p_min must be greater than 0, got p_min={p_min}") + if isinstance(index, int): + data = self._storage[index] + weight = np.array(self._sum_tree[index]) + else: + data = [self._storage[i] for i in index] + weight = self._sum_tree[index] + + if isinstance(data, list): + data = self._collate_fn(data) + # weight = np.power(weight / (p_min + self._eps), -self._beta) + weight = np.power(weight / p_min, -self._beta) + # x = first_field(data) + # if isinstance(x, torch.Tensor): + device = data.device if hasattr(data, "device") else torch.device("cpu") + weight = to_torch(weight, device, self._pin_memory) + return data, weight + + @property + def alpha(self) -> float: + return self._alpha + + @property + def beta(self) -> float: + return self._beta + + @property + def eps(self) -> float: + return self._eps + + @property + def max_priority(self) -> float: + with self._replay_lock: + return self._max_priority + + @property + def _default_priority(self) -> float: + return (self._max_priority + self._eps) ** self._alpha + + def _add_or_extend( + self, + data: Any, + priority: Optional[torch.Tensor] = None, + do_add: bool = True, + ) -> torch.Tensor: + if priority is not None: + priority = to_numpy(priority) + max_priority = np.max(priority) + with self._replay_lock: + self._max_priority = max(self._max_priority, max_priority) + priority = np.power(priority + self._eps, self._alpha) + else: + with self._replay_lock: + priority = self._default_priority + + if do_add: + index = super(PrioritizedReplayBuffer, self).add(data) + else: + index = super(PrioritizedReplayBuffer, self).extend(data) + + if not ( + isinstance(priority, float) + or len(priority) == 1 + or len(priority) == len(index) + ): + raise RuntimeError( + "priority should be a scalar or an iterable of the same " + "length as index" + ) + + with self._replay_lock: + self._sum_tree[index] = priority + self._min_tree[index] = priority + + return index + + def add(self, data: Any, priority: Optional[torch.Tensor] = None) -> torch.Tensor: + return self._add_or_extend(data, priority, True) + + def extend( + self, data: Sequence, priority: Optional[torch.Tensor] = None + ) -> torch.Tensor: + return self._add_or_extend(data, priority, False) + + @pin_memory_output + def _sample(self, batch_size: int) -> Tuple[Any, torch.Tensor, torch.Tensor]: + with self._replay_lock: + p_sum = self._sum_tree.query(0, self._capacity) + p_min = self._min_tree.query(0, self._capacity) + if p_sum <= 0: + raise RuntimeError("negative p_sum") + if p_min <= 0: + raise RuntimeError("negative p_min") + mass = np.random.uniform(0.0, p_sum, size=batch_size) + index = self._sum_tree.scan_lower_bound(mass) + if isinstance(index, torch.Tensor): + index.clamp_max_(len(self._storage) - 1) + else: + index = np.clip(index, None, len(self._storage) - 1) + data = [self._storage[i] for i in index] + weight = self._sum_tree[index] + + data = self._collate_fn(data) + + # Importance sampling weight formula: + # w_i = (p_i / sum(p) * N) ^ (-beta) + # weight_i = w_i / max(w) + # weight_i = (p_i / sum(p) * N) ^ (-beta) / + # ((min(p) / sum(p) * N) ^ (-beta)) + # weight_i = ((p_i / sum(p) * N) / (min(p) / sum(p) * N)) ^ (-beta) + # weight_i = (p_i / min(p)) ^ (-beta) + # weight = np.power(weight / (p_min + self._eps), -self._beta) + weight = np.power(weight / p_min, -self._beta) + + # x = first_field(data) # avoid calling tree.flatten + # if isinstance(x, torch.Tensor): + device = data.device if hasattr(data, "device") else torch.device("cpu") + weight = to_torch(weight, device, self._pin_memory) + return data, weight, index + + def sample(self, batch_size: int) -> Tuple[Any, np.ndarray, torch.Tensor]: + """Gather a batch of data according to the non-uniform multinomial + distribution with weights computed with the provided priorities of + each input. + + Args: + batch_size (int): float of data to be collected. + + Returns: + + """ + if not self._prefetch: + return self._sample(batch_size) + + with self._future_lock: + if len(self._prefetch_fut) == 0: + ret = self._sample(batch_size) + else: + ret = self._prefetch_fut.popleft().result() + + while len(self._prefetch_fut) < self._prefetch_cap: + fut = self._prefetch_executor.submit(self._sample, batch_size) + self._prefetch_fut.append(fut) + + return ret + + def update_priority( + self, index: Union[int, Tensor], priority: Union[float, Tensor] + ) -> None: + """Updates the priority of the data pointed by the index. + + Args: + index (int or torch.Tensor): indexes of the priorities to be + updated. + priority (Number or torch.Tensor): new priorities of the + indexed elements + + + """ + if isinstance(index, int): + if not isinstance(priority, float): + if len(priority) != 1: + raise RuntimeError( + f"priority length should be 1, got {len(priority)}" + ) + priority = priority.item() + else: + if not ( + isinstance(priority, float) + or len(priority) == 1 + or len(index) == len(priority) + ): + raise RuntimeError( + "priority should be a number or an iterable of the same " + "length as index" + ) + index = to_numpy(index) + priority = to_numpy(priority) + + with self._replay_lock: + self._max_priority = max(self._max_priority, np.max(priority)) + priority = np.power(priority + self._eps, self._alpha) + self._sum_tree[index] = priority + self._min_tree[index] = priority + + +class TensorDictReplayBuffer(ReplayBuffer): + """ + TensorDict-specific wrapper around the ReplayBuffer class. + """ + + def __init__( + self, + size: int, + collate_fn: Optional[Callable] = None, + pin_memory: bool = False, + prefetch: Optional[int] = None, + ): + if collate_fn is None: + + def collate_fn(x): + return stack_td(x, 0, contiguous=True) + + super().__init__(size, collate_fn, pin_memory, prefetch) + + def sample(self, size: int) -> Any: + return super(TensorDictReplayBuffer, self).sample(size)[0] + + +class TensorDictPrioritizedReplayBuffer(PrioritizedReplayBuffer): + """ + TensorDict-specific wrapper around the PrioritizedReplayBuffer class. + This class returns tensordicts with a new key "index" that represents + the index of each element in the replay buffer. It also facilitates the + call to the 'update_priority' method, as it only requires for the + tensordict to be passed to it with its new priority value. + + Args: + size (int): integer indicating the maximum size of the replay buffer. + alpha (flaot): exponent α determines how much prioritization is + used, with α = 0 corresponding to the uniform case. + beta (float): importance sampling negative exponent. + priority_key (str, optional): key where the priority value can be + found in the stored tensordicts. Default is `"td_error"` + eps (float, optional): delta added to the priorities to ensure that the + buffer does not contain null priorities. + collate_fn (callable, optional): merges a list of samples to form a + mini-batch of Tensor(s)/outputs. Used when using batched loading + from a map-style dataset. + pin_memory (bool, optional): whether pin_memory() should be called on + the rb samples. Default is `False`. + prefetch (int, optional): number of next batches to be prefetched + using multithreading. + """ + + def __init__( + self, + size: int, + alpha: float, + beta: float, + priority_key: str = "td_error", + eps: float = 1e-8, + collate_fn=None, + pin_memory: bool = False, + prefetch: Optional[int] = None, + ) -> None: + if collate_fn is None: + + def collate_fn(x): + return stack_td(x, 0, contiguous=True) + + super(TensorDictPrioritizedReplayBuffer, self).__init__( + size=size, + alpha=alpha, + beta=beta, + eps=eps, + collate_fn=collate_fn, + pin_memory=pin_memory, + prefetch=prefetch, + ) + self.priority_key = priority_key + + def _get_priority(self, tensor_dict: _TensorDict) -> torch.Tensor: + if tensor_dict.batch_dims: + raise RuntimeError( + "expected void batch_size for input tensor_dict in " + "rb._get_priority()" + ) + try: + priority = tensor_dict.get(self.priority_key).item() + except ValueError: + raise ValueError( + f"Found a priority key of size" + f" {tensor_dict.get(self.priority_key).shape} but expected " + f"scalar value" + ) + except KeyError: + priority = self._default_priority + return priority + + def add(self, tensor_dict: _TensorDict) -> torch.Tensor: + priority = self._get_priority(tensor_dict) + index = super().add(tensor_dict, priority) + tensor_dict.set("index", index) + return index + + def extend(self, tensor_dicts: _TensorDict) -> torch.Tensor: + if isinstance(tensor_dicts, _TensorDict): + try: + priorities = tensor_dicts.get(self.priority_key) + except KeyError: + priorities = None + tensor_dicts = list(tensor_dicts.unbind(0)) + else: + priorities = [self._get_priority(td) for td in tensor_dicts] + + stacked_td = torch.stack(tensor_dicts, 0) + idx = super().extend(tensor_dicts, priorities) + stacked_td.set("index", idx) + return idx + + def update_priority(self, tensor_dict: _TensorDict) -> None: + """Updates the priorities of the tensordicts stored in the replay + buffer. + + Args: + tensor_dict: tensordict with key-value pairs 'self.priority_key' + and 'index'. + + + """ + priority = tensor_dict.get(self.priority_key) + if (priority < 0).any(): + raise RuntimeError( + f"Priority must be a positive value, got " + f"{(priority < 0).sum()} negative priority values." + ) + return super().update_priority(tensor_dict.get("index"), priority=priority) + + def sample(self, size: int) -> _TensorDict: + """ + Gather a batch of tensordicts according to the non-uniform multinomial + distribution with weights computed with the priority_key of each + input tensordict. + + Args: + size (int): size of the batch to be returned + + Returns: + Stack of tensordicts + + """ + return super(TensorDictPrioritizedReplayBuffer, self).sample(size)[0] + + +def create_replay_buffer( + size: int, + device: Optional[DEVICE_TYPING] = None, + collate_fn: Callable = None, + pin_memory: bool = False, + prefetch: Optional[int] = None, +) -> ReplayBuffer: + """ + Helper function to create a Replay buffer. + + Args: + size (int): integer indicating the maximum size of the replay buffer. + device (str, int or torch.device, optional): device where to cast the + samples. + collate_fn (callable, optional): merges a list of samples to form a + mini-batch of Tensor(s)/outputs. Used when using batched loading + from a map-style dataset. + pin_memory (bool): whether pin_memory() should be called on the rb + samples. + prefetch (int, optional): number of next batches to be prefetched + using multithreading. + + Returns: + a ReplayBuffer instance + + """ + if isinstance(device, str): + device = torch.device(device) + + if device.type == "cuda" and collate_fn is None: + # Postman will add batch_dim for uploaded data, so using cat instead of + # stack here. + collate_fn = functools.partial(cat_fields_to_device, device=device) + + return ReplayBuffer(size, collate_fn, pin_memory, prefetch) + + +def create_prioritized_replay_buffer( + size: int, + alpha: float, + beta: float, + eps: float = 1e-8, + device: Optional[DEVICE_TYPING] = "cpu", + collate_fn: Callable = None, + pin_memory: bool = False, + prefetch: Optional[int] = None, +) -> PrioritizedReplayBuffer: + """ + Helper function to create a Prioritized Replay buffer. + + Args: + size (int): integer indicating the maximum size of the replay buffer. + alpha (float): exponent α determines how much prioritization is used, + with α = 0 corresponding to the uniform case. + beta (float): importance sampling negative exponent. + eps (float): delta added to the priorities to ensure that the buffer + does not contain null priorities. + device (str, int or torch.device, optional): device where to cast the + samples. + collate_fn (callable, optional): merges a list of samples to form a + mini-batch of Tensor(s)/outputs. Used when using batched loading + from a map-style dataset. + pin_memory (bool): whether pin_memory() should be called on the rb + samples. + prefetch (int, optional): number of next batches to be prefetched + using multithreading. + + Returns: + a ReplayBuffer instance + + """ + if isinstance(device, str): + device = torch.device(device) + + if device.type == "cuda" and collate_fn is None: + # Postman will add batch_dim for uploaded data, so using cat instead of + # stack here. + collate_fn = functools.partial(cat_fields_to_device, device=device) + + return PrioritizedReplayBuffer( + size, alpha, beta, eps, collate_fn, pin_memory, prefetch + ) + + +class InPlaceSampler: + def __init__(self, device: Optional[DEVICE_TYPING] = None): + self.out = None + if device is None: + device = "cpu" + self.device = torch.device(device) + + def __call__(self, list_of_tds): + if self.out is None: + self.out = torch.stack(list_of_tds, 0).contiguous() + if self.device is not None: + self.out = self.out.to(self.device) + else: + torch.stack(list_of_tds, 0, out=self.out) + return self.out diff --git a/torchrl/data/replay_buffers/utils.py b/torchrl/data/replay_buffers/utils.py new file mode 100644 index 00000000000..9046e02239a --- /dev/null +++ b/torchrl/data/replay_buffers/utils.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +# import tree +from torch import Tensor + + +def fields_pin_memory(input): # type: ignore + raise NotImplementedError + # return tree.map_structure(lambda x: pin_memory(x), input) + + +def pin_memory(data: Tensor) -> Tensor: + if isinstance(data, torch.Tensor): + return data.pin_memory() + else: + return data + + +def to_numpy(data: Tensor) -> np.ndarray: + return data.detach().cpu().numpy() if isinstance(data, torch.Tensor) else data + + +def fast_map(func, *inputs): # type: ignore + raise NotImplementedError + # flat_inputs = (tree.flatten(x) for x in inputs) + # entries = zip(*flat_inputs) + # return tree.unflatten_as(inputs[-1], [func(*x) for x in entries]) + + +def stack_tensors(input): # type: ignore + if not len(input): + raise RuntimeError("input length must be non-null") + if isinstance(input[0], torch.Tensor): + size = input[0].size() + if len(size) == 0: + return torch.stack(input) + else: + # torch.cat is much faster than torch.stack + # https://github.com/pytorch/pytorch/issues/22462 + return torch.cat(input).view(-1, *size) + else: + return np.stack(input) + + +def stack_fields(input): # type: ignore + if not len(input): + raise RuntimeError("stack_fields requires non-empty list if tensors") + return fast_map(lambda *x: stack_tensors(x), *input) + + +def first_field(data) -> Tensor: # type: ignore + raise NotImplementedError + # return next(iter(tree.flatten(data))) + + +def to_torch( + data: Tensor, device, pin_memory: bool = False, non_blocking: bool = False +) -> torch.Tensor: + if isinstance(data, np.generic): + return torch.tensor(data, device=device) + + if isinstance(data, np.ndarray): + data = torch.from_numpy(data) + + if pin_memory: + data = data.pin_memory() + if device is not None: + data = data.to(device, non_blocking=non_blocking) + + return data + + +def cat_fields_to_device( + input, device, pin_memory: bool = False, non_blocking: bool = False +): # type: ignore + input_on_device = fields_to_device(input, device, pin_memory, non_blocking) + return cat_fields(input_on_device) + + +def cat_fields(input): + if not input: + raise RuntimeError("cat_fields requires a non-empty input collection.") + return fast_map(lambda *x: torch.cat(x), *input) + + +def fields_to_device( + input, device, pin_memory: bool = False, non_blocking: bool = False +): # type:ignore + raise NotImplementedError diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py new file mode 100644 index 00000000000..48a98c9ea61 --- /dev/null +++ b/torchrl/data/tensor_specs.py @@ -0,0 +1,874 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from copy import deepcopy +from dataclasses import dataclass +from textwrap import indent +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch + +__all__ = [ + "TensorSpec", + "BoundedTensorSpec", + "OneHotDiscreteTensorSpec", + "UnboundedContinuousTensorSpec", + "NdBoundedTensorSpec", + "NdUnboundedContinuousTensorSpec", + "BinaryDiscreteTensorSpec", + "MultOneHotDiscreteTensorSpec", + "CompositeSpec", +] + +from torchrl.data.tensordict.tensordict import _TensorDict, TensorDict + +DEVICE_TYPING = Union[torch.device, str, int] + +INDEX_TYPING = Union[int, torch.Tensor, np.ndarray, slice, List] + + +def _default_dtype_and_device( + dtype: Union[None, torch.dtype], + device: Union[None, str, int, torch.device], +) -> Tuple[torch.dtype, torch.device]: + if dtype is None: + dtype = torch.get_default_dtype() + if device is None: + device = torch.device("cpu") + device = torch.device(device) + return dtype, device + + +class invertible_dict(dict): + def __init__(self, *args, inv_dict=dict(), **kwargs): + super().__init__(*args, **kwargs) + self.inv_dict = inv_dict + + def __setitem__(self, k, v): + if v in self.inv_dict or k in self: + raise Exception("overwriting in invertible_dict is not permitted") + self.inv_dict[v] = k + return super().__setitem__(k, v) + + def update(self, d): + raise NotImplementedError + + def invert(self): + d = invertible_dict() + for k, value in self.items(): + d[value] = k + return d + + def inverse(self): + return self.inv_dict + + +class Box: + """ + A box of values + """ + + def __iter__(self): + raise NotImplementedError + + +@dataclass(repr=False) +class Values: + values: Tuple + + +@dataclass(repr=False) +class ContinuousBox(Box): + """ + A continuous box of values, in between a minimum and a maximum. + + """ + + minimum: torch.Tensor + maximum: torch.Tensor + + def __iter__(self): + yield self.minimum + yield self.maximum + + +@dataclass(repr=False) +class DiscreteBox(Box): + """ + A box of discrete values + + """ + + n: int + register = invertible_dict() + + +@dataclass(repr=False) +class BinaryBox(Box): + """ + A box of n binary values + + """ + + n: int + + +@dataclass(repr=False) +class TensorSpec: + """ + Parent class of the tensor meta-data containers for observation, actions + and rewards. + + Args: + shape (torch.Size): size of the tensor + space (Box): Box instance describing what kind of values can be + expected + device (torch.device): device of the tensor + dtype (torch.dtype): dtype of the tensor + + """ + + shape: torch.Size + space: Union[None, Box] + device: torch.device = torch.device("cpu") + dtype: torch.dtype = torch.float + domain: str = "" + + def encode(self, val: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: + """Encodes a value given the specified spec, and return the + corresponding tensor. + + Args: + val (np.ndarray or torch.Tensor): value to be encoded as tensor. + + Returns: + torch.Tensor matching the required tensor specs. + + """ + if not isinstance(val, torch.Tensor): + try: + val = torch.tensor(val, dtype=self.dtype) + except ValueError: + val = torch.tensor(deepcopy(val), dtype=self.dtype) + self.assert_is_in(val) + return val + + def to_numpy(self, val: torch.Tensor) -> np.ndarray: + """Returns the np.ndarray correspondent of an input tensor. + + Args: + val (torch.Tensor): tensor to be transformed to numpy + + Returns: + a np.ndarray + + """ + self.assert_is_in(val) + return val.detach().cpu().numpy() + + def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Tensor: + """Indexes the input tensor + + Args: + index (int, torch.Tensor, slice or list): index of the tensor + tensor_to_index: tensor to be indexed + + Returns: + indexed tensor + + """ + raise NotImplementedError + + def _project(self, val: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + def is_in(self, val: torch.Tensor) -> bool: + """If the value `val` is in the box defined by the TensorSpec, + returns True, otherwise False. + + Args: + val (torch.Tensor): value to be checked + + Returns: + boolean indicating if values belongs to the TensorSpec box + + """ + raise NotImplementedError + + def project(self, val: torch.Tensor) -> torch.Tensor: + """If the input tensor is not in the TensorSpec box, it maps it back + to it given some heuristic. + + Args: + val (torch.Tensor): tensor to be mapped to the box. + + Returns: + a torch.Tensor belonging to the TensorSpec box. + + """ + if not self.is_in(val): + return self._project(val) + return val + + def assert_is_in(self, value: torch.Tensor) -> None: + """Asserts whether a tensor belongs to the box, and raises an + exception otherwise. + + Args: + value (torch.Tensor): value to be checked. + + """ + if not self.is_in(value): + raise AssertionError( + f"Encoding failed because value is not in space. " + f"Consider calling project(val) first. value was = {value}" + ) + + def type_check(self, value: torch.Tensor, key: str = None) -> None: + """Checks the input value dtype against the TensorSpec dtype and + raises an exception if they don't match. + + Args: + value (torch.Tensor): tensor whose dtype has to be checked + key (str, optional): if the TensorSpec has keys, the value + dtype will be checked against the spec pointed by the + indicated key. + + """ + if value.dtype is not self.dtype: + raise TypeError( + f"value.dtype={value.dtype} but" + f" {self.__class__.__name__}.dtype={self.dtype}" + ) + + def rand(self, shape=torch.Size([])) -> torch.Tensor: + """Returns a random tensor in the box. The sampling will be uniform + unless the box is unbounded. + + Args: + shape (torch.Size): shape of the random tensor + + Returns: + a random tensor sampled in the TensorSpec box. + + """ + raise NotImplementedError + + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> "TensorSpec": + if isinstance(dest, (torch.device, str, int)): + self.device = torch.device(dest) + else: + self.dtype = dest + return self + + def __repr__(self): + shape_str = "shape=" + str(self.shape) + space_str = "space=" + str(self.space) + device_str = "device=" + str(self.device) + dtype_str = "dtype=" + str(self.dtype) + domain_str = "domain=" + str(self.domain) + sub_string = ",".join([shape_str, space_str, device_str, dtype_str, domain_str]) + string = f"{self.__class__.__name__}(\n {sub_string})" + return string + + +@dataclass(repr=False) +class BoundedTensorSpec(TensorSpec): + """ + A bounded, unidimensional, continuous tensor spec. + + Args: + minimum (np.ndarray, torch.Tensor or number): lower bound of the box. + maximum (np.ndarray, torch.Tensor or number): upper bound of the box. + device (str, int or torch.device, optional): device of the tensors. + dtype (str or torch.dtype, optional): dtype of the tensors. + """ + + shape: torch.Size + space: ContinuousBox + device: torch.device = torch.device("cpu") + dtype: torch.dtype = torch.float + domain: str = "" + + def __init__( + self, + minimum: Union[np.ndarray, torch.Tensor, float], + maximum: Union[np.ndarray, torch.Tensor, float], + device: Optional[DEVICE_TYPING] = None, + dtype: Optional[torch.dtype] = None, + ): + dtype, device = _default_dtype_and_device(dtype, device) + if not isinstance(minimum, torch.Tensor) or minimum.dtype is not dtype: + minimum = torch.tensor(minimum, dtype=dtype, device=device) + if not isinstance(maximum, torch.Tensor) or maximum.dtype is not dtype: + maximum = torch.tensor(maximum, dtype=dtype, device=device) + super().__init__( + torch.Size( + [ + 1, + ] + ), + ContinuousBox(minimum, maximum), + device, + dtype, + "continuous", + ) + + def rand(self, shape=torch.Size([])) -> torch.Tensor: + a, b = self.space + out = ( + torch.zeros( + *shape, *self.shape, dtype=self.dtype, device=self.device + ).uniform_() + * (b - a) + + a + ) + if (out > b).any(): + out[out > b] = b.expand_as(out)[out > b] + if (out < a).any(): + out[out < a] = a.expand_as(out)[out < a] + return out + + def _project(self, val: torch.Tensor) -> torch.Tensor: + minimum = self.space.minimum.to(val.device) # type: ignore + maximum = self.space.maximum.to(val.device) # type: ignore + try: + val = val.clamp_(minimum.item(), maximum.item()) + except ValueError: + minimum = minimum.expand_as(val) + maximum = maximum.expand_as(val) + val[val < minimum] = minimum[val < minimum] + val[val > maximum] = maximum[val > maximum] + return val + + def is_in(self, val: torch.Tensor) -> bool: + return (val >= self.space.minimum.to(val.device)).all() and ( + val <= self.space.maximum.to(val.device) + ).all() # type: ignore + + +@dataclass(repr=False) +class OneHotDiscreteTensorSpec(TensorSpec): + """ + A unidimensional, one-hot discrete tensor spec. + By default, TorchRL assumes that categorical variables are encoded as + one-hot encodings of the variable. This allows for simple indexing of + tensors, e.g. + + >>> batch, size = 3, 4 + >>> action_value = torch.arange(batch*size) + >>> action_value = action_value.view(batch, size).to(torch.float) + >>> action = (action_value == action_value.max(-1, + ... keepdim=True)[0]).to(torch.long) + >>> chosen_action_value = (action * action_value).sum(-1) + >>> print(chosen_action_value) + tensor([ 3., 7., 11.]) + + Args: + n (int): number of possible outcomes. + device (str, int or torch.device, optional): device of the tensors. + dtype (str or torch.dtype, optional): dtype of the tensors. + user_register (bool): experimental feature. If True, every integer + will be mapped onto a binary vector in the order in which they + appear. This feature is designed for environment with no + a-priori definition of the number of possible outcomes (e.g. + discrete outcomes are sampled from an arbitrary set, whose + elements will be mapped in a register to a series of unique + one-hot binary vectors). + + """ + + shape: torch.Size + space: DiscreteBox + device: torch.device = torch.device("cpu") + dtype: torch.dtype = torch.float + domain: str = "" + + def __init__( + self, + n: int, + device: Optional[DEVICE_TYPING] = None, + dtype: Optional[Union[str, torch.dtype]] = torch.long, + use_register: bool = False, + ): + + dtype, device = _default_dtype_and_device(dtype, device) + self.use_register = use_register + space = DiscreteBox( + n, + ) + shape = torch.Size((space.n,)) + super().__init__(shape, space, device, dtype, "discrete") + + def rand(self, shape=torch.Size([])) -> torch.Tensor: + return torch.nn.functional.gumbel_softmax( + torch.rand(*shape, self.space.n, device=self.device), + hard=True, + dim=-1, + ).to(torch.long) + + def encode( + self, + val: Union[np.ndarray, torch.Tensor], + space: Optional[DiscreteBox] = None, + ) -> torch.Tensor: + if not isinstance(val, torch.Tensor): + val = torch.tensor(val) + + val = torch.tensor(val, dtype=torch.long) + if space is None: + space = self.space + + if self.use_register: + if val not in space.register: + space.register[val] = len(space.register) + val = space.register[val] + + val = torch.nn.functional.one_hot(val, space.n).to(torch.long) + return val + + def to_numpy(self, val: torch.Tensor) -> np.ndarray: + if not isinstance(val, torch.Tensor): + raise NotImplementedError + self.assert_is_in(val) + val = val.argmax(-1).cpu().numpy() + if self.use_register: + inv_reg = self.space.register.inverse() + vals = [] + for _v in val.view(-1): + vals.append(inv_reg[int(_v)]) + return np.array(vals).reshape(tuple(val.shape)) + return val + + def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Tensor: + if not isinstance(index, torch.Tensor): + raise ValueError( + f"Only tensors are allowed for indexing using " + f"{self.__class__.__name__}.index(...)" + ) + index = index.nonzero().squeeze() + index = index.expand(*tensor_to_index.shape[:-1], index.shape[-1]) + return tensor_to_index.gather(-1, index) + + def _project(self, val: torch.Tensor) -> torch.Tensor: + # idx = val.sum(-1) != 1 + out = torch.nn.functional.gumbel_softmax(val.to(torch.float)) + out = (out == out.max(dim=-1, keepdim=True)[0]).to(torch.long) + return out + + def is_in(self, val: torch.Tensor) -> bool: + return (val.sum(-1) == 1).all() + + +@dataclass(repr=False) +class UnboundedContinuousTensorSpec(TensorSpec): + """ + An unbounded, unidimensional, continuous tensor spec. + + Args: + device (str, int or torch.device, optional): device of the tensors. + dtype (str or torch.dtype, optional): dtype of the tensors. + + """ + + shape: torch.Size + space: ContinuousBox + device: torch.device = torch.device("cpu") + dtype: torch.dtype = torch.float + domain: str = "" + + def __init__(self, device=None, dtype=None): + dtype, device = _default_dtype_and_device(dtype, device) + box = ContinuousBox(torch.tensor(-np.inf), torch.tensor(np.inf)) + super().__init__(torch.Size((1,)), box, device, dtype, "composite") + + def rand(self, shape=torch.Size([])) -> torch.Tensor: + return torch.randn(*shape, *self.shape, device=self.device, dtype=self.dtype) + + def is_in(self, val: torch.Tensor) -> bool: + return True + + +@dataclass(repr=False) +class NdBoundedTensorSpec(BoundedTensorSpec): + """ + A bounded, multi-dimensional, continuous tensor spec. + + Args: + minimum (np.ndarray, torch.Tensor or number): lower bound of the box. + maximum (np.ndarray, torch.Tensor or number): upper bound of the box. + device (str, int or torch.device, optional): device of the tensors. + dtype (str or torch.dtype, optional): dtype of the tensors. + + """ + + def __init__( + self, + minimum: Union[float, torch.Tensor, np.ndarray], + maximum: Union[float, torch.Tensor, np.ndarray], + shape: Optional[torch.Size] = None, + device: Optional[DEVICE_TYPING] = None, + dtype: Optional[Union[torch.dtype, str]] = None, + ): + dtype, device = _default_dtype_and_device(dtype, device) + if dtype is None: + dtype = torch.get_default_dtype() + if device is None: + device = torch._get_default_device() + + if not isinstance(minimum, torch.Tensor): + minimum = torch.tensor(minimum, dtype=dtype, device=device) + if not isinstance(maximum, torch.Tensor): + maximum = torch.tensor(maximum, dtype=dtype, device=device) + if dtype is not None and minimum.dtype is not dtype: + minimum = minimum.to(dtype) + if dtype is not None and maximum.dtype is not dtype: + maximum = maximum.to(dtype) + err_msg = ( + "NdBoundedTensorSpec requires the shape to be explicitely (via " + "the shape argument) or implicitely defined (via either the " + "minimum or the maximum or both). If the maximum and/or the " + "minimum have a non-singleton shape, they must match the " + "provided shape if this one is set explicitely." + ) + if shape is not None and not isinstance(shape, torch.Size): + if isinstance(shape, int): + shape = torch.Size([shape]) + else: + shape = torch.Size(list(shape)) + + if maximum.ndimension(): + if shape is not None and shape != maximum.shape: + raise RuntimeError(err_msg) + shape = maximum.shape + minimum = minimum.expand(*shape) + elif minimum.ndimension(): + if shape is not None and shape != minimum.shape: + raise RuntimeError(err_msg) + shape = minimum.shape + maximum = maximum.expand(*shape) + elif shape is None: + raise RuntimeError(err_msg) + else: + minimum = minimum.expand(*shape) + maximum = maximum.expand(*shape) + + if minimum.numel() > maximum.numel(): + maximum = maximum.expand_as(minimum) + elif maximum.numel() > minimum.numel(): + minimum = minimum.expand_as(maximum) + if shape is None: + shape = minimum.shape + else: + if isinstance(shape, float): + shape = torch.Size([shape]) + elif not isinstance(shape, torch.Size): + shape = torch.Size(shape) + shape_err_msg = ( + f"minimum and shape mismatch, got {minimum.shape} and {shape}" + ) + if len(minimum.shape) != len(shape): + raise RuntimeError(shape_err_msg) + if not all(_s == _sa for _s, _sa in zip(shape, minimum.shape)): + raise RuntimeError(shape_err_msg) + self.shape = shape + + super(BoundedTensorSpec, self).__init__( + shape, ContinuousBox(minimum, maximum), device, dtype, "continuous" + ) + + +@dataclass(repr=False) +class NdUnboundedContinuousTensorSpec(UnboundedContinuousTensorSpec): + """ + An unbounded, multi-dimensional, continuous tensor spec. + + Args: + device (str, int or torch.device, optional): device of the tensors. + dtype (str or torch.dtype, optional): dtype of the tensors. + """ + + def __init__( + self, + shape: Union[torch.Size, int], + device: Optional[DEVICE_TYPING] = None, + dtype: Optional[Union[str, torch.dtype]] = None, + ): + if isinstance(shape, int): + shape = torch.Size([shape]) + + dtype, device = _default_dtype_and_device(dtype, device) + super(UnboundedContinuousTensorSpec, self).__init__( + shape=shape, + space=None, + device=device, + dtype=dtype, + domain="continuous", + ) + + +@dataclass(repr=False) +class BinaryDiscreteTensorSpec(TensorSpec): + """ + A binary discrete tensor spec. + + Args: + n (int): length of the binary vector. + device (str, int or torch.device, optional): device of the tensors. + dtype (str or torch.dtype, optional): dtype of the tensors. + + """ + + shape: torch.Size + space: BinaryBox + device: torch.device = torch.device("cpu") + dtype: torch.dtype = torch.float + domain: str = "" + + def __init__( + self, + n: int, + device: Optional[DEVICE_TYPING] = None, + dtype: Union[str, torch.dtype] = torch.long, + ): + dtype, device = _default_dtype_and_device(dtype, device) + shape = torch.Size((n,)) + box = BinaryBox(n) + super().__init__(shape, box, device, dtype, domain="discrete") + + def rand(self, shape=torch.Size([])) -> torch.Tensor: + return torch.zeros( + *shape, *self.shape, device=self.device, dtype=self.dtype + ).bernoulli_() + + def index( + self, index: INDEX_TYPING, tensor_to_index: torch.Tensor + ) -> torch.Tensor: # type: ignore + if not isinstance(index, torch.Tensor): + raise ValueError( + f"Only tensors are allowed for indexing using" + f" {self.__class__.__name__}.index(...)" + ) + index = index.nonzero().squeeze() + index = index.expand(*tensor_to_index.shape[:-1], index.shape[-1]) + return tensor_to_index.gather(-1, index) + + def is_in(self, val: torch.Tensor) -> bool: + return ((val == 0) | (val == 1)).all() + + +@dataclass(repr=False) +class MultOneHotDiscreteTensorSpec(OneHotDiscreteTensorSpec): + """ + A concatenation of one-hot discrete tensor spec. + + Args: + nvec (iterable of integers): cardinality of each of the elements of + the tensor. + device (str, int or torch.device, optional): device of + the tensors. + dtype (str or torch.dtype, optional): dtype of the tensors. + + Examples: + >>> ts = MultOneHotDiscreteTensorSpec((3,2,3)) + >>> ts.is_in(torch.tensor([0,0,1, + ... 0,1, + ... 1,0,0])) + True + >>> ts.is_in(torch.tensor([1,0,1, + ... 0,1, + ... 1,0,0])) # False + False + + """ + + def __init__( + self, + nvec: Sequence[int], + device=None, + dtype=torch.long, + use_register=False, + ): + dtype, device = _default_dtype_and_device(dtype, device) + shape = torch.Size((sum(nvec),)) + space = [DiscreteBox(n) for n in nvec] + self.use_register = use_register + super(OneHotDiscreteTensorSpec, self).__init__( + shape, space, device, dtype, domain="discrete" + ) + + def rand(self, shape: torch.Size = torch.Size([])) -> torch.Tensor: + x = torch.cat( + [ + torch.nn.functional.one_hot( + torch.randint( + space.n, + ( + *shape, + 1, + ), + device=self.device, + ), + space.n, + ).to(torch.long) + for space in self.space + ], + -1, + ).squeeze(-2) + return x + + def encode(self, val: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: + if not isinstance(val, torch.Tensor): + val = torch.tensor(val) + + x = [] + for v, space in zip(val.unbind(-1), self.space): + if not (v < space.n).all(): + raise RuntimeError( + f"value {v} is greater than the allowed max {space.n}" + ) + x.append(super(MultOneHotDiscreteTensorSpec, self).encode(v, space)) + return torch.cat(x, -1) + + def _split(self, val: torch.Tensor) -> torch.Tensor: + vals = val.split([space.n for space in self.space], dim=-1) + return vals + + def to_numpy(self, val: torch.Tensor) -> np.ndarray: + vals = self._split(val) + out = torch.stack([val.argmax(-1) for val in vals], -1).numpy() + return out + + def index( + self, index: INDEX_TYPING, tensor_to_index: torch.Tensor + ) -> torch.Tensor: # type: ignore + if not isinstance(index, torch.Tensor): + raise ValueError( + f"Only tensors are allowed for indexing using" + f" {self.__class__.__name__}.index(...)" + ) + indices = self._split(index) + tensor_to_index = self._split(tensor_to_index) + + out = [] + for _index, _tensor_to_index in zip(indices, tensor_to_index): + _index = _index.nonzero().squeeze() + _index = _index.expand(*_tensor_to_index.shape[:-1], _index.shape[-1]) + out.append(_tensor_to_index.gather(-1, _index)) + return torch.cat(out, -1) + + def is_in(self, val: torch.Tensor) -> bool: + vals = self._split(val) + return all( + [super(MultOneHotDiscreteTensorSpec, self).is_in(_val) for _val in vals] + ) + + def _project(self, val: torch.Tensor) -> torch.Tensor: + vals = self._split(val) + return torch.cat([super()._project(_val) for _val in vals], -1) + + +class CompositeSpec(TensorSpec): + """ + A composition of TensorSpecs. + + Args: + **kwargs (key (str): value (TensorSpec)): dictionary of tensorspecs + to be stored + + Examples: + >>> observation_pixels_spec = NdBoundedTensorSpec( + ... torch.zeros(3,32,32), + ... torch.ones(3, 32, 32)) + >>> observation_vector_spec = NdBoundedTensorSpec(torch.zeros(33), + ... torch.ones(33)) + >>> composite_spec = CompositeSpec( + ... observation_pixels=observation_pixels_spec, + ... observation_vector=observation_vector_spec) + >>> td = TensorDict({"observation_pixels": torch.rand(10,3,32,32), + ... "observation_vector": torch.rand(10,33)}, batch_size=[10]) + >>> print("td (rand) is within bounds: ", composite_spec.is_in(td)) + td (rand) is within bounds: True + >>> td = TensorDict({"observation_pixels": torch.randn(10,3,32,32), + ... "observation_vector": torch.randn(10,33)}, batch_size=[10]) + >>> print("td (randn) is within bounds: ", composite_spec.is_in(td)) + td (randn) is within bounds: False + >>> td_project = composite_spec.project(td) + >>> print("td modification done in place: ", td_project is td) + td modification done in place: True + >>> print("check td is within bounds after projection: ", + ... composite_spec.is_in(td_project)) + check td is within bounds after projection: True + >>> print("random td: ", composite_spec.rand([3,])) + random td: TensorDict( + fields={ + observation_pixels: Tensor(torch.Size([3, 3, 32, 32]), \ +dtype=torch.float32), + observation_vector: Tensor(torch.Size([3, 33]), \ +dtype=torch.float32)}, + batch_size=torch.Size([3]), + device=cpu, + is_shared=False) + + """ + + domain: str = "composite" + + def __init__(self, **kwargs): + self._specs = kwargs + + def __getitem__(self, item): + if item in {"shape", "device", "dtype", "space"}: + raise AttributeError(f"CompositeSpec has no key {item}") + return self._specs[item] + + def __setitem__(self, key, value): + if key in {"shape", "device", "dtype", "space"}: + raise AttributeError(f"CompositeSpec[{key}] cannot be set") + self._specs[key] = value + + def __iter__(self): + for k in self._specs: + yield k + + def del_(self, key: str) -> None: + del self._specs[key] + + def encode(self, vals: Dict[str, Any]) -> Dict[str, torch.Tensor]: + out = {} + for key, item in vals.items(): + out[key] = self[key].encode(item) + return out + + def __repr__(self) -> str: + sub_str = [ + indent(f"{k}: {str(item)}", 4 * " ") for k, item in self._specs.items() + ] + sub_str = ",\n".join(sub_str) + return f"CompositeSpec(\n{sub_str})" + + def type_check(self, value, key): + for _key in self: + if _key in key: + self._specs[_key].type_check(value, _key) + + def is_in(self, val: Union[dict, _TensorDict]) -> bool: # type: ignore + return all([self[key].is_in(val.get(key)) for key in self._specs]) + + def project(self, val: _TensorDict) -> _TensorDict: # type: ignore + for key in self._specs: + _val = val.get(key) + if not self._specs[key].is_in(_val): + val.set(key, self._specs[key].project(_val)) + return val + + def rand(self, shape=torch.Size([])): + return TensorDict( + {key: value.rand(shape) for key, value in self._specs.items()}, + batch_size=shape, + ) diff --git a/torchrl/data/tensordict/__init__.py b/torchrl/data/tensordict/__init__.py new file mode 100644 index 00000000000..77c310ccb02 --- /dev/null +++ b/torchrl/data/tensordict/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .memmap import * +from .metatensor import * +from .tensordict import * diff --git a/torchrl/data/tensordict/memmap.py b/torchrl/data/tensordict/memmap.py new file mode 100644 index 00000000000..5c8ddfd88d4 --- /dev/null +++ b/torchrl/data/tensordict/memmap.py @@ -0,0 +1,441 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import functools +import tempfile +from typing import Any, Callable, List, Optional, Tuple, Union + +import numpy as np +import torch + +from torchrl.data.utils import ( + DEVICE_TYPING, + INDEX_TYPING, + torch_to_numpy_dtype_dict, +) + +MEMMAP_HANDLED_FN = {} + +__all__ = ["MemmapTensor", "set_transfer_ownership"] + + +def implements_for_memmap(torch_function) -> Callable: + """Register a torch function override for ScalarTensor""" + + @functools.wraps(torch_function) + def decorator(func): + MEMMAP_HANDLED_FN[torch_function] = func + return func + + return decorator + + +def to_numpy(tensor: Union[torch.Tensor, np.ndarray]) -> np.ndarray: + if isinstance(tensor, torch.Tensor): + return tensor.detach().cpu().numpy() + else: + return tensor + + +class MemmapTensor(object): + """A torch.tensor interface with a np.memmap array. + + A temporary file is created and cleared once the object is out-of-scope. + This class is aimed at being used for data transfer in between processes + and remote workers that have access to + a common storage, and as such it supports serialization and + deserialization. It is possible to choose if the ownership is + transferred upon serialization / deserialization: If owenership is not + transferred (transfer_ownership=False, default), then the process where + the MemmapTensor was created will be responsible of clearing it once it + gets out of scope (in that process). Otherwise, the process that + deserialize the MemmapTensor will be responsible of clearing the files + once the object is out of scope. + + Supports (almost) all tensor operations. + + Args: + elem (torch.Tensor or MemmapTensor): Tensor to be stored on physical + storage. If MemmapTensor, a new MemmapTensor is created and the + same data is stored in it. + transfer_ownership: bool: affects the ownership after serialization: + if True, the current process looses ownership immediately after + serialization. If False, the current process keeps the ownership + of the temporary file. + Default: False. + + Examples: + >>> x = torch.ones(3,4) + >>> x_memmap = MemmapTensor(x) + >>> # indexing + >>> x0 = x_memmap[0] + >>> x0[:] = 2 + >>> assert (x_memmap[0]==2).all() + >>> + >>> # device + >>> x = x.to('cuda:0') + >>> x_memmap = MemmapTensor(x) + >>> assert (x_memmap.clone()).device == torch.device('cuda:0') + >>> + >>> # operations + >>> assert (x_memmap + 1 == x+1).all() + >>> assert (x_memmap / 2 == x/2).all() + >>> assert (x_memmap * 2 == x*2).all() + >>> + >>> # temp file clearance + >>> filename = x_memmap.filename + >>> assert os.path.isfile(filename) + >>> del x_memmap + >>> assert not os.path.isfile(filename) + + """ + + def __init__( + self, + elem: Union[torch.Tensor, MemmapTensor], + transfer_ownership: bool = False, + ): + if not isinstance(elem, (torch.Tensor, MemmapTensor)): + raise TypeError( + "convert input to torch.Tensor before calling MemmapTensor() " "on it." + ) + + if elem.requires_grad: + raise RuntimeError( + "MemmapTensor is incompatible with tensor.requires_grad. " + "Consider calling tensor.detach() first." + ) + + self.idx = None + self._memmap_array = None + self.file = tempfile.NamedTemporaryFile() + self.filename = self.file.name + self._device = elem.device + self._shape = elem.shape + self.transfer_ownership = transfer_ownership + self.np_shape = tuple(self._shape) + self._dtype = elem.dtype + self._tensor_dir = elem.__dir__() + self._ndim = elem.ndimension() + self._numel = elem.numel() + self.mode = "r+" + self._has_ownership = True + if isinstance(elem, MemmapTensor): + prev_filename = elem.filename + self._copy_item(prev_filename) + if self.memmap_array is elem.memmap_array: + raise RuntimeError + else: + self._save_item(elem) + + def _get_memmap_array(self) -> np.memmap: + if self._memmap_array is None: + self._memmap_array = np.memmap( + self.filename, + dtype=torch_to_numpy_dtype_dict[self.dtype], + mode=self.mode, + shape=self.np_shape, + ) + return self._memmap_array + + def _set_memmap_array(self, value: np.memmap) -> None: + self._memmap_array = value + + memmap_array = property(_get_memmap_array, _set_memmap_array) + + def _save_item( + self, + value: Union[torch.Tensor, MemmapTensor, np.ndarray], + idx: Optional[int] = None, + ): + if isinstance(value, (torch.Tensor,)): + np_array = value.cpu().numpy() + else: + np_array = value + memmap_array = self.memmap_array + if idx is None: + memmap_array[:] = np_array + else: + memmap_array[idx] = np_array + + def _copy_item(self, filename: Union[bytes, str]) -> None: + self.memmap_array[:] = np.memmap( + filename, + dtype=torch_to_numpy_dtype_dict[self.dtype], + mode="r", + shape=self.np_shape, + ) + + def _load_item( + self, + idx: Optional[int] = None, + memmap_array: Optional[np.ndarray] = None, + ) -> torch.Tensor: + if memmap_array is None: + memmap_array = self.memmap_array + if idx is not None: + memmap_array = memmap_array[idx] + return self._np_to_tensor(memmap_array) # type: ignore + + def _np_to_tensor(self, memmap_array: np.ndarray) -> torch.Tensor: + return torch.as_tensor(memmap_array, device=self.device) + + @classmethod + def __torch_function__( + cls, + func: Callable, + types, + args: Tuple = (), + kwargs: Optional[dict] = None, + ): + if kwargs is None: + kwargs = {} + if func not in MEMMAP_HANDLED_FN: + args = tuple(a._tensor if hasattr(a, "_tensor") else a for a in args) + ret = func(*args, **kwargs) + return ret + + return MEMMAP_HANDLED_FN[func](*args, **kwargs) + + @property + def _tensor(self) -> torch.Tensor: + return self._load_item() + + def ndimension(self) -> int: + return self._ndim + + def numel(self) -> int: + return self._numel + + def clone(self) -> MemmapTensor: + """Clones the MemmapTensor onto another MemmapTensor + + Returns: + a new MemmapTensor with the same data but a new storage. + + """ + return MemmapTensor(self) + + def contiguous(self) -> torch.Tensor: + """Copies the MemmapTensor onto a torch.Tensor object. + + Returns: + a torch.Tensor instance with the data of the MemmapTensor + stored on the desired device. + + """ + return self._tensor.clone() + + @property + def device(self) -> torch.device: + return self._device + + @property + def dtype(self) -> torch.dtype: + return self._dtype + + @property + def shape(self) -> torch.Size: + return self._shape + + def cpu(self) -> torch.Tensor: + return self._tensor.cpu() + + def numpy(self) -> np.ndarray: + return self._tensor.numpy() + + def copy_(self, other: Union[torch.Tensor, MemmapTensor]) -> MemmapTensor: + self._save_item(other) + return self + + def set_transfer_ownership(self, value: bool = True) -> MemmapTensor: + """Controls whether the ownership will be transferred to another + process upon serialization/deserialization + + Args: + value (bool): if True, the ownership will be transferred. + Otherwise the process will keep ownership of the + MemmapTensor temp file. + Default = True + + Returns: + the MemmapTensor + + """ + if not isinstance(value, bool): + raise TypeError( + f"value provided to set_transfer_ownership should be a " + f"boolean, got {type(value)}" + ) + self.transfer_ownership = value + return self + + def __del__(self) -> None: + if hasattr(self, "file"): + self.file.close() + + def __eq__(self, other: Any) -> torch.Tensor: # type: ignore + if not isinstance(other, (MemmapTensor, torch.Tensor, float, int, np.ndarray)): + raise NotImplementedError(f"Unknown type {type(other)}") + return self._tensor == other + + def __getattr__(self, attr: str) -> Any: + if attr in self.__dir__(): + print(f"loading {attr} has raised an exception") + return self.__getattribute__( + attr + ) # make sure that appropriate exceptions are raised + if attr not in self.__getattribute__("_tensor_dir"): + raise AttributeError(f"{attr} not found") + _tensor = self.__getattribute__("_tensor") + return getattr(_tensor, attr) + + # if not hasattr(torch.Tensor, attr): + # raise AttributeError(attr) + # return getattr(self._tensor, attr) + + def is_shared(self) -> bool: + return False + + def __add__(self, other: Union[float, MemmapTensor, torch.Tensor]) -> torch.Tensor: + return torch.add(self, other) # type: ignore + + def __truediv__( + self, other: Union[float, MemmapTensor, torch.Tensor] + ) -> torch.Tensor: + return torch.div(self, other) # type: ignore + + def __neg__(self: Union[float, MemmapTensor, torch.Tensor]) -> torch.Tensor: + return torch.neg(self) # type: ignore + + def __sub__(self, other: Union[float, MemmapTensor, torch.Tensor]) -> torch.Tensor: + return torch.sub(self, other) # type: ignore + + def __matmul__( + self, other: Union[float, MemmapTensor, torch.Tensor] + ) -> torch.Tensor: + return torch.matmul(self, other) # type: ignore + + def __mul__(self, other: Union[float, MemmapTensor, torch.Tensor]) -> torch.Tensor: + return torch.mul(self, other) # type: ignore + + def __pow__(self, other: Union[float, MemmapTensor, torch.Tensor]) -> torch.Tensor: + return torch.pow(self, other) # type: ignore + + def __repr__(self) -> str: + return ( + f"MemmapTensor(shape={self.shape}, device={self.device}, " + f"dtype={self.dtype})" + ) + + def __getitem__(self, item: INDEX_TYPING) -> torch.Tensor: + # return self._load_item(memmap_array=self.memmap_array[item])#[item] + return self._load_item()[item] + + def __setitem__(self, idx: INDEX_TYPING, value: torch.Tensor): + # self.memmap_array[idx] = to_numpy(value) + self._load_item()[idx] = value + + def __setstate__(self, state: dict) -> None: + if state["file"] is None: + delete = state["transfer_ownership"] and state["_has_ownership"] + state["_has_ownership"] = delete + tmpfile = tempfile.NamedTemporaryFile(delete=delete) + tmpfile.name = state["filename"] + tmpfile._closer.name = state["filename"] + state["file"] = tmpfile + self.__dict__.update(state) + + def __getstate__(self) -> dict: + state = self.__dict__.copy() + state["file"] = None + state["_memmap_array"] = None + self._has_ownership = self.file.delete + return state + + def __reduce__(self, *args, **kwargs): + if self.transfer_ownership: + self.file.delete = False + self.file._closer.delete = False + return super(MemmapTensor, self).__reduce__(*args, **kwargs) + + def to( + self, dest: Union[DEVICE_TYPING, torch.dtype] + ) -> Union[torch.Tensor, MemmapTensor]: + """Maps a MemmapTensor to a given dtype or device. + + Args: + dest (device indicator or torch.dtype): where to cast the + MemmapTensor. For devices, this is a lazy operation + (as the data is stored on physical memory). For dtypes, the + tensor will be retrieved, mapped to the + desired dtype and cast to a new MemmapTensor. + + Returns: + + """ + if isinstance(dest, (int, str, torch.device)): + dest = torch.device(dest) + return self._tensor.to(dest) + elif isinstance(dest, torch.dtype): + return MemmapTensor(self._tensor.to(dest)) + else: + raise NotImplementedError( + f"argument dest={dest} to MemmapTensor.to(dest) is not " + f"handled. " + f"Please provide a dtype or a device." + ) + + def unbind(self, dim: int) -> Tuple[torch.Tensor, ...]: + """Unbinds a MemmapTensor along the desired dimension. + + Args: + dim (int): dimension along which the MemmapTensor will be split. + + Returns: + A tuple of indexed MemmapTensors that share the same storage. + + """ + idx = [ + (tuple(slice(None) for _ in range(dim)) + (i,)) + for i in range(self.shape[dim]) + ] + return tuple(self[_idx] for _idx in idx) + + +@implements_for_memmap(torch.stack) +def stack( + list_of_memmap: List[MemmapTensor], + dim: int, + out: Optional[Union[torch.Tensor]] = None, +) -> torch.Tensor: + list_of_tensors = [ + a._tensor if isinstance(a, MemmapTensor) else a for a in list_of_memmap + ] + return torch.stack(list_of_tensors, dim, out=out) + + +@implements_for_memmap(torch.unbind) +def unbind(memmap: MemmapTensor, dim: int) -> Tuple[torch.Tensor, ...]: + return memmap.unbind(dim) + + +@implements_for_memmap(torch.cat) +def cat( + list_of_memmap: List[MemmapTensor], + dim: int, + out: Optional[Union[torch.Tensor, MemmapTensor]] = None, +) -> torch.Tensor: + list_of_tensors = [ + a._tensor if isinstance(a, MemmapTensor) else a for a in list_of_memmap + ] + return torch.cat(list_of_tensors, dim, out=out) + + +def set_transfer_ownership(memmap: MemmapTensor, value: bool = True) -> None: + if isinstance(memmap, MemmapTensor): + memmap.set_transfer_ownership(value) diff --git a/torchrl/data/tensordict/metatensor.py b/torchrl/data/tensordict/metatensor.py new file mode 100644 index 00000000000..2fb457b9c47 --- /dev/null +++ b/torchrl/data/tensordict/metatensor.py @@ -0,0 +1,285 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import functools +from numbers import Number +from typing import Callable, List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch + +from torchrl.data.utils import DEVICE_TYPING, INDEX_TYPING +from .memmap import MemmapTensor +from .utils import _getitem_batch_size + +META_HANDLED_FUNCTIONS = dict() + + +def implements_for_meta(torch_function) -> Callable: + """Register a torch function override for ScalarTensor""" + + @functools.wraps(torch_function) + def decorator(func): + META_HANDLED_FUNCTIONS[torch_function] = func + return func + + return decorator + + +class MetaTensor: + """MetaTensor is a custom class that stores the meta-information about a + tensor without requiring to access the tensor. + + This is intended to be used with tensors that have a high access cost. + MetaTensor supports more operations than tensors on 'meta' device ( + `torch.tensor(..., device='meta')`). + For instance, MetaTensor supports some operations on its shape and device, + such as `mt.to(device)`, `mt.view(*new_shape)`, `mt.expand( + *expand_shape)` etc. + + Args: + shape (iterable of integers): shape of the tensor. If the first + element of "shape" is a torch.Tensor, the + MetaTensor is built with this tensor specs. + device (int, str or torch.device): device on which the tensor is + stored. + dtype (torch.dtype): tensor dtype. + + Examples: + >>> meta1 = MetaTensor(3,4, device=torch.device("cpu")) + >>> meta2 = MetaTensor(torch.randn(3,4,device="cuda:0", + ... dtype=torch.double)) + >>> assert meta1.device != meta2.device + >>> assert meta1.dtype != meta2.dtype + >>> assert meta1.expand(2, 3, 4).shape == torch.Size([2, 3, 4]) + >>> assert torch.stack([MetaTensor(3,4) for _ in range(10)], + ... 1).shape == torch.Size([3, 10, 4]) + """ + + def __init__( + self, + *shape: Union[int, torch.Tensor, "MemmapTensor"], + device: Optional[DEVICE_TYPING] = "cpu", + dtype: torch.dtype = torch.get_default_dtype(), + _is_shared: bool = False, + _is_memmap: bool = False, + ): + + if len(shape) == 1 and not isinstance(shape[0], (Number,)): + tensor = shape[0] + shape = tensor.shape + try: + _is_shared = ( + tensor.is_shared() + if tensor.device != torch.device("meta") + else _is_shared + ) + except: # noqa + _is_shared = False + _is_memmap = ( + isinstance(tensor, MemmapTensor) + if tensor.device != torch.device("meta") + else _is_memmap + ) + device = tensor.device if tensor.device != torch.device("meta") else device + dtype = tensor.dtype + if not isinstance(shape, torch.Size): + shape = torch.Size(shape) + self.shape = shape + self.device = ( + torch.device(device) if not isinstance(device, torch.device) else device + ) + self.dtype = dtype + self._ndim = len(shape) + self._numel = np.prod(shape) + self._is_shared = _is_shared + self._is_memmap = _is_memmap + if _is_memmap: + name = "MemmapTensor" + elif _is_shared: + name = "SharedTensor" + else: + name = "Tensor" + self.class_name = name + + def memmap_(self) -> MetaTensor: + """Changes the storage of the MetaTensor to memmap. + + Returns: + self + + """ + self._is_memmap = True + self.class_name = "MemmapTensor" + return self + + def share_memory_(self) -> MetaTensor: + """Changes the storage of the MetaTensor to shared memory. + + Returns: + self + + """ + + self._is_shared = True + self.class_name = "SharedTensor" + return self + + def is_shared(self) -> bool: + return self._is_shared + + def is_memmap(self) -> bool: + return self._is_memmap + + def numel(self) -> int: + return self._numel + + def ndimension(self) -> int: + return self._ndim + + def clone(self) -> MetaTensor: + """ + + Returns: + a new MetaTensor with the same specs. + + """ + return MetaTensor( + *self.shape, + device=self.device, + dtype=self.dtype, + _is_shared=self.is_shared(), + _is_memmap=self.is_memmap(), + ) + + def _to_meta(self) -> torch.Tensor: + return torch.empty(*self.shape, dtype=self.dtype, device="meta") + + def __getitem__(self, item: INDEX_TYPING) -> MetaTensor: + shape = _getitem_batch_size(self.shape, item) + return MetaTensor( + *shape, + dtype=self.dtype, + device=self.device, + _is_shared=self.is_shared(), + ) + + def __torch_function__( + self, + func: Callable, + types, + args: Tuple = (), + kwargs: Optional[dict] = None, + ): + if kwargs is None: + kwargs = {} + if func not in META_HANDLED_FUNCTIONS or not all( + issubclass(t, (torch.Tensor, MetaTensor)) for t in types + ): + return NotImplemented + return META_HANDLED_FUNCTIONS[func](*args, **kwargs) + + def expand(self, *shape: int) -> MetaTensor: + shape = torch.Size([*shape, *self.shape]) + return MetaTensor(*shape, device=self.device, dtype=self.dtype) + + def __repr__(self) -> str: + return ( + f"MetaTensor(shape={self.shape}, device={self.device}, " + f"dtype={self.dtype})" + ) + + def unsqueeze(self, dim: int) -> MetaTensor: + clone = self.clone() + new_shape = [] + shape = [i for i in clone.shape] + for i in range(len(shape) + 1): + if i == dim: + new_shape.append(1) + else: + new_shape.append(shape[0]) + shape = shape[1:] + clone.shape = torch.Size(new_shape) + return clone + + def squeeze(self, dim: Optional[int] = None) -> MetaTensor: + clone = self.clone() + shape = [i for i in clone.shape] + if dim is None: + new_shape = [i for i in shape if i != 1] + else: + new_shape = [] + for i in range(len(shape)): + if i == dim and shape[0] == 1: + continue + else: + new_shape.append(shape[0]) + shape = shape[1:] + clone.shape = torch.Size(new_shape) + return clone + + def view( + self, + *shape: Sequence, + size: Optional[Union[List, Tuple, torch.Size]] = None, + ) -> MetaTensor: + if len(shape) == 0 and size is not None: + return self.view(*size) + elif len(shape) == 1 and isinstance(shape[0], (list, tuple, torch.Size)): + return self.view(*shape[0]) + elif not isinstance(shape, torch.Size): + shape = torch.Size(shape) + new_shape = torch.zeros(self.shape, device="meta").view(*shape) + return MetaTensor(new_shape, device=self.device, dtype=self.dtype) + + +def _stack_meta( + list_of_meta_tensors: Sequence[MetaTensor], + dim: int = 0, + dtype: torch.dtype = torch.float, + device: DEVICE_TYPING = "cpu", + safe: bool = False, +) -> MetaTensor: + if not len(list_of_meta_tensors): + raise RuntimeError("empty list of meta tensors is not supported") + shape = list_of_meta_tensors[0].shape + if safe: + for tensor in list_of_meta_tensors: + if tensor.shape != shape: + raise RuntimeError( + f"Stacking meta tensors of different shapes is not " + f"allowed, got shapes {shape} and {tensor.shape}" + ) + if tensor.dtype != dtype: + raise TypeError( + f"Stacking meta tensors of different dtype is not " + f"allowed, got shapes {dtype} and {tensor.dtype}" + ) + shape = [s for s in shape] + shape.insert(dim, len(list_of_meta_tensors)) + return MetaTensor(*shape, dtype=dtype, device=device) + + +@implements_for_meta(torch.stack) +def stack_meta( + list_of_meta_tensors: Sequence[MetaTensor], + dim: int = 0, + safe: bool = False, +) -> MetaTensor: + dtype = ( + list_of_meta_tensors[0].dtype + if len(list_of_meta_tensors) + else torch.get_default_dtype() + ) + device = ( + list_of_meta_tensors[0].device + if len(list_of_meta_tensors) + else torch.device("cpu") + ) + return _stack_meta( + list_of_meta_tensors, dim=dim, dtype=dtype, device=device, safe=safe + ) diff --git a/torchrl/data/tensordict/tensordict.py b/torchrl/data/tensordict/tensordict.py new file mode 100644 index 00000000000..1bc91520689 --- /dev/null +++ b/torchrl/data/tensordict/tensordict.py @@ -0,0 +1,3247 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import abc +import functools +import math +import tempfile +import textwrap +from collections import OrderedDict +from collections.abc import Mapping +from copy import copy, deepcopy +from numbers import Number +from textwrap import indent +from typing import ( + Callable, + Dict, + Generator, + Iterator, + KeysView, + List, + Optional, + Sequence, + Set, + Tuple, + Type, + Union, +) +from warnings import warn + +import numpy as np +import torch + +from torchrl.data.tensordict.memmap import MemmapTensor +from torchrl.data.tensordict.metatensor import MetaTensor +from torchrl.data.tensordict.utils import _getitem_batch_size, _sub_index +from torchrl.data.utils import DEVICE_TYPING, expand_as_right, INDEX_TYPING + +__all__ = [ + "TensorDict", + "SubTensorDict", + "merge_tensor_dicts", + "LazyStackedTensorDict", + "SavedTensorDict", +] + +TD_HANDLED_FUNCTIONS: Dict = dict() +COMPATIBLE_TYPES = Union[ + torch.Tensor, + MemmapTensor, +] # None? # leaves space for _TensorDict +_accepted_classes = (torch.Tensor, MemmapTensor) + + +class _TensorDict(Mapping, metaclass=abc.ABCMeta): + """ + _TensorDict is an abstract parent class for TensorDicts, the torchrl + data container. + """ + + _safe = False + + def __init__( + self, + source: Union[Dict[str, COMPATIBLE_TYPES], _TensorDict], + batch_size: Optional[Sequence[int]] = None, + ): + raise NotImplementedError + + @property + def shape(self) -> torch.Size: + """See _TensorDict.batch_size""" + return self.batch_size + + @property + @abc.abstractmethod + def batch_size(self) -> torch.Size: + """Shape of (or batch_size) of a TensorDict. + The shape of a tensordict corresponds to the common N first + dimensions of the tensors it contains, where N is an arbitrary + number. The TensorDict shape is controlled by the user upon + initialization (i.e. it is not inferred from the tensor shapes) and + it should not be changed dynamically. + + Returns: + a torch.Size object describing the TensorDict batch size. + + """ + raise NotImplementedError + + @property + def batch_dims(self) -> int: + """Length of the tensordict batch size. + + Returns: + int describing the number of dimensions of the tensordict. + + """ + return len(self.batch_size) + + def ndimension(self) -> int: + return self.batch_dims + + def dim(self) -> int: + return self.batch_dims + + @property + @abc.abstractmethod + def device(self) -> torch.device: + """Device of a TensorDict. All tensors of a tensordict must live on the + same device. + + Returns: + torch.device object indicating the device where the tensors + are placed. + + """ + raise NotImplementedError + + def is_shared(self, no_check=True) -> bool: + """Checks if tensordict is in shared memory. + + This is always True for CUDA tensordicts, except when stored as + MemmapTensors. + + Args: + no_check (bool, optional): checks if all tensors are in shared + memory or not + + """ + if not no_check: + raise RuntimeError( + f"no_check=False is not compatible with TensorDict of type" + f" {self.__class__.__name__}." + ) + return all([item.is_shared() for key, item in self.items_meta()]) + + def is_memmap(self) -> bool: + """Checks if tensordict is stored with MemmapTensors.""" + + return all([item.is_memmap() for key, item in self.items_meta()]) + + def numel(self) -> int: + """Total number of elements in the batch.""" + return max(1, math.prod(self.batch_size)) + + def _check_batch_size(self) -> None: + bs = [value.shape[: self.batch_dims] for key, value in self.items_meta()] + if len(bs): + if bs[0] != self.batch_size: + raise RuntimeError( + "batch_size provided during initialization violates " + "batch size of registered tensors, " + f"got self._batch_size={self.batch_size} and " + f"tensor.shape[:batch_dim]={bs[0]}" + ) + if len(bs) > 1: + for _bs in bs[1:]: + if _bs != bs[0]: + raise RuntimeError( + f"batch_size are incongruent, got {_bs} and {bs[0]} " + f"-- expected {self.batch_size}" + ) + + def _check_is_shared(self) -> bool: + raise NotImplementedError(f"{self.__class__.__name__}") + + def _check_device(self) -> None: + raise NotImplementedError(f"{self.__class__.__name__}") + + def set( + self, key: str, item: COMPATIBLE_TYPES, inplace: bool = False, **kwargs + ) -> _TensorDict: # type: ignore + """Sets a new key-value pair. + + Args: + key (str): name of the value + item (torch.Tensor): value to be stored in the tensordict + inplace (bool, optional): if True and if a key matches an existing + key in the tensordict, then the update will occur in-place + for that key-value pair. Default is `False`. + + Returns: + self + + """ + raise NotImplementedError(f"{self.__class__.__name__}") + + @abc.abstractmethod + def set_(self, key: str, item: COMPATIBLE_TYPES) -> _TensorDict: + """Sets a value to an existing key while keeping the original storage. + + Args: + key (str): name of the value + item (torch.Tensor): value to be stored in the tensordict + + Returns: + self + + """ + raise NotImplementedError(f"{self.__class__.__name__}") + + def _default_get( + self, key: str, default: Union[str, COMPATIBLE_TYPES] = "_no_default_" + ) -> COMPATIBLE_TYPES: + if not isinstance(default, str): + return default + if default == "_no_default_": + raise KeyError( + f"key {key} not found in {self.__class__.__name__} with " + f"keys {sorted(list(self.keys()))}" + ) + else: + raise ValueError( + f"default should be None or a torch.Tensor instance, " f"got {default}" + ) + + @abc.abstractmethod + def get( # type: ignore + self, key: str, default: Union[str, COMPATIBLE_TYPES] = "_no_default_" + ) -> COMPATIBLE_TYPES: # type: ignore + """ + Gets the value stored with the input key. + + Args: + key (str): key to be queried. + default: default value if the key is not found in the tensordict. + + """ + raise NotImplementedError(f"{self.__class__.__name__}") + + def _get_meta(self, key) -> MetaTensor: + raise NotImplementedError(f"{self.__class__.__name__}") + + def apply_(self, fn: Callable) -> _TensorDict: + """Applies a callable to all values stored in the tensordict and + re-writes them in-place. + + Args: + fn (Callable): function to be applied to the tensors in the + tensordict. + + Returns: + self + + """ + for key, item in self.items(): + item_trsf = fn(item) + if item_trsf is not None: + self.set(key, item_trsf, inplace=True) + return self + + def apply( + self, fn: Callable, batch_size: Optional[Sequence[int]] = None + ) -> _TensorDict: + """Applies a callable to all values stored in the tensordict and sets + them in a new tensordict. + + Args: + fn (Callable): function to be applied to the tensors in the + tensordict. + batch_size (sequence of int, optional): if provided, + the resulting TensorDict will have the desired batch_size. + The `batch_size` argument should match the batch_size after + the transformation. + + Returns: + a new tensordict with transformed tensors. + + """ + if batch_size is None: + td = TensorDict({}, batch_size=self.batch_size) + else: + td = TensorDict({}, batch_size=torch.Size(batch_size)) + for key, item in self.items(): + item_trsf = fn(item) + td.set(key, item_trsf) + return td + + def update( # type: ignore + self, + input_dict_or_td: Union[Dict[str, COMPATIBLE_TYPES], _TensorDict], + clone: bool = False, + inplace: bool = False, + **kwargs, + ) -> _TensorDict: + """Updates the TensorDict with values from either a dictionary or + another TensorDict. + + Args: + input_dict_or_td (_TensorDict or dict): Does not keyword arguments + (unlike `dict.update()`). + clone (bool, optional): whether the tensors in the input ( + tensor) dict should be cloned before being set. Default is + `False`. + inplace (bool, optional): if True and if a key matches an existing + key in the tensordict, then the update will occur in-place + for that key-value pair. Default is `False`. + **kwargs: keyword arguments for the `TensorDict.set` method + + Returns: + self + + """ + if input_dict_or_td is self: + # no op + return self + for key, value in input_dict_or_td.items(): + if not isinstance(value, _accepted_classes): + raise TypeError( + f"Expected value to be one of types " + f"{_accepted_classes} but got {type(value)}" + ) + if clone: + value = value.clone() + self.set(key, value, inplace=inplace, **kwargs) + return self + + def update_( + self, + input_dict_or_td: Union[Dict[str, COMPATIBLE_TYPES], _TensorDict], + clone: bool = False, + ) -> _TensorDict: + """Updates the TensorDict in-place with values from either a dictionary + or another TensorDict. + + Unlike TensorDict.update, this function will + throw an error if the key is unknown to the TensorDict + + Args: + input_dict_or_td (_TensorDict or dict): Does not keyword + arguments (unlike `dict.update()`). + clone (bool, optional): whether the tensors in the input ( + tensor) dict should be cloned before being set. Default is + `False`. + + Returns: + self + + """ + if input_dict_or_td is self: + # no op + return self + for key, value in input_dict_or_td.items(): + if not isinstance(value, _accepted_classes): + raise TypeError( + f"Expected value to be one of types {_accepted_classes} " + f"but got {type(value)}" + ) + if clone: + value = value.clone() + self.set_(key, value) + return self + + def update_at_( + self, + input_dict_or_td: Union[Dict[str, COMPATIBLE_TYPES], _TensorDict], + idx: INDEX_TYPING, + clone: bool = False, + ) -> _TensorDict: + """Updates the TensorDict in-place at the specified index with + values from either a dictionary or another TensorDict. + + Unlike TensorDict.update, this function will throw an error if the + key is unknown to the TensorDict. + + Args: + input_dict_or_td (_TensorDict or dict): Does not keyword arguments + (unlike `dict.update()`). + idx (int, torch.Tensor, iterable, slice): index of the tensordict + where the update should occur. + clone (bool, optional): whether the tensors in the input ( + tensor) dict should be cloned before being set. Default is + `False`. + + Returns: + self + + Examples: + >>> td = TensorDict(source={'a': torch.zeros(3, 4, 5), + ... 'b': torch.zeros(3, 4, 10)}, batch_size=[3, 4]) + >>> td.update_at_( + ... TensorDict(source={'a': torch.ones(1, 4, 5), + ... 'b': torch.ones(1, 4, 10)}, batch_size=[1, 4]), + ... slice(1, 2)) + TensorDict( + fields={a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32), + b: Tensor(torch.Size([3, 4, 10]),\ +dtype=torch.float32)}, + shared=False, + batch_size=torch.Size([3, 4]), + device=cpu) + + """ + + for key, value in input_dict_or_td.items(): + if not isinstance(value, _accepted_classes): + raise TypeError( + f"Expected value to be one of types {_accepted_classes} " + f"but got {type(value)}" + ) + if clone: + value = value.clone() + self.set_at_( + key, + value, + idx, + ) + return self + + def _convert_to_tensor( + self, array: np.ndarray + ) -> Union[torch.Tensor, MemmapTensor]: + return torch.tensor(array, device=self.device) + + def _process_tensor( + self, + input: Union[COMPATIBLE_TYPES, np.ndarray], + check_device: bool = True, + check_tensor_shape: bool = True, + check_shared: bool = True, + ) -> Union[torch.Tensor, MemmapTensor]: + + # TODO: move to _TensorDict? + if not isinstance(input, _accepted_classes): + tensor = self._convert_to_tensor(input) + else: + tensor = input + + if ( + check_device + and (self.device is not None) + and (tensor.device is not self.device) + ): + tensor = tensor.to(self.device) + + if check_shared: + if self.is_shared(): + tensor = tensor.share_memory_() + elif self.is_memmap(): + tensor = MemmapTensor(tensor) # type: ignore + elif tensor.is_shared() and len(self): + tensor = tensor.clone() + + if check_tensor_shape and tensor.shape[: self.batch_dims] != self.batch_size: + raise RuntimeError( + f"batch dimension mismatch, got self.batch_size" + f"={self.batch_size} and tensor.shape[:self.batch_dims]" + f"={tensor.shape[: self.batch_dims]}" + ) + + # minimum ndimension is 1 + if tensor.ndimension() - self.ndimension() == 0: + tensor = tensor.unsqueeze(-1) + + return tensor + + @abc.abstractmethod + def pin_memory(self) -> _TensorDict: + """Calls pin_memory() on the stored tensors.""" + raise NotImplementedError(f"{self.__class__.__name__}") + + # @abc.abstractmethod + # def is_pinned(self) -> bool: + # """Checks if tensors are pinned.""" + # raise NotImplementedError(f"{self.__class__.__name__}") + + @abc.abstractmethod + def items(self) -> Iterator[Tuple[str, COMPATIBLE_TYPES]]: # type: ignore + """ + Returns a generator of key-value pairs for the tensordict. + + """ + raise NotImplementedError(f"{self.__class__.__name__}") + + @abc.abstractmethod + def items_meta(self) -> Iterator[Tuple[str, MetaTensor]]: + """Returns a generator of key-value pairs for the tensordict, where the + values are MetaTensor instances corresponding to the stored tensors. + + """ + + raise NotImplementedError(f"{self.__class__.__name__}") + + @abc.abstractmethod + def keys(self) -> KeysView: + """Returns a generator of tensordict keys.""" + + raise NotImplementedError(f"{self.__class__.__name__}") + + def expand(self, *shape: int) -> _TensorDict: + """Expands each tensors of the tensordict according to + `tensor.expand(*shape, *tensor.shape)` + + Examples: + >>> td = TensorDict(source={'a': torch.zeros(3, 4, 5), + ... 'b': torch.zeros(3, 4, 10)}, batch_size=[3, 4]) + >>> td_expand = td.expand(10) + >>> assert td_expand.shape == torch.Size([10, 3, 4]) + >>> assert td_expand.get("a").shape == torch.Size([10, 3, 4, 5]) + """ + + return TensorDict( + source={ + key: value.expand(*shape, *value.shape) for key, value in self.items() + }, + batch_size=[*shape, *self.batch_size], + ) + + def __ne__(self, other: object) -> _TensorDict: # type: ignore + """XOR operation over two tensordicts, for evey key. The two + tensordicts must have the same key set. + + Returns: + a new TensorDict instance with all tensors are boolean + tensors of the same shape as the original tensors. + + """ + + if not isinstance(other, _TensorDict): + raise TypeError( + f"TensorDict comparision requires both objects to be " + f"_TensorDict subclass, got {type(other)}" + ) + keys1 = set(self.keys()) + keys2 = set(other.keys()) + if len(keys1.difference(keys2)) or len(keys1) != len(keys2): + raise KeyError( + f"keys in {self} and {other} mismatch, got {keys1} and {keys2}" + ) + d = dict() + for (key, item1) in self.items(): + d[key] = item1 != other.get(key) + return TensorDict(batch_size=self.batch_size, source=d) + + def __eq__(self, other: object) -> _TensorDict: # type: ignore + """Compares two tensordicts against each other, for evey key. The two + tensordicts must have the same key set. + + Returns: + a new TensorDict instance with all tensors are boolean + tensors of the same shape as the original tensors. + + """ + if not isinstance(other, _TensorDict): + raise TypeError( + f"TensorDict comparision requires both objects to be " + f"_TensorDict subclass, got {type(other)}" + ) + keys1 = set(self.keys()) + keys2 = set(other.keys()) + if len(keys1.difference(keys2)) or len(keys1) != len(keys2): + raise KeyError( + f"keys in {self} and {other} mismatch, got {keys1} and {keys2}" + ) + d = dict() + for (key, item1) in self.items(): + d[key] = item1 == other.get(key) + return TensorDict(batch_size=self.batch_size, source=d) + + @abc.abstractmethod + def del_(self, key: str) -> _TensorDict: + """Deletes a key of the tensordict. + + Args: + key (str): key to be deleted + + Returns: + self + + """ + raise NotImplementedError(f"{self.__class__.__name__}") + + @abc.abstractmethod + def select(self, *keys: str, inplace: bool = False) -> _TensorDict: + """Selects the keys of the tensordict and returns an new tensordict + with only the selected keys. + + The values are not copied: in-place modifications a tensor of either + of the original or new tensordict will result in a change in both + tensordicts. + + Args: + *keys (str): keys to select + inplace (bool): if True, the tensordict is pruned in place. + Default is `False`. + + Returns: + A new tensordict with the selected keys only. + + """ + raise NotImplementedError(f"{self.__class__.__name__}") + + @abc.abstractmethod + def set_at_( + self, key: str, value: COMPATIBLE_TYPES, idx: INDEX_TYPING + ) -> _TensorDict: + """Sets the values in-place at the index indicated by `idx`. + + Args: + key (str): key to be modified. + value (torch.Tensor): value to be set at the index `idx` + idx (int, tensor or tuple): index where to write the values. + + Returns: + self + + """ + raise NotImplementedError(f"{self.__class__.__name__}") + + def copy_(self, tensor_dict: _TensorDict) -> _TensorDict: + """See `_TensorDict.update_`.""" + return self.update_(tensor_dict) + + def copy_at_(self, tensor_dict: _TensorDict, idx: INDEX_TYPING) -> _TensorDict: + """See `_TensorDict.update_at_`.""" + return self.update_at_(tensor_dict, idx) + + def get_at( + self, key: str, idx: INDEX_TYPING, default: COMPATIBLE_TYPES = None + ) -> COMPATIBLE_TYPES: + """Get the value of a tensordict from the key `key` at the index `idx`. + + Args: + key (str): key to be retrieved. + idx (int, slice, torch.Tensor, iterable): index of the tensor. + default (torch.Tensor): default value to return if the key is + not present in the tensordict. + + Returns: + indexed tensor. + + """ + try: + return self.get(key)[idx] + except KeyError: + if default is not None: + return default + raise KeyError( + f"key {key} not found in {self.__class__.__name__} with keys" + f" {sorted(list(self.keys()))}" + ) + + @abc.abstractmethod + def share_memory_(self) -> _TensorDict: + """Places all the tensors in shared memory. + + Returns: + self. + + """ + raise NotImplementedError(f"{self.__class__.__name__}") + + @abc.abstractmethod + def memmap_(self) -> _TensorDict: + """Writes all tensors onto a MemmapTensor. + + Returns: + self. + + """ + + raise NotImplementedError(f"{self.__class__.__name__}") + + @abc.abstractmethod + def detach_(self) -> _TensorDict: + """Detach the tensors in the tensordict in-place. + + Returns: + self. + + """ + raise NotImplementedError(f"{self.__class__.__name__}") + + def detach(self) -> _TensorDict: + """Detach the tensors in the tensordict. + + Returns: + a new tensordict with no tensor requiring gradient. + + """ + + return TensorDict( + {key: item.detach() for key, item in self.items()}, + batch_size=self.batch_size, + ) + + def to_tensordict(self): + """Returns a regular TensorDict instance from the _TensorDict. + + Returns: + a new TensorDict object containing the same values. + + """ + return self.to(TensorDict) + + def zero_(self) -> _TensorDict: + """Zeros all tensors in the tensordict in-place.""" + for key in self.keys(): + self.fill_(key, 0) + return self + + def unbind(self, dim: int) -> Tuple[_TensorDict, ...]: + """Returns a tuple of indexed tensordicts unbound along the + indicated dimension. Resulting tensordicts will share + the storage of the initial tensordict. + + """ + idx = [ + (tuple(slice(None) for _ in range(dim)) + (i,)) + for i in range(self.shape[dim]) + ] + return tuple(self[_idx] for _idx in idx) + + def chunk(self, chunks: int, dim: int = 0) -> Tuple[_TensorDict, ...]: + """Attempts to split a tendordict into the specified number of + chunks. Each chunk is a view of the input tensordict. + + Args: + chunks (int): number of chunks to return + dim (int, optional): dimension along which to split the + tensordict. Default is 0. + + """ + if chunks < 1: + raise ValueError( + f"chunks must be a strictly positive integer, got {chunks}." + ) + indices = [] + _idx_start = 0 + if chunks > 1: + interval = _idx_end = self.batch_size[dim] // chunks + else: + interval = _idx_end = self.batch_size[dim] + for c in range(chunks): + indices.append(slice(_idx_start, _idx_end)) + _idx_start = _idx_end + if c < chunks - 2: + _idx_end = _idx_end + interval + else: + _idx_end = self.batch_size[dim] + if dim < 0: + dim = len(self.batch_size) + dim + return tuple(self[(*[slice(None) for _ in range(dim)], idx)] for idx in indices) + + def clone(self, recursive: bool = True) -> _TensorDict: + """Clones a _TensorDict subclass instance onto a new TensorDict. + + Args: + recursive (bool, optional): if True, each tensor contained in the + TensorDict will be copied too. Default is `True`. + """ + return TensorDict( + source={ + key: value.clone() if recursive else value + for key, value in self.items() + }, + batch_size=self.batch_size, + ) + + def __torch_function__( + self, + func: Callable, + types, + args: Tuple = (), + kwargs: Optional[dict] = None, + ) -> Callable: + if kwargs is None: + kwargs = {} + if func not in TD_HANDLED_FUNCTIONS or not all( + issubclass(t, (torch.Tensor, _TensorDict)) for t in types + ): + return NotImplemented + return TD_HANDLED_FUNCTIONS[func](*args, **kwargs) + + @abc.abstractmethod + def to(self, dest: Union[DEVICE_TYPING, Type], **kwargs) -> _TensorDict: + """Maps a _TensorDict subclass either on a new device or to another + _TensorDict subclass (if permitted). Casting tensors to a new dtype + is not allowed, as tensordicts are not bound to contain a single + tensor dtype. + + Args: + dest (device or _TensorDict subclass): destination of the + tensordict. + + Returns: + a new tensordict. If device indicated by dest differs from + the tensordict device, this is a no-op. + + """ + raise NotImplementedError + + def cpu(self) -> _TensorDict: + """Casts a tensordict to cpu (if not already on cpu).""" + return self.to("cpu") + + def cuda(self, device: int = 0) -> _TensorDict: + """Casts a tensordict to a cuda device (if not already on it).""" + return self.to(f"cuda:{device}") + + @abc.abstractmethod + def masked_fill_( + self, mask: torch.Tensor, value: Union[float, bool] + ) -> _TensorDict: + """Fills the values corresponding to the mask with the desired value. + + Args: + mask (boolean torch.Tensor): mask of values to be filled. Shape + must match tensordict shape. + value: value to used to fill the tensors. + + Returns: + self + + Examples: + >>> td = TensorDict(source={'a': torch.zeros(3, 4)}, + ... batch_size=[3]) + >>> mask = torch.tensor([True, False, False]) + >>> _ = td.masked_fill_(mask, 1.0) + >>> td.get("a") + tensor([[1., 1., 1., 1.], + [0., 0., 0., 0.], + [0., 0., 0., 0.]]) + """ + raise NotImplementedError + + def masked_select(self, mask: torch.Tensor) -> _TensorDict: + """Masks all tensors of the TensorDict and return a new TensorDict + instance with similar keys pointing to masked values. + + Args: + mask (torch.Tensor): boolean mask to be used for the tensors. + Shape must match the TensorDict batch_size. + + Examples: + >>> td = TensorDict(source={'a': torch.zeros(3, 4)}, + ... batch_size=[3]) + >>> mask = torch.tensor([True, False, False]) + >>> td_mask = td.masked_select(mask) + >>> td_mask.get("a") + tensor([[0., 0., 0., 0.]]) + + """ + d = dict() + for key, value in self.items(): + mask_expand = mask.squeeze(-1) + value_select = value[mask_expand] + d[key] = value_select + dim = int(mask.sum().item()) + return TensorDict(device=self.device, source=d, batch_size=torch.Size([dim])) + + @abc.abstractmethod + def is_contiguous(self) -> bool: + """ + + Returns: + boolean indicating if all the tensors are contiguous. + + """ + raise NotImplementedError + + @abc.abstractmethod + def contiguous(self) -> _TensorDict: + """ + + Returns: + a new tensordict of the same type with contiguous values ( + or self if values are already contiguous). + + """ + raise NotImplementedError + + def to_dict(self) -> dict: + """ + + Returns: + dictionary with key-value pairs matching those of the + tensordict. + + """ + return {key: value for key, value in self.items()} + + def unsqueeze(self, dim: int) -> _TensorDict: + """Unsqueeze all tensors for a dimension comprised in between + `-td.batch_dims` and `td.batch_dims` and returns them in a new + tensordict. + + Args: + dim (int): dimension along which to unsqueeze + + """ + if dim < 0: + dim = self.batch_dims + dim + 1 + + if (dim > self.batch_dims) or (dim < 0): + raise RuntimeError( + f"unsqueezing is allowed for dims comprised between " + f"`-td.batch_dims` and `td.batch_dims` only. Got " + f"dim={dim} with a batch size of {self.batch_size}." + ) + return UnsqueezedTensorDict( + source=self, + custom_op="unsqueeze", + inv_op="squeeze", + custom_op_kwargs={"dim": dim}, + inv_op_kwargs={"dim": dim}, + ) + + def squeeze(self, dim: int) -> _TensorDict: + """Squeezes all tensors for a dimension comprised in between + `-td.batch_dims+1` and `td.batch_dims-1` and returns them + in a new tensordict. + + Args: + dim (int): dimension along which to squeeze + + """ + if dim < 0: + dim = self.batch_dims + dim + + if self.batch_dims and (dim >= self.batch_dims or dim < 0): + raise RuntimeError( + f"squeezing is allowed for dims comprised between 0 and " + f"td.batch_dims only. Got dim={dim} and batch_size" + f"={self.batch_size}." + ) + + if dim >= self.batch_dims or self.batch_size[dim] != 1: + return self + return SqueezedTensorDict( + source=self, + custom_op="squeeze", + inv_op="unsqueeze", + custom_op_kwargs={"dim": dim}, + inv_op_kwargs={"dim": dim}, + ) + + def reshape( + self, + *shape: int, + size: Optional[Union[List, Tuple, torch.Size]] = None, + ) -> _TensorDict: + """Returns a contiguous, reshaped tensor of the desired shape. + + Args: + *shape (int): new shape of the resulting tensordict. + size: iterable + + Returns: + A TensorDict with reshaped keys + + """ + if len(shape) == 0 and size is not None: + return self.view(*size) + elif len(shape) == 1 and isinstance(shape[0], (list, tuple, torch.Size)): + return self.view(*shape[0]) + elif not isinstance(shape, torch.Size): + shape = torch.Size(shape) + + d = {} + for key, item in self.items(): + d[key] = item.reshape(*shape, *item.shape[self.ndimension() :]) + if len(d): + batch_size = d[key].shape[: len(shape)] + else: + if any(not isinstance(i, int) or i < 0 for i in shape): + raise RuntimeError( + "Implicit reshaping is not permitted with empty " "tensordicts" + ) + batch_size = shape + return TensorDict(d, batch_size) + + def view( + self, + *shape: int, + size: Optional[Union[List, Tuple, torch.Size]] = None, + ) -> _TensorDict: + """Returns a tensordict with views of the tensors according to a new + shape, compatible with the tensordict batch_size. + + Args: + *shape (int): new shape of the resulting tensordict. + size: iterable + + Returns: + a new tensordict with the desired batch_size. + + Examples: + >>> td = TensorDict(source={'a': torch.zeros(3,4,5), + ... 'b': torch.zeros(3,4,10,1)}, batch_size=torch.Size([3, 4])) + >>> td_view = td.view(12) + >>> print(td_view.get("a").shape) # torch.Size([12, 5]) + >>> print(td_view.get("b").shape) # torch.Size([12, 10, 1]) + >>> td_view = td.view(-1, 4, 3) + >>> print(td_view.get("a").shape) # torch.Size([1, 4, 3, 5]) + >>> print(td_view.get("b").shape) # torch.Size([1, 4, 3, 10, 1]) + + """ + if len(shape) == 0 and size is not None: + return self.view(*size) + elif len(shape) == 1 and isinstance(shape[0], (list, tuple, torch.Size)): + return self.view(*shape[0]) + elif not isinstance(shape, torch.Size): + shape = torch.Size(shape) + return ViewedTensorDict( + source=self, + custom_op="view", + inv_op="view", + custom_op_kwargs={"size": shape}, + inv_op_kwargs={"size": self.batch_size}, + ) + + def __repr__(self) -> str: + fields = _td_fields(self) + field_str = indent(f"fields={{{fields}}}", 4 * " ") + batch_size_str = indent(f"batch_size={self.batch_size}", 4 * " ") + device_str = indent(f"device={self.device}", 4 * " ") + is_shared_str = indent(f"is_shared={self.is_shared()}", 4 * " ") + string = ",\n".join([field_str, batch_size_str, device_str, is_shared_str]) + return f"{type(self).__name__}(\n{string})" + + def all(self, dim: int = None) -> Union[bool, _TensorDict]: + """Checks if all values are True/non-null in the tensordict. + + Args: + dim (int, optional): if None, returns a boolean indicating + whether all tensors return `tensor.all() == True` + If integer, all is called upon the dimension specified if + and only if this dimension is compatible with the tensordict + shape. + + """ + if dim is not None and (dim >= self.batch_dims or dim <= -self.batch_dims): + raise RuntimeError( + "dim must be greater than -tensordict.batch_dims and smaller " + "than tensordict.batchdims" + ) + if dim is not None: + if dim < 0: + dim = self.batch_dims + dim + return TensorDict( + source={key: value.all(dim=dim) for key, value in self.items()}, + batch_size=[b for i, b in enumerate(self.batch_size) if i != dim], + ) + return all([value.all() for key, value in self.items()]) + + def any(self, dim: int = None) -> Union[bool, _TensorDict]: + """Checks if any value is True/non-null in the tensordict. + + Args: + dim (int, optional): if None, returns a boolean indicating + whether all tensors return `tensor.any() == True`. + If integer, all is called upon the dimension specified if + and only if this dimension is compatible with + the tensordict shape. + + """ + if dim is not None and (dim >= self.batch_dims or dim <= -self.batch_dims): + raise RuntimeError( + "dim must be greater than -tensordict.batch_dims and smaller " + "than tensordict.batchdims" + ) + if dim is not None: + if dim < 0: + dim = self.batch_dims + dim + return TensorDict( + source={key: value.any(dim=dim) for key, value in self.items()}, + batch_size=[b for i, b in enumerate(self.batch_size) if i != dim], + ) + return any([value.any() for key, value in self.items()]) + + def get_sub_tensor_dict(self, idx: INDEX_TYPING) -> _TensorDict: + """Returns a SubTensorDict with the desired index.""" + sub_td = SubTensorDict( + source=self, + idx=idx, + ) + return sub_td + + def __iter__(self) -> Generator: + if not self.batch_dims: + raise StopIteration + length = self.batch_size[0] + for i in range(length): + yield self[i] + + def __len__(self) -> int: + """ + + Returns: + Number of keys in _TensorDict instance. + + """ + return len(list(self.keys())) + + def __getitem__(self, idx: INDEX_TYPING) -> _TensorDict: + """Indexes all tensors according to idx and returns a new tensordict + where the values share the storage of the original tensors (even + when the index is a torch.Tensor). Any in-place modification to the + resulting tensordict will impact the parent tensordict too. + + Examples: + >>> td = TensorDict(source={'a': torch.zeros(3,4,5)}, + ... batch_size=torch.Size([3, 4])) + >>> subtd = td[torch.zeros(1, dtype=torch.long)] + >>> assert subtd.shape == torch.Size([1,4]) + >>> subtd.set("a", torch.ones(1,4,5)) + >>> print(td.get("a")) # first row is full of 1 + >>> # Warning: this will not work as expected + >>> subtd.get("a")[:] = 2.0 + >>> print(td.get("a")) # values have not changed + + """ + if isinstance(idx, str): + return self.get(idx) + if isinstance(idx, Number): + idx = (idx,) + elif isinstance(idx, torch.Tensor) and idx.dtype == torch.bool: + return self.masked_select(idx) + + contiguous_input = (int, slice) + return_simple_view = isinstance(idx, contiguous_input) or ( + isinstance(idx, tuple) + and all(isinstance(_idx, contiguous_input) for _idx in idx) + ) + if not self.batch_size: + raise RuntimeError( + "indexing a tensordict with td.batch_dims==0 is not permitted" + ) + if return_simple_view and not self.is_memmap(): + return TensorDict( + source={key: item[idx] for key, item in self.items()}, + batch_size=_getitem_batch_size(self.batch_size, idx), + ) + # SubTensorDict keeps the same storage as TensorDict + # in all cases not accounted for above + return self.get_sub_tensor_dict(idx) + + def __setitem__(self, index: INDEX_TYPING, value: _TensorDict) -> None: + indexed_bs = _getitem_batch_size(self.batch_size, index) + if value.batch_size != indexed_bs: + raise RuntimeError( + f"indexed destination TensorDict batch size is {indexed_bs} " + f"(batch_size = {self.batch_size}, index={index}), " + f"which differs from the source batch size {value.batch_size}" + ) + for key, item in value.items(): + self.set_at_(key, item, index) + + @abc.abstractmethod + def rename_key(self, old_key: str, new_key: str, safe: bool = False) -> _TensorDict: + """Renames a key with a new string. + + Args: + old_key (str): key to be renamed + new_key (str): new name + safe (bool, optional): if True, an error is thrown when the new + key is already present in the TensorDict. + + Returns: + self + + """ + raise NotImplementedError + + def fill_(self, key: str, value: Union[float, bool]) -> _TensorDict: + """Fills a tensor pointed by the key with the a given value. + + Args: + key (str): key to be remaned + value (Number, bool): value to use for the filling + + Returns: + self + + """ + + meta_tensor = self._get_meta(key) + shape = meta_tensor.shape + device = meta_tensor.device + dtype = meta_tensor.dtype + value = torch.full(shape, value, device=device, dtype=dtype) + self.set_(key, value) + return self + + def empty(self) -> _TensorDict: + """Returns a new, empty tensordict with the same device and batch size.""" + return self.select() + + def is_empty(self): + for i in self.items_meta(): + return False + return True + + +class TensorDict(_TensorDict): + """A batched dictionary of tensors. + + TensorDict is a tensor container where all tensors are stored in a + key-value pair fashion and where each element shares at least the + following features: + - device; + - memory location (shared, memory-mapped array, ...); + - batch size (i.e. n^th first dimensions). + + TensorDict instances support many regular tensor operations as long as + they are dtype-independent (as a TensorDict instance can contain tensors + of many different dtypes). Those operations include (but are not limited + to): + + - operations on shape: when a shape operation is called (indexing, + reshape, view, expand, transpose, permute, + unsqueeze, squeeze, masking etc), the operations is done as if it + was done on a tensor of the same shape as the batch size then + expended to the right, e.g.: + + >>> td = TensorDict({'a': torch.zeros(3,4,5)}, batch_size=[3, 4]) + >>> # returns a TensorDict of batch size [3, 4, 1] + >>> td_unsqueeze = td.unsqueeze(-1) + >>> # returns a TensorDict of batch size [12] + >>> td_view = td.view(-1) + >>> # returns a tensor of batch size [12, 4] + >>> a_view = td.view(-1).get("a") + + - casting operations: a TensorDict can be cast on a different device + or another TensorDict type using + + >>> td_cpu = td.to("cpu") + >>> td_savec = td.to(SavedTensorDict) # TensorDict saved on disk + >>> dictionary = td.to_dict() + + A call of the `.to()` method with a dtype will return an error. + + - Cloning, contiguous + + - Reading: `td.get(key)`, `td.get_at(key, index)` + + - Content modification: `td.set(key, value)`, `td.set_(key, value)`, + `td.update(td_or_dict)`, `td.update_(td_or_dict)`, `td.fill_(key, + value)`, `td.rename_key(old_name, new_name)`, etc. + + - Operations on multiple tensordicts: `torch.cat(tensordict_list, dim)`, + `torch.stack(tensordict_list, dim)`, `td1 == td2` etc. + + Args: + source (TensorDict or dictionary): a data source. If empty, the + tensordict can be populated subsequently. + batch_size (iterable of int, optional): a batch size for the + tensordict. The batch size is immutable and can only be modified + by calling operations that create a new TensorDict. Unless the + source is another TensorDict, the batch_size argument must be + provided as it won't be inferred from the data. + device (torch.device or compatible type, optional): a device for the + TensorDict. If the source is non-empty and the device is + missing, it will be inferred from the input dictionary, assuming + that all tensors are on the same device. + + Examples: + >>> import torch + >>> from torchrl.data import TensorDict + >>> source = {'random': torch.randn(3, 4), + ... 'zeros': torch.zeros(3, 4, 5)} + >>> batch_size = [3] + >>> td = TensorDict(source, batch_size) + >>> print(td.shape) # equivalent to td.batch_size + torch.Size([3]) + >>> td_unqueeze = td.unsqueeze(-1) + >>> print(td_unqueeze.get("zeros").shape) + torch.Size([3, 1, 4, 5]) + >>> print(td_unqueeze[0].shape) + torch.Size([1]) + >>> print(td_unqueeze.view(-1).shape) + torch.Size([3]) + >>> print((td.clone()==td).all()) + True + + """ + + # TODO: split, transpose, permute + _safe = True + + def __init__( + self, + source: Union[_TensorDict, dict], + batch_size: Optional[Union[Sequence[int], torch.Size, int]] = None, + device: Optional[DEVICE_TYPING] = None, + _meta_source: Optional[dict] = None, + ): + self._tensor_dict: Dict = dict() + self._tensor_dict_meta: OrderedDict = OrderedDict() + if not isinstance(source, (_TensorDict, dict)): + raise ValueError( + "A TensorDict source is expected to be a _TensorDict " + f"sub-type or a dictionary, found type(source)={type(source)}." + ) + if isinstance( + batch_size, + ( + Number, + Sequence, + ), + ): + if not isinstance(batch_size, torch.Size): + if isinstance(batch_size, int): + batch_size = torch.Size([batch_size]) + else: + batch_size = torch.Size(batch_size) + self._batch_size = batch_size + self._batch_dims = len(batch_size) + elif isinstance(source, _TensorDict): + self._batch_size = source.batch_size + else: + raise ValueError( + "batch size was not specified when creating the TensorDict " + "instance and it could not be retrieved from source." + ) + + if isinstance(device, (int, str)): + device = torch.device(device) + map_item_to_device = device is not None + self._device = device + if source is not None: + for key, value in source.items(): + if not isinstance(key, str): + raise TypeError( + f"Expected key to be a string but found {type(key)}" + ) + if not isinstance(value, _accepted_classes): + raise TypeError( + f"Expected value to be one of types" + f" {_accepted_classes} but got {type(value)}" + ) + if map_item_to_device: + value = value.to(device) # type: ignore + _meta_val = None if _meta_source is None else _meta_source[key] + self.set(key, value, _meta_val=_meta_val, _run_checks=False) + + self._check_batch_size() + self._check_device() + + @property + def batch_dims(self) -> int: + if self._safe and hasattr(self, "_batch_dims"): + if len(self.batch_size) != self._batch_dims: + raise RuntimeError("len(self.batch_size) and self._batch_dims mismatch") + return len(self.batch_size) + + @batch_dims.setter + def batch_dims(self, value: COMPATIBLE_TYPES) -> None: + raise RuntimeError( + f"Setting batch dims on {self.__class__.__name__} instances is " + f"not allowed." + ) + + def is_shared(self, no_check: bool = False) -> bool: + if no_check: + for key, item in self.items_meta(): + return item.is_shared() + return self._check_is_shared() + + def is_memmap(self) -> bool: + return self._check_is_memmap() + + @property + def device(self) -> torch.device: + device = self._device + if device is None and not self.is_empty(): + for _, item in self.items_meta(): + device = item.device + break + if (not isinstance(device, torch.device)) and (device is not None): + device = torch.device(device) + self._device = device + return device # type: ignore + + @device.setter + def device(self, value: DEVICE_TYPING) -> None: + raise RuntimeError( + f"Setting device on {self.__class__.__name__} instances is not " + f"allowed. Please call {self.__class__.__name__}.to(device) " + f"instead." + ) + + @property + def batch_size(self) -> torch.Size: + return self._batch_size + + @batch_size.setter + def batch_size(self, value: COMPATIBLE_TYPES) -> None: + raise RuntimeError( + f"Setting batch size on {self.__class__.__name__} instances is " + f"not allowed." + ) + + # Checks + def _check_is_shared(self) -> bool: + share_list = [value.is_shared() for key, value in self.items_meta()] + if any(share_list) and not all(share_list): + shared_str = ", ".join( + [f"{key}: {value.is_shared()}" for key, value in self.items_meta()] + ) + raise RuntimeError( + f"tensors must be either all shared or not, but mixed " + f"features is not allowed. " + f"Found: {shared_str}" + ) + return all(share_list) and len(share_list) > 0 + + def _check_is_memmap(self) -> bool: + memmap_list = [value.is_memmap() for key, value in self.items_meta()] + if any(memmap_list) and not all(memmap_list): + memmap_str = ", ".join( + [f"{key}: {value.is_memmap()}" for key, value in self.items_meta()] + ) + raise RuntimeError( + f"tensors must be either all MemmapTensor or not, but mixed " + f"features is not allowed. " + f"Found: {memmap_str}" + ) + return all(memmap_list) and len(memmap_list) > 0 + + def _check_device(self) -> None: + devices = {key: value.device for key, value in self.items_meta()} + if len(devices): + if not ( + len(np.unique([str(device) for key, device in devices.items()])) == 1 + ): + raise RuntimeError( + f"expected tensors to be on a single device, found" f" {devices}" + ) + device = devices[list(devices.keys())[0]] + if torch.device(device) != self.device: + raise RuntimeError( + f"expected {self.__class__.__name__}.device to be " + f"identical to tensors device, found" + f" {self.__class__.__name__}.device={self.device} and" + f" {device}" + ) + + def pin_memory(self) -> _TensorDict: + if self.device == torch.device("cpu"): + for key, value in self.items(): + if value.dtype in (torch.half, torch.float, torch.double): + self.set(key, value.pin_memory(), inplace=False) + return self + + def expand(self, *shape: int) -> _TensorDict: + """Expands every tensor with `(*shape, *tensor.shape)` and returns the + same tensordict with new tensors with expanded shapes. + """ + _batch_size = torch.Size([*shape, *self.batch_size]) + d = {key: value.expand(*shape, *value.shape) for key, value in self.items()} + return TensorDict(source=d, batch_size=_batch_size) + + def set( # type: ignore + self, + key: str, + value: COMPATIBLE_TYPES, + inplace: bool = False, + _run_checks: bool = True, + _meta_val: Optional[MetaTensor] = None, + ) -> _TensorDict: + """Sets a value in the TensorDict. If inplace=True (default is False), + and if the key already exists, set will call set_ (in place setting). + """ + if not isinstance(key, str): + raise TypeError(f"Expected key to be a string but found {type(key)}") + + if key in self._tensor_dict and value is self._tensor_dict[key]: + return self + + proc_value = self._process_tensor( + value, + check_tensor_shape=_run_checks, + check_shared=_run_checks, + check_device=_run_checks, + ) # check_tensor_shape=_run_checks + if key in self._tensor_dict and inplace: + return self.set_(key, proc_value) + self._tensor_dict[key] = proc_value + self._tensor_dict_meta[key] = ( + MetaTensor(proc_value) if _meta_val is None else _meta_val + ) + return self + + def del_(self, key: str) -> _TensorDict: + del self._tensor_dict[key] + del self._tensor_dict_meta[key] + return self + + def rename_key(self, old_key: str, new_key: str, safe: bool = False) -> _TensorDict: + if not isinstance(old_key, str): + raise TypeError( + f"Expected old_name to be a string but found {type(old_key)}" + ) + if not isinstance(new_key, str): + raise TypeError( + f"Expected new_name to be a string but found {type(new_key)}" + ) + + if safe and (new_key in self.keys()): + raise KeyError(f"key {new_key} already present in TensorDict.") + self.set(new_key, self.get(old_key)) + self.del_(old_key) + return self + + def set_(self, key: str, value: COMPATIBLE_TYPES) -> _TensorDict: + if not isinstance(key, str): + raise TypeError(f"Expected key to be a string but found {type(key)}") + + if key in self.keys(): + proc_value = self._process_tensor( + value, check_device=False, check_shared=False + ) + target_shape = self._get_meta(key).shape + if proc_value.shape != target_shape: + raise RuntimeError( + f'calling set_("{key}", tensor) with tensors of ' + f"different shape: got tensor.shape={proc_value.shape} " + f'and get("{key}").shape={target_shape}' + ) + if proc_value is not self._tensor_dict[key]: + self._tensor_dict[key].copy_(proc_value) + else: + raise AttributeError( + f"key {key} not found in tensordict, " + f"call td.set({key}, value) for populating tensordict with " + f"new key-value pair" + ) + return self + + def set_at_( + self, key: str, value: COMPATIBLE_TYPES, idx: INDEX_TYPING + ) -> _TensorDict: + if not isinstance(key, str): + raise TypeError(f"Expected key to be a string but found {type(key)}") + + # do we need this? + # value = self._process_tensor( + # value, check_tensor_shape=False, check_device=False + # ) + if key not in self.keys(): + raise KeyError(f"did not find key {key} in {self.__class__.__name__}") + tensor_in = self._tensor_dict[key] + if isinstance(idx, tuple) and len(idx) and isinstance(idx[0], tuple): + warn( + "Multiple indexing can lead to unexpected behaviours when " + "setting items, for instance `td[idx1][idx2] = other` may " + "not write to the desired location if idx1 is a list/tensor." + ) + tensor_in = _sub_index(tensor_in, idx) + tensor_in.copy_(value) + else: + tensor_in[idx] = value + return self + + def get( # type: ignore + self, key: str, default: Union[str, COMPATIBLE_TYPES] = "_no_default_" + ) -> COMPATIBLE_TYPES: # type: ignore + if not isinstance(key, str): + raise TypeError(f"Expected key to be a string but found {type(key)}") + + try: + return self._tensor_dict[key] + except KeyError: + return self._default_get(key, default) + + def _get_meta(self, key: str) -> MetaTensor: + if not isinstance(key, str): + raise TypeError(f"Expected key to be a string but found {type(key)}") + + try: + return self._tensor_dict_meta[key] + except KeyError: + raise KeyError( + f"key {key} not found in {self.__class__.__name__} with keys" + f" {sorted(list(self.keys()))}" + ) + + def share_memory_(self) -> _TensorDict: + if self.is_memmap(): + raise RuntimeError( + "memmap and shared memory are mutually exclusive features." + ) + if not len(self._tensor_dict): + raise Exception( + "share_memory_ must be called when the TensorDict is (" + "partially) populated. Set a tensor first." + ) + for key, value in self.items(): + value.share_memory_() + for key, value in self.items_meta(): + value.share_memory_() + return self + + def detach_(self) -> _TensorDict: + for key, value in self.items(): + value.detach_() + return self + + def memmap_(self) -> _TensorDict: + if self.is_memmap(): + raise RuntimeError( + "memmap and shared memory are mutually exclusive features." + ) + if not len(self._tensor_dict): + raise Exception( + "memmap_() must be called when the TensorDict is (partially) " + "populated. Set a tensor first." + ) + for key, value in self.items(): + self._tensor_dict[key] = MemmapTensor(value) + for key, value in self.items_meta(): + value.memmap_() + return self + + def to(self, dest: Union[DEVICE_TYPING, Type], **kwargs) -> _TensorDict: + if isinstance(dest, type) and issubclass(dest, _TensorDict): + td = dest( + source=self, + **kwargs, # type: ignore + ) + return td + elif isinstance(dest, (torch.device, str, int)): + # must be device + if not isinstance(dest, torch.device): + dest = torch.device(dest) + if dest == self.device: + return self + + self_copy = TensorDict( + source={key: value.to(dest) for key, value in self.items()}, + batch_size=self.batch_size, + ) + if self._safe: + # sanity check + self_copy._check_device() + self_copy._check_is_shared() + return self_copy + else: + raise NotImplementedError( + f"dest must be a string, torch.device or a TensorDict " + f"instance, {dest} not allowed" + ) + + def masked_fill_( + self, mask: torch.Tensor, val: Union[float, int, bool] + ) -> _TensorDict: + for key, value in self.items(): + mask_expand = expand_as_right(mask, value) + value.masked_fill_(mask_expand, val) + return self + + def is_contiguous(self) -> bool: + return all([value.is_contiguous() for _, value in self.items()]) + + def contiguous(self) -> _TensorDict: + if not self.is_contiguous(): + return self.clone() + return self + + def select(self, *keys: str, inplace: bool = False) -> _TensorDict: + d = {key: value for (key, value) in self.items() if key in keys} + d_meta = {key: value for (key, value) in self.items_meta() if key in keys} + if inplace: + self._tensor_dict = d + self._tensor_dict_meta = OrderedDict( + {key: value for (key, value) in self.items_meta() if key in keys} + ) + return self + return TensorDict( + device=self.device, + batch_size=self.batch_size, + source=d, + _meta_source=d_meta, + ) + + def items(self) -> Iterator[Tuple[str, COMPATIBLE_TYPES]]: # type: ignore + for k in self._tensor_dict: + yield k, self.get(k) + + def items_meta(self) -> Iterator[Tuple[str, MetaTensor]]: + for k in self._tensor_dict_meta: + yield k, self._get_meta(k) + + def keys(self) -> KeysView: + return self._tensor_dict.keys() + + +def implements_for_td(torch_function: Callable) -> Callable: + """Register a torch function override for ScalarTensor""" + + @functools.wraps(torch_function) + def decorator(func): + TD_HANDLED_FUNCTIONS[torch_function] = func + return func + + return decorator + + +# @implements_for_td(torch.testing.assert_allclose) TODO +def assert_allclose_td( + actual: _TensorDict, + expected: _TensorDict, + rtol: float = None, + atol: float = None, + equal_nan: bool = True, + msg: str = "", +) -> bool: + if not isinstance(actual, _TensorDict) or not isinstance(expected, _TensorDict): + raise TypeError("assert_allclose inputs must be of TensorDict type") + set1 = set(actual.keys()) + set2 = set(expected.keys()) + if not (len(set1.difference(set2)) == 0 and len(set2) == len(set1)): + raise KeyError( + "actual and expected tensordict keys mismatch, " + f"keys {(set1 - set2).union(set2 - set1)} appear in one but not " + f"the other." + ) + keys = sorted(list(actual.keys())) + for key in keys: + input1 = actual.get(key) + input2 = expected.get(key) + mse = (input1.to(torch.float) - input2.to(torch.float)).pow(2).sum() + mse = mse.div(input1.numel()).sqrt().item() + + default_msg = f"key {key} does not match, got mse = {mse:4.4f}" + if len(msg): + msg = "\t".join([default_msg, msg]) + else: + msg = default_msg + torch.testing.assert_allclose( + input1, input2, rtol=rtol, atol=atol, equal_nan=equal_nan, msg=msg + ) + return True + + +@implements_for_td(torch.unbind) +def unbind(td: _TensorDict, *args, **kwargs) -> Tuple[_TensorDict, ...]: + return td.unbind(*args, **kwargs) + + +@implements_for_td(torch.clone) +def clone(td: _TensorDict, *args, **kwargs) -> _TensorDict: + return td.clone(*args, **kwargs) + + +@implements_for_td(torch.cat) +def cat( + list_of_tensor_dicts: Sequence[_TensorDict], + dim: int = 0, + device: DEVICE_TYPING = None, + out: _TensorDict = None, +) -> _TensorDict: + if not list_of_tensor_dicts: + raise RuntimeError("list_of_tensor_dicts cannot be empty") + if not device: + device = list_of_tensor_dicts[0].device + if dim < 0: + raise RuntimeError( + f"negative dim in torch.dim(list_of_tensor_dicts, dim=dim) not " + f"allowed, got dim={dim}" + ) + + batch_size = list(list_of_tensor_dicts[0].batch_size) + if dim >= len(batch_size): + raise RuntimeError( + f"dim must be in the range 0 <= dim < len(batch_size), got dim" + f"={dim} and batch_size={batch_size}" + ) + batch_size[dim] = sum([td.batch_size[dim] for td in list_of_tensor_dicts]) + batch_size = torch.Size(batch_size) + + # check that all tensordict match + keys = _check_keys(list_of_tensor_dicts, strict=True) + if out is None: + out = TensorDict({}, device=device, batch_size=batch_size) + for key in keys: + out.set( + key, + torch.cat([td.get(key) for td in list_of_tensor_dicts], dim) + # type: ignore + ) + return out + else: + if out.batch_size != batch_size: + raise RuntimeError( + "out.batch_size and cat batch size must match, " + f"got out.batch_size={out.batch_size} and batch_size" + f"={batch_size}" + ) + + for key in keys: + out.set_( + key, + torch.cat([td.get(key) for td in list_of_tensor_dicts], dim) + # type: ignore + ) + return out + + +@implements_for_td(torch.stack) +def stack( + list_of_tensor_dicts: Sequence[_TensorDict], + dim: int = 0, + out: _TensorDict = None, + strict=False, + contiguous=False, +) -> _TensorDict: + if not list_of_tensor_dicts: + raise RuntimeError("list_of_tensor_dicts cannot be empty") + batch_size = list_of_tensor_dicts[0].batch_size + if dim < 0: + dim = len(batch_size) + dim + 1 + if len(list_of_tensor_dicts) > 1: + for td in list_of_tensor_dicts[1:]: + if td.batch_size != list_of_tensor_dicts[0].batch_size: + raise RuntimeError( + "stacking tensor_dicts requires them to have congruent " + "batch sizes, got td1.batch_size={td.batch_size} and " + f"td2.batch_size{list_of_tensor_dicts[0].batch_size}" + ) + # check that all tensordict match + keys = _check_keys(list_of_tensor_dicts) + batch_size = list(batch_size) + batch_size.insert(dim, len(list_of_tensor_dicts)) + batch_size = torch.Size(batch_size) + + if out is None: + out = LazyStackedTensorDict( + *list_of_tensor_dicts, + stack_dim=dim, + ) + if contiguous: + out = out.contiguous() + return out + else: + if out.batch_size != batch_size: + raise RuntimeError( + "out.batch_size and stacked batch size must match, " + f"got out.batch_size={out.batch_size} and batch_size" + f"={batch_size}" + ) + if strict: + out_keys = set(out.keys()) + in_keys = set(keys) + if len(out_keys - in_keys) > 0: + raise RuntimeError( + "The output tensordict has keys that are missing in the " + "tensordict that has to be written: {out_keys - in_keys}. " + "As per the call to `stack(..., strict=True)`, this " + "is not permitted." + ) + elif len(in_keys - out_keys) > 0: + raise RuntimeError( + "The resulting tensordict has keys that are missing in " + f"its destination: {in_keys - out_keys}. As per the call " + "to `stack(..., strict=True)`, this is not permitted." + ) + + for key in keys: + out.set( + key, + torch.stack([td.get(key) for td in list_of_tensor_dicts], dim), + # type: ignore + inplace=True, + ) + return out + + +# @implements_for_td(torch.nn.utils.rnn.pad_sequence) +def pad_sequence_td( + list_of_tensor_dicts: Sequence[_TensorDict], + batch_first: bool = True, + padding_value: float = 0.0, + out: _TensorDict = None, + device: Optional[DEVICE_TYPING] = None, +): + if not list_of_tensor_dicts: + raise RuntimeError("list_of_tensor_dicts cannot be empty") + # check that all tensordict match + keys = _check_keys(list_of_tensor_dicts) + if out is None: + out = TensorDict({}, [], device=device) + for key in keys: + out.set( + key, + torch.nn.utils.rnn.pad_sequence( + [td.get(key) for td in list_of_tensor_dicts], + # type: ignore + batch_first=batch_first, + padding_value=padding_value, + ), + ) + return out + else: + for key in keys: + out.set_( + key, + torch.nn.utils.rnn.pad_sequence( + [td.get(key) for td in list_of_tensor_dicts], + # type: ignore + batch_first=batch_first, + padding_value=padding_value, + ), + ) + return out + + +class SubTensorDict(_TensorDict): + """ + A TensorDict that only sees an index of the stored tensors. + + By default, indexing a tensordict with an iterable will result in a + SubTensorDict. This is done such that a TensorDict indexed with + non-contiguous index (e.g. a Tensor) will still point to the original + memory location (unlike regular indexing of tensors). + + Examples: + >>> from torchrl.data import TensorDict, SubTensorDict + >>> source = {'random': torch.randn(3, 4, 5, 6), + ... 'zeros': torch.zeros(3, 4, 1, dtype=torch.bool)} + >>> batch_size = torch.Size([3, 4]) + >>> td = TensorDict(source, batch_size) + >>> td_index = td[:, 2] + >>> print(type(td_index), td_index.shape) + \ +torch.Size([3]) + >>> td_index = td[:, slice(None)] + >>> print(type(td_index), td_index.shape) + \ +torch.Size([3, 4]) + >>> td_index = td[:, torch.Tensor([0, 2]).to(torch.long)] + >>> print(type(td_index), td_index.shape) + \ +torch.Size([3, 2]) + >>> _ = td_index.fill_('zeros', 1) + >>> # the indexed tensors are updated with Trues + >>> print(td.get('zeros')) + tensor([[[ True], + [False], + [ True], + [False]], + + [[ True], + [False], + [ True], + [False]], + + [[ True], + [False], + [ True], + [False]]]) + + """ + + _safe = False + + def __init__( + self, + source: _TensorDict, + idx: INDEX_TYPING, + batch_size: Optional[Sequence[int]] = None, + ): + if not isinstance(source, _TensorDict): + raise TypeError( + f"Expected source to be a subclass of _TensorDict, " + f"got {type(source)}" + ) + self._source = source + if not isinstance(idx, (tuple, list)): + idx = (idx,) + else: + idx = tuple(idx) + self.idx = idx + self._batch_size = _getitem_batch_size(self._source.batch_size, self.idx) + if batch_size is not None and batch_size != self.batch_size: + raise RuntimeError("batch_size does not match self.batch_size.") + + @property + def batch_size(self) -> torch.Size: + return self._batch_size + + @property + def device(self) -> torch.device: + return self._source.device + + def _preallocate(self, key: str, value: COMPATIBLE_TYPES) -> _TensorDict: + return self._source.set(key, value) + + def set( # type: ignore + self, + key: str, + tensor: COMPATIBLE_TYPES, + inplace: bool = False, + _run_checks: bool = True, + ) -> _TensorDict: # type: ignore + if inplace and key in self.keys(): + return self.set_(key, tensor) + + tensor = self._process_tensor( + tensor, check_device=False, check_tensor_shape=False + ) + parent = self.get_parent_tensor_dict() + tensor_expand = torch.zeros( + *parent.batch_size, + *tensor.shape[self.batch_dims :], + dtype=tensor.dtype, + device=self.device, + ) + + if self.is_shared(): + tensor_expand.share_memory_() + elif self.is_memmap(): + tensor_expand = MemmapTensor(tensor_expand) # type: ignore + + parent.set(key, tensor_expand, _run_checks=_run_checks) + self.set_(key, tensor) + return self + + def keys(self) -> KeysView: + return self._source.keys() + + def set_(self, key: str, tensor: COMPATIBLE_TYPES) -> SubTensorDict: + if key not in self.keys(): + raise KeyError(f"key {key} not found in {self.keys()}") + if tensor.shape[: self.batch_dims] != self.batch_size: + raise RuntimeError( + f"tensor.shape={tensor.shape[:self.batch_dims]} and " + f"self.batch_size={self.batch_size} mismatch" + ) + self._source.set_at_(key, tensor, self.idx) + return self + + def to(self, dest: Union[DEVICE_TYPING, Type], **kwargs) -> _TensorDict: + if isinstance(dest, type) and issubclass(dest, _TensorDict): + return dest( + source=self.clone(), + ) + elif isinstance(dest, (torch.device, str, int)): + if not isinstance(dest, torch.device): + dest = torch.device(dest) + if dest == self.device: + return self + td = self.to(TensorDict) + # must be device + return td.to(dest, **kwargs) + else: + raise NotImplementedError( + f"dest must be a string, torch.device or a TensorDict " + f"instance, {dest} not allowed" + ) + + def get( # type: ignore + self, key: str, default: Optional[Union[torch.Tensor, str]] = None + ) -> COMPATIBLE_TYPES: # type: ignore + return self._source.get_at(key, self.idx) + + def _get_meta(self, key: str) -> MetaTensor: + return self._source._get_meta(key)[self.idx] + + def set_at_( + self, + key: str, + value: COMPATIBLE_TYPES, + idx: INDEX_TYPING, + discard_idx_attr: bool = False, + ) -> SubTensorDict: + if not isinstance(idx, tuple): + idx = (idx,) + if discard_idx_attr: + self._source.set_at_(key, value, idx) + else: + tensor = self._source.get_at(key, self.idx) + tensor[idx] = value # type: ignore + self._source.set_at_(key, tensor, self.idx) + # self._source.set_at_(key, value, (self.idx, idx)) + return self + + def get_at( # type: ignore + self, key: str, idx: INDEX_TYPING, discard_idx_attr: bool = False + ) -> COMPATIBLE_TYPES: + if not isinstance(idx, tuple): + idx = (idx,) + if discard_idx_attr: + return self._source.get_at(key, idx) + else: + return self._source.get_at(key, self.idx)[idx] + + def update_( # type: ignore + self, + input_dict: Union[Dict[str, COMPATIBLE_TYPES], _TensorDict], + clone: bool = False, + ) -> SubTensorDict: + return self.update_at_( + input_dict, idx=self.idx, discard_idx_attr=True, clone=clone + ) + + def update_at_( # type: ignore + self, + input_dict: Union[Dict[str, COMPATIBLE_TYPES], _TensorDict], + idx: INDEX_TYPING, + discard_idx_attr: bool = False, + clone: bool = False, + ) -> SubTensorDict: + for key, value in input_dict.items(): + if not isinstance(value, _accepted_classes): + raise TypeError( + f"Expected value to be one of types {_accepted_classes} " + f"but got {type(value)}" + ) + if clone: + value = value.clone() + self.set_at_( + key, + value, + idx, + discard_idx_attr=discard_idx_attr, + ) + return self + + def get_parent_tensor_dict(self) -> _TensorDict: + if not isinstance(self._source, _TensorDict): + raise TypeError( + f"SubTensorDict was initialized with a source of type" + f" {self._source.__class__.__name__}, " + "parent tensordict not accessible" + ) + return self._source + + def del_(self, key: str) -> _TensorDict: + self._source = self._source.del_(key) + return self + + def clone(self, recursive: bool = True) -> SubTensorDict: + if not recursive: + return copy(self) + return SubTensorDict( + source=self._source, + idx=self.idx, + ) + + def is_contiguous(self) -> bool: + return all([value.is_contiguous() for _, value in self.items()]) + + def contiguous(self) -> _TensorDict: + if self.is_contiguous(): + return self + return TensorDict( + batch_size=self.batch_size, + source={key: value for key, value in self.items()}, + ) + + def items(self) -> Iterator[Tuple[str, COMPATIBLE_TYPES]]: # type: ignore + for k in self.keys(): + yield k, self.get(k) + + def items_meta(self) -> Iterator[Tuple[str, MetaTensor]]: + for key, value in self._source.items_meta(): + yield key, value[self.idx] + + def select(self, *keys: str, inplace: bool = False) -> _TensorDict: + if inplace: + self._source = self._source.select(*keys) + return self + return self._source.select(*keys)[self.idx] + + def expand(self, *shape: int, inplace: bool = False) -> _TensorDict: + new_source = self._source.expand(*shape) + idx = tuple(slice(None) for _ in shape) + tuple(self.idx) + if inplace: + self._source = new_source + self.idx = idx + return new_source[idx] + + def is_shared(self, no_check: bool = True) -> bool: + return self._source.is_shared(no_check=no_check) + + def is_memmap(self) -> bool: + return self._source.is_memmap() + + def rename_key( + self, old_key: str, new_key: str, safe: bool = False + ) -> SubTensorDict: + self._source.rename_key(old_key, new_key, safe=safe) + return self + + def pin_memory(self) -> _TensorDict: + self._source.pin_memory() + return self + + def detach_(self) -> _TensorDict: + raise RuntimeError("Detaching a sub-tensordict in-place cannot be done.") + + def masked_fill_( + self, mask: torch.Tensor, value: Union[float, bool] + ) -> _TensorDict: + for key, item in self.items(): + self.set_(key, torch.full_like(item, value)) + return self + + def memmap_(self) -> _TensorDict: + raise RuntimeError( + "Converting a sub-tensordict values to memmap cannot be done." + ) + + def share_memory_(self) -> _TensorDict: + raise RuntimeError( + "Casting a sub-tensordict values to shared memory cannot be done." + ) + + +def merge_tensor_dicts(*tensor_dicts: _TensorDict) -> _TensorDict: + if len(tensor_dicts) < 2: + raise RuntimeError( + f"at least 2 tensor_dicts must be provided, got" f" {len(tensor_dicts)}" + ) + d = tensor_dicts[0].to_dict() + for td in tensor_dicts[1:]: + d.update(td.to_dict()) + return TensorDict({}, [], device=td.device).update(d) + + +class LazyStackedTensorDict(_TensorDict): + """A Lazy stack of TensorDicts. + + When stacking TensorDicts together, the default behaviour is to put them + in a stack that is not instantiated. + This allows to seamlessly work with stacks of tensordicts with operations + that will affect the original tensordicts. + + Args: + *tensor_dicts (TensorDict instances): a list of tensordict with + same batch size. + stack_dim (int): a dimension (between `-td.ndimension()` and + `td.ndimension()-1` along which the stack should be performed. + + Examples: + >>> from torchrl.data import TensorDict + >>> import torch + >>> tds = [TensorDict({'a': torch.randn(3, 4)}, batch_size=[3]) + ... for _ in range(10)] + >>> td_stack = torch.stack(tds, -1) + >>> print(td_stack.shape) + torch.Size([3, 10]) + >>> print(td_stack.get("a").shape) + torch.Size([3, 10, 4]) + >>> print(td_stack[:, 0] is tds[0]) + True + """ + + _safe = False + + def __init__( + self, + *tensor_dicts: _TensorDict, + stack_dim: int = 0, + batch_size: Optional[Sequence[int]] = None, # TODO: remove + ): + # sanity check + N = len(tensor_dicts) + if not isinstance(tensor_dicts[0], _TensorDict): + raise TypeError( + f"Expected input to be _TensorDict instance" + f" but got {type(tensor_dicts[0])} instead." + ) + if stack_dim < 0: + raise RuntimeError( + f"stack_dim must be non negative, got stack_dim={stack_dim}" + ) + if not N: + raise RuntimeError( + "at least one tensordict must be provided to " + "StackedTensorDict to be instantiated" + ) + _batch_size = tensor_dicts[0].batch_size + device = tensor_dicts[0].device + + for i, td in enumerate(tensor_dicts[1:]): + if not isinstance(td, _TensorDict): + raise TypeError( + f"Expected input to be _TensorDict instance" + f" but got {type(tensor_dicts[0])} instead." + ) + _bs = td.batch_size + _device = td.device + if device != _device: + raise RuntimeError(f"devices differ, got {device} and {_device}") + if _bs != _batch_size: + raise RuntimeError( + f"batch sizes in tensor_dicts differs, StackedTensorDict " + f"cannot be created. Got td[0].batch_size={_batch_size} " + f"and td[i].batch_size={_bs} " + ) + self.tensor_dicts: List[_TensorDict] = list(tensor_dicts) + self.stack_dim = stack_dim + self._batch_size = self._compute_batch_size(_batch_size, stack_dim, N) + self._batch_dims = len(self._batch_size) + self._update_valid_keys() + self._meta_dict = dict() + self._meta_dict.update({k: value for k, value in self.items_meta()}) + if batch_size is not None and batch_size != self.batch_size: + raise RuntimeError("batch_size does not match self.batch_size.") + + @property + def device(self) -> torch.device: + device_set = {td.device for td in self.tensor_dicts} + if len(device_set) != 1: + raise RuntimeError( + f"found multiple devices in {self.__class__.__name__}:" f" {device_set}" + ) + return self.tensor_dicts[0].device + + @property + def batch_size(self) -> torch.Size: + return self._batch_size + + def is_shared(self, no_check: bool = False) -> bool: + are_shared = [td.is_shared(no_check=no_check) for td in self.tensor_dicts] + if any(are_shared) and not all(are_shared): + raise RuntimeError( + f"tensor_dicts shared status mismatch, got {sum(are_shared)} " + f"shared tensor_dicts and " + f"{len(are_shared) - sum(are_shared)} non shared tensordict " + ) + return all(are_shared) + + def is_memmap(self, no_check: bool = False) -> bool: + are_memmap = [td.is_memmap() for td in self.tensor_dicts] + if any(are_memmap) and not all(are_memmap): + raise RuntimeError( + f"tensor_dicts memmap status mismatch, got {sum(are_memmap)} " + f"memmap tensor_dicts and " + f"{len(are_memmap) - sum(are_memmap)} non memmap tensordict " + ) + return all(are_memmap) + + # def is_memmap(self) -> bool: + # return all(td.is_memmap() for td in self.tensor_dicts) + + def get_valid_keys(self) -> Set[str]: + self._update_valid_keys() + return self._valid_keys + + def set_valid_keys(self, keys: Sequence[str]) -> None: + raise RuntimeError( + "setting valid keys is not permitted. valid keys are defined as " + "the intersection of all the key sets from the TensorDicts in a " + "stack and cannot be defined explicitely." + ) + + valid_keys = property(get_valid_keys, set_valid_keys) + + @staticmethod + def _compute_batch_size( + batch_size: torch.Size, stack_dim: int, N: int + ) -> torch.Size: + s = list(batch_size) + s.insert(stack_dim, N) + return torch.Size(s) + + def set( + self, key: str, tensor: COMPATIBLE_TYPES, **kwargs + ) -> _TensorDict: # type: ignore + if self.batch_size != tensor.shape[: self.batch_dims]: + raise RuntimeError( + "Setting tensor to tensordict failed because the shapes " + "mismatch: got tensor.shape = {tensor.shape} and " + "tensordict.batch_size={self.batch_size}" + ) + proc_tensor = self._process_tensor( + tensor, check_device=False, check_tensor_shape=False + ) + proc_tensor = proc_tensor.unbind(self.stack_dim) + for td, _item in zip(self.tensor_dicts, proc_tensor): + td.set(key, _item, **kwargs) + return self + + def set_(self, key: str, tensor: COMPATIBLE_TYPES) -> _TensorDict: + if self.batch_size != tensor.shape[: self.batch_dims]: + raise RuntimeError( + "Setting tensor to tensordict failed because the shapes " + "mismatch: got tensor.shape = {tensor.shape} and " + "tensordict.batch_size={self.batch_size}" + ) + if key not in self.valid_keys: + raise KeyError( + "setting a value in-place on a stack of TensorDict is only " + "permitted if all members of the stack have this key in " + "their register." + ) + tensor = self._process_tensor( + tensor, check_device=False, check_tensor_shape=False + ) + tensor = tensor.unbind(self.stack_dim) + for td, _item in zip(self.tensor_dicts, tensor): + td.set_(key, _item) + return self + + def set_at_( + self, key: str, value: COMPATIBLE_TYPES, idx: INDEX_TYPING + ) -> _TensorDict: + sub_td = self[idx] + sub_td.set_(key, value) + return self + + def get( # type: ignore + self, + key: str, + default: Union[str, COMPATIBLE_TYPES] = "_no_default_", + ) -> COMPATIBLE_TYPES: # type: ignore + if not (key in self.valid_keys): + return self._default_get(key, default) + tensors = [td.get(key, default=default) for td in self.tensor_dicts] + shapes = set(tensor.shape for tensor in tensors) # type: ignore + if len(shapes) != 1: + raise RuntimeError( + f"found more than one unique shape in the tensors to be " + f"stacked ({shapes}). This is likely due to a modification " + f"of one of the stacked TensorDicts, where a key has been " + f"updated/created with an uncompatible shape." + ) + return torch.stack(tensors, self.stack_dim) # type: ignore + + def _get_meta(self, key: str) -> MetaTensor: + if key in self._meta_dict: + return self._meta_dict[key] + if key not in self.valid_keys: + raise KeyError(f"key {key} not found in {list(self._valid_keys)}") + return torch.stack( # type: ignore + [td._get_meta(key) for td in self.tensor_dicts], + self.stack_dim + # type: ignore + ) + + def is_contiguous(self) -> bool: + return False + + def contiguous(self) -> _TensorDict: + return TensorDict( + source={key: value for key, value in self.items()}, + batch_size=self.batch_size, + _meta_source={k: value for k, value in self.items_meta()}, + ) + + def clone(self, recursive: bool = True) -> _TensorDict: + if recursive: + return LazyStackedTensorDict( + *[td.clone() for td in self.tensor_dicts], + stack_dim=self.stack_dim, + ) + return LazyStackedTensorDict( + *[td for td in self.tensor_dicts], stack_dim=self.stack_dim + ) + + def pin_memory(self) -> _TensorDict: + for td in self.tensor_dicts: + td.pin_memory() + return self + + def to(self, dest: Union[DEVICE_TYPING, Type], **kwargs) -> _TensorDict: + if isinstance(dest, type) and issubclass(dest, _TensorDict): + return dest(source=self, batch_size=self.batch_size) # type: ignore + elif isinstance(dest, (torch.device, str, int)): + if not isinstance(dest, torch.device): + dest = torch.device(dest) + if dest == self.device: + return self + tds = [td.to(dest) for td in self.tensor_dicts] + return LazyStackedTensorDict(*tds, stack_dim=self.stack_dim) + else: + raise NotImplementedError( + f"dest must be a string, torch.device or a TensorDict " + f"instance, {dest} not allowed" + ) + + def items(self) -> Iterator[Tuple[str, COMPATIBLE_TYPES]]: # type: ignore + for key in self.keys(): + item = self.get(key) + yield key, item + + def items_meta(self) -> Iterator[Tuple[str, MetaTensor]]: + for key in self.keys(): + item = self._get_meta(key) + yield key, item + + def keys(self) -> Iterator[str]: # type: ignore + for key in self.valid_keys: + yield key + + def _update_valid_keys(self) -> None: + valid_keys = set(self.tensor_dicts[0].keys()) + for td in self.tensor_dicts[1:]: + valid_keys = valid_keys.intersection(td.keys()) + self._valid_keys = valid_keys + + def select(self, *keys: str, inplace: bool = False) -> _TensorDict: + if len(self.valid_keys.intersection(keys)) != len(keys): + raise KeyError( + f"Selected and existing keys mismatch, got self.valid_keys" + f"={self.valid_keys} and keys={keys}" + ) + tensor_dicts = [td.select(*keys, inplace=inplace) for td in self.tensor_dicts] + if inplace: + return self + return LazyStackedTensorDict( + *tensor_dicts, + stack_dim=self.stack_dim, + ) + + def __getitem__(self, item: INDEX_TYPING) -> _TensorDict: + + if isinstance(item, torch.Tensor) and item.dtype == torch.bool: + return self.masked_select(item) + elif ( + isinstance(item, (Number,)) + or (isinstance(item, torch.Tensor) and item.ndimension() == 0) + ) and self.stack_dim == 0: + return self.tensor_dicts[item] + elif isinstance(item, (torch.Tensor, list)) and self.stack_dim == 0: + return LazyStackedTensorDict( + *[self.tensor_dicts[_item] for _item in item], + stack_dim=self.stack_dim, + ) + elif isinstance(item, slice) and self.stack_dim == 0: + return LazyStackedTensorDict( + *self.tensor_dicts[item], stack_dim=self.stack_dim + ) + elif isinstance(item, (slice, Number)): + new_stack_dim = ( + self.stack_dim - 1 if isinstance(item, Number) else self.stack_dim + ) + return LazyStackedTensorDict( + *[td[item] for td in self.tensor_dicts], + stack_dim=new_stack_dim, + ) + elif isinstance(item, tuple): + _sub_item = tuple( + _item for i, _item in enumerate(item) if i == self.stack_dim + ) + if len(_sub_item): + tensor_dicts = self.tensor_dicts[_sub_item[0]] + if isinstance(tensor_dicts, _TensorDict): + return tensor_dicts + else: + tensor_dicts = self.tensor_dicts + # select sub tensor_dicts + _sub_item = tuple( + _item for i, _item in enumerate(item) if i != self.stack_dim + ) + if len(_sub_item): + tensor_dicts = [td[_sub_item] for td in tensor_dicts] + new_stack_dim = self.stack_dim - sum( + [isinstance(_item, Number) for _item in item[: self.stack_dim]] + ) + return torch.stack(list(tensor_dicts), dim=new_stack_dim) # type: ignore + else: + raise NotImplementedError( + f"selecting StackedTensorDicts with type " + f"{item.__class__.__name__} is not supported yet" + ) + + def del_(self, *args, **kwargs) -> _TensorDict: + for td in self.tensor_dicts: + td.del_(*args, **kwargs) + return self + + def share_memory_(self) -> _TensorDict: + for td in self.tensor_dicts: + td.share_memory_() + return self + + def detach_(self) -> _TensorDict: + for td in self.tensor_dicts: + td.detach_() + return self + + def memmap_(self) -> _TensorDict: + for td in self.tensor_dicts: + td.memmap_() + return self + + def expand(self, *shape: int, inplace: bool = False) -> _TensorDict: + stack_dim = self.stack_dim + len(shape) + tensor_dicts = [td.expand(*shape) for td in self.tensor_dicts] + if inplace: + self.tensor_dicts = tensor_dicts + self.stack_dim = stack_dim + return self + return torch.stack(tensor_dicts, stack_dim) # type: ignore + + def update( # type: ignore + self, input_dict_or_td: _TensorDict, clone: bool = False, **kwargs + ) -> _TensorDict: + if input_dict_or_td is self: + # no op + return self + for key, value in input_dict_or_td.items(): + if not isinstance(value, _accepted_classes): + raise TypeError( + f"Expected value to be one of types {_accepted_classes} " + f"but got {type(value)}" + ) + if clone: + value = value.clone() + self.set(key, value, **kwargs) + return self + + def update_( + self, + input_dict_or_td: Union[Dict[str, COMPATIBLE_TYPES], _TensorDict], + clone: bool = False, + **kwargs, + ) -> _TensorDict: + if input_dict_or_td is self: + # no op + return self + for key, value in input_dict_or_td.items(): + if not isinstance(value, _accepted_classes): + raise TypeError( + f"Expected value to be one of types {_accepted_classes} " + f"but got {type(value)}" + ) + if clone: + value = value.clone() + self.set_(key, value, **kwargs) + return self + + def rename_key(self, old_key: str, new_key: str, safe: bool = False) -> _TensorDict: + for td in self.tensor_dicts: + td.rename_key(old_key, new_key, safe=safe) + return self + + def masked_fill_( + self, mask: torch.Tensor, value: Union[float, bool] + ) -> _TensorDict: + mask_unbind = mask.unique(dim=self.stack_dim) + for _mask, td in zip(mask_unbind, self.tensor_dicts): + td.masked_fill_(_mask, value) + return self + + +class SavedTensorDict(_TensorDict): + _safe = False + + def __init__( + self, + source: _TensorDict, + device: Optional[torch.device] = None, + batch_size: Optional[Sequence[int]] = None, + ): + if not isinstance(source, _TensorDict): + raise TypeError( + f"Expected source to be a _TensorDict instance, but got {type(source)} instead." + ) + elif isinstance(source, SavedTensorDict): + source = source._load() + + self.file = tempfile.NamedTemporaryFile() + self.filename = self.file.name + if source.is_memmap(): + source = source.clone() + self._device = ( + torch.device(device) + if device is not None + else source.device + if hasattr(source, "device") + else source[list(source.keys())[0]].device + if len(source) + else torch.device("cpu") + ) + td = source + self._save(td) + if batch_size is not None and batch_size != self.batch_size: + raise RuntimeError("batch_size does not match self.batch_size.") + + def _save(self, tensor_dict: _TensorDict) -> None: + self._keys = list(tensor_dict.keys()) + self._batch_size = tensor_dict.batch_size + self._td_fields = _td_fields(tensor_dict) + self._tensor_dict_meta = {key: value for key, value in tensor_dict.items_meta()} + torch.save(tensor_dict, self.filename) + + def _load(self) -> _TensorDict: + return torch.load(self.filename, self.device) + + def _get_meta(self, key: str) -> MetaTensor: + return self._tensor_dict_meta.get(key) # type: ignore + + @property + def batch_size(self) -> torch.Size: + return self._batch_size + + @property + def device(self) -> torch.device: + return self._device + + def keys(self) -> Sequence[str]: # type: ignore + for k in self._keys: + yield k + + def get( # type: ignore + self, key: str, default: Union[str, COMPATIBLE_TYPES] = "_no_default_" + ) -> COMPATIBLE_TYPES: # type: ignore + td = self._load() + return td.get(key, default=default) + + def set( + self, key: str, value: COMPATIBLE_TYPES, **kwargs + ) -> _TensorDict: # type: ignore + td = self._load() + td.set(key, value, **kwargs) + self._save(td) + return self + + def expand(self, *shape: int, inplace: bool = False) -> _TensorDict: + td = self._load() + td = td.expand(*shape) + if inplace: + self._save(td) + return self + return td.to(SavedTensorDict) + + def set_(self, key: str, value: COMPATIBLE_TYPES) -> _TensorDict: + self.set(key, value) + return self + + def set_at_( + self, key: str, value: COMPATIBLE_TYPES, idx: INDEX_TYPING + ) -> _TensorDict: + td = self._load() + td.set_at_(key, value, idx) + self._save(td) + return self + + def update( # type: ignore + self, + input_dict_or_td: Union[Dict[str, COMPATIBLE_TYPES], _TensorDict], + clone: bool = False, + **kwargs, + ) -> _TensorDict: + if input_dict_or_td is self: + # no op + return self + td = self._load() + for key, value in input_dict_or_td.items(): + if not isinstance(value, _accepted_classes): + raise TypeError( + f"Expected value to be one of types {_accepted_classes} " + f"but got {type(value)}" + ) + if clone: + value = value.clone() + td.set(key, value, **kwargs) + self._save(td) + return self + + def update_( + self, + input_dict_or_td: Union[Dict[str, COMPATIBLE_TYPES], _TensorDict], + clone: bool = False, + ) -> _TensorDict: + if input_dict_or_td is self: + return self + return self.update(input_dict_or_td, clone=clone) + + def __del__(self) -> None: + if hasattr(self, "file"): + self.file.close() + + def is_shared(self, no_check: bool = False) -> bool: + return False + + def is_memmap(self, no_check: bool = False) -> bool: + return False + + def share_memory_(self) -> _TensorDict: + raise RuntimeError("SavedTensorDict cannot be put in shared memory.") + + def memmap_(self) -> _TensorDict: + raise RuntimeError( + "SavedTensorDict and memmap are mutually exclusive features." + ) + + def detach_(self) -> _TensorDict: + raise RuntimeError("SavedTensorDict cannot be put detached.") + + def items(self) -> Iterator[Tuple[str, COMPATIBLE_TYPES]]: # type: ignore + return self._load().items() + + def items_meta(self) -> Iterator[Tuple[str, MetaTensor]]: + return self._tensor_dict_meta.items() # type: ignore + + def is_contiguous(self) -> bool: + return False + + def contiguous(self) -> _TensorDict: + return self._load().contiguous() + + def clone(self, recursive: bool = True) -> _TensorDict: + return SavedTensorDict(self, device=self.device) + + def select(self, *keys: str, inplace: bool = False) -> _TensorDict: + _source = self.contiguous().select(*keys) + if inplace: + self._save(_source) + return self + return SavedTensorDict(source=_source) + + def rename_key(self, old_key: str, new_key: str, safe: bool = False) -> _TensorDict: + td = self._load() + td.rename_key(old_key, new_key, safe=safe) + self._save(td) + return self + + def __repr__(self) -> str: + return ( + f"SavedTensorDict(\n\tfields={{{self._td_fields}}}, \n\t" + f"batch_size={self.batch_size}, \n\tfile={self.filename})" + ) + + def to(self, dest: Union[DEVICE_TYPING, Type], **kwargs): + if isinstance(dest, type) and issubclass(dest, _TensorDict): + td = dest( + source=TensorDict(self.to_dict(), batch_size=self.batch_size), + **kwargs, + ) + return td + elif isinstance(dest, (torch.device, str, int)): + # must be device + if not isinstance(dest, torch.device): + dest = torch.device(dest) + if dest == self.device: + return self + self_copy = copy(self) + self_copy._device = dest + for k, item in self.items_meta(): + item.device = dest + return self_copy + else: + raise NotImplementedError( + f"dest must be a string, torch.device or a TensorDict " + f"instance, {dest} not allowed" + ) + + def del_(self, key: str) -> _TensorDict: + td = self._load() + td = td.del_(key) + self._save(td) + return self + + def pin_memory(self) -> _TensorDict: + raise RuntimeError("pin_memory requires tensordicts that live in memory.") + + def __reduce__(self, *args, **kwargs): + if hasattr(self, "file"): + file = self.file + del self.file + self_copy = copy(self) + self.file = file + return super(SavedTensorDict, self_copy).__reduce__(*args, **kwargs) + return super().__reduce__(*args, **kwargs) + + def __getitem__(self, idx: INDEX_TYPING) -> _TensorDict: + if isinstance(idx, Number): + idx = (idx,) + elif isinstance(idx, torch.Tensor) and idx.dtype == torch.bool: + return self.masked_select(idx) + if not self.batch_size: + raise IndexError( + "indexing a tensordict with td.batch_dims==0 is not permitted" + ) + return self.get_sub_tensor_dict(idx) + + def masked_fill_( + self, mask: torch.Tensor, value: Union[float, bool] + ) -> _TensorDict: + td = self._load() + td.masked_fill_(mask, value) + self._save(td) + return self + + +class _CustomOpTensorDict(_TensorDict): + def __init__( + self, + source: _TensorDict, + custom_op: str, + inv_op: Optional[str] = None, + custom_op_kwargs: Optional[dict] = None, + inv_op_kwargs: Optional[dict] = None, + batch_size: Optional[Sequence[int]] = None, + ): + """Encodes lazy operations on tensors contained in a TensorDict.""" + + if not isinstance(source, _TensorDict): + raise TypeError( + f"Expected source to be a _TensorDict isntance, " + f"but got {type(source)} instead." + ) + self._source = source + self.custom_op = custom_op + self.inv_op = inv_op + self.custom_op_kwargs = custom_op_kwargs if custom_op_kwargs is not None else {} + self.inv_op_kwargs = inv_op_kwargs if inv_op_kwargs is not None else {} + if batch_size is not None and batch_size != self.batch_size: + raise RuntimeError("batch_size does not match self.batch_size.") + + def _update_custom_op_kwargs(self, source_meta_tensor: MetaTensor) -> dict: + """Allows for a transformation to be customized for a certain shape, + device or dtype. By default, this is a no-op on self.custom_op_kwargs + + Args: + source_meta_tensor: corresponding MetaTensor + + Returns: + a dictionary with the kwargs of the operation to execute + for the tensor + + """ + return self.custom_op_kwargs + + def _update_inv_op_kwargs(self, source_meta_tensor: MetaTensor) -> dict: + """Allows for an inverse transformation to be customized for a + certain shape, device or dtype. + + By default, this is a no-op on self.inv_op_kwargs + + Args: + source_meta_tensor: corresponding MetaTensor + + Returns: + a dictionary with the kwargs of the operation to execute for + the tensor + + """ + return self.inv_op_kwargs + + @property + def device(self) -> torch.device: + return self._source.device + + def _get_meta(self, key: str) -> MetaTensor: + item = self._source._get_meta(key) + return getattr(item, self.custom_op)(**self._update_custom_op_kwargs(item)) + + def items_meta(self) -> Iterator[Tuple[str, MetaTensor]]: + for key, value in self._source.items_meta(): + yield key, self._get_meta(key) + + def items(self) -> Iterator[Tuple[str, COMPATIBLE_TYPES]]: # type: ignore + for key in self._source.keys(): + yield key, self.get(key) + + @property + def batch_size(self) -> torch.Size: + return getattr(MetaTensor(*self._source.batch_size), self.custom_op)( + **self.custom_op_kwargs + ).shape + + def get( # type: ignore + self, + key: str, + default: Union[str, COMPATIBLE_TYPES] = "_no_default_", + _return_original_tensor: bool = False, + ) -> COMPATIBLE_TYPES: # type: ignore + try: + source_meta_tensor = self._source._get_meta(key) + item = self._source.get(key) + transformed_tensor = getattr(item, self.custom_op)( + **self._update_custom_op_kwargs(source_meta_tensor) + ) + if not _return_original_tensor: + return transformed_tensor + return transformed_tensor, item # type: ignore + except KeyError: + if _return_original_tensor: + raise RuntimeError( + "_return_original_tensor not compatible with get(..., " + "default=smth)" + ) + return self._default_get(key, default) + + def set( + self, key: str, value: COMPATIBLE_TYPES, **kwargs + ) -> _TensorDict: # type: ignore + if self.inv_op is None: + raise Exception( + f"{self.__class__.__name__} does not support setting values. " + f"Consider calling .contiguous() before calling this method." + ) + proc_value = self._process_tensor( + value, check_device=False, check_tensor_shape=False + ) + if key in self.keys(): + source_meta_tensor = self._source._get_meta(key) + else: + source_meta_tensor = MetaTensor( + *proc_value.shape, + device=proc_value.device, + dtype=proc_value.dtype, + ) + proc_value = getattr(proc_value, self.inv_op)( + **self._update_inv_op_kwargs(source_meta_tensor) + ) + self._source.set(key, proc_value, **kwargs) + return self + + def set_(self, key: str, value: COMPATIBLE_TYPES) -> _CustomOpTensorDict: + if self.inv_op is None: + raise Exception( + f"{self.__class__.__name__} does not support setting values. " + f"Consider calling .contiguous() before calling this method." + ) + meta_tensor = self._source._get_meta(key) + value = getattr(value, self.inv_op)(**self._update_inv_op_kwargs(meta_tensor)) + self._source.set_(key, value) + return self + + def set_at_( + self, key: str, value: COMPATIBLE_TYPES, idx: INDEX_TYPING + ) -> _CustomOpTensorDict: + transformed_tensor, original_tensor = self.get( # type: ignore + key, _return_original_tensor=True + ) + if transformed_tensor.data_ptr() != original_tensor.data_ptr(): + raise RuntimeError( + f"{self} original tensor and transformed do not point to the " + f"same storage. Setting values in place is not currently " + f"supported in this setting, consider calling " + f"`td.clone()` before `td.set_at_(...)`" + ) + transformed_tensor[idx] = value # type: ignore + return self + + def __repr__(self) -> str: + custom_op_kwargs_str = ", ".join( + [f"{key}={value}" for key, value in self.custom_op_kwargs.items()] + ) + indented_source = textwrap.indent(f"source={self._source}", "\t") + return ( + f"{self.__class__.__name__}(\n{indented_source}, " + f"\n\top={self.custom_op}({custom_op_kwargs_str}))" + ) + + def keys(self) -> KeysView: + return self._source.keys() + + def select(self, *keys: str, inplace: bool = False) -> _CustomOpTensorDict: + if inplace: + self._source.select(*keys, inplace=inplace) + return self + try: + return type(self)( + source=self._source.select(*keys), + custom_op=self.custom_op, + inv_op=self.inv_op, + custom_op_kwargs=self.custom_op_kwargs, + inv_op_kwargs=self.inv_op_kwargs, + ) + except TypeError: + self_copy = deepcopy(self) + self_copy._source = self._source.select(*keys) + return self_copy + + def clone(self, recursive: bool = True) -> _TensorDict: + if not recursive: + return copy(self) + return TensorDict( + source=self.to_dict(), + batch_size=self.batch_size, + ) + + def is_contiguous(self) -> bool: + return all([value.is_contiguous() for _, value in self.items()]) # type: ignore + + def contiguous(self) -> _TensorDict: + if self.is_contiguous(): + return self + return self.to(TensorDict) + + def rename_key( + self, old_key: str, new_key: str, safe: bool = False + ) -> _CustomOpTensorDict: + self._source.rename_key(old_key, new_key, safe=safe) + return self + + def del_(self, key: str) -> _CustomOpTensorDict: + self._source = self._source.del_(key) + return self + + def to(self, dest: Union[DEVICE_TYPING, Type], **kwargs) -> _TensorDict: + if isinstance(dest, type) and issubclass(dest, _TensorDict): + return dest(source=self.contiguous().clone()) + elif isinstance(dest, (torch.device, str, int)): + if torch.device(dest) == self.device: + return self + td = self._source.to(dest) + self_copy = copy(self) + self_copy._source = td + return self_copy + else: + raise NotImplementedError( + f"dest must be a string, torch.device or a TensorDict " + f"instance, {dest} not allowed" + ) + + def pin_memory(self) -> _TensorDict: + self._source.pin_memory() + return self + + def detach_(self): + self._source.detach_() + + def masked_fill_( + self, mask: torch.Tensor, value: Union[float, bool] + ) -> _TensorDict: + for key, item in self.items(): + source_meta_tensor = self._get_meta(key) + mask_proc_inv = getattr(mask, self.inv_op)( + **self._update_inv_op_kwargs(source_meta_tensor) + ) + val = self._source.get(key) + val[mask_proc_inv] = value + self._source.set(key, val) + return self + + def memmap_(self): + self._source.memmap_() + + def share_memory_(self): + self._source.share_memory_() + + +class UnsqueezedTensorDict(_CustomOpTensorDict): + """A lazy view on an unsqueezed TensorDict. + + When calling `tensordict.unsqueeze(dim)`, a lazy view of this operation is + returned such that the following code snippet works without raising an + exception: + + >>> assert tensordict.unsqueeze(dim).squeeze(dim) is tensordict + + Examples: + >>> from torchrl.data import TensorDict + >>> import torch + >>> td = TensorDict({'a': torch.randn(3, 4)}, batch_size=[3]) + >>> td_unsqueeze = td.unsqueeze(-1) + >>> print(td_unsqueeze.shape) + torch.Size([3, 1]) + >>> print(td_unsqueeze.squeeze(-1) is td) + True + """ + + def squeeze(self, dim: int) -> _TensorDict: + if dim < 0: + dim = self.batch_dims + dim + if dim == self.custom_op_kwargs.get("dim"): + return self._source + return super().squeeze(dim) + + +class SqueezedTensorDict(_CustomOpTensorDict): + """ + A lazy view on a squeezed TensorDict. + See the `UnsqueezedTensorDict` class documentation for more information. + """ + + def unsqueeze(self, dim: int) -> _TensorDict: + if dim < 0: + dim = self.batch_dims + dim + 1 + if dim == self.inv_op_kwargs.get("dim"): + return self._source + return super().unsqueeze(dim) + + +class ViewedTensorDict(_CustomOpTensorDict): + def _update_custom_op_kwargs(self, source_meta_tensor: MetaTensor) -> dict: + new_dim_list = list(self.custom_op_kwargs.get("size")) # type: ignore + new_dim_list += list( + source_meta_tensor.shape[self._source.batch_dims :] + ) # type: ignore + new_dim = torch.Size(new_dim_list) + new_dict = deepcopy(self.custom_op_kwargs) + new_dict.update({"size": new_dim}) + return new_dict + + def _update_inv_op_kwargs(self, source_meta_tensor: MetaTensor) -> Dict: + size = list(self.inv_op_kwargs.get("size")) # type: ignore + size += list( + source_meta_tensor.shape[self._source.batch_dims :] + ) # type: ignore + new_dim = torch.Size(size) + new_dict = deepcopy(self.inv_op_kwargs) + new_dict.update({"size": new_dim}) + return new_dict + + def view( + self, *shape, size: Optional[Union[List, Tuple, torch.Size]] = None + ) -> _TensorDict: + if len(shape) == 0 and size is not None: + return self.view(*size) + elif len(shape) == 1 and isinstance(shape[0], (list, tuple, torch.Size)): + return self.view(*shape[0]) + elif not isinstance(shape, torch.Size): + shape = torch.Size(shape) + if shape == self._source.batch_size: + return self._source + return super().view(*shape) + + +def _td_fields(td: _TensorDict) -> str: + return indent( + "\n" + + ",\n".join( + [ + f"{key}: {item.class_name}({item.shape}, dtype={item.dtype})" + for key, item in td.items_meta() + ] + ), + 4 * " ", + ) + + +def _check_keys( + list_of_tensor_dicts: Sequence[_TensorDict], strict: bool = False +) -> Set[str]: + keys: Set[str] = set() + for td in list_of_tensor_dicts: + if not len(keys): + keys = set(td.keys()) + else: + if not strict: + keys = keys.intersection(set(td.keys())) + else: + if len(set(td.keys()).difference(keys)) or len(set(td.keys())) != len( + keys + ): + raise KeyError( + f"got keys {keys} and {set(td.keys())} which are " + f"incompatible" + ) + return keys diff --git a/torchrl/data/tensordict/utils.py b/torchrl/data/tensordict/utils.py new file mode 100644 index 00000000000..72bc2953149 --- /dev/null +++ b/torchrl/data/tensordict/utils.py @@ -0,0 +1,79 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from numbers import Number + +import torch + +from torchrl.data.utils import INDEX_TYPING + + +def _sub_index(tensor: torch.Tensor, idx: INDEX_TYPING) -> torch.Tensor: + """Allows indexing of tensors with nested tuples, i.e. + tensor[tuple1][tuple2] can be indexed via _sub_index(tensor, (tuple1, + tuple2)) + """ + if isinstance(idx, tuple) and len(idx) and isinstance(idx[0], tuple): + idx0 = idx[0] + idx1 = idx[1:] + return _sub_index(_sub_index(tensor, idx0), idx1) + return tensor[idx] + + +def _getitem_batch_size( + shape: torch.Size, + items: INDEX_TYPING, +): + """ + Given an input shape and an index, returns the size of the resulting + indexed tensor. + + This function is aimed to be used when indexing is an + expensive operation. + Args: + shape: Input shape + items: Index of the hypothetical tensor + + Returns: + + """ + if not isinstance(items, tuple): + items = (items,) + bs = [] + iter_bs = iter(shape) + if all(isinstance(_item, torch.Tensor) for _item in items) and len(items) == len( + shape + ): + shape0 = items[0].shape + for _item in items[1:]: + if _item.shape != shape0: + raise RuntimeError( + f"all tensor indices must have the same shape, " + f"got {_item.shape} and {shape0}" + ) + return shape0 + + for _item in items: + if isinstance(_item, slice): + batch = next(iter_bs) + v = len(range(*_item.indices(batch))) + elif isinstance(_item, (list, torch.Tensor)): + batch = next(iter_bs) + v = len(_item) + elif _item is None: + v = 1 + elif isinstance(_item, Number): + batch = next(iter_bs) + continue + else: + raise NotImplementedError( + f"batch dim cannot be computed for type {type(_item)}" + ) + bs.append(v) + list_iter_bs = list(iter_bs) + bs += list_iter_bs + return torch.Size(bs) diff --git a/torchrl/data/utils.py b/torchrl/data/utils.py new file mode 100644 index 00000000000..f36aaebfa57 --- /dev/null +++ b/torchrl/data/utils.py @@ -0,0 +1,114 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Callable, List, Sequence, Tuple, Union + +import numpy as np +import torch +from torch import Tensor + +numpy_to_torch_dtype_dict = { + np.dtype("bool"): torch.bool, + np.dtype("uint8"): torch.uint8, + np.dtype("int8"): torch.int8, + np.dtype("int16"): torch.int16, + np.dtype("int32"): torch.int32, + np.dtype("int64"): torch.int64, + np.dtype("float16"): torch.float16, + np.dtype("float32"): torch.float32, + np.dtype("float64"): torch.float64, + np.dtype("complex64"): torch.complex64, + np.dtype("complex128"): torch.complex128, +} +torch_to_numpy_dtype_dict = { + value: key for key, value in numpy_to_torch_dtype_dict.items() +} +DEVICE_TYPING = Union[torch.device, str] # , int] + +INDEX_TYPING = Union[None, int, slice, Tensor, List[Any], Tuple[Any, ...]] + + +class CloudpickleWrapper(object): + def __init__(self, fn: Callable): + if fn.__class__.__name__ == "EnvCreator": + raise RuntimeError( + "CloudpickleWrapper usage with EnvCreator class is " + "prohibited as it breaks the transmission of shared tensors." + ) + self.fn = fn + + def __getstate__(self): + import cloudpickle + + return cloudpickle.dumps(self.fn) + + def __setstate__(self, ob: bytes): + import pickle + + self.fn = pickle.loads(ob) + + def __call__(self, **kwargs) -> Any: + return self.fn(**kwargs) + + +def expand_as_right( + tensor: Union[torch.Tensor, "MemmapTensor"], + dest: Union[torch.Tensor, "MemmapTensor"], +): + """Expand a tensor on the right to match another tensor shape. + Args: + tensor: tensor to be expanded + dest: tensor providing the target shape + + Returns: + a tensor with shape matching the dest input tensor shape. + + Examples: + >>> tensor = torch.zeros(3,4) + >>> dest = torch.zeros(3,4,5) + >>> print(expand_as_right(tensor, dest).shape) + torch.Size([3,4,5]) + """ + + if dest.ndimension() < tensor.ndimension(): + raise RuntimeError( + "expand_as_right requires the destination tensor to have less " + f"dimensions than the input tensor, got" + f" tensor.ndimension()={tensor.ndimension()} and " + f"dest.ndimension()={dest.ndimension()}" + ) + if not (tensor.shape == dest.shape[: tensor.ndimension()]): + raise RuntimeError( + f"tensor shape is incompatible with dest shape, " + f"got: tensor.shape={tensor.shape}, dest={dest.shape}" + ) + for _ in range(dest.ndimension() - tensor.ndimension()): + tensor = tensor.unsqueeze(-1) + return tensor.expand_as(dest) + + +def expand_right( + tensor: Union[torch.Tensor, "MemmapTensor"], shape: Sequence[int] +) -> torch.Tensor: + """Expand a tensor on the right to match a desired shape. + Args: + tensor: tensor to be expanded + shape: target shape + + Returns: + a tensor with shape matching the target shape. + + Examples: + >>> tensor = torch.zeros(3,4) + >>> shape = (3,4,5) + >>> print(expand_right(tensor, shape).shape) + torch.Size([3,4,5]) + """ + + tensor_expand = tensor + while tensor_expand.ndimension() < len(shape): + tensor_expand = tensor_expand.unsqueeze(-1) + tensor_expand = tensor_expand.expand(*shape) + return tensor_expand diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py new file mode 100644 index 00000000000..4ae0ebd4599 --- /dev/null +++ b/torchrl/envs/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .common import * +from .libs import * +from .vec_env import * +from .transforms import * diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py new file mode 100644 index 00000000000..78f071380fa --- /dev/null +++ b/torchrl/envs/common.py @@ -0,0 +1,654 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import math +from collections import OrderedDict +from numbers import Number +from typing import Any, Callable, Iterator, Optional, Tuple, Union + +import numpy as np +import torch + +from torchrl.data import CompositeSpec, TensorDict +from ..data.tensordict.tensordict import _TensorDict +from ..data.utils import DEVICE_TYPING +from .utils import get_available_libraries, step_tensor_dict + +LIBRARIES = get_available_libraries() + + +def _tensor_to_np(t): + return t.detach().cpu().numpy() + + +dtype_map = { + torch.float: np.float32, + torch.double: np.float64, + torch.bool: bool, +} + +__all__ = ["Specs", "GymLikeEnv", "make_tensor_dict"] + + +class Specs: + """Container for action, observation and reward specs. + + This class allows one to create an environment, retrieve all of the specs + in a single data container (and access them in one place) before erasing + the environment from the workspace. + + Args: + env (_EnvClass): environment from which the specs have to be read. + + """ + + _keys = {"action_spec", "observation_spec", "reward_spec", "from_pixels"} + + def __init__(self, env: _EnvClass): + self.env = env + + def __getitem__(self, item: str) -> Any: + if item not in self._keys: + raise KeyError(f"item must be one of {self._keys}") + return getattr(self.env, item) + + def keys(self) -> dict: + return self._keys + + def build_tensor_dict( + self, next_observation: bool = True, log_prob: bool = False + ) -> _TensorDict: + """returns a TensorDict with empty tensors of the desired shape""" + # build a tensordict from specs + td = TensorDict({}, batch_size=torch.Size([])) + action_placeholder = torch.zeros( + self["action_spec"].shape, dtype=self["action_spec"].dtype + ) + if not isinstance(self["observation_spec"], CompositeSpec): + observation_placeholder = torch.zeros( + self["observation_spec"].shape, + dtype=self["observation_spec"].dtype, + ) + td.set("observation", observation_placeholder) + else: + for i, key in enumerate(self["observation_spec"]): + item = self["observation_spec"][key] + observation_placeholder = torch.zeros(item.shape, dtype=item.dtype) + td.set(f"observation_{key}", observation_placeholder) + if next_observation: + td.set( + f"next_observation_{key}", + observation_placeholder.clone(), + ) + + reward_placeholder = torch.zeros( + self["reward_spec"].shape, dtype=self["reward_spec"].dtype + ) + done_placeholder = torch.zeros_like(reward_placeholder, dtype=torch.bool) + + td.set("action", action_placeholder) + td.set("reward", reward_placeholder) + + if log_prob: + td.set( + "log_prob", + torch.zeros_like(reward_placeholder, dtype=torch.float32), + ) # we assume log_prob to be of type float32 + td.set("done", done_placeholder) + return td + + +class _EnvClass: + """ + Abstract environment parent class for TorchRL. + + Properties: + - observation_spec (TensorSpec): sampling spec of the observations; + - action_spec (TensorSpec): sampling spec of the actions; + - reward_spec (TensorSpec): sampling spec of the rewards; + - batch_size (torch.Size): number of environments contained in the instance; + - device (torch.device): device where the env input and output are expected to live + - is_done (torch.Tensor): boolean value(s) indicating if the environment has reached a done state since the + last reset + - current_tensordict (_TensorDict): last tensordict returned by `reset` or `step`. + + Methods: + step (_TensorDict -> _TensorDict): step in the environment + reset (_TensorDict, optional -> _TensorDict): reset the environment + set_seed (int -> int): sets the seed of the environment + rand_step (_TensorDict, optional -> _TensorDict): random step given the action spec + rollout (Callable, ... -> _TensorDict): executes a rollout in the environment with the given policy (or random + steps if no policy is provided) + + """ + + action_spec = None + reward_spec = None + observation_spec = None + from_pixels: bool + device = torch.device("cpu") + batch_size = torch.Size([]) + + def __init__( + self, + device: DEVICE_TYPING = "cpu", + dtype: Optional[Union[torch.dtype, np.dtype]] = None, + ): + self.device = device + self.dtype = dtype_map.get(dtype, dtype) + self._is_done = torch.zeros(self.batch_size, device=device) + self._cache = dict() + + def step(self, tensor_dict: _TensorDict) -> _TensorDict: + """Makes a step in the environment. + Step accepts a single argument, tensor_dict, which usually carries an 'action' key which indicates the action + to be taken. + Step will call an out-place private method, _step, which is the method to be re-written by _EnvClass subclasses. + + Args: + tensor_dict (_TensorDict): Tensordict containing the action to be taken. + + Returns: + the input tensor_dict, modified in place with the resulting observations, done state and reward + (+ others if needed). + + """ + + # sanity check + if tensor_dict.get("action").dtype is not self.action_spec.dtype: + raise TypeError( + f"expected action.dtype to be {self.action_spec.dtype} " + f"but got {tensor_dict.get('action').dtype}" + ) + + tensor_dict_out = self._step(tensor_dict) + + if tensor_dict_out is tensor_dict: + raise RuntimeError( + "_EnvClass._step should return outplace changes to the input " + "tensordict. Consider emptying the TensorDict first (e.g. tensordict.empty() or " + "tensordict.select()) inside _step before writing new tensors onto this new instance." + ) + self.is_done = tensor_dict_out.get("done") + self._current_tensordict = step_tensor_dict(tensor_dict_out) + + for key in self._select_observation_keys(tensor_dict_out): + obs = tensor_dict_out.get(key) + self.observation_spec.type_check(obs, key) + + if tensor_dict_out.get("reward").dtype is not self.reward_spec.dtype: + raise TypeError( + f"expected reward.dtype to be {self.reward_spec.dtype} " + f"but got {tensor_dict_out.get('reward').dtype}" + ) + + if tensor_dict_out.get("done").dtype is not torch.bool: + raise TypeError( + f"expected done.dtype to be torch.bool but got {tensor_dict_out.get('done').dtype}" + ) + tensor_dict.update(tensor_dict_out, inplace=True) + + del tensor_dict_out + return tensor_dict + + def state_dict(self, destination: Optional[OrderedDict] = None) -> OrderedDict: + if destination is not None: + return destination + return OrderedDict() + + def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None: + pass + + def eval(self) -> _EnvClass: + return self + + def train(self, mode: bool = True) -> _EnvClass: + return self + + def _step( + self, + tensor_dict: _TensorDict, + ) -> _TensorDict: + raise NotImplementedError + + def _reset(self, tensor_dict: _TensorDict, **kwargs) -> _TensorDict: + raise NotImplementedError + + def reset(self, tensor_dict: Optional[_TensorDict] = None, **kwargs) -> _TensorDict: + """Resets the environment. + As for step and _step, only the private method `_reset` should be overwritten by _EnvClass subclasses. + + Args: + tensor_dict (_TensorDict, optional): tensor_dict to be used to contain the resulting new observation. + In some cases, this input can also be used to pass argument to the reset function. + kwargs (optional): other arguments to be passed to the native + reset function. + Returns: + a tensor_dict (or the input tensor_dict, if any), modified in place with the resulting observations. + + """ + # if tensor_dict is None: + # tensor_dict = self.specs.build_tensor_dict() + if tensor_dict is None: + tensor_dict = TensorDict({}, device=self.device, batch_size=self.batch_size) + tensor_dict_reset = self._reset(tensor_dict, **kwargs) + if tensor_dict_reset is tensor_dict: + raise RuntimeError( + "_EnvClass._reset should return outplace changes to the input " + "tensordict. Consider emptying the TensorDict first (e.g. tensordict.empty() or " + "tensordict.select()) inside _reset before writing new tensors onto this new instance." + ) + + self._current_tensordict = tensor_dict_reset + self.is_done = tensor_dict_reset.get( + "done", + torch.zeros(self.batch_size, dtype=torch.bool, device=self.device), + ) + if self.is_done: + raise RuntimeError( + f"Env {self} was done after reset. This is (currently) not allowed." + ) + if tensor_dict is not None: + tensor_dict.update(tensor_dict_reset) + else: + tensor_dict = tensor_dict_reset + del tensor_dict_reset + return tensor_dict + + @property + def current_tensordict(self) -> _TensorDict: + """Returns the last tensordict encountered after calling `reset` or `step`.""" + try: + return self._current_tensordict + except AttributeError: + print( + f"env {self} does not have a _current_tensordict attribute. Consider calling reset() before step()." + ) + + def numel(self) -> int: + return math.prod(self.batch_size) + + def set_seed(self, seed: int) -> int: + """Sets the seed of the environment and returns the last seed used ( + which is the input seed if a single environment is present) + + Args: + seed: integer + + Returns: + integer representing the "final seed" in case the environment has + a non-empty batch. This feature makes sure that the same seed + won't be used for two different environments. + + """ + raise NotImplementedError + + def set_state(self): + raise NotImplementedError + + def _assert_tensordict_shape(self, tensor_dict: _TensorDict) -> None: + if tensor_dict.batch_size != self.batch_size: + raise RuntimeError( + f"Expected a tensor_dict with shape==env.shape, " + f"got {tensor_dict.batch_size} and {self.batch_size}" + ) + + def is_done_get_fn(self) -> bool: + return self._is_done.all() + + def is_done_set_fn(self, val: bool) -> None: + self._is_done = val + + is_done = property(is_done_get_fn, is_done_set_fn) + + def rand_step(self, tensor_dict: Optional[_TensorDict] = None) -> _TensorDict: + """Performs a random step in the environment given the action_spec attribute. + + Args: + tensor_dict (_TensorDict, optional): tensordict where the resulting info should be written. + + Returns: + a tensordict object with the new observation after a random step in the environment. The action will + be stored with the "action" key. + + """ + if tensor_dict is None: + tensor_dict = self.current_tensordict + action = self.action_spec.rand(self.batch_size) + tensor_dict.set("action", action) + return self.step(tensor_dict) + + @property + def specs(self) -> Specs: + """ + + Returns a Specs container where all the environment specs are contained. + This feature allows one to create an environment, retrieve all of the specs in a single data container and then + erase the environment from the workspace. + + """ + return Specs(self) + + def rollout( + self, + policy: Optional[Callable[[_TensorDict], _TensorDict]] = None, + n_steps: int = 1, + callback: Optional[Callable[[_TensorDict, ...], _TensorDict]] = None, + auto_reset: bool = True, + ) -> _TensorDict: + """ + + Args: + policy (callable, optional): callable to be called to compute the desired action. If no policy is provided, + actions will be called using `env.rand_step()` + default = None + n_steps (int, optional): maximum number of steps to be executed. The actual number of steps can be smaller if + the environment reaches a done state before n_steps have been executed. + default = 1 + callback (callable, optional): function to be called at each iteration with the given TensorDict. + auto_reset (bool): if True, resets automatically the environment if it is in a done state when the rollout + is initiated. + default = True. + + Returns: + TensorDict object containing the resulting trajectory. + + """ + try: + policy_device = next(policy.parameters()).device + except AttributeError: + policy_device = "cpu" + + if auto_reset: + tensor_dict = self.reset() + else: + # tensor_dict = ( + # self.specs.build_tensor_dict().expand(*self.batch_size).contiguous() + # ) + tensor_dict = self.current_tensordict.clone() + + if policy is None: + + def policy(td): + return td.set("action", self.action_spec.rand(self.batch_size)) + + tensor_dicts = [] + if not self.is_done: + for i in range(n_steps): + td = tensor_dict.to(policy_device) + td = policy(td) + tensor_dict = td.to("cpu") + + tensor_dict = self.step(tensor_dict.clone()) + + tensor_dicts.append(tensor_dict.clone()) + if tensor_dict.get("done").all() or i == n_steps - 1: + break + tensor_dict = step_tensor_dict(tensor_dict) + + if callback is not None: + callback(self, tensor_dict) + else: + raise Exception("reset env before calling rollout!") + + out_td = torch.stack(tensor_dicts, len(self.batch_size)) + return out_td + + def _select_observation_keys(self, tensor_dict: _TensorDict) -> Iterator[str]: + for key in tensor_dict.keys(): + if key.rfind("observation") >= 0: + yield key + + def _to_tensor( + self, + value: Union[dict, bool, float, torch.Tensor, np.ndarray], + device: Optional[DEVICE_TYPING] = None, + dtype: Optional[torch.dtype] = None, + ) -> Union[torch.Tensor, dict]: + + if isinstance(value, dict): + return { + _key: self._to_tensor(_value, dtype=dtype, device=device) + for _key, _value in value.items() + } + elif isinstance(value, (bool, Number)): + value = np.array(value) + + if dtype is None and self.dtype is not None: + dtype = self.dtype + elif dtype is not None: + dtype = dtype_map.get(dtype, dtype) + else: + dtype = value.dtype + + if device is None: + device = self.device + + if not isinstance(value, torch.Tensor): + if dtype is not None: + try: + value = value.astype(dtype) + except TypeError: + raise Exception( + "dtype must be a numpy-compatible dtype. Got {dtype}" + ) + value = torch.from_numpy(value) + if device != "cpu": + value = value.to(device) + else: + value = value.to(device) + # if dtype is not None: + # value = value.to(dtype) + return value + + def close(self): + pass + + def __del__(self): + self.close() + + +class _EnvWrapper(_EnvClass): + """Abstract environment wrapper class. + + Unlike _EnvClass, _EnvWrapper comes with a `_build_env` private method that will be called upon instantiation. + Interfaces with other libraries should be coded using _EnvWrapper. + """ + + git_url: str = "" + available_envs: dict = {} + libname: str = "" + + def __init__( + self, + envname: str, + taskname: str = "", + frame_skip: int = 1, + dtype: Optional[np.dtype] = None, + device: DEVICE_TYPING = "cpu", + seed: Optional[int] = None, + **kwargs, + ): + super().__init__( + device=device, + dtype=dtype, + ) + self.envname = envname + self.taskname = taskname + + self.frame_skip = frame_skip + self.wrapper_frame_skip = frame_skip # this value can be changed if frame_skip is passed during env construction + + self.constructor_kwargs = kwargs + if not ( + (envname in self.available_envs) + and ( + taskname in self.available_envs[envname] + if isinstance(self.available_envs, dict) + else True + ) + ): + raise RuntimeError( + f"{envname} with task {taskname} is unknown in {self.libname}" + ) + self._build_env(envname, taskname, **kwargs) # writes the self._env attribute + self._init_env(seed=seed) # runs all the steps to have a ready-to-use env + + def _init_env(self, seed: Optional[int] = None) -> Optional[int]: + """Runs all the necessary steps such that the environment is ready to use. + + This step is intended to ensure that a seed is provided to the environment (if needed) and that the environment + is reset (if needed). For instance, DMControl envs require the env to be reset before being used, but Gym envs + don't. + + Args: + seed (int, optional): seed to be set, if any. + + Returns: + the resulting seed + + """ + + raise NotImplementedError + + def _build_env( + self, envname: str, taskname: Optional[str] = None, **kwargs + ) -> None: + """Creates an environment from the target library and stores it with the `_env` attribute. + + When overwritten, this function should pass all the required kwargs to the env instantiation method. + + Args: + envname (str): name of the environment + taskname: (str, optional): task to be performed, if any. + + + """ + raise NotImplementedError + + def close(self) -> None: + """Closes the contained environment if possible.""" + self.is_closed = True + try: + self._env.close() + except AttributeError: + pass + + +class GymLikeEnv(_EnvWrapper): + """ + A gym-like env is an environment. + + + A `GymLikeEnv` has a `.step()` method with the following signature: + + ``env.step(action: np.ndarray) -> Tuple[Union[np.ndarray, dict], double, bool, *info]`` + + where the outputs are the observation, reward and done state respectively. + In this implementation, the info output is discarded. + + By default, the first output is written at the "next_observation" key-value pair in the output tensordict, unless + the first output is a dictionary. In that case, each observation output will be put at the corresponding + "next_observation_{key}" location. + + It is also expected that env.reset() returns an observation similar to the one observed after a step is completed. + """ + + def _step(self, tensor_dict: _TensorDict) -> _TensorDict: + action = tensor_dict.get("action") + action_np = self.action_spec.to_numpy(action) + + reward = 0.0 + for _ in range(self.wrapper_frame_skip): + obs, _reward, done, *info = self._output_transform( + self._env.step(action_np) + ) + if _reward is None: + _reward = 0.0 + reward += _reward + if done: + break + + obs_dict = self._read_obs(obs) + + if reward is None: + reward = np.nan + reward = self._to_tensor(reward, dtype=self.reward_spec.dtype) + done = self._to_tensor(done, dtype=torch.bool) + self._is_done = done + self._current_tensordict = obs_dict + + tensor_dict_out = TensorDict({}, batch_size=tensor_dict.batch_size) + for key, value in obs_dict.items(): + tensor_dict_out.set(f"next_{key}", value) + tensor_dict_out.set("reward", reward) + tensor_dict_out.set("done", done) + return tensor_dict_out + + def set_seed(self, seed: Optional[int] = None) -> Optional[int]: + if seed is not None: + torch.manual_seed(seed) + return self._set_seed(seed) + + def _set_seed(self, seed: Optional[int]) -> Optional[int]: + raise NotImplementedError + + def _reset( + self, tensor_dict: Optional[_TensorDict] = None, **kwargs + ) -> _TensorDict: + obs, *_ = self._output_transform((self._env.reset(**kwargs),)) + tensor_dict_out = TensorDict( + source=self._read_obs(obs), batch_size=self.batch_size + ) + self._is_done = torch.zeros(1, dtype=torch.bool) + tensor_dict_out.set("done", self._is_done) + return tensor_dict_out + + def _read_obs(self, observations: torch.Tensor) -> dict: + observations = self.observation_spec.encode(observations) + if isinstance(observations, dict): + obs_dict = {f"observation_{key}": obs for key, obs in observations.items()} + else: + obs_dict = {"observation": observations} + obs_dict = self._to_tensor(obs_dict) + return obs_dict + + def _output_transform(self, step_outputs_tuple: Tuple) -> Tuple: + """To be overwritten when step_outputs differ from Tuple[Observation: Union[np.ndarray, dict], reward: Number, done:Bool]""" + if not isinstance(step_outputs_tuple, tuple): + raise TypeError( + f"Expected step_outputs_tuple type to be Tuple but got {type(step_outputs_tuple)}" + ) + return step_outputs_tuple + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(env={self.envname}, task={self.taskname if self.taskname else None}, batch_size={self.batch_size})" + + +def make_tensor_dict( + env: _EnvClass, + policy: Optional[Callable[[_TensorDict, ...], _TensorDict]] = None, +) -> _TensorDict: + """ + Returns a zeroed-tensordict with fields matching those required for a full step + (action selection and environment step) in the environment + + Args: + env (_EnvWrapper): environment defining the observation, action and reward space; + policy (Callable, optional): policy corresponding to the environment. + + """ + with torch.no_grad(): + tensor_dict = env.reset() + if policy is not None: + tensor_dict = tensor_dict.unsqueeze(0) + tensor_dict = policy(tensor_dict.to(next(policy.parameters()).device)) + tensor_dict = tensor_dict.squeeze(0) + else: + tensor_dict.set("action", env.action_spec.rand(), inplace=False) + tensor_dict = env.step(tensor_dict.to("cpu")) + return tensor_dict.zero_() diff --git a/torchrl/envs/libs/__init__.py b/torchrl/envs/libs/__init__.py new file mode 100644 index 00000000000..52722542368 --- /dev/null +++ b/torchrl/envs/libs/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .dm_control import * +from .gym import * diff --git a/torchrl/envs/libs/dm_control.py b/torchrl/envs/libs/dm_control.py new file mode 100644 index 00000000000..b1e27560da2 --- /dev/null +++ b/torchrl/envs/libs/dm_control.py @@ -0,0 +1,170 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from torchrl.data import ( + CompositeSpec, + NdBoundedTensorSpec, + NdUnboundedContinuousTensorSpec, + TensorSpec, +) +from ...data.utils import numpy_to_torch_dtype_dict +from ..common import GymLikeEnv + +__all__ = ["DMControlEnv"] +try: + import collections + + import dm_env + from dm_control import suite + from dm_control.suite.wrappers import pixels + + _has_dmc = True + +except ImportError: + _has_dmc = False + + +def _dmcontrol_to_torchrl_spec_transform( + spec, dtype: Optional[torch.dtype] = None +) -> TensorSpec: + if isinstance(spec, collections.OrderedDict): + spec = { + k: _dmcontrol_to_torchrl_spec_transform(item) for k, item in spec.items() + } + return CompositeSpec(**spec) + elif isinstance(spec, dm_env.specs.BoundedArray): + if dtype is None: + dtype = numpy_to_torch_dtype_dict[spec.dtype] + return NdBoundedTensorSpec( + shape=spec.shape, + minimum=spec.minimum, + maximum=spec.maximum, + dtype=dtype, + ) + elif isinstance(spec, dm_env.specs.Array): + if dtype is None: + dtype = numpy_to_torch_dtype_dict[spec.dtype] + return NdUnboundedContinuousTensorSpec(shape=spec.shape, dtype=dtype) + else: + raise NotImplementedError + + +def _get_envs(to_dict: bool = True) -> dict: + if not _has_dmc: + return dict() + if not to_dict: + return tuple(suite.BENCHMARKING) + tuple(suite.EXTRA) + d = dict() + for tup in suite.BENCHMARKING: + envname = tup[0] + d.setdefault(envname, []).append(tup[1]) + for tup in suite.EXTRA: + envname = tup[0] + d.setdefault(envname, []).append(tup[1]) + return d + + +def _robust_to_tensor(array: Union[float, np.ndarray]) -> torch.Tensor: + if isinstance(array, np.ndarray): + return torch.tensor(array.copy()) + else: + return torch.tensor(array) + + +class DMControlEnv(GymLikeEnv): + """ + DeepMind Control lab environment wrapper. + + Args: + envname (str): name of the environment + taskname (str): name of the task + seed (int, optional): seed to use for the environment + from_pixels (bool): if True, the observation + + Examples: + >>> env = DMControlEnv(envname="cheetah", taskname="run", + ... from_pixels=True, frame_skip=4) + >>> td = env.rand_step() + >>> print(td) + >>> print(env.available_envs) + """ + + git_url = "https://github.com/deepmind/dm_control" + libname = "dm_control" + available_envs = _get_envs() + + def _build_env( + self, + envname: str, + taskname: str, + _seed: Optional[int] = None, + from_pixels: bool = False, + render_kwargs: Optional[dict] = None, + pixels_only: bool = False, + **kwargs, + ): + if not _has_dmc: + raise RuntimeError( + f"dm_control not found, unable to create {envname}:" + f" {taskname}. Consider downloading and installing " + f"dm_control from {self.git_url}" + ) + self.from_pixels = from_pixels + self.pixels_only = pixels_only + + if _seed is not None: + random_state = np.random.RandomState(_seed) + kwargs = {"random": random_state} + env = suite.load(envname, taskname, task_kwargs=kwargs) + if from_pixels: + self.render_kwargs = {"camera_id": 0} + if render_kwargs is not None: + self.render_kwargs.update(render_kwargs) + env = pixels.Wrapper( + env, + pixels_only=self.pixels_only, + render_kwargs=self.render_kwargs, + ) + self._env = env + return env + + def _init_env(self, seed: Optional[int] = None) -> Optional[int]: + seed = self.set_seed(seed) + return seed + + def _set_seed(self, _seed: Optional[int]) -> Optional[int]: + self._env = self._build_env( + self.envname, self.taskname, _seed=_seed, **self.constructor_kwargs + ) + self.reset() + return _seed + + def _output_transform( + self, timestep_tuple: Tuple["TimeStep"] + ) -> Tuple[np.ndarray, float, bool]: + if type(timestep_tuple) is not tuple: + timestep_tuple = (timestep_tuple,) + reward = timestep_tuple[0].reward + + done = False # dm_control envs are non-terminating + observation = timestep_tuple[0].observation + return observation, reward, done + + @property + def action_spec(self) -> TensorSpec: + return _dmcontrol_to_torchrl_spec_transform(self._env.action_spec()) + + @property + def observation_spec(self) -> TensorSpec: + return _dmcontrol_to_torchrl_spec_transform(self._env.observation_spec()) + + @property + def reward_spec(self) -> TensorSpec: + return _dmcontrol_to_torchrl_spec_transform(self._env.reward_spec()) diff --git a/torchrl/envs/libs/dmlab.py b/torchrl/envs/libs/dmlab.py new file mode 100644 index 00000000000..7bec24cb17b --- /dev/null +++ b/torchrl/envs/libs/dmlab.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py new file mode 100644 index 00000000000..856ddcb3f4b --- /dev/null +++ b/torchrl/envs/libs/gym.py @@ -0,0 +1,187 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from types import ModuleType +from typing import List, Optional, Sequence + +import torch + +from torchrl.data import ( + BinaryDiscreteTensorSpec, + CompositeSpec, + MultOneHotDiscreteTensorSpec, + NdBoundedTensorSpec, + OneHotDiscreteTensorSpec, + TensorSpec, + UnboundedContinuousTensorSpec, +) +from ...data.utils import numpy_to_torch_dtype_dict +from ..common import GymLikeEnv +from ..utils import classproperty + +try: + import gym + + _has_gym = True + from gym.wrappers.pixel_observation import PixelObservationWrapper + +except ImportError: + _has_gym = False + +try: + import retro + + _has_retro = True +except ImportError: + _has_retro = False + +__all__ = ["GymEnv", "RetroEnv"] + + +def _gym_to_torchrl_spec_transform(spec, dtype=None, device="cpu") -> TensorSpec: + if isinstance(spec, gym.spaces.tuple.Tuple): + raise NotImplementedError("gym.spaces.tuple.Tuple mapping not yet implemented") + if isinstance(spec, gym.spaces.discrete.Discrete): + return OneHotDiscreteTensorSpec(spec.n) + elif isinstance(spec, gym.spaces.multi_binary.MultiBinary): + return BinaryDiscreteTensorSpec(spec.n) + elif isinstance(spec, gym.spaces.multi_discrete.MultiDiscrete): + return MultOneHotDiscreteTensorSpec(spec.nvec) + elif isinstance(spec, gym.spaces.Box): + if dtype is None: + dtype = numpy_to_torch_dtype_dict[spec.dtype] + return NdBoundedTensorSpec( + torch.tensor(spec.low, device=device, dtype=dtype), + torch.tensor(spec.high, device=device, dtype=dtype), + torch.Size(spec.shape), + dtype=dtype, + ) + elif isinstance(spec, (dict, gym.spaces.dict.Dict)): + spec = {k: _gym_to_torchrl_spec_transform(spec[k]) for k in spec} + return CompositeSpec(**spec) + else: + raise NotImplementedError( + f"spec of type {type(spec).__name__} is currently unaccounted for" + ) + + +def _get_envs(to_dict=False) -> List: + envs = gym.envs.registration.registry.env_specs.keys() + envs = list(envs) + envs = sorted(envs) + return envs + + +def _get_gym(): + if _has_gym: + return gym + else: + return None + + +def _is_from_pixels(observation_space): + return ( + isinstance(observation_space, gym.spaces.Box) + and (observation_space.low == 0).all() + and (observation_space.high == 255).all() + and observation_space.low.shape[-1] == 3 + and observation_space.low.ndim == 3 + ) + + +class GymEnv(GymLikeEnv): + """ + OpenAI Gym environment wrapper. + + Examples: + >>> env = GymEnv(envname="Pendulum-v0", frame_skip=4) + >>> td = env.rand_step() + >>> print(td) + >>> print(env.available_envs) + """ + + git_url = "https://github.com/openai/gym" + libname = "gym" + + @classproperty + def available_envs(cls) -> List[str]: + return _get_envs() + + @property + def lib(self) -> ModuleType: + return gym + + def _set_seed(self, seed: int) -> int: + self.reset(seed=seed) + return seed + + def _build_env( + self, + envname: str, + taskname: str, + from_pixels: bool = False, + pixels_only: bool = False, + ) -> gym.core.Env: + self.pixels_only = pixels_only + if not _has_gym: + raise RuntimeError( + f"gym not found, unable to create {envname}. " + f"Consider downloading and installing dm_control from" + f" {self.git_url}" + ) + if not ((taskname == "") or (taskname is None)): + raise ValueError( + f"gym does not support taskname, received {taskname} instead." + ) + try: + env = self.lib.make(envname, frameskip=self.frame_skip) + self.wrapper_frame_skip = 1 + except TypeError as err: + if "unexpected keyword argument 'frameskip" not in str(err): + raise TypeError(err) + env = self.lib.make(envname) + self.wrapper_frame_skip = self.frame_skip + self._env = env + + from_pixels = from_pixels or _is_from_pixels(self._env.observation_space) + self.from_pixels = from_pixels + if from_pixels: + self._env.reset() + self._env = PixelObservationWrapper(self._env, pixels_only) + + self.action_spec = _gym_to_torchrl_spec_transform(self._env.action_space) + self.observation_spec = _gym_to_torchrl_spec_transform( + self._env.observation_space + ) + self.reward_spec = UnboundedContinuousTensorSpec( + device=self.device, + ) # default + + def _init_env(self, seed: Optional[int] = None) -> Optional[int]: + if seed is not None: + seed = self.set_seed(seed) + self.reset() # make sure that _current_observation and + # _is_done are populated + return seed + + +def _get_retro_envs() -> Sequence: + if not _has_retro: + return tuple() + else: + return retro.data.list_games() + + +def _get_retro() -> Optional[ModuleType]: + if _has_retro: + return retro + else: + return None + + +class RetroEnv(GymEnv): + available_envs = _get_retro_envs() + lib = "retro" + lib = _get_retro() diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py new file mode 100644 index 00000000000..90831d099bc --- /dev/null +++ b/torchrl/envs/transforms/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .transforms import * diff --git a/torchrl/envs/transforms/functional.py b/torchrl/envs/transforms/functional.py new file mode 100644 index 00000000000..2fb7bff62ad --- /dev/null +++ b/torchrl/envs/transforms/functional.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List + +from torch import Tensor + + +# copied from torchvision +def _get_image_num_channels(img: Tensor) -> int: + if img.ndim == 2: + return 1 + elif img.ndim > 2: + return img.shape[-3] + + raise TypeError("Input ndim should be 2 or more. Got {}".format(img.ndim)) + + +def _assert_channels(img: Tensor, permitted: List[int]) -> None: + c = _get_image_num_channels(img) + if c not in permitted: + raise TypeError( + "Input image tensor permitted channel values are {}, but found" + "{}".format(permitted, c) + ) + + +def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor: + if img.ndim < 3: + raise TypeError( + "Input image tensor should have at least 3 dimensions, but found" + "{}".format(img.ndim) + ) + _assert_channels(img, [3]) + + if num_output_channels not in (1, 3): + raise ValueError("num_output_channels should be either 1 or 3") + + r, g, b = img.unbind(dim=-3) + l_img = (0.2989 * r + 0.587 * g + 0.114 * b).to(img.dtype) + l_img = l_img.unsqueeze(dim=-3) + + if num_output_channels == 3: + return l_img.expand(img.shape) + + return l_img diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py new file mode 100644 index 00000000000..8a55ddbc443 --- /dev/null +++ b/torchrl/envs/transforms/transforms.py @@ -0,0 +1,1514 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from copy import deepcopy +from typing import Any, List, Optional, OrderedDict, Sequence, Union + +import torch +from torch import nn +from torchvision.transforms.functional_tensor import ( + resize, +) # as of now resize is imported from torchvision + +from torchrl.data.tensor_specs import ( + BoundedTensorSpec, + CompositeSpec, + ContinuousBox, + NdUnboundedContinuousTensorSpec, + TensorSpec, + UnboundedContinuousTensorSpec, +) +from torchrl.data.tensordict.tensordict import _TensorDict, TensorDict +from torchrl.envs.common import _EnvClass, make_tensor_dict +from torchrl.envs.transforms import functional as F +from torchrl.envs.transforms.utils import FiniteTensor +from torchrl.envs.utils import step_tensor_dict + +__all__ = [ + "Transform", + "TransformedEnv", + "RewardClipping", + "Resize", + "GrayScale", + "Compose", + "ToTensorImage", + "ObservationNorm", + "RewardScaling", + "ObservationTransform", + "CatFrames", + "FiniteTensorDictCheck", + "DoubleToFloat", + "CatTensors", + "NoopResetEnv", + "BinerizeReward", + "PinMemoryTransform", + "VecNorm", + "gSDENoise", +] + +IMAGE_KEYS = ["next_observation", "next_observation_pixels"] +_MAX_NOOPS_TRIALS = 10 + + +class Transform(nn.Module): + """Environment transform parent class. + + In principle, a transform receives a tensordict as input and returns ( + the same or another) tensordict as output, where a series of values have + been modified or created with a new key. When instantiating a new + transform, the keys that are to be read from are passed to the + constructor via the `keys` argument. + + Transforms are to be combined with their target environments with the + TransformedEnv class, which takes as arguments an `_EnvClass` instance + and a transform. If multiple transforms are to be used, they can be + concatenated using the `Compose` class. + A transform can be stateless or stateful (e.g. CatTransform). Because of + this, Transforms support the `reset` operation, which should reset the + transform to its initial state (such that successive trajectories are kept + independent). + + Notably, `Transform` subclasses take care of transforming the affected + specs from an environment: when querying + `transformed_env.observation_spec`, the resulting objects will describe + the specs of the transformed tensors. + + """ + + invertible = False + + def __init__(self, keys: Sequence[str]): + super().__init__() + self.keys = keys + + def reset(self, tensor_dict: _TensorDict) -> _TensorDict: + """Resets a tranform if it is stateful.""" + return tensor_dict + + def _check_inplace(self) -> None: + if not hasattr(self, "inplace"): + raise AttributeError( + f"Transform of class {self.__class__.__name__} has no " + f"attribute inplace, consider implementing it." + ) + + def init(self, tensor_dict) -> None: + pass + + def _apply(self, obs: torch.Tensor) -> None: + """Applies the transform to a tensor. + This operation can be called multiple times (if multiples keys of the + tensordict match the keys of the transform). + + """ + raise NotImplementedError + + def _call(self, tensor_dict: _TensorDict) -> _TensorDict: + """Reads the input tensordict, and for the selected keys, applies the + transform. + + """ + self._check_inplace() + for _obs_key in tensor_dict.keys(): + if _obs_key in self.keys: + observation = self._apply(tensor_dict.get(_obs_key)) + tensor_dict.set(_obs_key, observation, inplace=self.inplace) + return tensor_dict + + def forward(self, tensor_dict: _TensorDict) -> _TensorDict: + self._call(tensor_dict) + return tensor_dict + + def _inv_apply(self, obs: torch.Tensor) -> torch.Tensor: + if self.invertible: + raise NotImplementedError + else: + return obs + + def _inv_call(self, tensor_dict: _TensorDict) -> _TensorDict: + self._check_inplace() + for _obs_key in tensor_dict.keys(): + if _obs_key in self.keys: + observation = self._inv_apply(tensor_dict.get(_obs_key)) + tensor_dict.set(_obs_key, observation, inplace=self.inplace) + return tensor_dict + + def inv(self, tensor_dict: _TensorDict) -> _TensorDict: + self._inv_call(tensor_dict) + return tensor_dict + + def transform_action_spec(self, action_spec: TensorSpec) -> TensorSpec: + """Transforms the action spec such that the resulting spec matches + transform mapping. + + Args: + action_spec (TensorSpec): spec before the transform + + Returns: + expected spec after the transform + + """ + return action_spec + + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + """Transforms the observation spec such that the resulting spec + matches transform mapping. + + Args: + observation_spec (TensorSpec): spec before the transform + + Returns: + expected spec after the transform + + """ + return observation_spec + + def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: + """Transforms the reward spec such that the resulting spec matches + transform mapping. + + Args: + reward_spec (TensorSpec): spec before the transform + + Returns: + expected spec after the transform + + """ + + return reward_spec + + def dump(self) -> None: + pass + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(keys={self.keys})" + + def set_parent(self, parent: Union[Transform, _EnvClass]) -> None: + self.__dict__["_parent"] = parent + + @property + def parent(self) -> _EnvClass: + if not hasattr(self, "_parent"): + raise AttributeError("transform parent uninitialized") + parent = self._parent + while not isinstance(parent, _EnvClass): + if not isinstance(parent, Transform): + raise ValueError( + "A transform parent must be either another transform or an environment object." + ) + parent = parent.parent + return parent + + +class TransformedEnv(_EnvClass): + """ + A transformed environment. + + Args: + env (_EnvClass): original environment to be transformed. + transform (Transform): transform to apply to the tensordict resulting + from env.step(td) + cache_specs (bool, optional): if True, the specs will be cached once + and for all after the first call (i.e. the specs will be + transformed only once). If the transform changes during + training, the original spec transform may not be valid anymore, + in which case this value should be set to `False`. Default is + `True`. + + Examples: + >>> env = GymEnv("Pendulum-v0") + >>> transform = RewardScaling(0.0, 1.0) + >>> transformed_env = TransformedEnv(env, transform) + + """ + + def __init__( + self, + env: _EnvClass, + transform: Transform, + cache_specs: bool = True, + **kwargs, + ): + self.env = env + self.transform = transform + transform.set_parent(self) # allows to find env specs from the transform + + self._last_obs = None + self.cache_specs = cache_specs + + self._action_spec = None + self._reward_spec = None + self._observation_spec = None + self.batch_size = self.env.batch_size + self.is_closed = False + + super().__init__(**kwargs) + + @property + def observation_spec(self) -> TensorSpec: + """Observation spec of the transformed environment""" + if self._observation_spec is None or not self.cache_specs: + observation_spec = self.transform.transform_observation_spec( + deepcopy(self.env.observation_spec) + ) + if self.cache_specs: + self._observation_spec = observation_spec + else: + observation_spec = self._observation_spec + return observation_spec + + @property + def action_spec(self) -> TensorSpec: + """Action spec of the transformed environment""" + + if self._action_spec is None or not self.cache_specs: + action_spec = self.transform.transform_action_spec( + deepcopy(self.env.action_spec) + ) + if self.cache_specs: + self._action_spec = action_spec + else: + action_spec = self._action_spec + return action_spec + + @property + def reward_spec(self) -> TensorSpec: + """Reward spec of the transformed environment""" + + if self._reward_spec is None or not self.cache_specs: + reward_spec = self.transform.transform_reward_spec( + deepcopy(self.env.reward_spec) + ) + if self.cache_specs: + self._reward_spec = reward_spec + else: + reward_spec = self._reward_spec + return reward_spec + + def _step(self, tensor_dict: _TensorDict) -> _TensorDict: + selected_keys = [key for key in tensor_dict.keys() if "action" in key] + tensor_dict_in = tensor_dict.select(*selected_keys).clone() + tensor_dict_in = self.transform.inv(tensor_dict_in) + tensor_dict_out = self.env._step(tensor_dict_in).to(self.device) + # tensor_dict should already have been processed by the transforms + # for logging purposes + tensor_dict_out = self.transform(tensor_dict_out) + return tensor_dict_out + + def set_seed(self, seed: int) -> int: + """Set the seeds of the environment""" + return self.env.set_seed(seed) + + def _reset(self, tensor_dict: Optional[_TensorDict] = None, **kwargs): + out_tensor_dict = self.env.reset(**kwargs).to(self.device) + out_tensor_dict = self.transform.reset(out_tensor_dict) + + # Transforms are made for "next_observations" and alike. We convert + # all the observations in next_observations, then map them back to + # their original key name + keys = list(out_tensor_dict.keys()) + for key in keys: + if key.startswith("observation"): + out_tensor_dict.rename_key(key, "next_" + key, safe=True) + + out_tensor_dict = self.transform(out_tensor_dict) + keys = list(out_tensor_dict.keys()) + for key in keys: + if key.startswith("next_observation"): + out_tensor_dict.rename_key(key, key[5:], safe=True) + return out_tensor_dict + + def state_dict(self, destination: Optional[OrderedDict] = None) -> OrderedDict: + state_dict = self.transform.state_dict(destination) + return state_dict + + def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None: + self.transform.load_state_dict(state_dict, **kwargs) + + def eval(self) -> TransformedEnv: + self.transform.eval() + return self + + def train(self, mode: bool = True) -> TransformedEnv: + self.transform.train(mode) + return self + + def __getattr__(self, attr: str) -> Any: + if attr in self.__dir__(): + return self.__getattribute__( + attr + ) # make sure that appropriate exceptions are raised + + elif "env" in self.__dir__(): + env = self.__getattribute__("env") + return getattr(env, attr) + + raise AttributeError( + f"env not set in {self.__class__.__name__}, cannot access {attr}" + ) + + def __repr__(self) -> str: + return f"TransformedEnv(env={self.env}, transform={self.transform})" + + def close(self): + self.is_closed = True + self.env.close() + + +class ObservationTransform(Transform): + """ + Abstract class for transformations of the observations. + + """ + + inplace = False + + def __init__(self, keys: Optional[Sequence[str]] = None): + if keys is None: + keys = [ + "next_observation", + "next_observation_pixels", + "next_observation_state", + ] + super(ObservationTransform, self).__init__(keys=keys) + + +class Compose(Transform): + """ + Composes a chain of transforms. + + Examples: + >>> env = GymEnv("Pendulum-v0") + >>> transforms = [RewardScaling(1.0, 1.0), RewardClipping(-2.0, 2.0)] + >>> transforms = Compose(*transforms) + >>> transformed_env = TransformedEnv(env, transforms) + + """ + + inplace = False + + def __init__(self, *transforms: Transform): + super().__init__(keys=[]) + self.transforms = nn.ModuleList(transforms) + for t in self.transforms: + t.set_parent(self) + + def _call(self, tensor_dict: _TensorDict) -> _TensorDict: + for t in self.transforms: + tensor_dict = t(tensor_dict) + return tensor_dict + + def transform_action_spec(self, action_spec: TensorSpec) -> TensorSpec: + for t in self.transforms: + action_spec = t.transform_action_spec(action_spec) + return action_spec + + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + for t in self.transforms: + observation_spec = t.transform_observation_spec(observation_spec) + return observation_spec + + def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: + for t in self.transforms: + reward_spec = t.transform_reward_spec(reward_spec) + return reward_spec + + def __getitem__(self, item: Union[int, slice, List]) -> Union: + transform = self.transforms + transform = transform[item] + if not isinstance(transform, Transform): + return Compose(*self.transforms[item]) + return transform + + def dump(self) -> None: + for t in self: + t.dump() + + def reset(self, tensor_dict: _TensorDict) -> _TensorDict: + for t in self.transforms: + tensor_dict = t.reset(tensor_dict) + return tensor_dict + + def init(self, tensor_dict: _TensorDict) -> None: + for t in self.transforms: + t.init(tensor_dict) + + def __repr__(self) -> str: + layers_str = ", \n\t".join([str(trsf) for trsf in self.transforms]) + return f"{self.__class__.__name__}(\n\t{layers_str})" + + +class ToTensorImage(ObservationTransform): + """Transforms a numpy-like image (3 x W x H) to a pytorch image + (3 x W x H). + + Transforms an observation image from a (... x W x H x 3) 0..255 uint8 + tensor to a single/double precision floating point (3 x W x H) tensor + with values between 0 and 1. + + Args: + unsqueeze (bool): if True, the observation tensor is unsqueezed + along the first dimension. default=False. + dtype (torch.dtype, optional): dtype to use for the resulting + observations. + + Examples: + >>> transform = ToTensorImage(keys=["next_observation_pixels"]) + >>> ri = torch.randint(0, 255, (1,1,10,11,3), dtype=torch.uint8) + >>> td = TensorDict( + ... {"next_observation_pixels": ri}, + ... [1, 1]) + >>> _ = transform(td) + >>> obs = td.get("next_observation_pixels") + >>> print(obs.shape, obs.dtype) + torch.Size([1, 1, 3, 10, 11]) torch.float32 + """ + + inplace = False + + def __init__( + self, + unsqueeze: bool = False, + dtype: Optional[torch.device] = None, + keys: Optional[Sequence[str]] = None, + ): + if keys is None: + keys = IMAGE_KEYS # default + super().__init__(keys=keys) + self.unsqueeze = unsqueeze + self.dtype = dtype if dtype is not None else torch.get_default_dtype() + + def _apply(self, observation: torch.FloatTensor) -> torch.Tensor: + observation = observation.div(255).to(self.dtype) + observation = observation.permute( + *list(range(observation.ndimension() - 3)), -1, -3, -2 + ) + if observation.ndimension() == 3 and self.unsqueeze: + observation = observation.unsqueeze(0) + return observation + + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + if isinstance(observation_spec, CompositeSpec): + _observation_spec = observation_spec["pixels"] + else: + _observation_spec = observation_spec + self._pixel_observation(_observation_spec) + _observation_spec.shape = torch.Size( + [ + *_observation_spec.shape[:-3], + _observation_spec.shape[-1], + _observation_spec.shape[-3], + _observation_spec.shape[-2], + ] + ) + _observation_spec.dtype = self.dtype + if isinstance(observation_spec, CompositeSpec): + observation_spec["pixels"] = _observation_spec + else: + observation_spec = _observation_spec + return observation_spec + + def _pixel_observation(self, spec: TensorSpec) -> None: + if isinstance(spec, BoundedTensorSpec): + spec.space.maximum = self._apply(spec.space.maximum) + spec.space.minimum = self._apply(spec.space.minimum) + + +class RewardClipping(Transform): + """ + Clips the reward between `clamp_min` and `clamp_max`. + + Args: + clip_min (scalar): minimum value of the resulting reward. + clip_max (scalar): maximum value of the resulting reward. + + """ + + inplace = True + + def __init__( + self, + clamp_min: float = None, + clamp_max: float = None, + keys: Optional[Sequence[str]] = None, + ): + if keys is None: + keys = ["reward"] + super().__init__(keys=keys) + self.clamp_min = clamp_min + self.clamp_max = clamp_max + + def _apply(self, reward: torch.Tensor) -> torch.Tensor: + if self.clamp_max is not None and self.clamp_min is not None: + reward = reward.clamp_(self.clamp_min, self.clamp_max) + elif self.clamp_min is not None: + reward = reward.clamp_min_(self.clamp_min) + elif self.clamp_max is not None: + reward = reward.clamp_max_(self.clamp_max) + return reward + + def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: + if isinstance(reward_spec, UnboundedContinuousTensorSpec): + return BoundedTensorSpec( + self.clamp_min, + self.clamp_max, + device=reward_spec.device, + dtype=reward_spec.dtype, + ) + else: + raise NotImplementedError( + f"{self.__class__.__name__}.transform_reward_spec not " + f"implemented for tensor spec of type" + f" {type(reward_spec).__name__}" + ) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"clamp_min={float(self.clamp_min):4.4f}, clamp_max" + f"={float(self.clamp_max):4.4f}, keys={self.keys})" + ) + + +class BinerizeReward(Transform): + """ + Maps the reward to a binary value (0 or 1) if the reward is null or + non-null, respectively. + + """ + + inplace = True + + def __init__(self, keys: Optional[Sequence[str]] = None): + if keys is None: + keys = ["reward"] + super().__init__(keys=keys) + + def _apply(self, reward: torch.Tensor) -> torch.Tensor: + return (reward != 0.0).to(reward.dtype) + + def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: + if isinstance(reward_spec, UnboundedContinuousTensorSpec): + return BoundedTensorSpec( + 0.0, 1.0, device=reward_spec.device, dtype=reward_spec.dtype + ) + else: + raise NotImplementedError( + f"{self.__class__.__name__}.transform_reward_spec not " + f"implemented for tensor spec of type " + f"{type(reward_spec).__name__}" + ) + + +class Resize(ObservationTransform): + """ + Resizes an pixel observation. + + Args: + w (int): resulting width + h (int): resulting height + interpolation (str): interpolation method + """ + + inplace = False + + def __init__( + self, + w: int, + h: int, + interpolation: str = "bilinear", + keys: Optional[Sequence[str]] = None, + ): + if keys is None: + keys = IMAGE_KEYS # default + super().__init__(keys=keys) + self.w = w + self.h = h + self.interpolation = interpolation + + def _apply(self, observation: torch.Tensor) -> torch.Tensor: + observation = resize( + observation, [self.w, self.h], interpolation=self.interpolation + ) + + return observation + + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + if isinstance(observation_spec, CompositeSpec): + _observation_spec = observation_spec["pixels"] + else: + _observation_spec = observation_spec + space = _observation_spec.space + if isinstance(space, ContinuousBox): + space.minimum = self._apply(space.minimum) + space.maximum = self._apply(space.maximum) + _observation_spec.shape = space.minimum.shape + else: + _observation_spec.shape = self._apply( + torch.zeros(_observation_spec.shape) + ).shape + + if isinstance(observation_spec, CompositeSpec): + observation_spec["pixels"] = _observation_spec + else: + observation_spec = _observation_spec + return observation_spec + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"w={float(self.w):4.4f}, h={float(self.h):4.4f}, " + f"interpolation={self.interpolation}, keys={self.keys})" + ) + + +class GrayScale(ObservationTransform): + """ + Turns a pixel observation to grayscale. + + """ + + inplace = False + + def __init__(self, keys: Optional[Sequence[str]] = None): + if keys is None: + keys = IMAGE_KEYS + super(GrayScale, self).__init__(keys=keys) + + def _apply(self, observation: torch.Tensor) -> torch.Tensor: + observation = F.rgb_to_grayscale(observation) + return observation + + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + if isinstance(observation_spec, CompositeSpec): + _observation_spec = observation_spec["pixels"] + else: + _observation_spec = observation_spec + space = _observation_spec.space + if isinstance(space, ContinuousBox): + space.minimum = self._apply(space.minimum) + space.maximum = self._apply(space.maximum) + _observation_spec.shape = space.minimum.shape + else: + _observation_spec.shape = self._apply( + torch.zeros(_observation_spec.shape) + ).shape + if isinstance(observation_spec, CompositeSpec): + observation_spec["pixels"] = _observation_spec + else: + observation_spec = _observation_spec + return observation_spec + + +class ObservationNorm(ObservationTransform): + """ + Normalizes an observation according to + + .. math:: + obs = obs * scale + loc + + Args: + loc (number or tensor): location of the affine transform + scale (number or tensor): scale of the affine transform + standard_normal (bool, optional): if True, the transform will be + + .. math:: + obs = (obs-loc)/scale + + as it is done for standardization. Default is `False`. + + Examples: + >>> torch.set_default_tensor_type(torch.DoubleTensor) + >>> r = torch.randn(100, 3)*torch.randn(3) + torch.randn(3) + >>> td = TensorDict({'next_obs': r}, [100]) + >>> transform = ObservationNorm( + ... loc = td.get('next_obs').mean(0), + ... scale = td.get('next_obs').std(0), + ... keys=["next_obs"], + ... standard_normal=True) + >>> _ = transform(td) + >>> print(torch.isclose(td.get('next_obs').mean(0), + ... torch.zeros(3)).all()) + Tensor(True) + >>> print(torch.isclose(td.get('next_obs').std(0), + ... torch.ones(3)).all()) + Tensor(True) + + """ + + inplace = True + + def __init__( + self, + loc: Union[float, torch.Tensor], + scale: Union[float, torch.Tensor], + keys: Optional[Sequence[str]] = None, + # observation_spec_key: =None, + standard_normal: bool = False, + ): + if keys is None: + keys = [ + "next_observation", + "next_observation_pixels", + "next_observation_state", + ] + super().__init__(keys=keys) + if not isinstance(loc, torch.Tensor): + loc = torch.tensor(loc, dtype=torch.float) + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale, dtype=torch.float) + + # self.observation_spec_key = observation_spec_key + self.standard_normal = standard_normal + self.register_buffer("loc", loc) + eps = 1e-6 + self.register_buffer("scale", scale.clamp_min(eps)) + + def _apply(self, obs: torch.Tensor) -> torch.Tensor: + if self.standard_normal: + # converts the transform (x-m)/sqrt(v) to x * s + loc + scale = self.scale.reciprocal() + loc = -self.loc * scale + else: + scale = self.scale + loc = self.loc + return obs * scale + loc + + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + if isinstance(observation_spec, CompositeSpec): + key = [key.split("observation_")[-1] for key in self.keys] + if len(set(key)) != 1: + raise RuntimeError(f"Too many compatible observation keys: {key}") + key = key[0] + _observation_spec = observation_spec[key] + else: + _observation_spec = observation_spec + space = _observation_spec.space + if isinstance(space, ContinuousBox): + space.minimum = self._apply(space.minimum) + space.maximum = self._apply(space.maximum) + return observation_spec + + def __repr__(self) -> str: + if self.loc.numel() == 1 and self.scale.numel() == 1: + return ( + f"{self.__class__.__name__}(" + f"loc={float(self.loc):4.4f}, scale" + f"={float(self.scale):4.4f}, keys={self.keys})" + ) + else: + return super().__repr__() + + +class CatFrames(ObservationTransform): + """Concatenates successive observation frames into a single tensor. + + This can, for instance, account for movement/velocity of the observed + feature. Proposed in "Playing Atari with Deep Reinforcement Learning" ( + https://arxiv.org/pdf/1312.5602.pdf). + + CatFrames is a stateful class and it can be reset to its native state by + calling the `reset()` method. + + Args: + N (int, optional): number of observation to concatenate. + Default is `4`. + cat_dim (int, optional): dimension along which concatenate the + observations. Default is `cat_dim=-3`. + keys (list of int, optional): keys pointing to the franes that have + to be concatenated. + + """ + + inplace = False + + def __init__( + self, + N: int = 4, + cat_dim: int = -3, + keys: Optional[Sequence[str]] = None, + ): + if keys is None: + keys = IMAGE_KEYS + super().__init__(keys=keys) + self.N = N + self.cat_dim = cat_dim + self.buffer = [] + + def reset(self, tensor_dict: _TensorDict) -> _TensorDict: + self.buffer = [] + return tensor_dict + + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + if isinstance(observation_spec, CompositeSpec): + _observation_spec = observation_spec["pixels"] + else: + _observation_spec = observation_spec + space = _observation_spec.space + if isinstance(space, ContinuousBox): + space.minimum = torch.cat([space.minimum] * self.N, 0) + space.maximum = torch.cat([space.maximum] * self.N, 0) + _observation_spec.shape = space.minimum.shape + else: + _observation_spec.shape = torch.Size([self.N, *_observation_spec.shape]) + if isinstance(observation_spec, CompositeSpec): + observation_spec["pixels"] = _observation_spec + else: + observation_spec = _observation_spec + return observation_spec + + def _apply(self, obs: torch.Tensor) -> torch.Tensor: + self.buffer.append(obs) + self.buffer = self.buffer[-self.N :] + buffer = list(reversed(self.buffer)) + buffer = [buffer[0]] * (self.N - len(buffer)) + buffer + if len(buffer) != self.N: + raise RuntimeError( + f"actual buffer length ({buffer}) differs from expected (" f"{self.N})" + ) + return torch.cat(buffer, self.cat_dim) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(N={self.N}, cat_dim" + f"={self.cat_dim}, keys={self.keys})" + ) + + +class RewardScaling(Transform): + """ + Affine transform of the reward according to + + .. math:: + reward = reward * scale + loc + + Args: + loc (number or torch.Tensor): location of the affine transform + scale (number or torch.Tensor): scale of the affine transform + """ + + inplace = True + + def __init__( + self, + loc: Union[float, torch.Tensor], + scale: Union[float, torch.Tensor], + keys: Optional[Sequence[str]] = None, + ): + if keys is None: + keys = ["reward"] + super().__init__(keys=keys) + if not isinstance(loc, torch.Tensor): + loc = torch.tensor(loc) + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale) + + self.register_buffer("loc", loc) + self.register_buffer("scale", scale.clamp_min(1e-6)) + + def _apply(self, reward: torch.Tensor) -> torch.Tensor: + reward.mul_(self.scale).add_(self.loc) + return reward + + def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: + if isinstance(reward_spec, UnboundedContinuousTensorSpec): + return reward_spec + else: + raise NotImplementedError( + f"{self.__class__.__name__}.transform_reward_spec not " + f"implemented for tensor spec of type" + f" {type(reward_spec).__name__}" + ) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"loc={self.loc.item():4.4f}, scale={self.scale.item():4.4f}, " + f"keys={self.keys})" + ) + + +class FiniteTensorDictCheck(Transform): + """ + This transform will check that all the items of the tensordict are + finite, and raise an exception if they are not. + + """ + + inplace = False + + def __init__(self): + super().__init__(keys=[]) + + def _call(self, tensor_dict: _TensorDict) -> _TensorDict: + source = {} + for key, item in tensor_dict.items(): + try: + source[key] = FiniteTensor(item) + except RuntimeError as err: + if str(err).rfind("FiniteTensor encountered") > -1: + raise Exception(f"Found non-finite elements in {key}") + else: + raise RuntimeError(str(err)) + + finite_tensor_dict = TensorDict( + batch_size=tensor_dict.batch_size, source=source + ) + return finite_tensor_dict + + +class DoubleToFloat(Transform): + """ + Maps actions float to double before they are called on the environment. + + Examples: + >>> td = TensorDict( + ... {'next_obs': torch.ones(1, dtype=torch.double)}, []) + >>> transform = DoubleToFloat(keys=["next_obs"]) + >>> _ = transform(td) + >>> print(td.get("next_obs").dtype) + torch.float32 + + """ + + invertible = True + inplace = False + + def __init__(self, keys: Optional[Sequence[str]] = None): + if keys is None: + keys = ["action"] + super().__init__(keys=keys) + + def _apply(self, obs: torch.Tensor) -> torch.Tensor: + return obs.to(torch.float) + + def _inv_apply(self, obs: torch.Tensor) -> torch.Tensor: + return obs.to(torch.double) + + def _transform_spec(self, spec: TensorSpec) -> None: + if isinstance(spec, CompositeSpec): + for key in spec: + self._transform_spec(spec[key]) + else: + spec.dtype = torch.float + space = spec.space + if isinstance(space, ContinuousBox): + space.minimum = space.minimum.to(torch.float) + space.maximum = space.maximum.to(torch.float) + + def transform_action_spec(self, action_spec: TensorSpec) -> TensorSpec: + if "action" in self.keys: + if action_spec.dtype is not torch.double: + raise TypeError("action_spec.dtype is not double") + self._transform_spec(action_spec) + return action_spec + + def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: + if "reward" in self.keys: + if reward_spec.dtype is not torch.double: + raise TypeError("reward_spec.dtype is not double") + + self._transform_spec(reward_spec) + return reward_spec + + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + keys = [key for key in self.keys if "observation" in key] + if keys: + keys = [key.split("observation_")[-1] for key in keys] + if len(keys) > 1 or isinstance(observation_spec, CompositeSpec): + if not isinstance(observation_spec, CompositeSpec): + raise TypeError( + f"observation_spec was found to be of type" + f" {type(observation_spec)} when CompositeSpec " + f"was expected (as more than one observation key has to " + f"be converted to float)." + ) + for key in keys: + self._transform_spec(observation_spec[key]) + elif len(keys): + self._transform_spec(observation_spec) + return observation_spec + + +class CatTensors(Transform): + """ + Concatenates several keys in a single tensor. + This is especially useful if multiple keys describe a single state (e.g. + "observation_position" and + "observation_velocity") + + Args: + keys (Sequence of str): keys to be concatenated + out_key: key of the resulting tensor. + + Examples: + >>> transform = CatTensors(keys=["key1", "key2"]) + >>> td = TensorDict({"key1": torch.zeros(1, 1), + ... "key2": torch.ones(1, 1)}, [1]) + >>> _ = transform(td) + >>> print(td.get("observation_vector")) + tensor([[0., 1.]]) + + """ + + invertible = False + inplace = False + + def __init__( + self, + keys: Optional[Sequence[str]] = None, + out_key: str = "observation_vector", + ): + if keys is None: + raise Exception("CatTensors requires keys to be non-empty") + super().__init__(keys=keys) + if "observation_" not in out_key: + raise KeyError("CatTensors is currently restricted to observation_* keys") + self.out_key = out_key + self.keys = sorted(list(self.keys)) + if ( + ("reward" in self.keys) + or ("action" in self.keys) + or ("reward" in self.keys) + ): + raise RuntimeError( + "Concatenating observations and reward / action / done state " + "is not allowed." + ) + + def _call(self, tensor_dict: _TensorDict) -> _TensorDict: + if all([key in tensor_dict.keys() for key in self.keys]): + out_tensor = torch.cat([tensor_dict.get(key) for key in self.keys], -1) + tensor_dict.set(self.out_key, out_tensor) + for key in self.keys: + tensor_dict.del_(key) + else: + raise Exception( + f"CatTensor failed, as it expected input keys =" + f" {sorted(list(self.keys))} but got a TensorDict with keys" + f" {sorted(list(tensor_dict.keys()))}" + ) + return tensor_dict + + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + if not isinstance(observation_spec, CompositeSpec): + # then there is a single tensor to be concatenated + return observation_spec + + keys = [key.split("observation_")[-1] for key in self.keys] + + if all([key in observation_spec for key in keys]): + sum_shape = sum( + [ + observation_spec[key].shape[-1] + if observation_spec[key].shape + else 1 + for key in keys + ] + ) + spec0 = observation_spec[keys[0]] + out_key = self.out_key.split("observation_")[-1] + observation_spec[out_key] = NdUnboundedContinuousTensorSpec( + shape=torch.Size([*spec0.shape[:-1], sum_shape]), + dtype=spec0.dtype, + ) + for key in keys: + observation_spec.del_(key) + return observation_spec + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(in_keys={self.keys}, out_key" + f"={self.out_key})" + ) + + +class DiscreteActionProjection(Transform): + """Projects discrete actions from a high dimensional space to a low + dimensional space. + + Given a discrete action (from 1 to N) encoded as a one-hot vector and a + maximum action index M (with M < N), transforms the action such that + action_out is at most M. + + If the input action is > M, it is being replaced by a random value + between N and M. Otherwise the same action is kept. + This is intended to be used with policies applied over multiple discrete + control environments with different action space. + + Args: + max_N (int): max number of action considered. + M (int): resulting number of actions. + + Examples: + >>> torch.manual_seed(0) + >>> N = 2 + >>> M = 1 + >>> action = torch.zeros(N, dtype=torch.long) + >>> action[-1] = 1 + >>> td = TensorDict({"action": action}, []) + >>> transform = DiscreteActionProjection(N, M) + >>> _ = transform.inv(td) + >>> print(td.get("action")) + tensor([1]) + """ + + inplace = False + + def __init__(self, max_N: int, M: int, action_key: str = "action"): + super().__init__([action_key]) + self.max_N = max_N + self.M = M + + def _inv_apply(self, action: torch.Tensor) -> torch.Tensor: + if action.shape[-1] < self.M: + raise RuntimeError( + f"action.shape[-1]={action.shape[-1]} is smaller than " + f"DiscreteActionProjection.M={self.M}" + ) + action = action.argmax(-1) # bool to int + idx = action >= self.M + if idx.any(): + action[idx] = torch.randint(self.M, (idx.sum(),)) + action = nn.functional.one_hot(action, self.M) + return action + + def transform_action_spec(self, action_spec: TensorSpec) -> TensorSpec: + shape = action_spec.shape + shape = torch.Size([*shape[:-1], self.max_N]) + action_spec.shape = shape + action_spec.space.n = self.max_N + return action_spec + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(max_N={self.max_N}, M={self.M}, " + f"keys={self.keys})" + ) + + +class NoopResetEnv(Transform): + """ + Runs a series of random actions when an environment is reset. + + Args: + env (_EnvClass): env on which the random actions have to be + performed. Can be the same env as the one provided to the + TransformedEnv class + noops (int, optional): number of actions performed after reset. + Default is `30`. + random (bool, optional): if False, the number of random ops will + always be equal to the noops value. If True, the number of + random actions will be randomly selected between 0 and noops. + Default is `True`. + + """ + + inplace = True + + def __init__(self, env: _EnvClass, noops: int = 30, random: bool = True): + """Sample initial states by taking random number of no-ops on reset. + No-op is assumed to be action 0. + """ + super().__init__([]) + self.env = env + self.noops = noops + self.random = random + + def reset(self, tensor_dict: _TensorDict) -> _TensorDict: + """Do no-op action for a number of steps in [1, noop_max].""" + keys = tensor_dict.keys() + noops = ( + self.noops if not self.random else torch.randint(self.noops, (1,)).item() + ) + i = 0 + trial = 0 + while i < noops: + i += 1 + tensor_dict = self.env.rand_step() + if self.env.is_done: + self.env.reset() + i = 0 + trial += 1 + if trial > _MAX_NOOPS_TRIALS: + self.env.reset() + tensor_dict = self.env.rand_step() + break + if self.env.is_done: + raise RuntimeError("NoopResetEnv concluded with done environment") + td = step_tensor_dict(tensor_dict).select(*keys) + for k in keys: + if k not in td.keys(): + td.set(k, tensor_dict.get(k)) + return td + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(noops={self.noops}, random" + f"={self.random}, keys={self.keys})" + ) + + +class PinMemoryTransform(Transform): + """ + Calls pin_memory on the tensordict to facilitate writing on CUDA devices. + + """ + + def __init__(self): + super().__init__([]) + + def _call(self, tensor_dict: _TensorDict) -> _TensorDict: + return tensor_dict.pin_memory() + + +def _sum_left(val, dest): + while val.ndimension() > dest.ndimension(): + val = val.sum(0) + return val + + +class gSDENoise(Transform): + inplace = False + + def __init__(self, action_dim: int, state_dim: Optional[int] = None) -> None: + super().__init__(keys=[]) + self.action_dim = action_dim + self.state_dim = state_dim + + def reset(self, tensor_dict: _TensorDict) -> _TensorDict: + tensor_dict = super().reset(tensor_dict=tensor_dict) + if self.state_dim is None: + obs_spec = self.parent.observation_spec + if isinstance(obs_spec, CompositeSpec): + obs_spec = obs_spec["vector"] + state_dim = obs_spec.shape[-1] + else: + state_dim = self.state_dim + + tensor_dict.set( + "_eps_gSDE", + torch.randn( + *tensor_dict.batch_size, + self.action_dim, + state_dim, + device=tensor_dict.device, + ), + ) + return tensor_dict + + +class VecNorm(Transform): + """ + Moving average normalization layer for torchrl environments. + VecNorm keeps track of the summary statistics of a dataset to standardize + it on-the-fly. If the transform is in 'eval' mode, the running + statistics are not updated. + + If multiple processes are running a similar environment, one can pass a + _TensorDict instance that is placed in shared memory: if so, every time + the normalization layer is queried it will update the values for all + processes that share the same reference. + + Args: + keys (iterable of str, optional): keys to be updated. + default: ["next_observation", "reward"] + shared_td (_TensorDict, optional): A shared tensordict containing the + keys of the transform. + decay (number, optional): decay rate of the moving average. + default: 0.99 + eps (number, optional): lower bound of the running standard + deviation (for numerical underflow). Default is 1e-4. + + Examples: + >>> from torchrl.envs import GymEnv + >>> t = VecNorm(decay=0.9) + >>> env = GymEnv("Pendulum-v0") + >>> env = TransformedEnv(env, t) + >>> tds = [] + >>> for _ in range(1000): + ... td = env.rand_step() + ... if td.get("done"): + ... _ = env.reset() + ... tds += [td] + >>> tds = torch.stack(tds, 0) + >>> print((abs(tds.get("next_observation").mean(0))<0.2).all()) + tensor(True) + >>> print((abs(tds.get("next_observation").std(0)-1)<0.2).all()) + tensor(True) + + """ + + inplace = True + + def __init__( + self, + keys: Optional[Sequence[str]] = None, + shared_td: Optional[_TensorDict] = None, + decay: float = 0.9999, + eps: float = 1e-4, + ) -> None: + if keys is None: + keys = ["next_observation", "reward"] + super().__init__(keys) + self._td = shared_td + if shared_td is not None and not ( + shared_td.is_shared() or shared_td.is_memmap() + ): + raise RuntimeError( + "shared_td must be either in shared memory or a memmap " "tensordict." + ) + if shared_td is not None: + for key in keys: + if ( + (key + "_sum" not in shared_td.keys()) + or (key + "_ssq" not in shared_td.keys()) + or (key + "_count" not in shared_td.keys()) + ): + raise KeyError( + f"key {key} not present in the shared tensordict " + f"with keys {shared_td.keys()}" + ) + + self.decay = decay + self.eps = eps + + def _call(self, tensordict: _TensorDict) -> _TensorDict: + for key in self.keys: + if key not in tensordict.keys(): + continue + self._init(tensordict, key) + # update anb standardize + new_val = self._update( + key, tensordict.get(key), N=max(1, tensordict.numel()) + ) + + tensordict.set_(key, new_val) + return tensordict + + def _init(self, tensordict: _TensorDict, key: str) -> None: + if self._td is None or key + "_sum" not in self._td.keys(): + td_view = tensordict.view(-1) + td_select = td_view[0] + d = {key + "_sum": torch.zeros_like(td_select.get(key))} + d.update({key + "_ssq": torch.zeros_like(td_select.get(key))}) + d.update( + { + key + + "_count": torch.zeros( + 1, device=td_select.get(key).device, dtype=torch.float + ) + } + ) + if self._td is None: + self._td = TensorDict(d, batch_size=[]) + else: + self._td.update(d) + else: + pass + + def _update(self, key, value, N) -> torch.Tensor: + _sum = self._td.get(key + "_sum") + _ssq = self._td.get(key + "_ssq") + _count = self._td.get(key + "_count") + + if self.training: + value_sum = _sum_left(value, _sum) + value_ssq = _sum_left(value.pow(2), _ssq) + + _sum = self.decay * _sum + value_sum + _ssq = self.decay * _ssq + value_ssq + _count = self.decay * _count + N + + self._td.set_(key + "_sum", _sum) + self._td.set_(key + "_ssq", _ssq) + self._td.set_(key + "_count", _count) + + mean = _sum / _count + std = (_ssq / _count - mean.pow(2)).clamp_min(self.eps).sqrt() + return (value - mean) / std.clamp_min(self.eps) + + @staticmethod + def build_td_for_shared_vecnorm( + env: _EnvClass, + keys_prefix: Optional[Sequence[str]] = None, + memmap: bool = False, + ) -> _TensorDict: + """Creates a shared tensordict that can be sent to different processes + for normalization across processes. + + Args: + env (_EnvClass): example environment to be used to create the + tensordict + keys_prefix (iterable of str, optional): prefix of the keys that + have to be normalized. Default is `["next_", "reward"]` + memmap (bool): if True, the resulting tensordict will be cast into + memmory map (using `memmap_()`). Otherwise, the tensordict + will be placed in shared memory. + + Returns: + A memory in shared memory to be sent to each process. + + Examples: + >>> from torch import multiprocessing as mp + >>> queue = mp.Queue() + >>> env = make_env() + >>> td_shared = VecNorm.build_td_for_shared_vecnorm(env, + ... ["next_observation", "reward"]) + >>> assert td_shared.is_shared() + >>> queue.put(td_shared) + >>> # on workers + >>> v = VecNorm(shared_td=queue.get()) + >>> env = TransformedEnv(make_env(), v) + + """ + if keys_prefix is None: + keys_prefix = ["next_", "reward"] + td = make_tensor_dict(env) + keys = set( + key + for key in td.keys() + if any(key.startswith(_prefix) for _prefix in keys_prefix) + ) + td_select = td.select(*keys) + if td.batch_dims: + raise RuntimeError( + f"VecNorm should be used with non-batched environments. " + f"Got batch_size={td.batch_size}" + ) + for key in keys: + td_select.set(key + "_ssq", td_select.get(key).clone()) + td_select.set( + key + "_count", + torch.zeros( + *td.batch_size, + 1, + device=td_select.device, + dtype=torch.float, + ), + ) + td_select.rename_key(key, key + "_sum") + td_select.zero_() + if memmap: + return td_select.memmap_() + return td_select.share_memory_() + + def get_extra_state(self) -> _TensorDict: + return self._td + + def set_extra_state(self, td: _TensorDict) -> None: + if not td.is_shared(): + raise RuntimeError( + "Only shared tensordicts can be set in VecNorm transforms" + ) + self._td = td + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(decay={self.decay:4.4f}," + f"eps={self.eps:4.4f}, keys={self.keys})" + ) diff --git a/torchrl/envs/transforms/utils.py b/torchrl/envs/transforms/utils.py new file mode 100644 index 00000000000..dba5ab1622a --- /dev/null +++ b/torchrl/envs/transforms/utils.py @@ -0,0 +1,47 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +from typing import Callable, Optional, Tuple + +import torch +from torch.utils._pytree import tree_map + + +@contextlib.contextmanager +def no_dispatch(): + guard = torch._C._DisableTorchDispatch() + try: + yield + finally: + del guard + + +class FiniteTensor(torch.Tensor): + @staticmethod + def __new__(cls, elem: torch.Tensor, *args, **kwargs): + if not torch.isfinite(elem).all(): + raise RuntimeError("FiniteTensor encountered a non-finite tensor.") + return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) + + def __repr__(self) -> str: + return f"FiniteTensor({super().__repr__()})" + + @classmethod + def __torch_dispatch__( + cls, + func: Callable, + types, + args: Tuple = (), + kwargs: Optional[dict] = None, + ): + # TODO: also explicitly recheck invariants on inplace/out mutation + if kwargs: + raise Exception("Expected empty kwargs") + with no_dispatch(): + rs = func(*args) + return tree_map( + lambda e: FiniteTensor(e) if isinstance(e, torch.Tensor) else e, rs + ) diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py new file mode 100644 index 00000000000..e1d875e9b45 --- /dev/null +++ b/torchrl/envs/utils.py @@ -0,0 +1,189 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Union + +import pkg_resources +from torch.autograd.grad_mode import _DecoratorContextManager + +from torchrl.data.tensordict.tensordict import _TensorDict + +AVAILABLE_LIBRARIES = {pkg.key for pkg in pkg_resources.working_set} + + +class classproperty(property): + def __get__(self, cls, owner): + return classmethod(self.fget).__get__(None, owner)() + + +def step_tensor_dict( + tensor_dict: _TensorDict, next_tensor_dict: _TensorDict = None +) -> _TensorDict: + """ + Given a tensor_dict retrieved after a step, returns another tensordict with all the 'next_' prefixes are removed, + i.e. all the `'next_some_other_string'` keys will be renamed onto `'some_other_string'` keys. + + + Args: + tensor_dict (_TensorDict): tensordict with keys to be renamed + next_tensor_dict (_TensorDict, optional): destination tensordict + + Returns: + A new tensordict (or next_tensor_dict) with the "next_*" keys renamed without the "next_" prefix. + + Examples: + This funtion allows for this kind of loop to be used: + >>> td_out = [] + >>> env = make_env() + >>> policy = make_policy() + >>> td = env.current_tensordict + >>> for i in range(n_steps): + >>> td = env.step(td) + >>> next_td = step_tensor_dict(td) + >>> assert next_td is not td # make sure that keys are not overwritten + >>> td_out.append(td) + >>> td = next_td + >>> td_out = torch.stack(td_out, 0) + >>> print(td_out) # should contain keys 'observation', 'next_observation', 'action', 'reward', 'done' or similar + + """ + keys = [key for key in tensor_dict.keys() if key.rfind("next_") == 0] + select_tensor_dict = tensor_dict.select(*keys).clone() + for key in keys: + select_tensor_dict.rename_key(key, key[5:], safe=True) + if next_tensor_dict is not None: + return next_tensor_dict.update(select_tensor_dict) + else: + return select_tensor_dict + + +def get_available_libraries(): + """ + + Returns: + all the supported libraries + + """ + return SUPPORTED_LIBRARIES + + +def _check_gym(): + """ + + Returns: + True if the gym library is installed + + """ + return "gym" in AVAILABLE_LIBRARIES + + +def _check_gym_atari(): + """ + + Returns: + True if the gym library is installed and atari envs can be found. + + """ + if not _check_gym(): + return False + return "atari-py" in AVAILABLE_LIBRARIES + + +def _check_mario(): + """ + + Returns: + True if the "gym-super-mario-bros" library is installed. + + """ + + return "gym-super-mario-bros" in AVAILABLE_LIBRARIES + + +def _check_dmcontrol(): + """ + + Returns: + True if the "dm-control" library is installed. + + """ + + return "dm-control" in AVAILABLE_LIBRARIES + + +def _check_dmlab(): + """ + + Returns: + True if the "deepmind-lab" library is installed. + + """ + + return "deepmind-lab" in AVAILABLE_LIBRARIES + + +SUPPORTED_LIBRARIES = { + "gym": _check_gym(), # OpenAI + "gym[atari]": _check_gym_atari(), # + "vizdoom": None, # 1.2k, https://github.com/mwydmuch/ViZDoom + "ml-agents": None, + # 11.5k, unity, https://github.com/Unity-Technologies/ml-agents + "pysc2": None, # 7.3k, DM, https://github.com/deepmind/pysc2 + "deepmind_lab": _check_dmlab(), + # 6.5k DM, https://github.com/deepmind/lab, https://github.com/deepmind/lab/tree/master/python/pip_package + "serpent.ai": None, # 6k, https://github.com/SerpentAI/SerpentAI + "gfootball": None, # 2.8k G, https://github.com/google-research/football + "dm_control": _check_dmcontrol(), + # 2.3k DM, https://github.com/deepmind/dm_control + "habitat": None, + # 1.2k FB, https://github.com/facebookresearch/habitat-sim + "meta-world": None, # 500, https://github.com/rlworkgroup/metaworld + "minerl": None, # 300, https://github.com/minerllabs/minerl + "multi-agent-emergence-environments": None, + # 1.2k, OpenAI, https://github.com/openai/multi-agent-emergence-environments + "openspiel": None, # 2.8k, DM, https://github.com/deepmind/open_spiel + "procgen": None, # 500, OpenAI, https://github.com/openai/procgen + "pybullet": None, # 641, https://github.com/benelot/pybullet-gym + "realworld_rl_suite": None, + # 250, G, https://github.com/google-research/realworldrl_suite + "rlcard": None, # 1.4k, https://github.com/datamllab/rlcard + "screeps": None, # 2.3k https://github.com/screeps/screeps + "gym-super-mario-bros": _check_mario(), +} + +EXPLORATION_MODE = None + + +class set_exploration_mode(_DecoratorContextManager): + """ + Sets the exploration mode of all ProbabilisticTDModules to the desired mode. + + Args: + mode (str): mode to use when the policy is being called. + + Examples: + >>> policy = Actor(action_spec, module=network, default_interaction_mode="mode") + >>> env.rollout(policy=policy, n_steps=100) # rollout with the "mode" interaction mode + >>> with set_exploration_mode("random"): + >>> env.rollout(policy=policy, n_steps=100) # rollout with the "random" interaction mode + """ + + def __init__(self, mode: str = "mode"): + super().__init__() + self.mode = mode + + def __enter__(self) -> None: + global EXPLORATION_MODE + self.prev = EXPLORATION_MODE + EXPLORATION_MODE = self.mode + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + global EXPLORATION_MODE + EXPLORATION_MODE = self.prev + + +def exploration_mode() -> Union[str, None]: + """Returns the exploration mode currently set.""" + return EXPLORATION_MODE diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py new file mode 100644 index 00000000000..0e978e64b6b --- /dev/null +++ b/torchrl/envs/vec_env.py @@ -0,0 +1,617 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import os +from collections import OrderedDict +from multiprocessing import connection +from typing import Callable, Optional, Sequence, Union + +import torch +from torch import multiprocessing as mp + +from torchrl.data import TensorDict, TensorSpec +from torchrl.data.tensordict.tensordict import _TensorDict +from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING +from torchrl.envs.common import _EnvClass, make_tensor_dict + +__all__ = ["SerialEnv", "ParallelEnv"] + + +def _check_start(fun): + def decorated_fun(self: _BatchedEnv, *args, **kwargs): + if self.is_closed: + self._create_td() + self._start_workers() + return fun(self, *args, **kwargs) + + return decorated_fun + + +class _BatchedEnv(_EnvClass): + """ + Batched environment abstract class. + + Args: + num_workers: number of workers (i.e. env instances) to be deployed simultaneously; + create_env_fn (callable or list of callables): function (or list of functions) to be used for the environment + creation; + create_env_kwargs (dict or list of dicts, optional): kwargs to be used with the environments being created; + device (str, int, torch.device): device of the environment; + action_keys (list of str, optional): list of keys that are to be considered policy-output. If the policy has it, + the attribute policy.out_keys can be used. + Providing the action_keys permit to select which keys to update after the policy is called, which can + drastically decrease the IO burden when the tensordict is placed in shared memory / memory map. + pin_memory (bool): if True and device is "cpu", calls `pin_memory` on the tensordicts when created. + selected_keys (list of str, optional): keys that have to be returned by the environment. + When creating a batch of environment, it might be the case that only some of the keys are to be returned. + For instance, if the environment returns 'observation_pixels' and 'observation_vector', the user might only + be interested in, say, 'observation_vector'. By indicating which keys must be returned in the tensordict, + one can easily control the amount of data occupied in memory (for instance to limit the memory size of a + replay buffer) and/or limit the amount of data passed from one process to the other; + excluded_keys (list of str, optional): list of keys to be excluded from the returned tensordicts. + See selected_keys for more details; + share_individual_td (bool): if True, a different tensordict is created for every process/worker and a lazy + stack is returned. + default = False; + shared_memory (bool): whether or not the returned tensordict will be placed in shared memory; + memmap (bool): whether or not the returned tensordict will be placed in memory map. + + """ + + _verbose: bool = False + + def __init__( + self, + num_workers: int, + create_env_fn: Union[ + Callable[[], _EnvClass], Sequence[Callable[[], _EnvClass]] + ], + create_env_kwargs: Union[dict, Sequence[dict]] = None, + device: DEVICE_TYPING = "cpu", + action_keys: Optional[Sequence[str]] = None, + pin_memory: bool = False, + selected_keys: Optional[Sequence[str]] = None, + excluded_keys: Optional[Sequence[str]] = None, + share_individual_td: bool = False, + shared_memory: bool = True, + memmap: bool = False, + ): + super().__init__(device=device) + self.is_closed = True + + create_env_kwargs = dict() if create_env_kwargs is None else create_env_kwargs + if callable(create_env_fn): + create_env_fn = [create_env_fn for _ in range(num_workers)] + else: + if len(create_env_fn) != num_workers: + raise RuntimeError( + f"num_workers and len(create_env_fn) mismatch, " + f"got {len(create_env_fn)} and {num_workers}" + ) + if isinstance(create_env_kwargs, dict): + create_env_kwargs = [create_env_kwargs for _ in range(num_workers)] + self._dummy_env = create_env_fn[0](**create_env_kwargs[0]) + self.num_workers = num_workers + self.create_env_fn = create_env_fn + self.create_env_kwargs = create_env_kwargs + self.action_keys = action_keys + self.pin_memory = pin_memory + self.selected_keys = selected_keys + self.excluded_keys = excluded_keys + self.share_individual_td = share_individual_td + self._share_memory = shared_memory + self._memmap = memmap + if self._share_memory and self._memmap: + raise RuntimeError( + "memmap and shared memory are mutually exclusive features." + ) + + self.batch_size = torch.Size([self.num_workers, *self._dummy_env.batch_size]) + self._action_spec = self._dummy_env.action_spec + self._observation_spec = self._dummy_env.observation_spec + self._reward_spec = self._dummy_env.reward_spec + self._dummy_env.close() + + def state_dict(self, destination: Optional[OrderedDict] = None) -> OrderedDict: + raise NotImplementedError + + def load_state_dict(self, state_dict: OrderedDict) -> None: + raise NotImplementedError + + @property + def action_spec(self) -> TensorSpec: + return self._action_spec + + @property + def observation_spec(self) -> TensorSpec: + return self._observation_spec + + @property + def reward_spec(self) -> TensorSpec: + return self._reward_spec + + def is_done_set_fn(self, value: bool) -> None: + self._is_done = value.all() + + def _create_td(self) -> None: + """Creates self.shared_tensor_dict_parent, a TensorDict used to store the most recent observations.""" + shared_tensor_dict_parent = make_tensor_dict( + self._dummy_env, + None, + ) + + shared_tensor_dict_parent = shared_tensor_dict_parent.expand( + self.num_workers + ).clone() + + raise_no_selected_keys = False + if self.selected_keys is None: + self.selected_keys = list(shared_tensor_dict_parent.keys()) + if self.excluded_keys is not None: + self.selected_keys = set(self.selected_keys) - set(self.excluded_keys) + else: + raise_no_selected_keys = True + if self.action_keys is not None: + if not all( + action_key in self.selected_keys for action_key in self.action_keys + ): + raise KeyError( + "One of the action keys is not part of the selected keys or is part of the excluded keys. Action " + "keys need to be part of the selected keys for env.step() to be called." + ) + else: + self.action_keys = [ + key for key in self.selected_keys if key.startswith("action") + ] + if not len(self.action_keys): + raise RuntimeError( + f"found 0 action keys in {sorted(list(self.selected_keys))}" + ) + shared_tensor_dict_parent = shared_tensor_dict_parent.select( + *self.selected_keys + ) + self.shared_tensor_dict_parent = shared_tensor_dict_parent.to(self.device) + + if self.share_individual_td: + self.shared_tensor_dicts = [ + td.clone() for td in self.shared_tensor_dict_parent.unbind(0) + ] + if self._share_memory: + for td in self.shared_tensor_dicts: + td.share_memory_() + elif self._memmap: + for td in self.shared_tensor_dicts: + td.memmap_() + self.shared_tensor_dict_parent = torch.stack(self.shared_tensor_dicts, 0) + else: + if self._share_memory: + self.shared_tensor_dict_parent.share_memory_() + if not self.shared_tensor_dict_parent.is_shared(): + raise RuntimeError("share_memory_() failed") + elif self._memmap: + self.shared_tensor_dict_parent.memmap_() + if not self.shared_tensor_dict_parent.is_memmap(): + raise RuntimeError("memmap_() failed") + + self.shared_tensor_dicts = self.shared_tensor_dict_parent.unbind(0) + if self.pin_memory: + self.shared_tensor_dict_parent.pin_memory() + + if raise_no_selected_keys: + if self._verbose: + print( + f"\n {self.__class__.__name__}.shared_tensor_dict_parent is \n{self.shared_tensor_dict_parent}. \n" + f"You can select keys to be synchronised by setting the selected_keys and/or excluded_keys " + f"arguments when creating the batched environment." + ) + + def _start_workers(self) -> None: + """Starts the various envs.""" + raise NotImplementedError + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(\n\tenv={self._dummy_env}, \n\tbatch_size={self.batch_size})" + + def __del__(self) -> None: + if not self.is_closed: + self.close() + + def close(self) -> None: + self.shutdown() + self.is_closed = True + + def shutdown(self) -> None: + self._shutdown_workers() + + def _shutdown_workers(self) -> None: + raise NotImplementedError + + +class SerialEnv(_BatchedEnv): + """ + Creates a series of environments in the same process. + + """ + + _share_memory = False + + def _start_workers(self) -> None: + _num_workers = self.num_workers + + self._envs = [] + + for idx in range(_num_workers): + env = self.create_env_fn[idx](**self.create_env_kwargs[idx]) + self._envs.append(env) + self.is_closed = False + + @_check_start + def state_dict(self, destination: Optional[OrderedDict] = None) -> OrderedDict: + state_dict = OrderedDict() + for idx, env in enumerate(self._envs): + state_dict[f"worker{idx}"] = env.state_dict() + + if destination is not None: + destination.update(state_dict) + return destination + return state_dict + + @_check_start + def load_state_dict(self, state_dict: OrderedDict) -> None: + if "worker0" not in state_dict: + state_dict = OrderedDict( + **{f"worker{idx}": state_dict for idx in range(self.num_workers)} + ) + for idx, env in enumerate(self._envs): + env.load_state_dict(state_dict[f"worker{idx}"]) + + @_check_start + def _step( + self, + tensor_dict: TensorDict, + ) -> TensorDict: + self._assert_tensordict_shape(tensor_dict) + + self.shared_tensor_dict_parent.update_(tensor_dict) + for i in range(self.num_workers): + self._envs[i].step(self.shared_tensor_dicts[i]) + + return self.shared_tensor_dict_parent + + def _shutdown_workers(self) -> None: + if not self.is_closed: + for env in self._envs: + env.close() + del self._envs + + def __del__(self) -> None: + self.close() + + def close(self) -> None: + self.shutdown() + self.is_closed = True + + @_check_start + def set_seed(self, seed: int) -> int: + for i, env in enumerate(self._envs): + env.set_seed(seed) + if i < self.num_workers - 1: + seed = seed + 1 + return seed + + @_check_start + def _reset(self, tensor_dict: _TensorDict, **kwargs) -> _TensorDict: + if tensor_dict is not None and "reset_workers" in tensor_dict.keys(): + self._assert_tensordict_shape(tensor_dict) + reset_workers = tensor_dict.get("reset_workers") + else: + reset_workers = torch.ones(self.num_workers, 1, dtype=torch.bool) + + keys = set() + for i, _env in enumerate(self._envs): + if not reset_workers[i]: + continue + _td = _env.reset(**kwargs) + keys = keys.union(_td.keys()) + self.shared_tensor_dicts[i].update_(_td) + + return self.shared_tensor_dict_parent.select(*keys).clone() + + +class ParallelEnv(_BatchedEnv): + """ + Creates one environment per process. + TensorDicts are passed via shared memory or memory map. + + """ + + def _start_workers(self) -> None: + + _num_workers = self.num_workers + ctx = mp.get_context("spawn") + + self.parent_channels = [] + self._workers = [] + + for idx in range(_num_workers): + if self._verbose: + print(f"initiating worker {idx}") + # No certainty which module multiprocessing_context is + channel1, channel2 = ctx.Pipe() + env_fun = self.create_env_fn[idx] + if env_fun.__class__.__name__ != "EnvCreator": + env_fun = CloudpickleWrapper(env_fun) + + w = mp.Process( + target=_run_worker_pipe_shared_mem, + args=( + idx, + channel1, + channel2, + env_fun, + self.create_env_kwargs[idx], + False, + self.action_keys, + ), + ) + w.daemon = True + w.start() + channel2.close() + self.parent_channels.append(channel1) + self._workers.append(w) + + # send shared tensordict to workers + for channel, shared_tensor_dict in zip( + self.parent_channels, self.shared_tensor_dicts + ): + channel.send(("init", shared_tensor_dict)) + self.is_closed = False + + @_check_start + def state_dict(self, destination: Optional[OrderedDict] = None) -> OrderedDict: + state_dict = OrderedDict() + for idx, channel in enumerate(self.parent_channels): + channel.send(("state_dict", None)) + for idx, channel in enumerate(self.parent_channels): + msg, _state_dict = channel.recv() + if msg != "state_dict": + raise RuntimeError(f"Expected 'state_dict' but received {msg}") + state_dict[f"worker{idx}"] = _state_dict + + if destination is not None: + destination.update(state_dict) + return destination + return state_dict + + @_check_start + def load_state_dict(self, state_dict: OrderedDict) -> None: + if "worker0" not in state_dict: + state_dict = OrderedDict( + **{f"worker{idx}": state_dict for idx in range(self.num_workers)} + ) + for i, channel in enumerate(self.parent_channels): + channel.send(("load_state_dict", state_dict[f"worker{i}"])) + for channel in self.parent_channels: + msg, _ = channel.recv() + if msg != "loaded": + raise RuntimeError(f"Expected 'loaded' but received {msg}") + + @_check_start + def _step(self, tensor_dict: _TensorDict) -> _TensorDict: + self._assert_tensordict_shape(tensor_dict) + + self.shared_tensor_dict_parent.update_(tensor_dict.select(*self.action_keys)) + for i in range(self.num_workers): + self.parent_channels[i].send(("step", None)) + + keys = set() + for i in range(self.num_workers): + msg, data = self.parent_channels[i].recv() + if msg != "step_result": + if msg != "done": + raise RuntimeError( + f"Expected 'done' but received {msg} from worker {i}" + ) + # data is the set of updated keys + keys = keys.union(data) + return self.shared_tensor_dict_parent.select(*keys) + + @_check_start + def _shutdown_workers(self) -> None: + if self.is_closed: + raise RuntimeError( + "calling {self.__class__.__name__}._shutdown_workers only allowed when env.is_closed = False" + ) + for i, channel in enumerate(self.parent_channels): + if self._verbose: + print(f"closing {i}") + channel.send(("close", None)) + msg, _ = channel.recv() + if msg != "closing": + raise RuntimeError( + f"Expected 'closing' but received {msg} from worker {i}" + ) + + del self.shared_tensor_dicts, self.shared_tensor_dict_parent + + for channel in self.parent_channels: + channel.close() + for proc in self._workers: + proc.join() + self.is_closed = True + del self._workers + del self.parent_channels + + def close(self) -> None: + if self.is_closed: + return None + if self._verbose: + print(f"closing {self.__class__.__name__}") + self.shutdown() + if not self.is_closed: + raise RuntimeError(f"expected {self.__class__.__name__} to be closed") + + @_check_start + def set_seed(self, seed: int) -> int: + for i, channel in enumerate(self.parent_channels): + channel.send(("seed", seed)) + if i < self.num_workers - 1: + seed = seed + 1 + for channel in self.parent_channels: + msg, _ = channel.recv() + if msg != "seeded": + raise RuntimeError(f"Expected 'seeded' but received {msg}") + return seed + + @_check_start + def _reset(self, tensor_dict: _TensorDict, **kwargs) -> _TensorDict: + cmd_out = "reset" + if tensor_dict is not None and "reset_workers" in tensor_dict.keys(): + self._assert_tensordict_shape(tensor_dict) + reset_workers = tensor_dict.get("reset_workers") + else: + reset_workers = torch.ones(self.num_workers, 1, dtype=torch.bool) + + for i, channel in enumerate(self.parent_channels): + if not reset_workers[i]: + continue + channel.send((cmd_out, kwargs)) + + keys = set() + for i, channel in enumerate(self.parent_channels): + if not reset_workers[i]: + continue + cmd_in, new_keys = channel.recv() + keys = keys.union(new_keys) + if cmd_in != "reset_obs": + raise RuntimeError(f"received cmd {cmd_in} instead of reset_obs") + if self.shared_tensor_dict_parent.get("done").any(): + raise RuntimeError("Envs have just been reset but some are still done") + return self.shared_tensor_dict_parent.select(*keys).clone() + + def __reduce__(self): + self.close() + return super().__reduce__() + + +def _run_worker_pipe_shared_mem( + idx: int, + parent_pipe: connection.Connection, + child_pipe: connection.Connection, + env_fun: Union[_EnvClass, Callable], + env_fun_kwargs: dict, + pin_memory: bool, + action_keys: dict, + verbose: bool = False, +) -> None: + parent_pipe.close() + pid = os.getpid() + if not isinstance(env_fun, _EnvClass): + env = env_fun(**env_fun_kwargs) + else: + if env_fun_kwargs: + raise RuntimeError( + "env_fun_kwargs must be empty if an environment is passed to a process." + ) + env = env_fun + i = -1 + initialized = False + + # make sure that process can be closed + tensor_dict = None + _td = None + data = None + + while True: + try: + cmd, data = child_pipe.recv() + except EOFError: + raise EOFError(f"proc {pid} failed, last command: {cmd}") + if cmd == "seed": + if not initialized: + raise RuntimeError("call 'init' before closing") + # torch.manual_seed(data) + # np.random.seed(data) + env.set_seed(data) + child_pipe.send(("seeded", None)) + + elif cmd == "init": + if verbose: + print(f"initializing {pid}") + if initialized: + raise RuntimeError("worker already initialized") + i = 0 + tensor_dict = data + if not (tensor_dict.is_shared() or tensor_dict.is_memmap()): + raise RuntimeError( + "tensor_dict must be placed in shared memory (share_memory_() or memmap_())" + ) + initialized = True + + elif cmd == "reset": + reset_kwargs = data + if verbose: + print(f"resetting worker {pid}") + if not initialized: + raise RuntimeError("call 'init' before resetting") + # _td = tensor_dict.select("observation").to(env.device).clone() + _td = env.reset(**reset_kwargs) + keys = set(_td.keys()) + if pin_memory: + _td.pin_memory() + tensor_dict.update_(_td) + child_pipe.send(("reset_obs", keys)) + just_reset = True + if env.is_done: + raise RuntimeError( + f"{env.__class__.__name__}.is_done is {env.is_done} after reset" + ) + + elif cmd == "step": + if not initialized: + raise RuntimeError("called 'init' before step") + i += 1 + _td = tensor_dict.select(*action_keys).to(env.device).clone() + if env.is_done: + raise RuntimeError( + f"calling step when env is done, just reset = {just_reset}" + ) + _td = env.step(_td) + keys = set(_td.keys()) - {key for key in action_keys} + if pin_memory: + _td.pin_memory() + tensor_dict.update_(_td.select(*keys)) + if _td.get("done"): + msg = "done" + else: + msg = "step_result" + data = (msg, keys) + child_pipe.send(data) + just_reset = False + + if cmd == "close": + del tensor_dict, _td, data + if not initialized: + raise RuntimeError("call 'init' before closing") + env.close() + del env + + child_pipe.send(("closing", None)) + child_pipe.close() + if verbose: + print(f"{pid} closed") + break + + if cmd == "load_state_dict": + env.load_state_dict(data) + msg = "loaded" + child_pipe.send((msg, None)) + + if cmd == "state_dict": + state_dict = env.state_dict() + msg = "state_dict" + child_pipe.send((msg, state_dict)) diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py new file mode 100644 index 00000000000..b45874a3235 --- /dev/null +++ b/torchrl/modules/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .distributions import * +from .models import * +from .td_module import * diff --git a/torchrl/modules/distributions/__init__.py b/torchrl/modules/distributions/__init__.py new file mode 100644 index 00000000000..a128ac2486a --- /dev/null +++ b/torchrl/modules/distributions/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .continuous import * +from .continuous import __all__ as _all_continuous +from .discrete import * +from .discrete import __all__ as _all_discrete + +distributions_maps = { + distribution_class.lower(): eval(distribution_class) + for distribution_class in _all_continuous + _all_discrete +} diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py new file mode 100644 index 00000000000..656607cf506 --- /dev/null +++ b/torchrl/modules/distributions/continuous.py @@ -0,0 +1,545 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from numbers import Number +from typing import Dict, Sequence, Union, Optional + +import numpy as np +import torch +from torch import distributions as D, nn +from torch.distributions import constraints + +from torchrl.modules.utils import mappings +from .truncated_normal import TruncatedNormal as _TruncatedNormal + +__all__ = ["NormalParamWrapper", "TanhNormal", "Delta", "TanhDelta", "TruncatedNormal"] + +D.Distribution.set_default_validate_args(False) + + +class IndependentNormal(D.Independent): + """Implements a Normal distribution with location scaling. + + Location scaling prevents the location to be "too far" from 0, which ultimately + leads to numerically unstable samples and poor gradient computation (e.g. gradient explosion). + In practice, the location is computed according to + + .. math:: + loc = tanh(loc / upscale) * upscale. + + This behaviour can be disabled by switching off the tanh_loc parameter (see below). + + + Args: + loc (torch.Tensor): normal distribution location parameter + scale (torch.Tensor): normal distribution sigma parameter (squared root of variance) + upscale (torch.Tensor or number, optional): 'a' scaling factor in the formula: + + .. math:: + loc = tanh(loc / upscale) * upscale. + + Default is 5.0 + + tanh_loc (bool, optional): if True, the above formula is used for the location scaling, otherwise the raw value + is kept. + Default is `True`; + """ + + num_params: int = 2 + + def __init__( + self, + loc: torch.Tensor, + scale: torch.Tensor, + upscale: float = 5.0, + tanh_loc: bool = True, + event_dim: int = 1, + **kwargs, + ): + self.tanh_loc = tanh_loc + self.upscale = upscale + self._event_dim = event_dim + self._kwargs = kwargs + super().__init__(D.Normal(loc, scale, **kwargs), event_dim) + + def update(self, loc, scale): + if self.tanh_loc: + loc = self.upscale * (loc / self.upscale).tanh() + super().__init__(D.Normal(loc, scale, **self._kwargs), self._event_dim) + + @property + def mode(self): + return self.base_dist.mean + + +class SafeTanhTransform(D.TanhTransform): + """ + TanhTransform subclass that ensured that the transformation is numerically invertible. + + """ + + delta = 1e-4 + + def _call(self, x: torch.Tensor) -> torch.Tensor: + y = super()._call(x) + y.data.clamp_(-1 + self.delta, 1 - self.delta) + return y + + def _inverse(self, y: torch.Tensor) -> torch.Tensor: + y.data.clamp_(-1 + self.delta, 1 - self.delta) + x = super()._inverse(y) + return x + + +class NormalParamWrapper(nn.Module): + """ + A wrapper for normal distirbution parameters. + + Args: + operator (nn.Module): operator whose output will be transformed in location and scale parameters + scale_mapping (str, optional): positive mapping function to be used with the std. + default = "biased_softplus_1.0" (i.e. softplus map with bias such that fn(0.0) = 1.0) + choices: "softplus", "exp", "relu", "biased_softplus_1"; + scale_lb (Number, optional): The minimum value that the variance can take. Default is 1e-4. + """ + + def __init__( + self, + operator: nn.Module, + scale_mapping: str = "biased_softplus_1.0", + scale_lb: Number = 1e-4, + ) -> None: + super().__init__() + self.operator = operator + self.scale_mapping = scale_mapping + self.scale_lb = scale_lb + + def forward(self, *tensors): + net_output = self.operator(*tensors) + loc, scale = net_output.chunk(2, -1) + scale = mappings(self.scale_mapping)(scale).clamp_min(self.scale_lb) + return loc, scale + + +class TruncatedNormal(D.Independent): + """Implements a Truncated Normal distribution with location scaling. + + Location scaling prevents the location to be "too far" from 0, which ultimately + leads to numerically unstable samples and poor gradient computation (e.g. gradient explosion). + In practice, the location is computed according to + + .. math:: + loc = tanh(loc / upscale) * upscale. + + This behaviour can be disabled by switching off the tanh_loc parameter (see below). + + + Args: + loc (torch.Tensor): normal distribution location parameter + scale (torch.Tensor): normal distribution sigma parameter (squared root of variance) + upscale (torch.Tensor or number, optional): 'a' scaling factor in the formula: + + .. math:: + loc = tanh(loc / upscale) * upscale. + + Default is 5.0 + + min (torch.Tensor or number, optional): minimum value of the distribution. Default = -1.0; + max (torch.Tensor or number, optional): maximum value of the distribution. Default = 1.0; + tanh_loc (bool, optional): if True, the above formula is used for the location scaling, otherwise the raw value + is kept. + Default is `True`; + """ + + num_params: int = 2 + + arg_constraints = { + "loc": constraints.real, + "scale": constraints.greater_than(1e-6), + } + + def __init__( + self, + loc: torch.Tensor, + scale: torch.Tensor, + upscale: Union[torch.Tensor, float] = 5.0, + min: Union[torch.Tensor, float] = -1.0, + max: Union[torch.Tensor, float] = 1.0, + tanh_loc: bool = True, + ): + err_msg = "TanhNormal max values must be strictly greater than min values" + if isinstance(max, torch.Tensor) or isinstance(min, torch.Tensor): + if not (max > min).all(): # type: ignore + raise RuntimeError(err_msg) + elif isinstance(max, Number) and isinstance(min, Number): + if not max > min: + raise RuntimeError(err_msg) + else: + if not all(max > min): # type: ignore + raise RuntimeError(err_msg) + + if isinstance(max, torch.Tensor): + self.non_trivial_max = (max != 1.0).any() + else: + self.non_trivial_max = max != 1.0 + + if isinstance(min, torch.Tensor): + self.non_trivial_min = (min != -1.0).any() + else: + self.non_trivial_min = min != -1.0 + self.tanh_loc = tanh_loc + + self.device = loc.device + self.upscale = ( + upscale + if not isinstance(upscale, torch.Tensor) + else upscale.to(self.device) + ) + + if isinstance(max, torch.Tensor): + max = max.to(self.device) + else: + max = torch.tensor(max, device=self.device) + if isinstance(min, torch.Tensor): + min = min.to(self.device) + else: + min = torch.tensor(min, device=self.device) + self.min = min + self.max = max + self.update(loc, scale) + + def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None: + if self.tanh_loc: + loc = (loc / self.upscale).tanh() * self.upscale + if self.non_trivial_max or self.non_trivial_min: + loc = loc + (self.max - self.min) / 2 + self.min + self.loc = loc + self.scale = scale + + base_dist = _TruncatedNormal( + loc, scale, self.min.expand_as(loc), self.max.expand_as(scale) + ) + super().__init__(base_dist, 1, validate_args=False) + + @property + def mode(self): + m = self.base_dist.loc + a = self.base_dist._non_std_a + self.base_dist._dtype_min_gt_0 + b = self.base_dist._non_std_b - self.base_dist._dtype_min_gt_0 + m = torch.min(torch.stack([m, b], -1), dim=-1)[0] + return torch.max(torch.stack([m, a], -1), dim=-1)[0] + + def log_prob(self, value, **kwargs): + a = self.base_dist._non_std_a + self.base_dist._dtype_min_gt_0 + a = a.expand_as(value) + b = self.base_dist._non_std_b - self.base_dist._dtype_min_gt_0 + b = b.expand_as(value) + value = torch.min(torch.stack([value, b], -1), dim=-1)[0] + value = torch.max(torch.stack([value, a], -1), dim=-1)[0] + return super().log_prob(value, **kwargs) + + +class TanhNormal(D.TransformedDistribution): + """Implements a TanhNormal distribution with location scaling. + + Location scaling prevents the location to be "too far" from 0 when a TanhTransform is applied, which ultimately + leads to numerically unstable samples and poor gradient computation (e.g. gradient explosion). + In practice, the location is computed according to + + .. math:: + loc = tanh(loc / upscale) * upscale. + + This behaviour can be disabled by switching off the tanh_loc parameter (see below). + + + Args: + loc (torch.Tensor): normal distribution location parameter + scale (torch.Tensor): normal distribution sigma parameter (squared root of variance) + upscale (torch.Tensor or number): 'a' scaling factor in the formula: + + .. math:: + loc = tanh(loc / upscale) * upscale. + + min (torch.Tensor or number, optional): minimum value of the distribution. Default is -1.0; + max (torch.Tensor or number, optional): maximum value of the distribution. Default is 1.0; + event_dims (int, optional): number of dimensions describing the action. + Default is 1; + tanh_loc (bool, optional): if True, the above formula is used for the location scaling, otherwise the raw + value is kept. Default is `True`; + """ + + arg_constraints = { + "loc": constraints.real, + "scale": constraints.greater_than(1e-6), + } + + num_params = 2 + + def __init__( + self, + loc: torch.Tensor, + scale: torch.Tensor, + upscale: Union[torch.Tensor, Number] = 5.0, + min: Union[torch.Tensor, Number] = -1.0, + max: Union[torch.Tensor, Number] = 1.0, + event_dims: int = 1, + tanh_loc: bool = True, + ): + err_msg = "TanhNormal max values must be strictly greater than min values" + if isinstance(max, torch.Tensor) or isinstance(min, torch.Tensor): + if not (max > min).all(): # type: ignore + raise RuntimeError(err_msg) + elif isinstance(max, Number) and isinstance(min, Number): + if not max > min: + raise RuntimeError(err_msg) + else: + if not all(max > min): # type: ignore + raise RuntimeError(err_msg) + + if isinstance(max, torch.Tensor): + self.non_trivial_max = (max != 1.0).any() + else: + self.non_trivial_max = max != 1.0 + + if isinstance(min, torch.Tensor): + self.non_trivial_min = (min != -1.0).any() + else: + self.non_trivial_min = min != -1.0 + + self.tanh_loc = tanh_loc + self._event_dims = event_dims + + self.device = loc.device + self.upscale = ( + upscale + if not isinstance(upscale, torch.Tensor) + else upscale.to(self.device) + ) + + if isinstance(max, torch.Tensor): + max = max.to(loc.device) + if isinstance(min, torch.Tensor): + min = min.to(loc.device) + self.min = min + self.max = max + + t = SafeTanhTransform() + if self.non_trivial_max or self.non_trivial_min: + t = D.ComposeTransform( + [ + t, + D.AffineTransform(loc=(max + min) / 2, scale=(max - min) / 2), + ] + ) + self._t = t + + self.update(loc, scale) + + def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None: + if self.tanh_loc: + loc = (loc / self.upscale).tanh() * self.upscale + if self.non_trivial_max or self.non_trivial_min: + loc = loc + (self.max - self.min) / 2 + self.min + self.loc = loc + self.scale = scale + + if ( + hasattr(self, "base_dist") + and (self.base_dist.base_dist.loc.shape == self.loc.shape) + and (self.base_dist.base_dist.scale.shape == self.scale.shape) + ): + self.base_dist.base_dist.loc = self.loc + self.base_dist.base_dist.scale = self.scale + else: + base = D.Independent(D.Normal(self.loc, self.scale), self._event_dims) + super().__init__(base, self._t) + + @property + def mode(self): + m = self.base_dist.base_dist.mean + for t in self.transforms: + m = t(m) + return m + + +def uniform_sample_tanhnormal(dist: TanhNormal, size=torch.Size([])) -> torch.Tensor: + """ + Defines what uniform sampling looks like for a TanhNormal distribution. + + Args: + dist (TanhNormal): distribution defining the space where the sampling should occur. + size (torch.Size): batch-size of the output tensor + + Returns: + a tensor sampled uniformly in the boundaries defined by the input distribution. + + """ + return torch.rand_like(dist.sample(size)) * (dist.max - dist.min) + dist.min + + +class Delta(D.Distribution): + """ + Delta distribution. + + Args: + param (torch.Tensor): parameter of the delta distribution; + atol (number, optional): absolute tolerance to consider that a tensor matches the distribution parameter; + Default is 1e-6 + rtol (number, optional): relative tolerance to consider that a tensor matches the distribution parameter; + Default is 1e-6 + batch_shape (torch.Size, optional): batch shape; + event_shape (torch.Size, optional): shape of the outcome. + + """ + + arg_constraints: Dict = {} + + def __init__( + self, + param: torch.Tensor, + atol: float = 1e-6, + rtol: float = 1e-6, + batch_shape: Union[torch.Size, Sequence[int]] = torch.Size([]), + event_shape: Union[torch.Size, Sequence[int]] = torch.Size([]), + ): + self.update(param) + self.atol = atol + self.rtol = rtol + if not len(batch_shape) and not len(event_shape): + batch_shape = param.shape[:-1] + event_shape = param.shape[-1:] + super().__init__(batch_shape=batch_shape, event_shape=event_shape) + + def update(self, param): + self.param = param + + def _is_equal(self, value: torch.Tensor) -> torch.Tensor: + param = self.param.expand_as(value) + is_equal = abs(value - param) < self.atol + self.rtol * abs(param) + for i in range(-1, -len(self.event_shape) - 1, -1): + is_equal = is_equal.all(i) + return is_equal + + def log_prob(self, value: torch.Tensor) -> torch.Tensor: + is_equal = self._is_equal(value) + out = torch.zeros_like(is_equal, dtype=value.dtype) + out.masked_fill_(is_equal, np.inf) + out.masked_fill_(~is_equal, -np.inf) + return out + + @torch.no_grad() + def sample(self, size=torch.Size([])) -> torch.Tensor: + return self.param.expand(*size, *self.param.shape) + + def rsample(self, size=torch.Size([])) -> torch.Tensor: + return self.param.expand(*size, *self.param.shape) + + @property + def mode(self) -> torch.Tensor: + return self.param + + @property + def mean(self) -> torch.Tensor: + return self.param + + +class TanhDelta(D.TransformedDistribution): + """ + Implements a Tanh transformed Delta distribution. + + Args: + net_output (torch.Tensor): parameter of the delta distribution; + min (torch.Tensor or number): minimum value of the distribution. Default is -1.0; + min (torch.Tensor or number, optional): minimum value of the distribution. Default is 1.0; + max (torch.Tensor or number, optional): maximum value of the distribution. Default is 1.0; + event_dims (int, optional): number of dimensions describing the action. + Default is 1; + atol (number, optional): absolute tolerance to consider that a tensor matches the distribution parameter; + Default is 1e-6 + rtol (number, optional): relative tolerance to consider that a tensor matches the distribution parameter; + Default is 1e-6 + batch_shape (torch.Size, optional): batch shape; + event_shape (torch.Size, optional): shape of the outcome; + + """ + + arg_constraints = { + "loc": constraints.real, + } + + def __init__( + self, + net_output: torch.Tensor, + min: Union[torch.Tensor, float] = -1.0, + max: Union[torch.Tensor, float] = 1.0, + event_dims: int = 1, + atol: float = 1e-4, + rtol: float = 1e-4, + **kwargs, + ): + minmax_msg = "max value has been found to be equal or less than min value" + if isinstance(max, torch.Tensor) or isinstance(min, torch.Tensor): + if not (max > min).all(): # type: ignore + raise ValueError(minmax_msg) + elif isinstance(max, Number) and isinstance(min, Number): + if max <= min: # type: ignore + raise ValueError(minmax_msg) + else: + if not all(max > min): # type: ignore + raise ValueError(minmax_msg) + + self.min = min + self.max = max + loc = self.update(net_output) + + t = D.TanhTransform() + non_trivial_min = (isinstance(min, torch.Tensor) and (min != 1.0).any()) or ( + not isinstance(min, torch.Tensor) and min != 1.0 + ) + non_trivial_max = (isinstance(max, torch.Tensor) and (max != 1.0).any()) or ( + not isinstance(max, torch.Tensor) and max != 1.0 + ) + if non_trivial_max or non_trivial_min: + t = D.ComposeTransform( # type: ignore + [ + t, + D.AffineTransform(loc=(max + min) / 2, scale=(max - min) / 2), + ] + ) + event_shape = net_output.shape[-event_dims:] + batch_shape = net_output.shape[:-event_dims] + base = Delta( + loc, + atol=atol, + rtol=rtol, + batch_shape=batch_shape, + event_shape=event_shape, + **kwargs, + ) + + super().__init__(base, t) + + def update(self, net_output: torch.Tensor) -> Optional[torch.Tensor]: + loc = net_output + loc = loc + (self.max - self.min) / 2 + self.min + if hasattr(self, "base_dist"): + self.base_dist.update(loc) + else: + return loc + + @property + def mode(self) -> torch.Tensor: + mode = self.base_dist.param + for t in self.transforms: + mode = t(mode) + return mode + + @property + def mean(self) -> torch.Tensor: + raise AttributeError("TanhDelta mean has not analytical form.") + + +def uniform_sample_delta(dist: Delta, size=torch.Size([])) -> torch.Tensor: + return torch.randn_like(dist.sample(size)) diff --git a/torchrl/modules/distributions/discrete.py b/torchrl/modules/distributions/discrete.py new file mode 100644 index 00000000000..4eaa5d6c857 --- /dev/null +++ b/torchrl/modules/distributions/discrete.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Sequence, Union + +import torch +from torch import distributions as D + +__all__ = [ + "OneHotCategorical", +] + + +def _treat_categorical_params( + params: Optional[torch.Tensor] = None, +) -> Optional[torch.Tensor]: + if params is None: + return None + if params.shape[-1] == 1: + params = params[..., 0] + return params + + +def rand_one_hot(values: torch.Tensor, do_softmax: bool = True) -> torch.Tensor: + if do_softmax: + values = values.softmax(-1) + out = values.cumsum(-1) > torch.rand_like(values[..., :1]) + out = (out.cumsum(-1) == 1).to(torch.long) + return out + + +class OneHotCategorical(D.Categorical): + """One-hot categorical distribution. + + This class behaves excacly as torch.distributions.Categorical except that it reads and produces one-hot encodings + of the discrete tensors. + + """ + + def __init__( + self, + logits: Optional[torch.Tensor] = None, + probs: Optional[torch.Tensor] = None, + **kwargs + ) -> None: + logits = _treat_categorical_params(logits) + probs = _treat_categorical_params(probs) + super().__init__(probs=probs, logits=logits, **kwargs) + + def log_prob(self, value: torch.Tensor) -> torch.Tensor: + return super().log_prob(value.argmax(dim=-1)) + + @property + def mode(self) -> torch.Tensor: + if hasattr(self, "logits"): + return (self.logits == self.logits.max(-1, True)[0]).to(torch.long) + else: + return (self.probs == self.probs.max(-1, True)[0]).to(torch.long) + + def sample( + self, sample_shape: Union[torch.Size, Sequence] = torch.Size([]) + ) -> torch.Tensor: + out = super().sample(sample_shape=sample_shape) + out = torch.nn.functional.one_hot(out, self.logits.shape[-1]).to(torch.long) + return out diff --git a/torchrl/modules/distributions/truncated_normal.py b/torchrl/modules/distributions/truncated_normal.py new file mode 100644 index 00000000000..794b5928442 --- /dev/null +++ b/torchrl/modules/distributions/truncated_normal.py @@ -0,0 +1,179 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +# from https://github.com/toshas/torch_truncnorm + +import math +from numbers import Number + +import torch +from torch.distributions import constraints, Distribution +from torch.distributions.utils import broadcast_all + +CONST_SQRT_2 = math.sqrt(2) +CONST_INV_SQRT_2PI = 1 / math.sqrt(2 * math.pi) +CONST_INV_SQRT_2 = 1 / math.sqrt(2) +CONST_LOG_INV_SQRT_2PI = math.log(CONST_INV_SQRT_2PI) +CONST_LOG_SQRT_2PI_E = 0.5 * math.log(2 * math.pi * math.e) + + +class TruncatedStandardNormal(Distribution): + """ + Truncated Standard Normal distribution + https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + """ + + arg_constraints = { + "a": constraints.real, + "b": constraints.real, + } + has_rsample = True + eps = 1e-6 + + def __init__(self, a, b, validate_args=None): + self.a, self.b = broadcast_all(a, b) + if isinstance(a, Number) and isinstance(b, Number): + batch_shape = torch.Size() + else: + batch_shape = self.a.size() + super(TruncatedStandardNormal, self).__init__( + batch_shape, validate_args=validate_args + ) + if self.a.dtype != self.b.dtype: + raise ValueError("Truncation bounds types are different") + if any( + (self.a >= self.b) + .view( + -1, + ) + .tolist() + ): + raise ValueError("Incorrect truncation range") + # eps = torch.finfo(self.a.dtype).eps * 10 + eps = self.eps + self._dtype_min_gt_0 = eps + self._dtype_max_lt_1 = 1 - eps + self._little_phi_a = self._little_phi(self.a) + self._little_phi_b = self._little_phi(self.b) + self._big_phi_a = self._big_phi(self.a) + self._big_phi_b = self._big_phi(self.b) + self._Z = (self._big_phi_b - self._big_phi_a).clamp(eps, 1 - eps) + self._log_Z = self._Z.log() + little_phi_coeff_a = torch.nan_to_num(self.a, nan=math.nan) + little_phi_coeff_b = torch.nan_to_num(self.b, nan=math.nan) + self._lpbb_m_lpaa_d_Z = ( + self._little_phi_b * little_phi_coeff_b + - self._little_phi_a * little_phi_coeff_a + ) / self._Z + self._mean = -(self._little_phi_b - self._little_phi_a) / self._Z + self._variance = ( + 1 + - self._lpbb_m_lpaa_d_Z + - ((self._little_phi_b - self._little_phi_a) / self._Z) ** 2 + ) + self._entropy = CONST_LOG_SQRT_2PI_E + self._log_Z - 0.5 * self._lpbb_m_lpaa_d_Z + + @constraints.dependent_property + def support(self): + return constraints.interval(self.a, self.b) + + @property + def mean(self): + return self._mean + + @property + def variance(self): + return self._variance + + @property + def entropy(self): + return self._entropy + + @property + def auc(self): + return self._Z + + @staticmethod + def _little_phi(x): + return (-(x ** 2) * 0.5).exp() * CONST_INV_SQRT_2PI + + def _big_phi(self, x): + phi = 0.5 * (1 + (x * CONST_INV_SQRT_2).erf()) + return phi.clamp(self.eps, 1 - self.eps) + + @staticmethod + def _inv_big_phi(x): + return CONST_SQRT_2 * (2 * x - 1).erfinv() + + def cdf(self, value): + if self._validate_args: + self._validate_sample(value) + return ((self._big_phi(value) - self._big_phi_a) / self._Z).clamp(0, 1) + + def icdf(self, value): + y = self._big_phi_a + value * self._Z + y = y.clamp(self.eps, 1 - self.eps) + return self._inv_big_phi(y) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + return CONST_LOG_INV_SQRT_2PI - self._log_Z - (value ** 2) * 0.5 + + def rsample(self, sample_shape=torch.Size()): + shape = self._extended_shape(sample_shape) + p = torch.empty(shape, device=self.a.device).uniform_( + self._dtype_min_gt_0, self._dtype_max_lt_1 + ) + return self.icdf(p) + + +class TruncatedNormal(TruncatedStandardNormal): + """ + Truncated Normal distribution + https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + """ + + has_rsample = True + + def __init__(self, loc, scale, a, b, validate_args=None): + scale = scale.clamp_min(self.eps) + self.loc, self.scale, a, b = broadcast_all(loc, scale, a, b) + self._non_std_a = a + self._non_std_b = b + a = (a - self.loc) / self.scale + b = (b - self.loc) / self.scale + super(TruncatedNormal, self).__init__(a, b, validate_args=validate_args) + self._log_scale = self.scale.log() + self._mean = self._mean * self.scale + self.loc + self._variance = self._variance * self.scale ** 2 + self._entropy += self._log_scale + + def _to_std_rv(self, value): + return (value - self.loc) / self.scale + + def _from_std_rv(self, value): + return value * self.scale + self.loc + + def cdf(self, value): + return super(TruncatedNormal, self).cdf(self._to_std_rv(value)) + + def icdf(self, value): + sample = self._from_std_rv(super().icdf(value)) + + # clamp data but keep gradients + sample_clip = torch.stack( + [sample.detach(), self._non_std_a.detach().expand_as(sample)], 0 + ).max(0)[0] + sample_clip = torch.stack( + [sample_clip, self._non_std_b.detach().expand_as(sample)], 0 + ).min(0)[0] + sample.data.copy_(sample_clip) + return sample + + def log_prob(self, value): + value = self._to_std_rv(value) + return super(TruncatedNormal, self).log_prob(value) - self._log_scale diff --git a/torchrl/modules/models/__init__.py b/torchrl/modules/models/__init__.py new file mode 100644 index 00000000000..f0743287930 --- /dev/null +++ b/torchrl/modules/models/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .models import * +from .exploration import * +from .utils import * diff --git a/torchrl/modules/models/exploration.py b/torchrl/modules/models/exploration.py new file mode 100644 index 00000000000..05f6348fd5e --- /dev/null +++ b/torchrl/modules/models/exploration.py @@ -0,0 +1,336 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Optional, Sequence, Union + +import torch +from torch import nn +from torch.nn.modules.lazy import LazyModuleMixin +from torch.nn.parameter import UninitializedBuffer, UninitializedParameter + +__all__ = ["NoisyLinear", "NoisyLazyLinear", "reset_noise"] + +from torchrl.data.utils import DEVICE_TYPING +from torchrl.modules.utils import inv_softplus + + +class NoisyLinear(nn.Linear): + """ + Noisy Linear Layer, as presented in "Noisy Networks for Exploration", https://arxiv.org/abs/1706.10295v3 + + A Noisy Linear Layer is a linear layer with parametric noise added to the weights. This induced stochasticity can + be used in RL networks for the agent's policy to aid efficient exploration. The parameters of the noise are learned + with gradient descent along with any other remaining network weights. Factorized Gaussian + noise is the type of noise usually employed. + + + Args: + in_features (int): input features dimension + out_features (int): out features dimension + bias (bool): if True, a bias term will be added to the matrix multiplication: Ax + b. + default: True + device (str, int or torch.device, optional): device of the layer. + default: "cpu" + dtype (torch.dtype, optional): dtype of the parameters. + default: None + std_init (scalar): initial value of the Gaussian standard deviation before optimization. + default: 1.0 + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device: Optional[DEVICE_TYPING] = None, + dtype: Optional[torch.dtype] = None, + std_init: float = 0.1, + ): + nn.Module.__init__(self) + self.in_features = int(in_features) + self.out_features = int(out_features) + self.std_init = std_init + + self.weight_mu = nn.Parameter( + torch.empty( + out_features, + in_features, + device=device, + dtype=dtype, + requires_grad=True, + ) + ) + self.weight_sigma = nn.Parameter( + torch.empty( + out_features, + in_features, + device=device, + dtype=dtype, + requires_grad=True, + ) + ) + self.register_buffer( + "weight_epsilon", + torch.empty(out_features, in_features, device=device, dtype=dtype), + ) + if bias: + self.bias_mu = nn.Parameter( + torch.empty( + out_features, + device=device, + dtype=dtype, + requires_grad=True, + ) + ) + self.bias_sigma = nn.Parameter( + torch.empty( + out_features, + device=device, + dtype=dtype, + requires_grad=True, + ) + ) + self.register_buffer( + "bias_epsilon", + torch.empty(out_features, device=device, dtype=dtype), + ) + else: + self.bias_mu = None # type: ignore + self.reset_parameters() + self.reset_noise() + + def reset_parameters(self) -> None: + mu_range = 1 / math.sqrt(self.in_features) + self.weight_mu.data.uniform_(-mu_range, mu_range) + self.weight_sigma.data.fill_(self.std_init / math.sqrt(self.in_features)) + if self.bias_mu is not None: + self.bias_mu.data.uniform_(-mu_range, mu_range) + self.bias_sigma.data.fill_(self.std_init / math.sqrt(self.out_features)) + + def reset_noise(self) -> None: + epsilon_in = self._scale_noise(self.in_features) + epsilon_out = self._scale_noise(self.out_features) + self.weight_epsilon.copy_(epsilon_out.outer(epsilon_in)) # type: ignore + if self.bias_mu is not None: + self.bias_epsilon.copy_(epsilon_out) # type: ignore + + def _scale_noise(self, size: Union[int, torch.Size, Sequence]) -> torch.Tensor: + if isinstance(size, int): + size = (size,) + x = torch.randn(*size, device=self.weight_mu.device) + return x.sign().mul_(x.abs().sqrt_()) + + @property + def weight(self) -> torch.Tensor: # type: ignore + if self.training: + return self.weight_mu + self.weight_sigma * self.weight_epsilon + else: + return self.weight_mu + + @property + def bias(self) -> Optional[torch.Tensor]: # type: ignore + if self.bias_mu is not None: + if self.training: + return self.bias_mu + self.bias_sigma * self.bias_epsilon + else: + return self.bias_mu + else: + return None + + +class NoisyLazyLinear(LazyModuleMixin, NoisyLinear): + """ + Noisy Lazy Linear Layer. + + This class makes the Noisy Linear layer lazy, in that the in_feature argument does not need to be passed at + initialization (but is inferred after the first call to the layer). + + For more context on noisy layers, see the NoisyLinear class. + + Args: + out_features (int): out features dimension + bias (bool): if True, a bias term will be added to the matrix multiplication: Ax + b. + default: True + device (str, int or torch.device, optional): device of the layer. + default: "cpu" + dtype (torch.dtype, optional): dtype of the parameters. + default: None + std_init (scalar): initial value of the Gaussian standard deviation before optimization. + default: 1.0 + """ + + def __init__( + self, + out_features: int, + bias: bool = True, + device: Optional[DEVICE_TYPING] = None, + dtype: Optional[torch.dtype] = None, + std_init: float = 0.1, + ): + super().__init__(0, 0, False) + self.out_features = out_features + self.std_init = std_init + + self.weight_mu = UninitializedParameter( + device=device, dtype=dtype + ) # type: ignore + self.weight_sigma = UninitializedParameter( + device=device, dtype=dtype + ) # type: ignore + self.register_buffer( + "weight_epsilon", + UninitializedBuffer(device=device, dtype=dtype) + # type: ignore + ) + if bias: + self.bias_mu = UninitializedParameter( + device=device, dtype=dtype + ) # type: ignore + self.bias_sigma = UninitializedParameter( + device=device, dtype=dtype + ) # type: ignore + self.register_buffer( + "bias_epsilon", + UninitializedBuffer(device=device, dtype=dtype) + # type: ignore + ) + else: + self.bias_mu = None # type: ignore + self.reset_parameters() + + def reset_parameters(self) -> None: + if not self.has_uninitialized_params() and self.in_features != 0: + super().reset_parameters() + + def reset_noise(self) -> None: + if not self.has_uninitialized_params() and self.in_features != 0: + super().reset_noise() + + def initialize_parameters( + self, input: torch.Tensor + ) -> None: # type: ignore[override] + if self.has_uninitialized_params(): + with torch.no_grad(): + self.in_features = input.shape[-1] + self.weight_mu.materialize( + (self.out_features, self.in_features) + ) # type: ignore + self.weight_sigma.materialize( + (self.out_features, self.in_features) + ) # type: ignore + self.weight_epsilon.materialize( + (self.out_features, self.in_features) + ) # type: ignore + if self.bias_mu is not None: + self.bias_mu.materialize((self.out_features,)) # type: ignore + self.bias_sigma.materialize((self.out_features,)) # type: ignore + self.bias_epsilon.materialize((self.out_features,)) # type: ignore + self.reset_parameters() + self.reset_noise() + + @property + def weight(self) -> torch.Tensor: # type: ignore + if not self.has_uninitialized_params() and self.in_features != 0: + return super().weight + + @property + def bias(self) -> torch.Tensor: # type: ignore + if not self.has_uninitialized_params() and self.in_features != 0: + return super().bias # type: ignore + + +def reset_noise(layer: nn.Module) -> None: + if hasattr(layer, "reset_noise"): + layer.reset_noise() # type: ignore + + +class gSDEWrapper(nn.Module): + """A gSDE exploration wrapper as presented in "Smooth Exploration for + Robotic Reinforcement Learning" by Antonin Raffin, Jens Kober, + Freek Stulp (https://arxiv.org/abs/2005.05719) + + gSDEWrapper encapsulates nn.Module that outputs the average of a + normal distribution and adds a state-dependent exploration noise to it. + It outputs the mean, scale (standard deviation) of the normal + distribution as well as the chosen action. + + For now, only vector states are considered, but the distribution can + read other inputs (e.g. hidden states etc.) + + When used, the gSDEWrapper should also be accompanied by a few + configuration changes: the exploration mode of the policy should be set + to "net_output", meaning that the action from the ProbabilisticTDModule + will be retrieved directly from the network output and not simulated + from the constructed distribution. Second, the noise input should be + created through a `torchrl.envs.transforms.gSDENoise` instance, + which will reset this noise parameter each time the environment is reset. + Finally, a regular normal distribution should be used to sample the + actions, the `ProbabilisticTDModule` should be created + in safe mode (in order for the action to be clipped in the desired + range) and its input keys should include `"_eps_gSDE"` which is the + default gSDE noise key: + + >>> actor = ProbabilisticActor( + ... wrapped_module, + ... in_keys=["observation", "_eps_gSDE"] + ... spec, + ... distribution_class=IndependentNormal, + ... safe=True) + + Args: + policy_model (nn.Module): a model that reads observations and + outputs a distribution average. + action_dim (int): the dimension of the action. + state_dim (int): the state dimension. + sigma_init (float): the initial value of the standard deviation. The + softplus non-linearity is used to map the log_sigma parameter to a + positive value. + + Examples: + >>> batch, state_dim, action_dim = 3, 7, 5 + >>> model = nn.Linear(state_dim, action_dim) + >>> wrapped_model = gSDEWrapper(model, action_dim=action_dim, + ... state_dim=state_dim) + >>> state = torch.randn(batch, state_dim) + >>> eps_gSDE = torch.randn(batch, action_dim, state_dim) + >>> # the module takes inputs (state, *additional_vectors, noise_param) + >>> mu, sigma, action = wrapped_model(state, eps_gSDE) + >>> print(mu.shape, sigma.shape, action.shape) + torch.Size([3, 5]) torch.Size([3, 5]) torch.Size([3, 5]) + """ + + def __init__( + self, + policy_model: nn.Module, + action_dim: int, + state_dim: int, + sigma_init: float = None, + ) -> None: + super().__init__() + self.policy_model = policy_model + self.action_dim = action_dim + self.state_dim = state_dim + if sigma_init is None: + sigma_init = inv_softplus(math.sqrt(1 / state_dim)) + self.register_parameter( + "log_sigma", + nn.Parameter(torch.zeros((action_dim, state_dim), requires_grad=True)), + ) + self.register_buffer("sigma_init", torch.tensor(sigma_init)) + + def forward(self, state, *tensors): + *tensors, gSDE_noise = tensors + sigma = torch.nn.functional.softplus(self.log_sigma + self.sigma_init) + if gSDE_noise is None: + gSDE_noise = torch.randn_like(sigma) + gSDE_noise = sigma * gSDE_noise + eps = (gSDE_noise @ state.unsqueeze(-1)).squeeze(-1) + mu = self.policy_model(state, *tensors) + action = mu + eps + sigma = (sigma * state.unsqueeze(-2)).pow(2).sum(-1).clamp_min(1e-5).sqrt() + if not torch.isfinite(sigma).all(): + print("inf sigma") + return mu, sigma, action diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py new file mode 100644 index 00000000000..277d2d82b04 --- /dev/null +++ b/torchrl/modules/models/models.py @@ -0,0 +1,983 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from numbers import Number +from typing import Dict, List, Optional, Sequence, Tuple, Type, Union + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +from torchrl.modules.models.utils import ( + _find_depth, + LazyMapping, + SquashDims, + Squeeze2dLayer, +) + +__all__ = [ + "MLP", + "ConvNet", + "DuelingCnnDQNet", + "DistributionalDQNnet", + "DdpgCnnActor", + "DdpgCnnQNet", + "DdpgMlpActor", + "DdpgMlpQNet", + "LSTMNet", +] + + +class MLP(nn.Sequential): + """ + + A multi-layer perceptron. + If MLP receives more than one input, it concatenates them all along the last dimension before passing the + resulting tensor through the network. This is aimed at allowing for a seamless interface with calls of the type of + + >>> model(state, action) # compute state-action value + + In the future, this feature may be moved to the ProbabilisticTDModule, though it would require it to handle + different cases (vectors, images, ...) + + Args: + in_features (int, optional): number of input features; + out_features (int, list of int): number of output features. If iterable of integers, the output is reshaped to + the desired shape; + depth (int, optional): depth of the network. A depth of 0 will produce a single linear layer network with the + desired input and output size. A length of 1 will create 2 linear layers etc. If no depth is indicated, + the depth information should be contained in the num_cells argument (see below). If num_cells is an + iterable and depth is indicated, both should match: len(num_cells) must be equal to depth. + num_cells (int or Sequence[int], optional): number of cells of every layer in between the input and output. If + an integer is provided, every layer will have the same number of cells. If an iterable is provided, + the linear layers out_features will match the content of num_cells. + default: 32; + activation_class (Type): activation class to be used. + default: nn.Tanh + activation_kwargs (dict, optional): kwargs to be used with the activation class; + norm_class (Type, optional): normalization class, if any. + norm_kwargs (dict, optional): kwargs to be used with the normalization layers; + bias_last_layer (bool): if True, the last Linear layer will have a bias parameter. + default: True; + single_bias_last_layer (bool): if True, the last dimension of the bias of the last layer will be a singleton + dimension. + default: True; + layer_class (Type): class to be used for the linear layers; + layer_kwargs (dict, optional): kwargs for the linear layers; + activate_last_layer (bool): whether the MLP output should be activated. This is useful when the MLP output + is used as the input for another module. + default: False. + + Examples: + >>> # All of the following examples provide valid, working MLPs + >>> mlp = MLP(in_features=3, out_features=6, depth=0) # MLP consisting of a single 3 x 6 linear layer + >>> print(mlp) + MLP( + (0): Linear(in_features=3, out_features=6, bias=True) + ) + >>> mlp = MLP(in_features=3, out_features=6, depth=4, num_cells=32) + >>> print(mlp) + MLP( + (0): Linear(in_features=3, out_features=32, bias=True) + (1): Tanh() + (2): Linear(in_features=32, out_features=32, bias=True) + (3): Tanh() + (4): Linear(in_features=32, out_features=32, bias=True) + (5): Tanh() + (6): Linear(in_features=32, out_features=32, bias=True) + (7): Tanh() + (8): Linear(in_features=32, out_features=6, bias=True) + ) + >>> mlp = MLP(out_features=6, depth=4, num_cells=32) # LazyLinear for the first layer + >>> print(mlp) + MLP( + (0): LazyLinear(in_features=0, out_features=32, bias=True) + (1): Tanh() + (2): Linear(in_features=32, out_features=32, bias=True) + (3): Tanh() + (4): Linear(in_features=32, out_features=32, bias=True) + (5): Tanh() + (6): Linear(in_features=32, out_features=32, bias=True) + (7): Tanh() + (8): Linear(in_features=32, out_features=6, bias=True) + ) + >>> mlp = MLP(out_features=6, num_cells=[32, 33, 34, 35]) # defines the depth by the num_cells arg + >>> print(mlp) + MLP( + (0): LazyLinear(in_features=0, out_features=32, bias=True) + (1): Tanh() + (2): Linear(in_features=32, out_features=33, bias=True) + (3): Tanh() + (4): Linear(in_features=33, out_features=34, bias=True) + (5): Tanh() + (6): Linear(in_features=34, out_features=35, bias=True) + (7): Tanh() + (8): Linear(in_features=35, out_features=6, bias=True) + ) + >>> mlp = MLP(out_features=(6, 7), num_cells=[32, 33, 34, 35]) # returns a view of the output tensor with shape [*, 6, 7] + >>> print(mlp) + MLP( + (0): LazyLinear(in_features=0, out_features=32, bias=True) + (1): Tanh() + (2): Linear(in_features=32, out_features=33, bias=True) + (3): Tanh() + (4): Linear(in_features=33, out_features=34, bias=True) + (5): Tanh() + (6): Linear(in_features=34, out_features=35, bias=True) + (7): Tanh() + (8): Linear(in_features=35, out_features=42, bias=True) + ) + >>> from torchrl.modules import NoisyLinear + >>> mlp = MLP(out_features=(6, 7), num_cells=[32, 33, 34, 35], layer_class=NoisyLinear) # uses NoisyLinear layers + >>> print(mlp) + MLP( + (0): NoisyLazyLinear(in_features=0, out_features=32, bias=False) + (1): Tanh() + (2): NoisyLinear(in_features=32, out_features=33, bias=True) + (3): Tanh() + (4): NoisyLinear(in_features=33, out_features=34, bias=True) + (5): Tanh() + (6): NoisyLinear(in_features=34, out_features=35, bias=True) + (7): Tanh() + (8): NoisyLinear(in_features=35, out_features=42, bias=True) + ) + + """ + + def __init__( + self, + in_features: Optional[int] = None, + out_features: Union[int, Sequence[int]] = None, + depth: Optional[int] = None, + num_cells: Optional[Union[Sequence, int]] = None, + activation_class: Type = nn.Tanh, + activation_kwargs: Optional[dict] = None, + norm_class: Optional[Type] = None, + norm_kwargs: Optional[dict] = None, + bias_last_layer: bool = True, + single_bias_last_layer: bool = False, + layer_class: Type = nn.Linear, + layer_kwargs: Optional[dict] = None, + activate_last_layer: bool = False, + ): + if out_features is None: + raise ValueError("out_feature must be specified for MLP.") + + default_num_cells = 32 + if num_cells is None: + if depth is None: + num_cells = [default_num_cells] * 3 + depth = 3 + else: + num_cells = [default_num_cells] * depth + + self.in_features = in_features + + _out_features_num = out_features + if not isinstance(out_features, Number): + _out_features_num = np.prod(out_features) + self.out_features = out_features + self._out_features_num = _out_features_num + self.activation_class = activation_class + self.activation_kwargs = ( + activation_kwargs if activation_kwargs is not None else dict() + ) + self.norm_class = norm_class + self.norm_kwargs = norm_kwargs if norm_kwargs is not None else dict() + self.bias_last_layer = bias_last_layer + self.single_bias_last_layer = single_bias_last_layer + self.layer_class = layer_class + self.layer_kwargs = layer_kwargs if layer_kwargs is not None else dict() + self.activate_last_layer = activate_last_layer + if single_bias_last_layer: + raise NotImplementedError + + if not (isinstance(num_cells, Sequence) or depth is not None): + raise RuntimeError( + "If num_cells is provided as an integer, \ + depth must be provided too." + ) + self.num_cells = ( + list(num_cells) if isinstance(num_cells, Sequence) else [num_cells] * depth + ) + self.depth = depth if depth is not None else len(self.num_cells) + if not (len(self.num_cells) == depth or depth is None): + raise RuntimeError( + "depth and num_cells length conflict, \ + consider matching or specifying a constan num_cells argument together with a a desired depth" + ) + layers = self._make_net() + super().__init__(*layers) + + def _make_net(self) -> List[nn.Module]: + layers = [] + in_features = [self.in_features] + self.num_cells + out_features = self.num_cells + [self._out_features_num] + for i, (_in, _out) in enumerate(zip(in_features, out_features)): + _bias = self.bias_last_layer if i == self.depth else True + if _in is not None: + layers.append( + self.layer_class(_in, _out, bias=_bias, **self.layer_kwargs) + ) + else: + try: + lazy_version = LazyMapping[self.layer_class] + except KeyError: + raise KeyError( + f"The lazy version of {self.layer_class.__name__} is not implemented yet. " + "Consider providing the input feature dimensions explicitely when creating an MLP module" + ) + layers.append(lazy_version(_out, bias=_bias, **self.layer_kwargs)) + + if i < self.depth or self.activate_last_layer: + layers.append(self.activation_class(**self.activation_kwargs)) + if self.norm_class is not None: + layers.append(self.norm_class(**self.norm_kwargs)) + return layers + + def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor: + if len(inputs) > 1: + inputs = (torch.cat([*inputs], -1),) + + out = super().forward(*inputs) + if not isinstance(self.out_features, Number): + out = out.view(*out.shape[:-1], *self.out_features) + return out + + +class ConvNet(nn.Sequential): + """ + A convolutional neural network. + + Args: + in_features (int, optional): number of input features; + depth (int, optional): depth of the network. A depth of 1 will produce a single linear layer network with the + desired input size, and with an output size equal to the last element of the num_cells argument. + If no depth is indicated, the depth information should be contained in the num_cells argument (see below). + If num_cells is an iterable and depth is indicated, both should match: len(num_cells) must be equal to + the depth. + num_cells (int or Sequence[int], optional): number of cells of every layer in between the input and output. If + an integer is provided, every layer will have the same number of cells. If an iterable is provided, + the linear layers out_features will match the content of num_cells. + default: [32, 32, 32]; + kernel_sizes (int, Sequence[Union[int, Sequence[int]]]): Kernel size(s) of the conv network. If iterable, the length must match the + depth, defined by the num_cells or depth arguments. + strides (int or Sequence[int]): Stride(s) of the conv network. If iterable, the length must match the + depth, defined by the num_cells or depth arguments. + activation_class (Type): activation class to be used. + default: nn.Tanh + activation_kwargs (dict, optional): kwargs to be used with the activation class; + norm_class (Type, optional): normalization class, if any; + norm_kwargs (dict, optional): kwargs to be used with the normalization layers; + bias_last_layer (bool): if True, the last Linear layer will have a bias parameter. + default: True; + aggregator_class (Type): aggregator to use at the end of the chain. + default: SquashDims; + aggregator_kwargs (dict, optional): kwargs for the aggregator_class; + squeeze_output (bool): whether the output should be squeezed of its singleton dimensions. + default: True. + + Examples: + >>> # All of the following examples provide valid, working MLPs + >>> cnet = ConvNet(in_features=3, depth=1, num_cells=[32,]) # MLP consisting of a single 3 x 6 linear layer + >>> print(cnet) + ConvNet( + (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1)) + (1): ELU(alpha=1.0) + (2): SquashDims() + ) + >>> cnet = ConvNet(in_features=3, depth=4, num_cells=32) + >>> print(cnet) + ConvNet( + (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1)) + (1): ELU(alpha=1.0) + (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1)) + (3): ELU(alpha=1.0) + (4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1)) + (5): ELU(alpha=1.0) + (6): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1)) + (7): ELU(alpha=1.0) + (8): SquashDims() + ) + >>> cnet = ConvNet(in_features=3, num_cells=[32, 33, 34, 35]) # defines the depth by the num_cells arg + >>> print(cnet) + ConvNet( + (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1)) + (1): ELU(alpha=1.0) + (2): Conv2d(32, 33, kernel_size=(3, 3), stride=(1, 1)) + (3): ELU(alpha=1.0) + (4): Conv2d(33, 34, kernel_size=(3, 3), stride=(1, 1)) + (5): ELU(alpha=1.0) + (6): Conv2d(34, 35, kernel_size=(3, 3), stride=(1, 1)) + (7): ELU(alpha=1.0) + (8): SquashDims() + ) + >>> cnet = ConvNet(in_features=3, num_cells=[32, 33, 34, 35], kernel_sizes=[3, 4, 5, (2, 3)]) # defines kernels, possibly rectangular + >>> print(cnet) + ConvNet( + (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1)) + (1): ELU(alpha=1.0) + (2): Conv2d(32, 33, kernel_size=(4, 4), stride=(1, 1)) + (3): ELU(alpha=1.0) + (4): Conv2d(33, 34, kernel_size=(5, 5), stride=(1, 1)) + (5): ELU(alpha=1.0) + (6): Conv2d(34, 35, kernel_size=(2, 3), stride=(1, 1)) + (7): ELU(alpha=1.0) + (8): SquashDims() + ) + + """ + + def __init__( + self, + in_features: Optional[int] = None, + depth: Optional[int] = None, + num_cells: Union[Sequence, int] = [32, 32, 32], + kernel_sizes: Union[Sequence[Union[int, Sequence[int]]], int] = 3, + strides: Union[Sequence, int] = 1, + actionvation_class: Type = nn.ELU, + activation_kwargs: Optional[dict] = None, + norm_class: Type = None, + norm_kwargs: Optional[dict] = None, + bias_last_layer: bool = True, + aggregator_class: Type = SquashDims, + aggregator_kwargs: Optional[dict] = None, + squeeze_output: bool = False, + ): + + self.in_features = in_features + self.activation_class = actionvation_class + self.activation_kwargs = ( + activation_kwargs if activation_kwargs is not None else dict() + ) + self.norm_class = norm_class + self.norm_kwargs = norm_kwargs if norm_kwargs is not None else dict() + self.bias_last_layer = bias_last_layer + self.aggregator_class = aggregator_class + self.aggregator_kwargs = ( + aggregator_kwargs if aggregator_kwargs is not None else {"ndims_in": 3} + ) + self.squeeze_output = squeeze_output + # self.single_bias_last_layer = single_bias_last_layer + + depth = _find_depth(depth, num_cells, kernel_sizes, strides) + self.depth = depth + assert depth > 0, "Null depth is not permitted with ConvNet." + + for _field, _value in zip( + ["num_cells", "kernel_sizes", "strides"], + [num_cells, kernel_sizes, strides], + ): + _depth = depth + setattr( + self, + _field, + (_value if isinstance(_value, Sequence) else [_value] * _depth), + ) + if not (isinstance(_value, Sequence) or _depth is not None): + raise RuntimeError( + f"If {_field} is provided as an integer, " + "depth must be provided too." + ) + if not (len(getattr(self, _field)) == _depth or _depth is None): + raise RuntimeError( + f"depth={depth} and {_field}={len(getattr(self, _field))} length conflict, " + + f"consider matching or specifying a constan {_field} argument together with a a desired depth" + ) + + self.out_features = self.num_cells[-1] + + self.depth = len(self.kernel_sizes) + layers = self._make_net() + super().__init__(*layers) + + def _make_net(self) -> nn.Module: + layers = [] + in_features = [self.in_features] + self.num_cells[: self.depth] + out_features = self.num_cells + [self.out_features] + kernel_sizes = self.kernel_sizes + strides = self.strides + for i, (_in, _out, _kernel, _stride) in enumerate( + zip(in_features, out_features, kernel_sizes, strides) + ): + _bias = (i < len(in_features) - 1) or self.bias_last_layer + if _in is not None: + layers.append( + nn.Conv2d( + _in, + _out, + kernel_size=_kernel, + stride=_stride, + bias=_bias, + ) + ) + else: + layers.append( + nn.LazyConv2d(_out, kernel_size=_kernel, stride=_stride, bias=_bias) + ) + + layers.append(self.activation_class(**self.activation_kwargs)) + if self.norm_class is not None: + layers.append(self.norm_class(**self.norm_kwargs)) + + if self.aggregator_class is not None: + layers.append(self.aggregator_class(**self.aggregator_kwargs)) + + if self.squeeze_output: + layers.append(Squeeze2dLayer()) + return layers + + +class DuelingMlpDQNet(nn.Module): + """ + Creates a Dueling MLP Q-network, as presented in + https://arxiv.org/abs/1511.06581 + + Args: + out_features (int): number of features for the advantage network + out_features_value (int): number of features for the value network + mlp_kwargs_feature (dict, optional): kwargs for the feature network. + Default is + + >>> mlp_kwargs_feature = { + ... 'num_cells': [256, 256], + ... 'activation_class': nn.ELU, + ... 'out_features': 256, + ... 'activate_last_layer': True, + ... } + + mlp_kwargs_output (dict, optional): kwargs for the advantage and + value networks. + Default is + + >>> mlp_kwargs_output = { + ... "depth": 1, + ... "activation_class": nn.ELU, + ... "num_cells": 512, + ... "bias_last_layer": True, + ... } + + """ + + def __init__( + self, + out_features: int, + out_features_value: int = 1, + mlp_kwargs_feature: Optional[dict] = None, + mlp_kwargs_output: Optional[dict] = None, + ): + super(DuelingMlpDQNet, self).__init__() + + mlp_kwargs_feature = ( + mlp_kwargs_feature if mlp_kwargs_feature is not None else dict() + ) + _mlp_kwargs_feature = { + "num_cells": [256, 256], + "out_features": 256, + "activation_class": nn.ELU, + "activate_last_layer": True, + } + _mlp_kwargs_feature.update(mlp_kwargs_feature) + self.features = MLP(**_mlp_kwargs_feature) # type: ignore + + _mlp_kwargs_output = { + "depth": 1, + "activation_class": nn.ELU, + "num_cells": 512, + "bias_last_layer": True, + } + mlp_kwargs_output = ( + mlp_kwargs_output if mlp_kwargs_output is not None else dict() + ) + _mlp_kwargs_output.update(mlp_kwargs_output) + self.out_features = out_features + self.out_features_value = out_features_value + self.advantage = MLP(out_features=out_features, **_mlp_kwargs_output) # type: ignore + self.value = MLP(out_features=out_features_value, **_mlp_kwargs_output) # type: ignore + for layer in self.modules(): + if isinstance(layer, (nn.Conv2d, nn.Linear)) and isinstance( + layer.bias, torch.Tensor + ): + layer.bias.data.zero_() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.features(x) + advantage = self.advantage(x) + value = self.value(x) + return value + advantage - advantage.mean(dim=-1, keepdim=True) + + +class DuelingCnnDQNet(nn.Module): + """ + Creates a Dueling CNN Q-network, as presented in https://arxiv.org/abs/1511.06581 + + Args: + out_features (int): number of features for the advantage network + out_features_value (int): number of features for the value network + cnn_kwargs (dict, optional): kwargs for the feature network. + Default is + + >>> cnn_kwargs = { + ... 'num_cells': [32, 64, 64], + ... 'strides': [4, 2, 1], + ... 'kernels': [8, 4, 3], + ... } + + mlp_kwargs (dict, optional): kwargs for the advantage and value network. + Default is + + >>> mlp_kwargs = { + ... "depth": 1, + ... "activation_class": nn.ELU, + ... "num_cells": 512, + ... "bias_last_layer": True, + ... } + + """ + + def __init__( + self, + out_features: int, + out_features_value: int = 1, + cnn_kwargs: Optional[dict] = None, + mlp_kwargs: Optional[dict] = None, + ): + super(DuelingCnnDQNet, self).__init__() + + cnn_kwargs = cnn_kwargs if cnn_kwargs is not None else dict() + _cnn_kwargs = { + "num_cells": [32, 64, 64], + "strides": [4, 2, 1], + "kernel_sizes": [8, 4, 3], + } + _cnn_kwargs.update(cnn_kwargs) + self.features = ConvNet(**_cnn_kwargs) # type: ignore + + _mlp_kwargs = { + "depth": 1, + "activation_class": nn.ELU, + "num_cells": 512, + "bias_last_layer": True, + } + mlp_kwargs = mlp_kwargs if mlp_kwargs is not None else dict() + _mlp_kwargs.update(mlp_kwargs) + self.out_features = out_features + self.out_features_value = out_features_value + self.advantage = MLP(out_features=out_features, **_mlp_kwargs) # type: ignore + self.value = MLP(out_features=out_features_value, **_mlp_kwargs) # type: ignore + for layer in self.modules(): + if isinstance(layer, (nn.Conv2d, nn.Linear)) and isinstance( + layer.bias, torch.Tensor + ): + layer.bias.data.zero_() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.features(x) + advantage = self.advantage(x) + value = self.value(x) + return value + advantage - advantage.mean(dim=-1, keepdim=True) + + +class DistributionalDQNnet(nn.Module): + """ + Distributional Deep Q-Network. + + Args: + DQNet (nn.Module): Q-Network with output length equal to the number of atoms: + output.shape = [*batch, atoms, actions]. + + """ + + _wrong_out_feature_dims_error = ( + "DistributionalDQNnet requires dqn output to be at least " + "2-dimensional, with dimensions *Batch x #Atoms x #Actions. Got {0} " + "instead." + ) + + def __init__(self, DQNet: nn.Module): + super().__init__() + if not ( + not isinstance(DQNet.out_features, Number) and len(DQNet.out_features) > 1 + ): + raise RuntimeError(self._wrong_out_feature_dims_error) + self.dqn = DQNet + + def forward(self, x: torch.Tensor) -> torch.Tensor: + q_values = self.dqn(x) + if q_values.ndimension() < 2: + raise RuntimeError( + self._wrong_out_feature_dims_error.format(q_values.shape) + ) + return F.log_softmax(q_values, dim=-2) + + +def ddpg_init_last_layer(last_layer: nn.Module, scale: float = 6e-4) -> None: + last_layer.weight.data.copy_( # type: ignore + torch.rand_like(last_layer.weight.data) * scale + - scale / 2 + # type: ignore + ) + if last_layer.bias is not None: + last_layer.bias.data.copy_( # type: ignore + torch.rand_like(last_layer.bias.data) * scale + - scale / 2 + # type: ignore + ) + + +class DdpgCnnActor(nn.Module): + """ + DDPG Convolutional Actor class, as presented in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING", + https://arxiv.org/pdf/1509.02971.pdf + + The DDPG Convolutional Actor takes as input an observation (some simple transformation of the observed pixels) and + returns an action vector from it. + It is trained to maximise the value returned by the DDPG Q Value network. + + Args: + action_dim (int): length of the action vector. + conv_net_kwargs (dict, optional): kwargs for the ConvNet. + default: { + 'in_features': None, + 'num_cells': [32, 64, 64], + 'kernel_sizes': [8, 4, 3], + 'strides': [4, 2, 1], + 'actionvation_class': nn.ELU, + 'activation_kwargs': {'inplace': True}, + 'norm_class': None, + 'aggregator_class': SquashDims, + 'aggregator_kwargs': {"ndims_in": 3}, + 'squeeze_output': True, + } + mlp_net_kwargs: kwargs for MLP. + Default: { + 'in_features': None, + 'out_features': action_dim, + 'depth': 2, + 'num_cells': 200, + 'activation_class': nn.ELU, + 'activation_kwargs': {'inplace': True}, + 'bias_last_layer': True, + } + """ + + def __init__( + self, + action_dim: int, + conv_net_kwargs: Optional[dict] = None, + mlp_net_kwargs: Optional[dict] = None, + ): + super().__init__() + conv_net_default_kwargs = { + "in_features": None, + "num_cells": [32, 64, 64], + "kernel_sizes": [8, 4, 3], + "strides": [4, 2, 1], + "actionvation_class": nn.ELU, + "activation_kwargs": {"inplace": True}, + "norm_class": None, + "aggregator_class": SquashDims, + "aggregator_kwargs": {"ndims_in": 3}, + "squeeze_output": False, + } + conv_net_kwargs = conv_net_kwargs if conv_net_kwargs is not None else dict() + conv_net_default_kwargs.update(conv_net_kwargs) + mlp_net_default_kwargs = { + "in_features": None, + "out_features": action_dim, + "depth": 2, + "num_cells": 200, + "activation_class": nn.ELU, + "activation_kwargs": {"inplace": True}, + "bias_last_layer": True, + } + mlp_net_kwargs = mlp_net_kwargs if mlp_net_kwargs is not None else dict() + mlp_net_default_kwargs.update(mlp_net_kwargs) + self.convnet = ConvNet(**conv_net_default_kwargs) # type: ignore + self.mlp = MLP(**mlp_net_default_kwargs) # type: ignore + ddpg_init_last_layer(self.mlp[-1], 6e-4) + + def forward(self, observation: torch.Tensor) -> torch.Tensor: + action = self.mlp(self.convnet(observation)) + return action + + +class DdpgMlpActor(nn.Module): + """ + DDPG Actor class, as presented in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING", + https://arxiv.org/pdf/1509.02971.pdf + + The DDPG Actor takes as input an observation vector and returns an action from it. + It is trained to maximise the value returned by the DDPG Q Value network. + + Args: + action_dim (int): length of the action vector + mlp_net_kwargs (dict, optional): kwargs for MLP. + Default: { + 'in_features': None, + 'out_features': action_dim, + 'depth': 2, + 'num_cells': [400, 300], + 'activation_class': nn.ELU, + 'activation_kwargs': {'inplace': True}, + 'bias_last_layer': True, + } + """ + + def __init__(self, action_dim: int, mlp_net_kwargs: Optional[dict] = None): + super().__init__() + mlp_net_default_kwargs = { + "in_features": None, + "out_features": action_dim, + "depth": 2, + "num_cells": [400, 300], + "activation_class": nn.ELU, + "activation_kwargs": {"inplace": True}, + "bias_last_layer": True, + } + mlp_net_kwargs = mlp_net_kwargs if mlp_net_kwargs is not None else dict() + mlp_net_default_kwargs.update(mlp_net_kwargs) + self.mlp = MLP(**mlp_net_default_kwargs) # type: ignore + ddpg_init_last_layer(self.mlp[-1], 6e-3) + + def forward(self, observation: torch.Tensor) -> torch.Tensor: + action = self.mlp(observation) + return action + + +class DdpgCnnQNet(nn.Module): + """ + DDPG Convolutional Q-value class, as presented in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING", + https://arxiv.org/pdf/1509.02971.pdf + + The DDPG Q-value network takes as input an observation and an action, and returns a scalar from it. + + Args: + conv_net_kwargs (dict, optional): kwargs for the convolutional network. + default: { + 'in_features': None, + 'num_cells': [32, 32, 32], + 'kernel_sizes': 3, + 'strides': 1, + 'actionvation_class': nn.ELU, + 'activation_kwargs': {'inplace': True}, + 'norm_class': None, + 'aggregator_class': SquashDims, + 'aggregator_kwargs': {"ndims_in": 3}, + 'squeeze_output': True, + } + mlp_net_kwargs (dict, optional): kwargs for MLP. + Default: { + 'in_features': None, + 'out_features': 1, + 'depth': 2, + 'num_cells': 200, + 'activation_class': nn.ELU, + 'activation_kwargs': {'inplace': True}, + 'bias_last_layer': True, + } + """ + + def __init__( + self, + conv_net_kwargs: Optional[dict] = None, + mlp_net_kwargs: Optional[dict] = None, + ): + super().__init__() + conv_net_default_kwargs = { + "in_features": None, + "num_cells": [32, 32, 32], + "kernel_sizes": 3, + "strides": 1, + "actionvation_class": nn.ELU, + "activation_kwargs": {"inplace": True}, + "norm_class": None, + "aggregator_class": SquashDims, + "aggregator_kwargs": {"ndims_in": 3}, + "squeeze_output": False, + } + conv_net_kwargs = conv_net_kwargs if conv_net_kwargs is not None else dict() + conv_net_default_kwargs.update(conv_net_kwargs) + mlp_net_default_kwargs = { + "in_features": None, + "out_features": 1, + "depth": 2, + "num_cells": 200, + "activation_class": nn.ELU, + "activation_kwargs": {"inplace": True}, + "bias_last_layer": True, + } + mlp_net_kwargs = mlp_net_kwargs if mlp_net_kwargs is not None else dict() + mlp_net_default_kwargs.update(mlp_net_kwargs) + self.convnet = ConvNet(**conv_net_default_kwargs) # type: ignore + self.mlp = MLP(**mlp_net_default_kwargs) # type: ignore + ddpg_init_last_layer(self.mlp[-1], 6e-4) + + def forward(self, observation: torch.Tensor, action: torch.Tensor) -> torch.Tensor: + value = self.mlp(torch.cat([self.convnet(observation), action], -1)) + return value + + +class DdpgMlpQNet(nn.Module): + """ + DDPG Q-value MLP class, as presented in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING", + https://arxiv.org/pdf/1509.02971.pdf + + The DDPG Q-value network takes as input an observation and an action, and returns a scalar from it. + Because actions are integrated later than observations, two networks are created. + + Args: + mlp_net_kwargs_net1 (dict, optional): kwargs for MLP. + Default: { + 'in_features': None, + 'out_features': 400, + 'depth': 0, + 'num_cells': [], + 'activation_class': nn.ELU, + 'activation_kwargs': {'inplace': True}, + 'bias_last_layer': True, + 'activate_last_layer': True, + } + mlp_net_kwargs_net2 + Default: { + 'in_features': None, + 'out_features': 1, + 'depth': 1, + 'num_cells': [300, ], + 'activation_class': nn.ELU, + 'activation_kwargs': {'inplace': True}, + 'bias_last_layer': True, + } + """ + + def __init__( + self, + mlp_net_kwargs_net1: Optional[dict] = None, + mlp_net_kwargs_net2: Optional[dict] = None, + ): + super().__init__() + mlp1_net_default_kwargs = { + "in_features": None, + "out_features": 400, + "depth": 0, + "num_cells": [], + "activation_class": nn.ELU, + "activation_kwargs": {"inplace": True}, + "bias_last_layer": True, + "activate_last_layer": True, + } + mlp_net_kwargs_net1: Dict = ( + mlp_net_kwargs_net1 if mlp_net_kwargs_net1 is not None else dict() + ) + mlp1_net_default_kwargs.update(mlp_net_kwargs_net1) + self.mlp1 = MLP(**mlp1_net_default_kwargs) # type: ignore + + mlp2_net_default_kwargs = { + "in_features": None, + "out_features": 1, + "depth": 1, + "num_cells": [ + 300, + ], + "activation_class": nn.ELU, + "activation_kwargs": {"inplace": True}, + "bias_last_layer": True, + } + mlp_net_kwargs_net2 = ( + mlp_net_kwargs_net2 if mlp_net_kwargs_net2 is not None else dict() + ) + mlp2_net_default_kwargs.update(mlp_net_kwargs_net2) + self.mlp2 = MLP(**mlp2_net_default_kwargs) # type: ignore + ddpg_init_last_layer(self.mlp2[-1], 6e-3) + + def forward(self, observation: torch.Tensor, action: torch.Tensor) -> torch.Tensor: + value = self.mlp2(torch.cat([self.mlp1(observation), action], -1)) + return value + + +class LSTMNet(nn.Module): + """ + An embedder for an LSTM followed by an MLP. + The forward method returns the hidden states of the current state (input hidden states) and the output, as + the environment returns the 'observation' and 'next_observation'. + + """ + + def __init__(self, out_features, lstm_kwargs: Dict, mlp_kwargs: Dict) -> None: + super().__init__() + lstm_kwargs.update({"batch_first": True}) + self.mlp = MLP(**mlp_kwargs) # type: ignore + self.lstm = nn.LSTM(**lstm_kwargs) + self.linear = nn.LazyLinear(out_features) + + def _lstm( + self, + input: torch.Tensor, + hidden0_in: Optional[torch.Tensor] = None, + hidden1_in: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + squeeze = False + if input.ndimension() == 2: + squeeze = True + input = input.unsqueeze(1).contiguous() + batch, steps = input.shape[:2] + + if hidden1_in is None and hidden0_in is None: + shape = (batch, steps) if not squeeze else (batch,) + hidden0_in, hidden1_in = [ + torch.zeros( + *shape, + self.lstm.num_layers, + self.lstm.hidden_size, + device=input.device, + dtype=input.dtype, + ) + for _ in range(2) + ] + elif hidden1_in is None or hidden0_in is None: + raise RuntimeError( + f"got type(hidden0)={type(hidden0_in)} and type(hidden1)={type(hidden1_in)}" + ) + + # we only need the first hidden state + if not squeeze: + _hidden0_in = hidden0_in[:, 0] + _hidden1_in = hidden1_in[:, 0] + else: + _hidden0_in = hidden0_in + _hidden1_in = hidden1_in + hidden = ( + _hidden0_in.transpose(-3, -2).contiguous(), + _hidden1_in.transpose(-3, -2).contiguous(), + ) + + y0, hidden = self.lstm(input, hidden) + # dim 0 in hidden is num_layers, but that will conflict with tensordict + hidden = tuple(_h.transpose(0, 1) for _h in hidden) + y = self.linear(y0) + + out = [y, hidden0_in, hidden1_in, *hidden] + if squeeze: + out[0] = out[0].squeeze(1) + else: + # we pad the hidden states with zero to make tensordict happy + for i in range(3, 5): + out[i] = torch.stack( + [torch.zeros_like(out[i]) for _ in range(input.shape[1] - 1)] + + [out[i]], + 1, + ) + return tuple(out) # type: ignore + + def forward( + self, + input: torch.Tensor, + hidden0_in: Optional[torch.Tensor] = None, + hidden1_in: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + + input = self.mlp(input) + return self._lstm(input, hidden0_in, hidden1_in) diff --git a/torchrl/modules/models/recipes/impala.py b/torchrl/modules/models/recipes/impala.py new file mode 100644 index 00000000000..44d9e7b3013 --- /dev/null +++ b/torchrl/modules/models/recipes/impala.py @@ -0,0 +1,186 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchrl.data.tensordict.tensordict import _TensorDict + + +# TODO: code small architecture ref in Impala paper + + +class _ResNetBlock(nn.Module): + def __init__( + self, + num_ch, + ): + super(_ResNetBlock, self).__init__() + resnet_block = [] + resnet_block.append(nn.ReLU(inplace=True)) + resnet_block.append( + nn.LazyConv2d( + out_channels=num_ch, + kernel_size=3, + stride=1, + padding=1, + ) + ) + resnet_block.append(nn.ReLU(inplace=True)) + resnet_block.append( + nn.Conv2d( + in_channels=num_ch, + out_channels=num_ch, + kernel_size=3, + stride=1, + padding=1, + ) + ) + self.seq = nn.Sequential(*resnet_block) + + def forward(self, x): + x += self.seq(x) + return x + + +class _ConvNetBlock(nn.Module): + def __init__(self, num_ch): + super().__init__() + + conv = nn.LazyConv2d( + out_channels=num_ch, + kernel_size=3, + stride=1, + padding=1, + ) + mp = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.feats_conv = nn.Sequential(conv, mp) + self.resnet1 = _ResNetBlock(num_ch=num_ch) + self.resnet2 = _ResNetBlock(num_ch=num_ch) + + def forward(self, x): + x = self.feats_conv(x) + x = self.resnet1(x) + x = self.resnet1(x) + return x + + +class ImpalaNet(nn.Module): + def __init__( + self, + num_actions, + channels=(16, 32, 32), + out_features=256, + use_lstm=False, + batch_first=True, + clamp_reward=True, + one_hot=False, + ): + super().__init__() + self.batch_first = batch_first + self.use_lstm = use_lstm + self.clamp_reward = clamp_reward + self.one_hot = one_hot + self.num_actions = num_actions + + layers = [_ConvNetBlock(num_ch) for num_ch in channels] + layers += [nn.ReLU(inplace=True)] + self.convs = nn.Sequential(*layers) + self.fc = nn.Sequential(nn.LazyLinear(out_features), nn.ReLU(inplace=True)) + + # FC output size + last reward. + core_output_size = out_features + 1 + + if use_lstm: + self.core = nn.LSTM( + core_output_size, + out_features, + num_layers=1, + batch_first=batch_first, + ) + core_output_size = out_features + + self.policy = nn.Linear(core_output_size, self.num_actions) + self.baseline = nn.Linear(core_output_size, 1) + + def forward(self, x, reward, done, core_state=None, mask=None): + if self.batch_first: + B, T, *x_shape = x.shape + batch_shape = torch.Size([B, T]) + else: + T, B, *x_shape = x.shape + batch_shape = torch.Size([T, B]) + if mask is None: + x = x.view(-1, *x.shape[-3:]) + else: + x = x[mask] + if x.ndimension() != 4: + raise RuntimeError( + f"masked input should have 4 dimensions but got {x.ndimension()} instead" + ) + x = self.convs(x) + x = x.view(B * T, -1) + x = self.fc(x) + + if mask is None: + if self.batch_first: + x = x.view(B, T, -1) + else: + x = x.view(T, B, -1) + else: + x = self._allocate_masked_x(x, mask) + + if self.clamp_reward: + reward = torch.clamp(reward, -1, 1) + reward = reward.unsqueeze(-1) + + core_input = torch.cat([x, reward], dim=-1) + + if self.use_lstm: + core_output, _ = self.core(core_input, core_state) + else: + core_output = core_input + + policy_logits = self.policy(core_output) + baseline = self.baseline(core_output) + + softmax_vals = F.softmax(policy_logits, dim=-1) + action = torch.multinomial( + softmax_vals.view(-1, softmax_vals.shape[-1]), num_samples=1 + ).view(softmax_vals.shape[:-1]) + if self.one_hot: + action = F.one_hot(action, policy_logits.shape[-1]) + + if policy_logits.shape[:2] != batch_shape: + raise RuntimeError("policy_logits and batch-shape mismatch") + if baseline.shape[:2] != batch_shape: + raise RuntimeError("baseline and batch-shape mismatch") + if action.shape[:2] != batch_shape: + raise RuntimeError("action and batch-shape mismatch") + + return (action, policy_logits, baseline), core_state + + def _allocate_masked_x(self, x, mask): + x_empty = torch.zeros( + *mask.shape[:2], x.shape[-1], device=x.device, dtype=x.dtype + ) + x_empty[mask] = x + return x_empty + + +class ImpalaNetTensorDict(ImpalaNet): + observation_key = "observation_pixels" + + def forward(self, tensor_dict: _TensorDict): # type: ignore + x = tensor_dict.get(self.observation_key) + done = tensor_dict.get("done").squeeze(-1) + reward = tensor_dict.get("reward").squeeze(-1) + mask = tensor_dict.get("mask").squeeze(-1) + core_state = ( + tensor_dict.get("core_state") + if "core_state" in tensor_dict.keys() + else None + ) + return super().forward(x, reward, done, core_state=core_state, mask=mask) diff --git a/torchrl/modules/models/utils.py b/torchrl/modules/models/utils.py new file mode 100644 index 00000000000..cb849288f36 --- /dev/null +++ b/torchrl/modules/models/utils.py @@ -0,0 +1,84 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Sequence + +import torch +from torch import nn + +from .exploration import NoisyLazyLinear, NoisyLinear + +LazyMapping = { + nn.Linear: nn.LazyLinear, + NoisyLinear: NoisyLazyLinear, +} + +__all__ = [ + "SqueezeLayer", + "Squeeze2dLayer", +] + + +class SqueezeLayer(nn.Module): + """ + Squeezing layer. + Squeezes some given singleton dimensions of an input tensor. + + Args: + dims (iterable): dimensions to be squeezed + default: (-1,) + + """ + + def __init__(self, dims: Sequence[int] = (-1,)): + super().__init__() + self.dims = dims + + def forward(self, input: torch.Tensor) -> torch.Tensor: + for dim in self.dims: + input = input.squeeze(dim) + return input + + +class Squeeze2dLayer(SqueezeLayer): + """ + Squeezing layer for convolutional neural networks. + Squeezes the last two singleton dimensions of an input tensor. + + """ + + def __init__(self): + super().__init__((-1, -2)) + + +class SquashDims(nn.Module): + """ + A squashing layer. + Flattens the N last dimensions of an input tensor. + + Args: + ndims_in (int): number of dimensions to be flattened. + default = 3 + """ + + def __init__(self, ndims_in: int = 3): + super().__init__() + self.ndims_in = ndims_in + + def forward(self, value: torch.Tensor) -> torch.Tensor: + value = value.flatten(-self.ndims_in, -1) + return value + + +def _find_depth(depth: Optional[int], *list_or_ints: Sequence): + if depth is None: + for item in list_or_ints: + if isinstance(item, (list, tuple)): + depth = len(item) + if depth is None: + raise Exception( + f"depth=None requires one of the input args (kernel_sizes, strides, num_cells) to be a a list or tuple. Got {tuple(type(item) for item in list_or_ints)}" + ) + return depth diff --git a/torchrl/modules/td_module/__init__.py b/torchrl/modules/td_module/__init__.py new file mode 100644 index 00000000000..aeea731263c --- /dev/null +++ b/torchrl/modules/td_module/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .actors import * +from .common import * +from .exploration import * diff --git a/torchrl/modules/td_module/actors.py b/torchrl/modules/td_module/actors.py new file mode 100644 index 00000000000..455fe85d531 --- /dev/null +++ b/torchrl/modules/td_module/actors.py @@ -0,0 +1,818 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Sequence, Tuple + +import torch +from torch import nn + +from torchrl.modules.models.models import DistributionalDQNnet +from torchrl.modules.td_module.common import ( + ProbabilisticTDModule, + TDModule, + TDModuleWrapper, + TDSequence, +) + +__all__ = [ + "Actor", + "ProbabilisticActor", + "ActorValueOperator", + "ValueOperator", + "QValueActor", + "ActorCriticOperator", + "ActorCriticWrapper", + "DistributionalQValueActor", +] + +from torchrl.data import UnboundedContinuousTensorSpec + + +class Actor(TDModule): + """General class for deterministic actors in RL. + + The Actor class comes with default values for the in_keys and out_keys + arguments (["observation"] and ["action"], respectively). + + Examples: + >>> from torchrl.data import TensorDict, + ... NdUnboundedContinuousTensorSpec + >>> from torchrl.modules import Actor + >>> import torch + >>> td = TensorDict({"observation": torch.randn(3, 4)}, [3,]) + >>> action_spec = NdUnboundedContinuousTensorSpec(4) + >>> module = torch.nn.Linear(4, 4) + >>> td_module = Actor( + ... module=module, + ... spec=action_spec, + ... ) + >>> td_module(td) + >>> print(td.get("action")) + + """ + + def __init__( + self, + *args, + in_keys: Optional[Sequence[str]] = None, + out_keys: Optional[Sequence[str]] = None, + **kwargs, + ): + if in_keys is None: + in_keys = ["observation"] + if out_keys is None: + out_keys = ["action"] + + super().__init__( + *args, + in_keys=in_keys, + out_keys=out_keys, + **kwargs, + ) + + +class ProbabilisticActor(ProbabilisticTDModule): + """ + General class for probabilistic actors in RL. + The Actor class comes with default values for the in_keys and out_keys + arguments (["observation"] and ["action"], respectively). + + Examples: + >>> from torchrl.data import TensorDict, NdBoundedTensorSpec + >>> from torchrl.modules import Actor, TanhNormal + >>> import torch, functorch + >>> td = TensorDict({"observation": torch.randn(3, 4)}, [3,]) + >>> action_spec = NdBoundedTensorSpec(shape=torch.Size([4]), + ... minimum=-1, maximum=1) + >>> module = torch.nn.Linear(4, 8) + >>> fmodule, params, buffers = functorch.make_functional_with_buffers( + ... module) + >>> td_module = ProbabilisticActor( + ... module=fmodule, + ... spec=action_spec, + ... distribution_class=TanhNormal, + ... ) + >>> td_module(td, params=params, buffers=buffers) + >>> print(td.get("action")) + + """ + + def __init__( + self, + *args, + in_keys: Optional[Sequence[str]] = None, + out_keys: Optional[Sequence[str]] = None, + **kwargs, + ): + if in_keys is None: + in_keys = ["observation"] + if out_keys is None: + out_keys = ["action"] + + super().__init__( + *args, + in_keys=in_keys, + out_keys=out_keys, + **kwargs, + ) + + +class ValueOperator(TDModule): + """ + General class for value functions in RL. + + The ValueOperator class comes with default values for the in_keys and + out_keys arguments (["observation"] and ["state_value"] or + ["state_action_value"], respectively and depending on whether the "action" + key is part of the in_keys list). + + Examples: + >>> from torchrl.data import TensorDict, NdUnboundedContinuousTensorSpec + >>> from torchrl.modules import ValueOperator + >>> import torch, functorch + >>> from torch import nn + >>> td = TensorDict({"observation": torch.randn(3, 4), "action": torch.randn(3, 2)}, [3,]) + >>> class CustomModule(nn.Module): + ... def __init__(self): + ... super().__init__() + ... self.linear = torch.nn.Linear(6, 1) + ... def forward(self, obs, action): + ... return self.linear(torch.cat([obs, action], -1)) + >>> module = CustomModule() + >>> fmodule, params, buffers = functorch.make_functional_with_buffers(module) + >>> td_module = ValueOperator( + ... in_keys=["observation", "action"], + ... module=fmodule, + ... ) + >>> td_module(td, params=params, buffers=buffers) + >>> print(td) + TensorDict( + fields={observation: Tensor(torch.Size([3, 4]), dtype=torch.float32), + action: Tensor(torch.Size([3, 2]), dtype=torch.float32), + state_action_value: Tensor(torch.Size([3, 1]), dtype=torch.float32)}, + shared=False, + batch_size=torch.Size([3]), + device=cpu) + + + """ + + def __init__( + self, + module: nn.Module, + in_keys: Optional[Sequence[str]] = None, + out_keys: Optional[Sequence[str]] = None, + ) -> None: + + if in_keys is None: + in_keys = ["observation"] + if out_keys is None: + out_keys = ( + ["state_value"] if "action" not in in_keys else ["state_action_value"] + ) + value_spec = UnboundedContinuousTensorSpec() + super().__init__( + spec=value_spec, + module=module, + in_keys=in_keys, + out_keys=out_keys, + ) + + +class QValueHook: + """ + Q-Value hook for Q-value policies. + Given a the output of a regular nn.Module, representing the values of the different discrete actions available, + a QValueHook will transform these values into their argmax component (i.e. the resulting greedy action). + Currently, this is returned as a one-hot encoding. + + Args: + action_space (str): Action space. Must be one of "one-hot", "mult_one_hot" or "binary". + var_nums (int, optional): if action_space == "mult_one_hot", this value represents the cardinality of each + action component. + + Examples: + >>> import functorch + >>> from torchrl.data import TensorDict, OneHotDiscreteTensorSpec + >>> from torchrl.modules.td_module.actors import QValueHook, Actor + >>> from torch import nn + >>> from torchrl.data import OneHotDiscreteTensorSpec, TensorDict + >>> import torch, functorch + >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5]) + >>> module = nn.Linear(4, 4) + >>> fmodule, params, buffers = functorch.make_functional_with_buffers(module) + >>> hook = QValueHook("one_hot") + >>> _ = fmodule.register_forward_hook(hook) + >>> action_spec = OneHotDiscreteTensorSpec(4) + >>> qvalue_actor = Actor(module=fmodule, spec=action_spec, out_keys=["action", "action_value"]) + >>> _ = qvalue_actor(td, params=params, buffers=buffers) + >>> print(td) + TensorDict( + fields={observation: Tensor(torch.Size([5, 4]), dtype=torch.float32), + action: Tensor(torch.Size([5, 4]), dtype=torch.int64), + action_value: Tensor(torch.Size([5, 4]), dtype=torch.float32)}, + shared=False, + batch_size=torch.Size([5]), + device=cpu) + + """ + + def __init__( + self, + action_space: str, + var_nums: Optional[int] = None, + ): + self.action_space = action_space + self.var_nums = var_nums + self.fun_dict = { + "one_hot": self._one_hot, + "mult_one_hot": self._mult_one_hot, + "binary": self._binary, + } + if action_space not in self.fun_dict: + raise ValueError( + f"action_space must be one of {list(self.fun_dict.keys())}" + ) + + def __call__( + self, net: nn.Module, observation: torch.Tensor, values: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + action = self.fun_dict[self.action_space](values) + chosen_action_value = (action * values).sum(-1, True) + return action, values, chosen_action_value + + @staticmethod + def _one_hot(value: torch.Tensor) -> torch.Tensor: + out = (value == value.max(dim=-1, keepdim=True)[0]).to(torch.long) + return out + + def _mult_one_hot(self, value: torch.Tensor, support: torch.Tensor) -> torch.Tensor: + values = value.split(self.var_nums, dim=-1) + return torch.cat( + [ + QValueHook._one_hot( + _value, + ) + for _value in values + ], + -1, + ) + + @staticmethod + def _binary(value: torch.Tensor, support: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + +class DistributionalQValueHook(QValueHook): + """Distributional Q-Value hook for Q-value policies. + + Given a the output of a mapping operator, representing the values of the different discrete actions available, + a DistributionalQValueHook will transform these values into their argmax component using the provided support. + Currently, this is returned as a one-hot encoding. + For more details regarding Distributional DQN, refer to "A Distributional Perspective on Reinforcement Learning", + https://arxiv.org/pdf/1707.06887.pdf + + Args: + action_space (str): Action space. Must be one of "one_hot", "mult_one_hot" or "binary". + support (torch.Tensor): support of the action values. + var_nums (int, optional): if action_space == "mult_one_hot", this value represents the cardinality of each + action component. + + Examples: + >>> from torchrl.data import TensorDict, OneHotDiscreteTensorSpec + >>> from torchrl.modules.td_module.actors import DistributionalQValueHook, Actor + >>> from torch import nn + >>> import torch, functorch + >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5]) + >>> nbins = 3 + >>> class CustomDistributionalQval(nn.Module): + ... def __init__(self): + ... super().__init__() + ... self.linear = nn.Linear(4, nbins*4) + ... + ... def forward(self, x): + ... return self.linear(x).view(-1, nbins, 4).log_softmax(-2) + ... + >>> module = CustomDistributionalQval() + >>> fmodule, params, buffers = functorch.make_functional_with_buffers(module) + >>> action_spec = OneHotDiscreteTensorSpec(4) + >>> hook = DistributionalQValueHook("one_hot", support = torch.arange(nbins)) + >>> _ = fmodule.register_forward_hook(hook) + >>> qvalue_actor = Actor(module=fmodule, spec=action_spec, out_keys=["action", "action_value"]) + >>> _ = qvalue_actor(td, params=params, buffers=buffers) + >>> print(td) + TensorDict( + fields={observation: Tensor(torch.Size([5, 4]), dtype=torch.float32), + action: Tensor(torch.Size([5, 4]), dtype=torch.int64), + action_value: Tensor(torch.Size([5, 3, 4]), dtype=torch.float32)}, + shared=False, + batch_size=torch.Size([5]), + device=cpu) + + """ + + def __init__( + self, + action_space: str, + support: torch.Tensor, + var_nums: Optional[int] = None, + ): + self.action_space = action_space + self.support = support + self.var_nums = var_nums + self.fun_dict = { + "one_hot": self._one_hot, + "mult_one_hot": self._mult_one_hot, + "binary": self._binary, + } + + def __call__( + self, net: nn.Module, observation: torch.Tensor, values: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + action = self.fun_dict[self.action_space](values, self.support) + return action, values + + def _support_expected( + self, log_softmax_values: torch.Tensor, support: torch.Tensor + ) -> torch.Tensor: + support = support.to(log_softmax_values.device) + if log_softmax_values.shape[-2] != support.shape[-1]: + raise RuntimeError( + "Support length and number of atoms in module output should match, " + f"got self.support.shape={support.shape} and module(...).shape={log_softmax_values.shape}" + ) + if (log_softmax_values > 0).any(): + raise ValueError( + f"input to QValueHook must be log-softmax values (which are expected to be non-positive numbers). " + f"got a maximum value of {log_softmax_values.max():4.4f}" + ) + return (log_softmax_values.exp() * support.unsqueeze(-1)).sum(-2) + + def _one_hot(self, value: torch.Tensor, support: torch.Tensor) -> torch.Tensor: + if not isinstance(value, torch.Tensor): + raise TypeError(f"got value of type {value.__class__.__name__}") + if not isinstance(support, torch.Tensor): + raise TypeError(f"got support of type {support.__class__.__name__}") + value = self._support_expected(value, support) + out = (value == value.max(dim=-1, keepdim=True)[0]).to(torch.long) + return out + + def _mult_one_hot(self, value: torch.Tensor, support: torch.Tensor) -> torch.Tensor: + values = value.split(self.var_nums, dim=-1) + return torch.cat( + [ + self._one_hot(_value, _support) + for _value, _support in zip(values, support) + ], + -1, + ) + + @staticmethod + def _binary(value: torch.Tensor, support: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + +class QValueActor(Actor): + """ + DQN Actor subclass. + This class hooks the module such that it returns a one-hot encoding of the argmax value. + + Examples: + >>> from torchrl.data import TensorDict, OneHotDiscreteTensorSpec + >>> from torchrl.modules.td_module.actors import QValueActor + >>> from torch import nn + >>> import torch, functorch + >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5]) + >>> module = nn.Linear(4, 4) + >>> fmodule, params, buffers = functorch.make_functional_with_buffers(module) + >>> action_spec = OneHotDiscreteTensorSpec(4) + >>> qvalue_actor = QValueActor(module=fmodule, spec=action_spec) + >>> _ = qvalue_actor(td, params=params, buffers=buffers) + >>> print(td) + TensorDict( + fields={ + observation: Tensor(torch.Size([5, 4]), dtype=torch.float32), + action: Tensor(torch.Size([5, 4]), dtype=torch.int64), + action_value: Tensor(torch.Size([5, 4]), dtype=torch.float32), + chosen_action_value: Tensor(torch.Size([5, 1]), dtype=torch.float32)}, + batch_size=torch.Size([5]), + device=cpu, + is_shared=False) + + """ + + def __init__(self, *args, action_space: int = "one_hot", **kwargs): + out_keys = [ + "action", + "action_value", + "chosen_action_value", + ] + super().__init__(*args, out_keys=out_keys, **kwargs) + self.action_space = action_space + self.module.register_forward_hook(QValueHook(self.action_space)) + + +class DistributionalQValueActor(QValueActor): + """ + Distributional DQN Actor subclass. + This class hooks the module such that it returns a one-hot encoding of the argmax value on its support. + + Examples: + >>> from torchrl.data import TensorDict, OneHotDiscreteTensorSpec + >>> from torchrl.modules import DistributionalQValueActor, MLP + >>> from torch import nn + >>> import torch, functorch + >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5]) + >>> nbins = 3 + >>> module = MLP(out_features=(nbins, 4), depth=2) + >>> action_spec = OneHotDiscreteTensorSpec(4) + >>> qvalue_actor = DistributionalQValueActor(module=module, spec=action_spec, support=torch.arange(nbins)) + >>> _ = qvalue_actor(td) + >>> print(td) + TensorDict( + fields={ + observation: Tensor(torch.Size([5, 4]), dtype=torch.float32), + action: Tensor(torch.Size([5, 4]), dtype=torch.int64), + action_value: Tensor(torch.Size([5, 3, 4]), dtype=torch.float32)}, + batch_size=torch.Size([5]), + device=cpu, + is_shared=False) + + """ + + def __init__( + self, + *args, + support: torch.Tensor, + action_space: str = "one_hot", + **kwargs, + ): + out_keys = [ + "action", + "action_value", + ] + super(QValueActor, self).__init__(*args, out_keys=out_keys, **kwargs) + self.action_space = action_space + + self.register_buffer("support", support) + self.action_space = action_space + if not isinstance(self.module, DistributionalDQNnet): + self.module = DistributionalDQNnet(self.module) + self.module.register_forward_hook( + DistributionalQValueHook(self.action_space, self.support) + ) + + +class ActorValueOperator(TDSequence): + """ + Actor-value operator. + + This class wraps together an actor and a value model that share a common observation embedding network: + + .. aafig:: + :aspect: 60 + :scale: 120 + :proportional: + :textual: + + +-------------+ + |"Observation"| + +-------------+ + | + v + +--------------+ + |"hidden state"| + +--------------+ + | | | + v | v + actor | critic + | | | + v | v + +--------+|+-------+ + |"action"|||"value"| + +--------+|+-------+ + + To facilitate the workflow, this class comes with a get_policy_operator() and get_value_operator() methods, which + will both return a stand-alone TDModule with the dedicated functionality. + + Args: + common_operator (TDModule): a common operator that reads observations and produces a hidden variable + policy_operator (TDModule): a policy operator that reads the hidden variable and returns an action + value_operator (TDModule): a value operator, that reads the hidden variable and returns a value + + Examples: + >>> from torchrl.data import TensorDict, NdUnboundedContinuousTensorSpec, NdBoundedTensorSpec + >>> from torchrl.modules import ProbabilisticActor, ValueOperator, TanhNormal, ActorValueOperator + >>> import torch + >>> spec_hidden = NdUnboundedContinuousTensorSpec(4) + >>> module_hidden = torch.nn.Linear(4, 4) + >>> td_module_hidden = TDModule( + ... module=module_hidden, + ... spec=spec_hidden, + ... in_keys=["observation"], + ... out_keys=["hidden"], + ... ) + >>> spec_action = NdBoundedTensorSpec(-1, 1, torch.Size([8])) + >>> module_action = torch.nn.Linear(4, 8) + >>> td_module_action = ProbabilisticActor( + ... module=module_action, + ... spec=spec_action, + ... in_keys=["hidden"], + ... distribution_class=TanhNormal, + ... return_log_prob=True, + ... ) + >>> module_value = torch.nn.Linear(4, 1) + >>> td_module_value = ValueOperator( + ... module=module_value, + ... in_keys=["hidden"], + ... ) + >>> td_module = ActorValueOperator(td_module_hidden, td_module_action, td_module_value) + >>> td = TensorDict({"observation": torch.randn(3, 4)}, [3,]) + >>> td_clone = td_module(td.clone()) + >>> print(td_clone) + TensorDict( + fields={observation: Tensor(torch.Size([3, 4]), dtype=torch.float32), + hidden: Tensor(torch.Size([3, 4]), dtype=torch.float32), + action: Tensor(torch.Size([3, 4]), dtype=torch.float32), + action_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32), + state_value: Tensor(torch.Size([3, 1]), dtype=torch.float32)}, + shared=False, + batch_size=torch.Size([3]), + device=cpu) + >>> td_clone = td_module.get_policy_operator()(td.clone()) + >>> print(td_clone) # no value + TensorDict( + fields={observation: Tensor(torch.Size([3, 4]), dtype=torch.float32), + hidden: Tensor(torch.Size([3, 4]), dtype=torch.float32), + action: Tensor(torch.Size([3, 4]), dtype=torch.float32), + action_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32)}, + shared=False, + batch_size=torch.Size([3]), + device=cpu) + + >>> td_clone = td_module.get_value_operator()(td.clone()) + >>> print(td_clone) # no action + TensorDict( + fields={observation: Tensor(torch.Size([3, 4]), dtype=torch.float32), + hidden: Tensor(torch.Size([3, 4]), dtype=torch.float32), + state_value: Tensor(torch.Size([3, 1]), dtype=torch.float32)}, + shared=False, + batch_size=torch.Size([3]), + device=cpu) + + """ + + def __init__( + self, + common_operator: TDModule, + policy_operator: TDModule, + value_operator: TDModule, + ): + super().__init__( + common_operator, + policy_operator, + value_operator, + ) + + def get_policy_operator(self) -> TDSequence: + """ + + Returns a stand-alone policy operator that maps an observation to an action. + + """ + return TDSequence(self.module[0], self.module[1]) + + def get_value_operator(self) -> TDSequence: + """ + + Returns a stand-alone value network operator that maps an observation to a value estimate. + + """ + return TDSequence(self.module[0], self.module[2]) + + +class ActorCriticOperator(ActorValueOperator): + """ + Actor-critic operator. + + This class wraps together an actor and a value model that share a common observation embedding network: + + .. aafig:: + :aspect: 60 + :scale: 120 + :proportional: + :textual: + + +-----------+ + |Observation| + +-----------+ + | + v + actor + | + v + +------+ + |action| --> critic + +------+ | + v + +-----+ + |value| + +-----+ + + To facilitate the workflow, this class comes with a get_policy_operator() method, which + will both return a stand-alone TDModule with the dedicated functionality. The get_critic_operator will return the + parent object, as the value is computed based on the policy output. + + Args: + common_operator (TDModule): a common operator that reads observations and produces a hidden variable + policy_operator (TDModule): a policy operator that reads the hidden variable and returns an action + value_operator (TDModule): a value operator, that reads the hidden variable and returns a value + + Examples: + >>> from torchrl.data import TensorDict, NdUnboundedContinuousTensorSpec, NdBoundedTensorSpec + >>> from torchrl.modules import ProbabilisticActor, ValueOperator, TanhNormal, ActorCriticOperator + >>> import torch + >>> spec_hidden = NdUnboundedContinuousTensorSpec(4) + >>> module_hidden = torch.nn.Linear(4, 4) + >>> td_module_hidden = TDModule( + ... module=module_hidden, + ... spec=spec_hidden, + ... in_keys=["observation"], + ... out_keys=["hidden"], + ... ) + >>> spec_action = NdBoundedTensorSpec(-1, 1, torch.Size([8])) + >>> module_action = torch.nn.Linear(4, 8) + >>> td_module_action = ProbabilisticActor( + ... module=module_action, + ... spec=spec_action, + ... in_keys=["hidden"], + ... distribution_class=TanhNormal, + ... return_log_prob=True, + ... ) + >>> module_value = torch.nn.Linear(4, 1) + >>> td_module_value = ValueOperator( + ... module=module_value, + ... in_keys=["hidden"], + ... ) + >>> td_module = ActorCriticOperator(td_module_hidden, td_module_action, td_module_value) + >>> td = TensorDict({"observation": torch.randn(3, 4)}, [3,]) + >>> td_clone = td_module(td.clone()) + >>> print(td_clone) + TensorDict( + fields={observation: Tensor(torch.Size([3, 4]), dtype=torch.float32), + hidden: Tensor(torch.Size([3, 4]), dtype=torch.float32), + action: Tensor(torch.Size([3, 4]), dtype=torch.float32), + action_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32), + state_value: Tensor(torch.Size([3, 1]), dtype=torch.float32)}, + shared=False, + batch_size=torch.Size([3]), + device=cpu) + >>> td_clone = td_module.get_policy_operator()(td.clone()) + >>> print(td_clone) # no value + TensorDict( + fields={observation: Tensor(torch.Size([3, 4]), dtype=torch.float32), + hidden: Tensor(torch.Size([3, 4]), dtype=torch.float32), + action: Tensor(torch.Size([3, 4]), dtype=torch.float32), + action_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32)}, + shared=False, + batch_size=torch.Size([3]), + device=cpu) + + >>> td_clone = td_module.get_critic_operator()(td.clone()) + >>> print(td_clone) # no action + TensorDict( + fields={observation: Tensor(torch.Size([3, 4]), dtype=torch.float32), + hidden: Tensor(torch.Size([3, 4]), dtype=torch.float32), + action: Tensor(torch.Size([3, 4]), dtype=torch.float32), + action_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32), + state_value: Tensor(torch.Size([3, 1]), dtype=torch.float32)}, + shared=False, + batch_size=torch.Size([3]), + device=cpu) + + """ + + def get_critic_operator(self) -> TDModuleWrapper: + """ + + Returns a stand-alone critic network operator that maps a state-action pair to a critic estimate. + + """ + return self + + def get_value_operator(self) -> TDModuleWrapper: + raise RuntimeError( + "value_operator is the term used for operators that associate a value with a " + "state/observation. This class computes the value of a state-action pair: to get the " + "network computing this value, please call td_sequence.get_critic_operator()" + ) + + +class ActorCriticWrapper(TDSequence): + """ + Actor-value operator without common module. + + This class wraps together an actor and a value model that do not share a common observation embedding network: + + .. aafig:: + :aspect: 60 + :scale: 120 + :proportional: + :textual: + + +-----------+ + |Observation| + +-----------+ + | | | + v | v + actor | critic + | | | + v | v + +------+|+-------+ + |action||| value | + +------+|+-------+ + + To facilitate the workflow, this class comes with a get_policy_operator() and get_value_operator() methods, which + will both return a stand-alone TDModule with the dedicated functionality. + + Args: + policy_operator (TDModule): a policy operator that reads the hidden variable and returns an action + value_operator (TDModule): a value operator, that reads the hidden variable and returns a value + + Examples: + >>> from torchrl.data import TensorDict, NdUnboundedContinuousTensorSpec, NdBoundedTensorSpec + >>> from torchrl.modules import ProbabilisticActor, ValueOperator, TanhNormal, ActorCriticWrapper + >>> import torch + >>> spec_action = NdBoundedTensorSpec(-1, 1, torch.Size([8])) + >>> module_action = torch.nn.Linear(4, 8) + >>> td_module_action = ProbabilisticActor( + ... module=module_action, + ... spec=spec_action, + ... distribution_class=TanhNormal, + ... return_log_prob=True, + ... ) + >>> module_value = torch.nn.Linear(4, 1) + >>> td_module_value = ValueOperator( + ... module=module_value, + ... in_keys=["observation"], + ... ) + >>> td_module = ActorCriticWrapper(td_module_action, td_module_value) + >>> td = TensorDict({"observation": torch.randn(3, 4)}, [3,]) + >>> td_clone = td_module(td.clone()) + >>> print(td_clone) + TensorDict( + fields={observation: Tensor(torch.Size([3, 4]), dtype=torch.float32), + action: Tensor(torch.Size([3, 4]), dtype=torch.float32), + action_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32), + state_value: Tensor(torch.Size([3, 1]), dtype=torch.float32)}, + shared=False, + batch_size=torch.Size([3]), + device=cpu) + >>> td_clone = td_module.get_policy_operator()(td.clone()) + >>> print(td_clone) # no value + TensorDict( + fields={observation: Tensor(torch.Size([3, 4]), dtype=torch.float32), + action: Tensor(torch.Size([3, 4]), dtype=torch.float32), + action_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32)}, + shared=False, + batch_size=torch.Size([3]), + device=cpu) + + >>> td_clone = td_module.get_value_operator()(td.clone()) + >>> print(td_clone) # no action + TensorDict( + fields={observation: Tensor(torch.Size([3, 4]), dtype=torch.float32), + state_value: Tensor(torch.Size([3, 1]), dtype=torch.float32)}, + shared=False, + batch_size=torch.Size([3]), + device=cpu) + + """ + + def __init__( + self, + policy_operator: TDModule, + value_operator: TDModule, + ): + super().__init__( + policy_operator, + value_operator, + ) + + def get_policy_operator(self) -> TDSequence: + """ + + Returns a stand-alone policy operator that maps an observation to an action. + + """ + return self.module[0] + + def get_value_operator(self) -> TDSequence: + """ + + Returns a stand-alone value network operator that maps an observation to a value estimate. + + """ + return self.module[1] diff --git a/torchrl/modules/td_module/common.py b/torchrl/modules/td_module/common.py new file mode 100644 index 00000000000..687409e09e0 --- /dev/null +++ b/torchrl/modules/td_module/common.py @@ -0,0 +1,1021 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from copy import copy, deepcopy +from textwrap import indent +from typing import ( + Any, + Callable, + Iterable, + List, + Optional, + Sequence, + Tuple, + Type, + Union, +) + +import functorch +import torch +from functorch import FunctionalModule, FunctionalModuleWithBuffers, vmap +from functorch._src.make_functional import _swap_state +from torch import distributions as d, nn, Tensor + +from torchrl.data import ( + CompositeSpec, + DEVICE_TYPING, + TensorSpec, + UnboundedContinuousTensorSpec, +) +from torchrl.data.tensordict.tensordict import _TensorDict, TensorDict +from torchrl.envs.utils import exploration_mode +from torchrl.modules.distributions import Delta, distributions_maps + +__all__ = [ + "TDModule", + "ProbabilisticTDModule", + "TDSequence", + "TDModuleWrapper", +] + + +def _forward_hook_safe_action(module, tensor_dict_in, tensor_dict_out): + if not module.spec.is_in(tensor_dict_out.get(module.out_keys[0])): + try: + tensor_dict_out.set_( + module.out_keys[0], + module.spec.project(tensor_dict_out.get(module.out_keys[0])), + ) + except RuntimeError: + tensor_dict_out.set( + module.out_keys[0], + module.spec.project(tensor_dict_out.get(module.out_keys[0])), + ) + + +class TDModule(nn.Module): + """A TDModule, for TensorDict module, is a python wrapper around a `nn.Module` that reads and writes to a + TensorDict, instead of reading and returning tensors. + + Args: + module (nn.Module): a nn.Module used to map the input to the output parameter space. Can be a functional + module (FunctionalModule or FunctionalModuleWithBuffers), in which case the `forward` method will expect + the params (and possibly) buffers keyword arguments. + spec (TensorSpec): specs of the output tensor. If the module outputs multiple output tensors, + spec characterize the space of the first output tensor. + in_keys (iterable of str): keys to be read from input tensordict and passed to the module. If it + contains more than one element, the values will be passed in the order given by the in_keys iterable. + out_keys (iterable of str): keys to be written to the input tensordict. The length of out_keys must match the + number of tensors returned by the embedded module. + safe (bool): if True, the value of the output is checked against the input spec. Out-of-domain sampling can + occur because of exploration policies or numerical under/overflow issues. + If this value is out of bounds, it is projected back onto the desired space using the `TensorSpec.project` + method. Default is `False`. + + Embedding a neural network in a TDModule only requires to specify the input and output keys. The domain spec can + be passed along if needed. TDModule support functional and regular `nn.Module` objects. In the functional + case, the 'params' (and 'buffers') keyword argument must be specified: + + Examples: + >>> from torchrl.data import TensorDict, NdUnboundedContinuousTensorSpec + >>> from torchrl.modules import TDModule + >>> import torch, functorch + >>> td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3,]) + >>> spec = NdUnboundedContinuousTensorSpec(8) + >>> module = torch.nn.GRUCell(4, 8) + >>> fmodule, params, buffers = functorch.make_functional_with_buffers(module) + >>> td_fmodule = TDModule( + ... module=fmodule, + ... spec=spec, + ... in_keys=["input", "hidden"], + ... out_keys=["output"], + ... ) + >>> td_functional = td_fmodule(td.clone(), params=params, buffers=buffers) + >>> print(td_functional) + TensorDict( + fields={input: Tensor(torch.Size([3, 4]), dtype=torch.float32), + hidden: Tensor(torch.Size([3, 8]), dtype=torch.float32), + output: Tensor(torch.Size([3, 8]), dtype=torch.float32)}, + shared=False, + batch_size=torch.Size([3]), + device=cpu) + + In the stateful case: + >>> td_module = TDModule( + ... module=module, + ... spec=spec, + ... in_keys=["input", "hidden"], + ... out_keys=["output"], + ... ) + >>> td_stateful = td_module(td.clone()) + >>> print(td_stateful) + TensorDict( + fields={input: Tensor(torch.Size([3, 4]), dtype=torch.float32), + hidden: Tensor(torch.Size([3, 8]), dtype=torch.float32), + output: Tensor(torch.Size([3, 8]), dtype=torch.float32)}, + shared=False, + batch_size=torch.Size([3]), + device=cpu) + + One can use a vmap operator to call the functional module. In this case the tensordict is expanded to match the + batch size (i.e. the tensordict isn't modified in-place anymore): + >>> # Model ensemble using vmap + >>> params_repeat = tuple(param.expand(4, *param.shape).contiguous().normal_() for param in params) + >>> buffers_repeat = tuple(param.expand(4, *param.shape).contiguous().normal_() for param in buffers) + >>> td_vmap = td_fmodule(td.clone(), params=params_repeat, buffers=buffers_repeat, vmap=True) + >>> print(td_vmap) + TensorDict( + fields={input: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32), + hidden: Tensor(torch.Size([4, 3, 8]), dtype=torch.float32), + output: Tensor(torch.Size([4, 3, 8]), dtype=torch.float32)}, + shared=False, + batch_size=torch.Size([4, 3]), + device=cpu) + + """ + + def __init__( + self, + module: Union[ + FunctionalModule, FunctionalModuleWithBuffers, TDModule, nn.Module + ], + spec: Optional[TensorSpec], + in_keys: Iterable[str], + out_keys: Iterable[str], + safe: bool = False, + ): + + super().__init__() + + if not out_keys: + raise RuntimeError(f"out_keys were not passed to {self.__class__.__name__}") + if not in_keys: + raise RuntimeError(f"in_keys were not passed to {self.__class__.__name__}") + self.out_keys = out_keys + self.in_keys = in_keys + + self._spec = spec + self.safe = safe + if safe: + if spec is None: + raise RuntimeError( + "`TDModule(spec=None, safe=True)` is not a valid configuration as the tensor " + "specs are not specified" + ) + self.register_forward_hook(_forward_hook_safe_action) + + self.module = module + + def __setattr__(self, key: str, attribute: Any) -> None: + if key == "spec" and isinstance(attribute, TensorSpec): + self._spec = attribute + return + super().__setattr__(key, attribute) + + @property + def spec(self) -> TensorSpec: + return self._spec + + @spec.setter + def _spec_set(self, spec: TensorSpec) -> None: + if not isinstance(spec, TensorSpec): + raise RuntimeError( + f"Trying to set an object of type {type(spec)} as a tensorspec." + ) + self._spec = spec + + def _write_to_tensor_dict( + self, + tensor_dict: _TensorDict, + tensors: List, + tensor_dict_out: Optional[_TensorDict] = None, + out_keys: Optional[Iterable[str]] = None, + vmap: Optional[int] = None, + ) -> _TensorDict: + + if out_keys is None: + out_keys = self.out_keys + if ( + (tensor_dict_out is None) + and vmap + and (isinstance(vmap, bool) or vmap[-1] is None) + ): + dim = tensors[0].shape[0] + shape = [dim, *tensor_dict.shape] + tensor_dict_out = TensorDict( + {key: val.expand(dim, *val.shape) for key, val in tensor_dict.items()}, + shape, + ) + elif tensor_dict_out is None: + tensor_dict_out = tensor_dict + for _out_key, _tensor in zip(out_keys, tensors): + tensor_dict_out.set(_out_key, _tensor) + return tensor_dict_out + + def _make_vmap(self, kwargs, n_input): + if "vmap" in kwargs and kwargs["vmap"]: + if not isinstance(kwargs["vmap"], (tuple, bool)): + raise RuntimeError( + "vmap argument must be a boolean or a tuple of dim expensions." + ) + _buffers = "buffers" in kwargs + _vmap = ( + kwargs["vmap"] + if isinstance(kwargs["vmap"], tuple) + else (0, 0, *(None,) * n_input) + if _buffers + else (0, *(None,) * n_input) + ) + return _vmap + + def _call_module( + self, tensors: Sequence[Tensor], **kwargs + ) -> Union[Tensor, Sequence[Tensor]]: + err_msg = "Did not find the {0} keyword argument to be used with the functional module." + if isinstance(self.module, (FunctionalModule, FunctionalModuleWithBuffers)): + _vmap = self._make_vmap(kwargs, len(tensors)) + if _vmap: + module = vmap(self.module, _vmap) + else: + module = self.module + + if isinstance(self.module, FunctionalModule): + if "params" not in kwargs: + raise KeyError(err_msg.format("params")) + kwargs_pruned = { + key: item + for key, item in kwargs.items() + if key not in ("params", "vmap") + } + return module(kwargs["params"], *tensors, **kwargs_pruned) + + elif isinstance(self.module, FunctionalModuleWithBuffers): + if "params" not in kwargs: + raise KeyError(err_msg.format("params")) + if "buffers" not in kwargs: + raise KeyError(err_msg.format("buffers")) + + kwargs_pruned = { + key: item + for key, item in kwargs.items() + if key not in ("params", "buffers", "vmap") + } + return module( + kwargs["params"], kwargs["buffers"], *tensors, **kwargs_pruned + ) + else: + out = self.module(*tensors, **kwargs) + return out + + def forward( + self, + tensor_dict: _TensorDict, + tensor_dict_out: Optional[_TensorDict] = None, + **kwargs, + ) -> _TensorDict: + tensors = tuple(tensor_dict.get(in_key) for in_key in self.in_keys) + tensors = self._call_module(tensors, **kwargs) + if not isinstance(tensors, tuple): + tensors = (tensors,) + tensor_dict_out = self._write_to_tensor_dict( + tensor_dict, + tensors, + tensor_dict_out, + vmap=kwargs.get("vmap", False), + ) + return tensor_dict_out + + def random(self, tensor_dict: _TensorDict) -> _TensorDict: + """Samples a random element in the target space, irrespective of any input. If multiple output keys are present, + only the first will be written in the input `tensordict`. + + Args: + tensor_dict (_TensorDict): tensordict where the output value should be written. + + Returns: + the original tensordict with a new/updated value for the output key. + + """ + key0 = self.out_keys[0] + tensor_dict.set(key0, self.spec.rand(tensor_dict.batch_size)) + return tensor_dict + + def random_sample(self, tensordict: _TensorDict) -> _TensorDict: + """see TDModule.random(...)""" + return self.random(tensordict) + + @property + def device(self): + for p in self.parameters(): + return p.device + return torch.device("cpu") + + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> TDModule: + if self.spec is not None: + self.spec = self.spec.to(dest) + out = super().to(dest) + return out # type: ignore + + def __repr__(self) -> str: + fields = indent( + f"module={self.module}, \n" + f"device={self.device}, \n" + f"in_keys={self.in_keys}, \n" + f"out_keys={self.out_keys}", + 4 * " ", + ) + + return f"{self.__class__.__name__}(\n{fields})" + + def make_functional_with_buffers(self, clone: bool = False): + """ + Transforms a stateful module in a functional module and returns its parameters and buffers. + Unlike functorch.make_functional_with_buffers, this method supports lazy modules. + + Returns: + A tuple of parameter and buffer tuples + + Examples: + >>> from torchrl.data import NdUnboundedContinuousTensorSpec, TensorDict + >>> lazy_module = nn.LazyLinear(4) + >>> spec = NdUnboundedContinuousTensorSpec(18) + >>> td_module = TDModule(lazy_module, spec, ["some_input"], + ... ["some_output"]) + >>> _, (params, buffers) = td_module.make_functional_with_buffers() + >>> print(params[0].shape) # the lazy module has been initialized + torch.Size([4, 18]) + >>> print(td_module( + ... TensorDict({'some_input': torch.randn(18)}, batch_size=[]), + ... params=params, + ... buffers=buffers)) + TensorDict( + fields={ + some_input: Tensor(torch.Size([18]), dtype=torch.float32), + some_output: Tensor(torch.Size([4]), dtype=torch.float32)}, + batch_size=torch.Size([]), + device=cpu, + is_shared=False) + + """ + if clone: + self_copy = deepcopy(self) + else: + self_copy = self + + if isinstance( + self_copy.module, + (TDModule, FunctionalModule, FunctionalModuleWithBuffers), + ): + raise RuntimeError( + "TDModule.make_functional_with_buffers requires the module to be a regular nn.Module. " + f"Found type {type(self_copy.module)}" + ) + + # check if there is a non-initialized lazy module + for m in self_copy.module.modules(): + if hasattr(m, "has_uninitialized_params") and m.has_uninitialized_params(): + pseudo_input = self_copy.spec.rand() + self_copy.module(pseudo_input) + break + + fmodule, params, buffers = functorch.make_functional_with_buffers( + self_copy.module + ) + self_copy.module = fmodule + + # Erase meta params + for _ in fmodule.parameters(): + none_state = [None for _ in params + buffers] + if hasattr(fmodule, "all_names_map"): + # functorch >= 0.2.0 + _swap_state(fmodule.stateless_model, fmodule.all_names_map, none_state) + else: + # functorch < 0.2.0 + _swap_state(fmodule.stateless_model, fmodule.split_names, none_state) + + break + + return self_copy, (params, buffers) + + +class ProbabilisticTDModule(TDModule): + """ + A probabilistic TD Module. + ProbabilisticTDModule is a special case of a TDModule where the output is sampled given some rule, specified by + the input `default_interaction_mode` argument and the `exploration_mode()` global function. + + A ProbabilisticTDModule instance has two main features: + - It reads and writes TensorDict objects + - It uses a real mapping R^n -> R^m to create a distribution in R^d from which values can be sampled or computed. + When the __call__ / forward method is called, a distribution is created, and a value computed (using the 'mean', + 'mode', 'median' attribute or the 'rsample', 'sample' method). + + By default, ProbabilisticTDModule distribution class is a Delta distribution, making ProbabilisticTDModule a + simple wrapper around a deterministic mapping function (i.e. it can be used interchangeably with its parent + TDModule). + + Args: + module (nn.Module): a nn.Module used to map the input to the output parameter space. Can be a functional + module (FunctionalModule or FunctionalModuleWithBuffers), in which case the `forward` method will expect + the params (and possibly) buffers keyword arguments. + spec (TensorSpec): specs of the first output tensor. Used when calling td_module.random() to generate random + values in the target space. + in_keys (iterable of str): keys to be read from input tensordict and passed to the module. If it + contains more than one element, the values will be passed in the order given by the in_keys iterable. + out_keys (iterable of str): keys to be written to the input tensordict. The length of out_keys must match the + number of tensors returned by the distribution sampling method plus the extra tensors returned by the + module. + distribution_class (Type, optional): a torch.distributions.Distribution class to be used for sampling. + Default is Delta. + distribution_kwargs (dict, optional): kwargs to be passed to the distribution. + default_interaction_mode (str, optional): default method to be used to retrieve the output value. Should be one of: + 'mode', 'median', 'mean' or 'random' (in which case the value is sampled randomly from the distribution). + Default is 'mode'. + Note: When a sample is drawn, the `ProbabilisticTDModule` instance will fist look for the interaction mode + dictated by the `exploration_mode()` global function. If this returns `None` (its default value), + then the `default_interaction_mode` of the `ProbabilisticTDModule` instance will be used. + Note that DataCollector instances will use `set_exploration_mode` to `"random"` by default. + return_log_prob (bool, optional): if True, the log-probability of the distribution sample will be written in the + tensordict with the key `f'{in_keys[0]}_log_prob'`. Default is `False`. + safe (bool, optional): if True, the value of the sample is checked against the input spec. Out-of-domain sampling can + occur because of exploration policies or numerical under/overflow issues. As for the `spec` argument, + this check will only occur for the distribution sample, but not the other tensors returned by the input + module. If the sample is out of bounds, it is projected back onto the desired space using the + `TensorSpec.project` + method. + Default is `False`. + save_dist_params (bool, optional): if True, the parameters of the distribution (i.e. the output of the module) + will be written to the tensordict along with the sample. Those parameters can be used to + re-compute the original distribution later on (e.g. to compute the divergence between the distribution + used to sample the action and the updated distribution in PPO). + Default is `False`. + cache_dist (bool, optional): if True, the parameters of the distribution (i.e. the output of the module) + will be written to the tensordict along with the sample. Those parameters can be used to + re-compute the original distribution later on (e.g. to compute the divergence between the distribution + used to sample the action and the updated distribution in PPO). + Default is `False`. + + Examples: + >>> from torchrl.data import TensorDict, NdUnboundedContinuousTensorSpec + >>> from torchrl.modules import ProbabilisticTDModule, TanhNormal + >>> import functorch, torch + >>> td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3,]) + >>> spec = NdUnboundedContinuousTensorSpec(4) + >>> module = torch.nn.GRUCell(4, 8) + >>> module_func, params, buffers = functorch.make_functional_with_buffers(module) + >>> td_module = ProbabilisticTDModule( + ... module=module_func, + ... spec=spec, + ... in_keys=["input"], + ... out_keys=["output"], + ... distribution_class=TanhNormal, + ... return_log_prob=True, + ... ) + >>> _ = td_module(td, params=params, buffers=buffers) + >>> print(td) + TensorDict( + fields={ + input: Tensor(torch.Size([3, 4]), dtype=torch.float32), + hidden: Tensor(torch.Size([3, 8]), dtype=torch.float32), + output: Tensor(torch.Size([3, 4]), dtype=torch.float32), + output_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32)}, + batch_size=torch.Size([3]), + device=cpu, + is_shared=False) + + >>> # In the vmap case, the tensordict is again expended to match the batch: + >>> params = tuple(p.expand(4, *p.shape).contiguous().normal_() for p in params) + >>> buffers = tuple(b.expand(4, *b.shape).contiguous().normal_() for p in buffers) + >>> td_vmap = td_module(td, params=params, buffers=buffers, vmap=True) + >>> print(td_vmap) + TensorDict( + fields={ + input: Tensor(torch.Size([3, 4]), dtype=torch.float32), + hidden: Tensor(torch.Size([3, 8]), dtype=torch.float32), + output: Tensor(torch.Size([3, 4]), dtype=torch.float32), + output_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32)}, + batch_size=torch.Size([3]), + device=cpu, + is_shared=False) + + """ + + def __init__( + self, + module: Union[Callable[[Tensor], Tensor], nn.Module], + spec: TensorSpec, + in_keys: Sequence[str], + out_keys: Sequence[str], + distribution_class: Type = Delta, + distribution_kwargs: Optional[dict] = None, + default_interaction_mode: str = "mode", + _n_empirical_est: int = 1000, + return_log_prob: bool = False, + safe: bool = False, + save_dist_params: bool = False, + cache_dist: bool = True, + ): + + super().__init__( + spec=spec, + module=module, + out_keys=out_keys, + in_keys=in_keys, + safe=safe, + ) + + self.save_dist_params = save_dist_params + self._n_empirical_est = _n_empirical_est + self.cache_dist = cache_dist + self._dist = None + + if isinstance(distribution_class, str): + distribution_class = distributions_maps.get(distribution_class.lower()) + self.distribution_class = distribution_class + self.distribution_kwargs = ( + distribution_kwargs if distribution_kwargs is not None else dict() + ) + self.return_log_prob = return_log_prob + + self.default_interaction_mode = default_interaction_mode + self.interact = False + + def get_dist( + self, + tensor_dict: _TensorDict, + **kwargs, + ) -> Tuple[torch.distributions.Distribution, ...]: + """Calls the module using the tensors retrieved from the 'in_keys' attribute and returns a distribution + using its output. + + Args: + tensor_dict (_TensorDict): tensordict with the input values for the creation of the distribution. + + Returns: + a distribution along with other tensors returned by the module. + + """ + tensors = [tensor_dict.get(key, None) for key in self.in_keys] + out_tensors = self._call_module(tensors, **kwargs) + if isinstance(out_tensors, Tensor): + out_tensors = (out_tensors,) + if self.save_dist_params: + for i, _tensor in enumerate(out_tensors): + tensor_dict.set(f"{self.out_keys[0]}_dist_param_{i}", _tensor) + dist, num_params = self.build_dist_from_params(out_tensors) + tensors = out_tensors[num_params:] + + return (dist, *tensors) + + def build_dist_from_params( + self, params: Tuple[Tensor, ...] + ) -> Tuple[d.Distribution, int]: + """Given a tuple of temsors, returns a distribution object and the number of parameters used for it. + + Args: + params (Tuple[Tensor, ...]): tensors to be used for the distribution construction. + + Returns: + a distribution object and the number of parameters used for its construction. + + """ + num_params = ( + getattr(self.distribution_class, "num_params") + if hasattr(self.distribution_class, "num_params") + else 1 + ) + if self.cache_dist and self._dist is not None: + self._dist.update(*params[:num_params]) + dist = self._dist + else: + dist = self.distribution_class( + *params[:num_params], **self.distribution_kwargs + ) + if self.cache_dist: + self._dist = dist + return dist, num_params + + def forward( + self, + tensor_dict: _TensorDict, + tensor_dict_out: Optional[_TensorDict] = None, + **kwargs, + ) -> _TensorDict: + + dist, *tensors = self.get_dist(tensor_dict, **kwargs) + out_tensor = self._dist_sample( + dist, *tensors, interaction_mode=exploration_mode() + ) + tensor_dict_out = self._write_to_tensor_dict( + tensor_dict, + [out_tensor] + list(tensors), + tensor_dict_out, + vmap=kwargs.get("vmap", 0), + ) + if self.return_log_prob: + log_prob = dist.log_prob(out_tensor) + tensor_dict_out.set("_".join([self.out_keys[0], "log_prob"]), log_prob) + return tensor_dict_out + + def log_prob(self, tensor_dict: _TensorDict, **kwargs) -> _TensorDict: + """ + Samples/computes an action using the module and writes this value onto the input tensordict along + with its log-probability. + + Args: + tensor_dict (_TensorDict): tensordict containing the in_keys specified in the initializer. + + Returns: + the same tensordict with the out_keys values added/updated as well as a + f"{out_keys[0]}_log_prob" key containing the log-probability of the first output. + + """ + dist, *_ = self.get_dist(tensor_dict, **kwargs) + lp = dist.log_prob(tensor_dict.get(self.out_keys[0])) + tensor_dict.set(self.out_keys[0] + "_log_prob", lp) + return tensor_dict + + def _dist_sample( + self, + dist: d.Distribution, + *tensors: Tensor, + interaction_mode: bool = None, + eps: float = None, + ) -> Tensor: + if interaction_mode is None: + interaction_mode = self.default_interaction_mode + if not isinstance(dist, d.Distribution): + raise TypeError(f"type {type(dist)} not recognised by _dist_sample") + + if interaction_mode == "mode": + if hasattr(dist, "mode"): + return dist.mode + else: + raise NotImplementedError( + f"method {type(dist)}.mode is not implemented" + ) + + elif interaction_mode == "median": + if hasattr(dist, "median"): + return dist.median + else: + raise NotImplementedError( + f"method {type(dist)}.median is not implemented" + ) + + elif interaction_mode == "mean": + try: + return dist.mean + except AttributeError: + if dist.has_rsample: + return dist.rsample((self._n_empirical_est,)).mean(0) + else: + return dist.sample((self._n_empirical_est,)).mean(0) + + elif interaction_mode == "random": + if dist.has_rsample: + return dist.rsample() + else: + return dist.sample() + elif interaction_mode == "net_output": + if len(tensors) > 1: + raise RuntimeError( + "Multiple values passed to _dist_sample when trying to return a single action " + "tensor." + ) + return tensors[0] + else: + raise NotImplementedError(f"unknown interaction_mode {interaction_mode}") + + @property + def device(self): + for p in self.parameters(): + return p.device + return torch.device("cpu") + + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> ProbabilisticTDModule: + if self.spec is not None: + self.spec = self.spec.to(dest) + out = super().to(dest) + return out + + def __deepcopy__(self, memodict={}): + self._dist = None + cls = self.__class__ + result = cls.__new__(cls) + memodict[id(self)] = result + for k, v in self.__dict__.items(): + setattr(result, k, deepcopy(v, memodict)) + return result + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(module={self.module}, distribution_class={self.distribution_class}, device={self.device})" + + +class TDSequence(TDModule): + """ + A sequence of TDModules. + Similarly to `nn.Sequence` which passes a tensor through a chain of mappings that read and write a single tensor + each, this module will read and write over a tensordict by querying each of the input modules. + When calling a `TDSequence` instance with a functional module, it is expected that the parameter lists (and + buffers) will be concatenated in a single list. + + Args: + modules (iterable of TDModules): ordered sequence of TDModule instances to be run sequentially. + + TDSequence supportse functional, modular and vmap coding: + Examples: + >>> from torchrl.data import TensorDict, NdUnboundedContinuousTensorSpec + >>> from torchrl.modules import ProbabilisticTDModule, TanhNormal, TDSequence + >>> import torch, functorch + >>> td = TensorDict({"input": torch.randn(3, 4)}, [3,]) + >>> spec1 = NdUnboundedContinuousTensorSpec(4) + >>> module1 = torch.nn.Linear(4, 8) + >>> fmodule1, params1, buffers1 = functorch.make_functional_with_buffers(module1) + >>> td_module1 = ProbabilisticTDModule( + ... module=fmodule1, + ... spec=spec1, + ... in_keys=["input"], + ... out_keys=["hidden"], + ... distribution_class=TanhNormal, + ... return_log_prob=True, + ... ) + >>> spec2 = NdUnboundedContinuousTensorSpec(8) + >>> module2 = torch.nn.Linear(4, 8) + >>> fmodule2, params2, buffers2 = functorch.make_functional_with_buffers(module2) + >>> td_module2 = TDModule( + ... module=fmodule2, + ... spec=spec2, + ... in_keys=["hidden"], + ... out_keys=["output"], + ... ) + >>> td_module = TDSequence(td_module1, td_module2) + >>> params = params1 + params2 + >>> buffers = buffers1 + buffers2 + >>> _ = td_module(td, params=params, buffers=buffers) + >>> print(td) + TensorDict( + fields={input: Tensor(torch.Size([3, 4]), dtype=torch.float32), + hidden: Tensor(torch.Size([3, 4]), dtype=torch.float32), + hidden_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32), + output: Tensor(torch.Size([3, 8]), dtype=torch.float32)}, + shared=False, + batch_size=torch.Size([3]), + device=cpu) + + >>> # The module spec aggregates all the input specs: + >>> print(td_module.spec) + CompositeSpec( + hidden: NdUnboundedContinuousTensorSpec( + shape=torch.Size([4]),space=None,device=cpu,dtype=torch.float32,domain=continuous), + output: NdUnboundedContinuousTensorSpec( + shape=torch.Size([8]),space=None,device=cpu,dtype=torch.float32,domain=continuous)) + + In the vmap case: + >>> params = tuple(p.expand(4, *p.shape).contiguous().normal_() for p in params) + >>> buffers = tuple(b.expand(4, *b.shape).contiguous().normal_() for p in buffers) + >>> td_vmap = td_module(td, params=params, buffers=buffers, vmap=True) + >>> print(td_vmap) + + + """ + + def __init__( + self, + *modules: TDModule, + ): + in_keys_tmp = [] + out_keys = [] + for module in modules: + in_keys_tmp += module.in_keys + out_keys += module.out_keys + in_keys = [] + for in_key in in_keys_tmp: + if (in_key not in in_keys) and (in_key not in out_keys): + in_keys.append(in_key) + if not len(in_keys): + raise RuntimeError( + "in_keys empty. Please ensure that there is at least one input " + "key that is not part of the output key set." + ) + out_keys = [ + out_key + for i, out_key in enumerate(out_keys) + if out_key not in out_keys[i + 1 :] + ] + + super().__init__( + spec=None, + module=nn.ModuleList(list(modules)), + in_keys=in_keys, + out_keys=out_keys, + ) + + @property + def param_len(self) -> List[int]: + param_list = [] + prev = 0 + for module in self.module: + param_list.append(len(module.module.param_names) + prev) + prev = param_list[-1] + return param_list + + @property + def buffer_len(self) -> List[int]: + buffer_list = [] + prev = 0 + for module in self.module: + buffer_list.append(len(module.module.buffer_names) + prev) + prev = buffer_list[-1] + return buffer_list + + def _split_param( + self, param_list: Iterable[Tensor], params_or_buffers: str + ) -> Iterable[Iterable[Tensor]]: + if params_or_buffers == "params": + list_out = self.param_len + elif params_or_buffers == "buffers": + list_out = self.buffer_len + list_in = [0] + list_out[:-1] + out = [] + for a, b in zip(list_in, list_out): + out.append(param_list[a:b]) + return out + + def forward(self, tensor_dict: _TensorDict, **kwargs) -> _TensorDict: + if "params" in kwargs and "buffers" in kwargs: + param_splits = self._split_param(kwargs["params"], "params") + buffer_splits = self._split_param(kwargs["buffers"], "buffers") + kwargs_pruned = { + key: item + for key, item in kwargs.items() + if key not in ("params", "buffers") + } + for i, (module, param, buffer) in enumerate( + zip(self.module, param_splits, buffer_splits) + ): # type: ignore + if "vmap" in kwargs_pruned and i > 0: + # the tensordict is already expended + kwargs_pruned["vmap"] = (0, 0, *(0,) * len(module.in_keys)) + tensor_dict = module( + tensor_dict, params=param, buffers=buffer, **kwargs_pruned + ) + + elif "params" in kwargs: + param_splits = self._split_param(kwargs["params"], "params") + kwargs_pruned = { + key: item for key, item in kwargs.items() if key not in ("params",) + } + for i, (module, param) in enumerate( + zip(self.module, param_splits) + ): # type: ignore + if "vmap" in kwargs_pruned and i > 0: + # the tensordict is already expended + kwargs_pruned["vmap"] = (0, *(0,) * len(module.in_keys)) + tensor_dict = module(tensor_dict, params=param, **kwargs_pruned) + + elif not len(kwargs): + for module in self.module: # type: ignore + tensor_dict = module(tensor_dict) + else: + raise RuntimeError( + "TDSequence does not support keyword arguments other than 'params', 'buffers' and 'vmap'" + ) + + return tensor_dict + + def __len__(self): + return len(self.module) # type: ignore + + @property + def spec(self): + kwargs = {} + for layer in self.module: # type: ignore + out_key = layer.out_keys[0] + spec = layer.spec + if spec is None: + # By default, we consider that unspecified specs are unbounded. + spec = UnboundedContinuousTensorSpec() + if not isinstance(spec, TensorSpec): + raise RuntimeError( + f"TDSequence.spec requires all specs to be valid TensorSpec objects. Got " + f"{type(layer.spec)}" + ) + kwargs[out_key] = spec + return CompositeSpec(**kwargs) + + def make_functional_with_buffers(self, clone: bool = False): + """ + Transforms a stateful module in a functional module and returns its parameters and buffers. + Unlike functorch.make_functional_with_buffers, this method supports lazy modules. + + Returns: + A tuple of parameter and buffer tuples + + Examples: + >>> from torchrl.data import NdUnboundedContinuousTensorSpec, TensorDict + >>> lazy_module1 = nn.LazyLinear(4) + >>> lazy_module2 = nn.LazyLinear(3) + >>> spec1 = NdUnboundedContinuousTensorSpec(18) + >>> spec2 = NdUnboundedContinuousTensorSpec(4) + >>> td_module1 = TDModule(spec1, lazy_module1, ["some_input"], ["hidden"]) + >>> td_module2 = TDModule(spec2, lazy_module2, ["hidden"], ["some_output"]) + >>> td_module = TDSequence(td_module1, td_module2) + >>> _, (params, buffers) = td_module.make_functional_with_buffers() + >>> print(params[0].shape) # the lazy module has been initialized + torch.Size([4, 18]) + >>> print(td_module( + ... TensorDict({'some_input': torch.randn(18)}, batch_size=[]), + ... params=params, + ... buffers=buffers)) + TensorDict( + fields={ + some_input: Tensor(torch.Size([18]), dtype=torch.float32), + hidden: Tensor(torch.Size([4]), dtype=torch.float32), + some_output: Tensor(torch.Size([3]), dtype=torch.float32)}, + batch_size=torch.Size([]), + device=cpu, + is_shared=False) + + """ + if clone: + self_copy = copy(self) + self_copy.module = copy(self_copy.module) + else: + self_copy = self + params = [] + buffers = [] + for i, module in enumerate(self.module): # type: ignore + self_copy.module[i], ( + _params, + _buffers, + ) = module.make_functional_with_buffers() + params.extend(_params) + buffers.extend(_buffers) + return self_copy, (params, buffers) + + +class TDModuleWrapper(nn.Module): + """ + Wrapper calss for TDModule objects. + Once created, a TDModuleWrapper will behave exactly as the TDModule it contains except for the methods that are + overwritten. + + Args: + probabilistic_operator (TDModule): operator to be wrapped. + + Examples: + >>> # This class can be used for exploration wrappers + >>> import functorch + >>> from torchrl.modules import TDModuleWrapper, TDModule + >>> from torchrl.data import TensorDict, NdUnboundedContinuousTensorSpec, expand_as_right + >>> import torch + >>> + >>> class EpsilonGreedyExploration(TDModuleWrapper): + ... eps = 0.5 + ... def forward(self, tensordict, params, buffers): + ... rand_output_clone = self.random(tensordict.clone()) + ... det_output_clone = self.td_module(tensordict.clone(), params, buffers) + ... rand_output_idx = torch.rand(tensordict.shape, device=rand_output_clone.device) < self.eps + ... for key in self.out_keys: + ... _rand_output = rand_output_clone.get(key) + ... _det_output = det_output_clone.get(key) + ... rand_output_idx_expand = expand_as_right(rand_output_idx, _rand_output).to(_rand_output.dtype) + ... tensordict.set(key, + ... rand_output_idx_expand * _rand_output + (1-rand_output_idx_expand) * _det_output) + ... return tensordict + >>> + >>> td = TensorDict({"input": torch.zeros(10, 4)}, [10]) + >>> module = torch.nn.Linear(4, 4, bias=False) # should return a zero tensor if input is a zero tensor + >>> fmodule, params, buffers = functorch.make_functional_with_buffers(module) + >>> spec = NdUnboundedContinuousTensorSpec(4) + >>> tdmodule = TDModule(module=fmodule, spec=spec, in_keys=["input"], out_keys=["output"]) + >>> tdmodule_wrapped = EpsilonGreedyExploration(tdmodule) + >>> tdmodule_wrapped(td, params=params, buffers=buffers) + >>> print(td.get("output")) + """ + + def __init__(self, probabilistic_operator: TDModule): + super().__init__() + self.td_module = probabilistic_operator + if len(self.td_module._forward_hooks): + for pre_hook in self.td_module._forward_hooks: + self.register_forward_hook(self.td_module._forward_hooks[pre_hook]) + + def __getattr__(self, name: str) -> Any: + try: + return super().__getattr__(name) + except AttributeError: + if name not in self.__dict__: + return getattr(self._modules["td_module"], name) + else: + raise AttributeError( + f"attribute {name} not recognised in {type(self).__name__}" + ) + + def forward(self, *args, **kwargs): + return self.td_module.forward(*args, **kwargs) diff --git a/torchrl/modules/td_module/exploration.py b/torchrl/modules/td_module/exploration.py new file mode 100644 index 00000000000..d85cdd7cbcd --- /dev/null +++ b/torchrl/modules/td_module/exploration.py @@ -0,0 +1,319 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Union + +import numpy as np +import torch + +from torchrl.data.utils import expand_as_right +from torchrl.envs.utils import exploration_mode +from torchrl.modules.td_module.common import ( + _forward_hook_safe_action, + TDModule, + TDModuleWrapper, +) + +__all__ = ["EGreedyWrapper", "OrnsteinUhlenbeckProcessWrapper"] + +from torchrl.data.tensordict.tensordict import _TensorDict + + +class EGreedyWrapper(TDModuleWrapper): + """ + Epsilon-Greedy PO wrapper. + + Args: + policy (TDModule): a deterministic policy. + eps_init (scalar): initial epsilon value. + default: 1.0 + eps_end (scalar): final epsilon value. + default: 0.1 + annealing_num_steps (int): number of steps it will take for epsilon to reach the eps_end value + + Examples: + >>> from torchrl.modules import EGreedyWrapper, Actor + >>> from torchrl.data import NdBoundedTensorSpec, TensorDict + >>> import torch + >>> torch.manual_seed(0) + >>> spec = NdBoundedTensorSpec(-1, 1, torch.Size([4])) + >>> module = torch.nn.Linear(4, 4, bias=False) + >>> policy = Actor(spec, module=module) + >>> explorative_policy = EGreedyWrapper(policy, eps_init=0.2) + >>> td = TensorDict({"observation": torch.zeros(10, 4)}, batch_size=[10]) + >>> print(explorative_policy(td).get("action")) + tensor([[ 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000], + [-0.6986, -0.9366, -0.5837, 0.8596], + [ 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000]], grad_fn=) + + """ + + def __init__( + self, + policy: TDModule, + eps_init: float = 1.0, + eps_end: float = 0.1, + annealing_num_steps: int = 1000, + ): + super().__init__(policy) + self.register_buffer("eps_init", torch.tensor([eps_init])) + self.register_buffer("eps_end", torch.tensor([eps_end])) + if self.eps_end > self.eps_init: + raise RuntimeError("eps should decrease over time or be constant") + self.annealing_num_steps = annealing_num_steps + self.register_buffer("eps", torch.tensor([eps_init])) + + def step(self, frames: int = 1) -> None: + """A step of epsilon decay. + After self.annealing_num_steps, this function is a no-op. + + Args: + frames (int): number of frames since last step. + + """ + for _ in range(frames): + self.eps.data[0] = max( + self.eps_end.item(), + ( + self.eps - (self.eps_init - self.eps_end) / self.annealing_num_steps + ).item(), + ) + + def forward(self, tensordict: _TensorDict) -> _TensorDict: + tensordict = self.td_module.forward(tensordict) + if exploration_mode() == "random" or exploration_mode() is None: + out = tensordict.get(self.td_module.out_keys[0]) + eps = self.eps.item() + cond = (torch.rand(tensordict.shape, device=tensordict.device) < eps).to( + out.dtype + ) + cond = expand_as_right(cond, out) + out = ( + cond * self.td_module.spec.rand(tensordict.shape).to(out.device) + + (1 - cond) * out + ) + tensordict.set(self.td_module.out_keys[0], out) + return tensordict + + +class OrnsteinUhlenbeckProcessWrapper(TDModuleWrapper): + """ + Ornstein-Uhlenbeck exploration policy wrapper as presented in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING", + https://arxiv.org/pdf/1509.02971.pdf. + + The OU exploration is to be used with continuous control policies and introduces a auto-correlated exploration + noise. This enables a sort of 'structured' exploration. + + Noise equation: + noise = prev_noise + theta * (mu - prev_noise) * dt + current_sigma * sqrt(dt) * W + Sigma equation: + current_sigma = (-(sigma - sigma_min) / (n_steps_annealing) * n_steps + sigma).clamp_min(sigma_min) + + To keep track of the steps and noise from sample to sample, an `"ou_prev_noise{id}"` and `"ou_steps{id}"` keys + will be written in the input/output tensordict. It is expected that the tensordict will be zeroed at reset, + indicating that a new trajectory is being collected. If not, and is the same tensordict is used for consecutive + trajectories, the step count will keep on increasing across rollouts. Note that the collector classes take care of + zeroing the tensordict at reset time. + + Args: + policy (TDModule): a policy + eps_init (scalar): initial epsilon value, determining the amount of noise to be added. + default: 1.0 + eps_end (scalar): final epsilon value, determining the amount of noise to be added. + default: 0.1 + annealing_num_steps (int): number of steps it will take for epsilon to reach the eps_end value. + default: 1000 + theta (scalar): theta factor in the noise equation + default: 0.15 + mu (scalar): OU average (mu in the noise equation). + default: 0.0 + sigma (scalar): sigma value in the sigma equation. + default: 0.2 + dt (scalar): dt in the noise equation. + default: 0.01 + x0 (Tensor, ndarray, optional): initial value of the process. + default: 0.0 + sigma_min (number, optional): sigma_min in the sigma equation. + default: None + n_steps_annealing (int): number of steps for the sigma annealing. + default: 1000 + key (str): key of the action to be modified. + default: "action" + safe (bool): if True, actions that are out of bounds given the action specs will be projected in the space + given the `TensorSpec.project` heuristic. + default: True + + Examples: + >>> from torchrl.modules import OrnsteinUhlenbeckProcessWrapper, Actor + >>> from torchrl.data import NdBoundedTensorSpec, TensorDict + >>> import torch + >>> torch.manual_seed(0) + >>> spec = NdBoundedTensorSpec(-1, 1, torch.Size([4])) + >>> module = torch.nn.Linear(4, 4, bias=False) + >>> policy = Actor(spec, module=module) + >>> explorative_policy = OrnsteinUhlenbeckProcessWrapper(policy) + >>> td = TensorDict({"observation": torch.zeros(10, 4)}, batch_size=[10]) + >>> print(explorative_policy(td)) + """ + + def __init__( + self, + policy: TDModule, + eps_init: float = 1.0, + eps_end: float = 0.1, + annealing_num_steps: int = 1000, + theta: float = 0.15, + mu: float = 0.0, + sigma: float = 0.2, + dt: float = 1e-2, + x0: Optional[Union[torch.Tensor, np.ndarray]] = None, + sigma_min: Optional[float] = None, + n_steps_annealing: int = 1000, + key: str = "action", + safe: bool = True, + ): + super().__init__(policy) + self.ou = _OrnsteinUhlenbeckProcess( + theta=theta, + mu=mu, + sigma=sigma, + dt=dt, + x0=x0, + sigma_min=sigma_min, + n_steps_annealing=n_steps_annealing, + key=key, + ) + self.register_buffer("eps_init", torch.tensor([eps_init])) + self.register_buffer("eps_end", torch.tensor([eps_end])) + if self.eps_end > self.eps_init: + raise ValueError( + "eps should decrease over time or be constant, " + f"got eps_init={eps_init} and eps_end={eps_end}" + ) + self.annealing_num_steps = annealing_num_steps + self.register_buffer("eps", torch.tensor([eps_init])) + self.out_keys = list(self.td_module.out_keys) + [self.ou.out_keys] + self.safe = safe + if self.safe: + self.register_forward_hook(_forward_hook_safe_action) + + def step(self, frames: int = 1) -> None: + """Updates the eps noise factor. + + Args: + frames (int): number of frames of the current batch (corresponding to the number of updates to be made). + + """ + for _ in range(frames): + if self.annealing_num_steps > 0: + self.eps.data[0] = max( + self.eps_end.item(), + ( + self.eps + - (self.eps_init - self.eps_end) / self.annealing_num_steps + ).item(), + ) + else: + raise ValueError( + f"{self.__class__.__name__}.step() called when " + f"self.annealing_num_steps={self.annealing_num_steps}. Expected a strictly positive " + f"number of frames." + ) + + def forward(self, tensor_dict: _TensorDict) -> _TensorDict: + tensor_dict = super().forward(tensor_dict) + if exploration_mode() == "random" or exploration_mode() is None: + tensor_dict = self.ou.add_sample(tensor_dict, self.eps.item()) + return tensor_dict + + +# Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab +class _OrnsteinUhlenbeckProcess: + def __init__( + self, + theta: float, + mu: float = 0.0, + sigma: float = 0.2, + dt: float = 1e-2, + x0: Optional[Union[torch.Tensor, np.ndarray]] = None, + sigma_min: Optional[float] = None, + n_steps_annealing: int = 1000, + key: str = "action", + ): + self.mu = mu + self.sigma = sigma + + if sigma_min is not None: + self.m = -float(sigma - sigma_min) / float(n_steps_annealing) + self.c = sigma + self.sigma_min = sigma_min + else: + self.m = 0.0 + self.c = sigma + self.sigma_min = sigma + + self.theta = theta + self.mu = mu + self.dt = dt + self.x0 = x0 if x0 is not None else 0.0 + self.key = key + self._noise_key = "_ou_prev_noise" + self._steps_key = "_ou_steps" + self.out_keys = [self.key, self.noise_key, self.steps_key] + + @property + def noise_key(self): + return self._noise_key # + str(id(self)) + + @property + def steps_key(self): + return self._steps_key # + str(id(self)) + + def _make_noise_pair(self, tensor_dict: _TensorDict) -> None: + tensor_dict.set( + self.noise_key, + torch.zeros(tensor_dict.get(self.key).shape, device=tensor_dict.device), + ) + tensor_dict.set( + self.steps_key, + torch.zeros( + torch.Size([*tensor_dict.batch_size, 1]), + dtype=torch.long, + device=tensor_dict.device, + ), + ) + + def add_sample(self, tensor_dict: _TensorDict, eps: float = 1.0) -> _TensorDict: + + if self.noise_key not in set(tensor_dict.keys()): + self._make_noise_pair(tensor_dict) + + prev_noise = tensor_dict.get(self.noise_key) + prev_noise = prev_noise + self.x0 + + n_steps = tensor_dict.get(self.steps_key) + + noise = ( + prev_noise + + self.theta * (self.mu - prev_noise) * self.dt + + self.current_sigma(n_steps) + * np.sqrt(self.dt) + * torch.randn_like(prev_noise) + ) + tensor_dict.set_(self.noise_key, noise - self.x0) + tensor_dict.set_(self.key, tensor_dict.get(self.key) + eps * noise) + tensor_dict.set_(self.steps_key, n_steps + 1) + return tensor_dict + + def current_sigma(self, n_steps: torch.Tensor) -> torch.Tensor: + sigma = (self.m * n_steps + self.c).clamp_min(self.sigma_min) + return sigma diff --git a/torchrl/modules/utils/__init__.py b/torchrl/modules/utils/__init__.py new file mode 100644 index 00000000000..862eb75da97 --- /dev/null +++ b/torchrl/modules/utils/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .mappings import * diff --git a/torchrl/modules/utils/functorch.py b/torchrl/modules/utils/functorch.py new file mode 100644 index 00000000000..1895bf42144 --- /dev/null +++ b/torchrl/modules/utils/functorch.py @@ -0,0 +1,138 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import functorch + + +def get_params_of_module(module, cf, p, b): + split_name_dict = make_split_names_dict(cf.split_names, p, b) + names, values = _get_params_of_module(module, cf.stateless_model, split_name_dict) + S = set(values) + S_param = S.intersection(set(p)) + S_buffer = S.intersection(set(b)) + + name_p_dict = {_p: _name for _p, _name in zip(values, names)} + param_names, params = zip(*[(name_p_dict[_p], _p) for _p in p if _p in S_param]) + buffer_names, buffers = zip(*[(name_p_dict[_b], _b) for _b in b if _b in S_buffer]) + + fmodule = functorch.FunctionalModuleWithBuffers(module, param_names, buffer_names) + return fmodule, params, buffers + + +def _get_params_of_module(module, target, split_name_dict): + if target is module: + return _get_params(split_name_dict) + else: + found = False + for name in split_name_dict: + sub_target = getattr(target, name) + sub_split_name_dict = split_name_dict[name] + if isinstance(sub_split_name_dict, dict): + out = _get_params_of_module(module, sub_target, sub_split_name_dict) + if out: + return out + return found + + +def _get_params(dictionary): + out = [] + for key, value in dictionary.items(): + if not isinstance(value, dict): + out.append((key, value)) + else: + _out = [ + (".".join([key, _key]), _val) + for (_key, _val) in zip(*_get_params(value)) + ] + out += _out + return tuple(zip(*out)) + + +def get_item(d, name, p): + _d = d[name[0]] + if isinstance(_d, dict): + get_item(_d, name[1:], p) + else: + p.append(_d) + + +def populate_params(split_names, d, p=None): + if p is None: + p = [] + for name in split_names: + get_item(d, name, p) + return p + + +class apply_to_class: + def __init__(self, layer_type): + self.layer_type = layer_type + + def __call__(self, func): + def new_func(cf, p, b, **kwargs): + split_name_dict = make_split_names_dict(cf.split_names, p, b) + d = self.dispatch_to_layers( + func, self.layer_type, split_name_dict, cf.stateless_model + ) + new_p = populate_params(cf.split_names, d) + new_p, new_b = new_p[: len(p)], new_p[len(p) :] + return new_p, new_b + + return new_func + + @staticmethod + def dispatch_to_layers(func, layer_type, split_name_dict, cf): + if isinstance(cf, layer_type): + split_name_dict.update(func(cf, split_name_dict)) + + for layer_name in split_name_dict: + layer_or_param = getattr(cf, layer_name) + if isinstance(layer_or_param, nn.Module): + split_name_dict[layer_name] = apply_to_class.dispatch_to_layers( + func, + layer_type, + split_name_dict[layer_name], + layer_or_param, + ) + return split_name_dict + + +def make_split_names_dict(split_names, p, b=[], split_name_dict=None): + if split_name_dict is None: + split_name_dict = dict() + + _firsts = dict() + for name, param in zip(split_names, list(p) + list(b)): + if len(name) > 1: + layer_list = _firsts.get(name[0], []) + layer_list.append((name[1:], param)) + _firsts[name[0]] = layer_list + else: + split_name_dict[name[0]] = param + for key in _firsts: + _names, _params = zip(*_firsts[key]) + _dict = make_split_names_dict(_names, _params) + split_name_dict[key] = _dict + return split_name_dict + + +def get_submodule_functional(module, cf): + p = [i for i, _ in enumerate(cf.param_names)] + b = [len(p) + i for i, _ in enumerate(cf.buffer_names)] + split_name_dict = make_split_names_dict(cf.split_names, p, b) + names, values = _get_params_of_module(module, cf.stateless_model, split_name_dict) + S = set(values) + S_param = S.intersection(set(p)) + S_buffer = S.intersection(set(b)) + + name_p_dict = {_p: _name for _p, _name in zip(values, names)} + param_names, params = zip(*[(name_p_dict[_p], _p) for _p in p if _p in S_param]) + if len(S_buffer): + buffer_names, _ = zip(*[(name_p_dict[_b], _b) for _b in b if _b in S_buffer]) + else: + buffer_names, _ = tuple(), tuple() + + fmodule = functorch.FunctionalModuleWithBuffers(module, param_names, buffer_names) + return fmodule diff --git a/torchrl/modules/utils/mappings.py b/torchrl/modules/utils/mappings.py new file mode 100644 index 00000000000..149c81d43e4 --- /dev/null +++ b/torchrl/modules/utils/mappings.py @@ -0,0 +1,79 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable + +import torch +from torch import nn + +__all__ = ["mappings", "inv_softplus", "biased_softplus"] + + +def inv_softplus(bias: float): + """ + inverse softplus function. + + """ + return torch.tensor(bias).expm1().clamp_min(1e-6).log().item() + + +class biased_softplus(nn.Module): + """ + A biased softplus layer. + Args: + bias (scalar): 'bias' of the softplus transform. If bias=1.0, then a _bias shift will be computed such that + softplus(0.0 + _bias) = bias. + min_val (scalar): minimum value of the transform. + default: 0.1 + """ + + def __init__(self, bias: float, min_val: float = 0.01): + super().__init__() + self.bias = inv_softplus(bias - min_val) + self.min_val = min_val + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.softplus(x + self.bias) + self.min_val + + +def expln(x): + """ + A smooth, continuous positive mapping presented in "State-Dependent + Exploration for Policy Gradient Methods" + https://people.idsia.ch/~juergen/ecml2008rueckstiess.pdf + + """ + out = torch.empty_like(x) + idx_neg = x <= 0 + out[idx_neg] = x[idx_neg].exp() + out[~idx_neg] = x[~idx_neg].log1p() + 1 + return out + + +def mappings(key: str) -> Callable: + """ + Given an input string, return a surjective function f(x): R -> R^+ + + Args: + key (str): one of "softplus", "exp", "relu", "expln", + or "biased_softplus". + + Returns: + a Callable + + """ + _mappings = { + "softplus": torch.nn.functional.softplus, + "exp": torch.exp, + "relu": torch.relu, + "biased_softplus": biased_softplus(1.0), + "expln": expln, + } + if key in _mappings: + return _mappings[key] + elif key.startswith("biased_softplus"): + return biased_softplus(float(key.split("_")[-1])) + else: + raise NotImplementedError(f"Unknown mapping {key}") diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py new file mode 100644 index 00000000000..458379ebfb8 --- /dev/null +++ b/torchrl/objectives/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .costs import * +from .returns import * diff --git a/torchrl/objectives/costs/__init__.py b/torchrl/objectives/costs/__init__.py new file mode 100644 index 00000000000..bbfbc776cae --- /dev/null +++ b/torchrl/objectives/costs/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .ddpg import * +from .dqn import * +from .ppo import * +from .sac import * +from .redq import * +from .utils import * diff --git a/torchrl/objectives/costs/common.py b/torchrl/objectives/costs/common.py new file mode 100644 index 00000000000..81478eaf761 --- /dev/null +++ b/torchrl/objectives/costs/common.py @@ -0,0 +1,230 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +__all__ = ["_LossModule"] + +from typing import Iterator, Optional, Tuple + +import functorch +import torch +from functorch._src.make_functional import _swap_state +from torch import nn +from torch.nn import Parameter + +from torchrl.data.tensordict.tensordict import _TensorDict +from torchrl.modules import TDModule + + +class _LossModule(nn.Module): + """ + A parent class for RL losses. + _LossModule inherits from nn.Module. It is designed to read an input TensorDict and return another tensordict + with loss keys named "loss_*". + Splitting the loss in its component can then be used by the agent to log the various loss values throughout + training. Other scalars present in the output tensordict will be logged too. + """ + + def forward(self, tensordict: _TensorDict) -> _TensorDict: + """It is designed to read an input TensorDict and return another tensordict + with loss keys named "loss*". + Splitting the loss in its component can then be used by the agent to log the various loss values throughout + training. Other scalars present in the output tensordict will be logged too. + + Args: + tensordict: an input tensordict with the values required to compute the loss. + + Returns: + A new tensordict with no batch dimension containing various loss scalars which will be named "loss*". It + is essential that the losses are returned with this name as they will be read by the agent before + backpropagation. + """ + raise NotImplementedError + + def convert_to_functional( + self, + module: TDModule, + module_name: str, + expand_dim: Optional[int] = None, + create_target_params: bool = False, + ) -> None: + # To make it robust to device casting, we must register list of + # tensors as lazy calls to `getattr(self, name_of_tensor)`. + # Otherwise, casting the module to a device will keep old references + # to uncast tensors + + network_orig = module + if hasattr(module, "make_functional_with_buffers"): + functional_module, ( + _, + module_buffers, + ) = module.make_functional_with_buffers(clone=True) + else: + ( + functional_module, + module_params, + module_buffers, + ) = functorch.make_functional_with_buffers(module) + for _ in functional_module.parameters(): + # Erase meta params + none_state = [None for _ in module_params + module_buffers] + if hasattr(functional_module, "all_names_map"): + # functorch >= 0.2.0 + _swap_state( + functional_module.stateless_model, + functional_module.all_names_map, + none_state, + ) + else: + # functorch < 0.2.0 + _swap_state( + functional_module.stateless_model, + functional_module.split_names, + none_state, + ) + break + del module_params + + param_name = module_name + "_params" + + # we keep the original parameters and not the copy returned by functorch + params = network_orig.parameters() + + # unless we need to expand them, in that case we'll delete the weights to make sure that the user does not + # run anything with them expecting them to be updated + params = list(params) + module_buffers = list(module_buffers) + + if expand_dim: + for i, p in enumerate(params): + p = p.repeat(expand_dim, *[1 for _ in p.shape]) + p = nn.Parameter( + p.uniform_(p.min().item(), p.max().item()).requires_grad_() + ) + params[i] = p + + for i, b in enumerate(module_buffers): + b = b.expand(expand_dim, *b.shape).clone() + module_buffers[i] = b + + # delete weights of original model as they do not correspond to the optimized weights + network_orig.to("meta") + + setattr(self, param_name, nn.ParameterList(params)) + + # we register each buffer independently + for i, p in enumerate(module_buffers): + _name = module_name + f"_buffer_{i}" + self.register_buffer(_name, p) + # replace buffer by its name + module_buffers[i] = _name + buffer_name = module_name + "_buffers" + setattr( + self.__class__, + buffer_name, + property(lambda _self: [getattr(_self, _name) for _name in module_buffers]), + ) + + # we set the functional module + setattr(self, module_name, functional_module) + + name_params_target = "_target_" + param_name + name_buffers_target = "_target_" + buffer_name + if create_target_params: + target_params = [p.detach().clone() for p in getattr(self, param_name)] + for i, p in enumerate(target_params): + name = "_".join([name_params_target, str(i)]) + self.register_buffer(name, p) + target_params[i] = name + setattr( + self.__class__, + name_params_target, + property( + lambda _self: [getattr(_self, _name) for _name in target_params] + ), + ) + + target_buffers = [p.detach().clone() for p in getattr(self, buffer_name)] + for i, p in enumerate(target_buffers): + name = "_".join([name_buffers_target, str(i)]) + self.register_buffer(name, p) + target_buffers[i] = name + setattr( + self.__class__, + name_buffers_target, + property( + lambda _self: [getattr(_self, _name) for _name in target_buffers] + ), + ) + + else: + setattr(self.__class__, name_params_target, None) + setattr(self.__class__, name_buffers_target, None) + + setattr( + self.__class__, + name_params_target[1:], + property(lambda _self: self._target_param_getter(module_name)), + ) + setattr( + self.__class__, + name_buffers_target[1:], + property(lambda _self: self._target_buffer_getter(module_name)), + ) + + def _target_param_getter(self, network_name): + target_name = "_target_" + network_name + "_params" + param_name = network_name + "_params" + if hasattr(self, target_name): + target_params = getattr(self, target_name) + if target_params is not None: + return tuple(target_params) + else: + return tuple(p.detach() for p in getattr(self, param_name)) + + else: + raise RuntimeError( + f"{self.__class__.__name__} does not have the target param {target_name}" + ) + + def _target_buffer_getter(self, network_name): + target_name = "_target_" + network_name + "_buffers" + buffer_name = network_name + "_buffers" + if hasattr(self, target_name): + target_buffers = getattr(self, target_name) + if target_buffers is not None: + return tuple(target_buffers) + else: + return tuple(p.detach() for p in getattr(self, buffer_name)) + + else: + raise RuntimeError( + f"{self.__class__.__name__} does not have the target buffer {target_name}" + ) + + def _networks(self) -> Iterator[nn.Module]: + for item in self.__dir__(): + if isinstance(item, nn.Module): + yield item + + @property + def device(self) -> torch.device: + for p in self.parameters(): + return p.device + return torch.device("cpu") + + def parameters(self, recurse: bool = True) -> Iterator[Parameter]: + for name, param in self.named_parameters(recurse=recurse): + yield param + + def named_parameters( + self, prefix: str = "", recurse: bool = True + ) -> Iterator[Tuple[str, Parameter]]: + for name, param in super().named_parameters(prefix=prefix, recurse=recurse): + if not name.startswith("_target"): + yield name, param + + def reset(self) -> None: + # mainly used for PPO with KL target + pass diff --git a/torchrl/objectives/costs/ddpg.py b/torchrl/objectives/costs/ddpg.py new file mode 100644 index 00000000000..c4db21ccb5e --- /dev/null +++ b/torchrl/objectives/costs/ddpg.py @@ -0,0 +1,170 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from typing import Tuple + +import torch + +from torchrl.data.tensordict.tensordict import _TensorDict, TensorDict +from torchrl.modules import TDModule +from torchrl.modules.td_module.actors import ActorCriticWrapper +from torchrl.objectives.costs.utils import ( + distance_loss, + hold_out_params, + next_state_value, +) +from .common import _LossModule + + +class DDPGLoss(_LossModule): + """ + The DDPG Loss class. + Args: + actor_network (TDModule): a policy operator. + value_network (TDModule): a Q value operator. + gamma (scalar): a discount factor for return computation. + device (str, int or torch.device, optional): a device where the losses will be computed, if it can't be found + via the value operator. + loss_function (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1". + """ + + delay_actor: bool = False + delay_value: bool = False + + def __init__( + self, + actor_network: TDModule, + value_network: TDModule, + gamma: float, + loss_function: str = "l2", + ) -> None: + super().__init__() + self.convert_to_functional( + actor_network, + "actor_network", + create_target_params=self.delay_actor, + ) + self.convert_to_functional( + value_network, + "value_network", + create_target_params=self.delay_value, + ) + + self.actor_in_keys = actor_network.in_keys + + self.gamma = gamma + self.loss_funtion = loss_function + + def forward(self, input_tensor_dict: _TensorDict) -> TensorDict: + """Computes the DDPG losses given a tensordict sampled from the replay buffer. + This function will also write a "td_error" key that can be used by prioritized replay buffers to assign + a priority to items in the tensordict. + + Args: + input_tensor_dict (_TensorDict): a tensordict with keys ["done", "reward"] and the in_keys of the actor + and value networks. + + Returns: + a tuple of 2 tensors containing the DDPG loss. + + """ + if not input_tensor_dict.device == self.device: + raise RuntimeError( + f"Got device={input_tensor_dict.device} but actor_network.device={self.device} " + f"(self.device={self.device})" + ) + + loss_value, td_error, pred_val, target_value = self._loss_value( + input_tensor_dict, + ) + td_error = td_error.detach() + td_error = td_error.unsqueeze(input_tensor_dict.ndimension()) + td_error = td_error.to(input_tensor_dict.device) + input_tensor_dict.set( + "td_error", + td_error, + inplace=True, + ) + loss_actor = self._loss_actor(input_tensor_dict) + return TensorDict( + source={ + "loss_actor": loss_actor.mean(), + "loss_value": loss_value.mean(), + "pred_value": pred_val.mean().detach(), + "target_value": target_value.mean().detach(), + "pred_value_max": pred_val.max().detach(), + "target_value_max": target_value.max().detach(), + }, + batch_size=[], + ) + + def _loss_actor( + self, + tensor_dict: _TensorDict, + ) -> torch.Tensor: + td_copy = tensor_dict.select(*self.actor_in_keys).detach() + td_copy = self.actor_network( + td_copy, + params=self.actor_network_params, + buffers=self.actor_network_buffers, + ) + with hold_out_params(self.value_network_params) as params: + td_copy = self.value_network( + td_copy, params=params, buffers=self.value_network_buffers + ) + return -td_copy.get("state_action_value") + + def _loss_value( + self, + tensor_dict: _TensorDict, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # value loss + td_copy = tensor_dict.select(*self.value_network.in_keys).detach() + self.value_network( + td_copy, + params=self.value_network_params, + buffers=self.value_network_buffers, + ) + pred_val = td_copy.get("state_action_value").squeeze(-1) + + actor_critic = ActorCriticWrapper(self.actor_network, self.value_network) + target_params = list(self.target_actor_network_params) + list( + self.target_value_network_params + ) + target_buffers = list(self.target_actor_network_buffers) + list( + self.target_value_network_buffers + ) + target_value = next_state_value( + tensor_dict, + actor_critic, + gamma=self.gamma, + params=target_params, + buffers=target_buffers, + ) + + # td_error = pred_val - target_value + loss_value = distance_loss( + pred_val, target_value, loss_function=self.loss_funtion + ) + + return loss_value, abs(pred_val - target_value), pred_val, target_value + + +class DoubleDDPGLoss(DDPGLoss): + """ + A Double DDPG loss class. + As for Double DQN loss, this class separates the target value/actor networks from the value/actor networks used for + data collection. Those target networks should be updated from their original counterparts with some delay using + dedicated classes (SoftUpdate and HardUpdate in objectives.cost.utils). + Note that the original networks will be copied at initialization using the copy.deepcopy method: in some rare cases + this may lead to unexpected behaviours (for instance if the networks change in a way that won't be reflected by their + state_dict). Please report any such bug if encountered. + + """ + + delay_actor: bool = True + delay_value: bool = True diff --git a/torchrl/objectives/costs/dqn.py b/torchrl/objectives/costs/dqn.py new file mode 100644 index 00000000000..800511410ec --- /dev/null +++ b/torchrl/objectives/costs/dqn.py @@ -0,0 +1,316 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from torchrl.data import TensorDict +from torchrl.envs.utils import step_tensor_dict +from torchrl.modules import ( + DistributionalQValueActor, + QValueActor, +) +from ...data.tensordict.tensordict import _TensorDict +from .common import _LossModule +from .utils import distance_loss, next_state_value + +__all__ = [ + "DQNLoss", + "DoubleDQNLoss", + "DistributionalDQNLoss", + "DistributionalDoubleDQNLoss", +] + + +class DQNLoss(_LossModule): + """ + The DQN Loss class. + Args: + value_network (ProbabilisticTDModule): a Q value operator. + gamma (scalar): a discount factor for return computation. + loss_function (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1". + """ + + delay_value: bool = False + + def __init__( + self, + value_network: QValueActor, + gamma: float, + loss_function: str = "l2", + priority_key: str = "td_error", + ) -> None: + + super().__init__() + self.convert_to_functional( + value_network, + "value_network", + create_target_params=self.delay_value, + ) + + self.value_network_in_keys = value_network.in_keys + if not isinstance(value_network, QValueActor): + raise TypeError( + f"DQNLoss requires value_network to be of QValueActor dtype, got {type(value_network)}" + ) + self.gamma = gamma + self.loss_function = loss_function + self.priority_key = priority_key + + def forward(self, input_tensor_dict: _TensorDict) -> TensorDict: + """ + Computes the DQN loss given a tensordict sampled from the replay buffer. + This function will also write a "td_error" key that can be used by prioritized replay buffers to assign + a priority to items in the tensordict. + + Args: + input_tensor_dict (_TensorDict): a tensordict with keys ["done", "reward", "action"] and the in_keys of + the value network. + + Returns: + a tensor containing the DQN loss. + + """ + + device = self.device if self.device is not None else input_tensor_dict.device + tensor_dict = input_tensor_dict.to(device) + if tensor_dict.device != device: + raise RuntimeError( + f"device {device} was expected for " + f"{tensor_dict.__class__.__name__} but {tensor_dict.device} was found" + ) + + for k, t in tensor_dict.items(): + if t.device != device: + raise RuntimeError( + f"found key value pair {k}-{t.shape} " + f"with device {t.device} when {device} was required" + ) + + action = tensor_dict.get("action") + + action = action.to(torch.float) + td_copy = tensor_dict.clone() + if td_copy.device != tensor_dict.device: + raise RuntimeError(f"{tensor_dict} and {td_copy} have different devices") + self.value_network( + td_copy, + params=self.value_network_params, + buffers=self.value_network_buffers, + ) + + pred_val = td_copy.get("action_value") + pred_val_index = (pred_val * action).sum(-1) + + with torch.no_grad(): + target_value = next_state_value( + tensor_dict, + self.value_network, + gamma=self.gamma, + params=self.target_value_network_params, + buffers=self.target_value_network_buffers, + next_val_key="chosen_action_value", + ) + priority_tensor = abs(pred_val_index - target_value) + priority_tensor = priority_tensor.detach().unsqueeze(-1) + priority_tensor = priority_tensor.to(input_tensor_dict.device) + + input_tensor_dict.set( + self.priority_key, + priority_tensor, + inplace=True, + ) + loss = distance_loss(pred_val_index, target_value, self.loss_function) + return TensorDict({"loss": loss.mean()}, []) + + +class DoubleDQNLoss(DQNLoss): + """ + A Double DQN loss class. + This class duplicates the value network into a new target value network, which differs from the value networks used + for data collection in that it has a similar weight configuration but delayed of a certain number of optimization + steps. The target network should be updated from its original counterpart with some delay using dedicated classes + (SoftUpdate and HardUpdate in objectives.cost.utils). + More information on double DQN can be found in "Deep Reinforcement Learning with Double Q-learning", + https://arxiv.org/abs/1509.06461. + + Note that the original network will be copied at initialization using the copy.deepcopy method: in some rare cases + this may lead to unexpected behaviours (for instance if the network changes in a way that won't be reflected by its + state_dict). Please report any such bug if encountered. + + """ + + delay_value: bool = True + + +class DistributionalDQNLoss(_LossModule): + """ + A distributional DQN loss class. + Distributional DQN uses a value network that outputs a distribution of + values over a discrete support of discounted returns (unlike regular DQN + where the value network outputs a single point prediction of the + disctounted return). + + For more details regarding Distributional DQN, refer to "A Distributional + Perspective on Reinforcement Learning", + https://arxiv.org/pdf/1707.06887.pdf + + Args: + value_network (DistributionalQValueActor): the distributional Q + value operator. + gamma (scalar): a discount factor for return computation. + """ + + delay_value: bool = False + + def __init__( + self, + value_network: DistributionalQValueActor, + gamma: float, + priority_key: str = "td_error", + ): + super().__init__() + self.gamma = gamma + self.priority_key = priority_key + if not isinstance(value_network, DistributionalQValueActor): + raise TypeError( + "Expected value_network to be of type " + "DistributionalQValueActor " + f"but got {type(value_network)}" + ) + self.convert_to_functional( + value_network, + "value_network", + create_target_params=self.delay_value, + ) + + def forward(self, input_tensor_dict: _TensorDict) -> TensorDict: + # from https://github.com/Kaixhin/Rainbow/blob/9ff5567ad1234ae0ed30d8471e8f13ae07119395/agent.py + device = self.device + tensor_dict = TensorDict( + source=input_tensor_dict, batch_size=input_tensor_dict.batch_size + ).to(device) + + if tensor_dict.batch_dims != 1: + raise RuntimeError( + f"{self.__class__.__name___} expects a 1-dimensional " + "tensor_dict as input" + ) + batch_size = tensor_dict.batch_size[0] + support = self.value_network.support + atoms = support.numel() + Vmin = support.min().item() + Vmax = support.max().item() + delta_z = (Vmax - Vmin) / (atoms - 1) + + action = tensor_dict.get("action") + reward = tensor_dict.get("reward") + done = tensor_dict.get("done") + + steps_to_next_obs = tensor_dict.get("steps_to_next_obs", 1) + discount = self.gamma ** steps_to_next_obs + + # Calculate current state probabilities (online network noise already + # sampled) + td_clone = tensor_dict.clone() + self.value_network( + td_clone, + params=self.value_network_params, + buffers=self.value_network_buffers, + ) # Log probabilities log p(s_t, ·; θonline) + action_log_softmax = td_clone.get("action_value") + action_expand = action.unsqueeze(-2).expand_as(action_log_softmax) + log_ps_a = action_log_softmax.masked_select(action_expand.to(torch.bool)) + log_ps_a = log_ps_a.view(batch_size, atoms) # log p(s_t, a_t; θonline) + + with torch.no_grad(): + # Calculate nth next state probabilities + next_td = step_tensor_dict(tensor_dict) + self.value_network( + next_td, + params=self.value_network_params, + buffers=self.value_network_buffers, + ) # Probabilities p(s_t+n, ·; θonline) + argmax_indices_ns = next_td.get("action").argmax(-1) # one-hot encoding + + self.value_network( + next_td, + params=self.target_value_network_params, + buffers=self.target_value_network_buffers, + ) # Probabilities p(s_t+n, ·; θtarget) + pns = next_td.get("action_value").exp() + # Double-Q probabilities + # p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget) + pns_a = pns[range(batch_size), :, argmax_indices_ns] + + # Compute Tz (Bellman operator T applied to z) + # Tz = R^n + (γ^n)z (accounting for terminal states) + if isinstance(discount, torch.Tensor): + discount = discount.to("cpu") + done = done.to("cpu") + reward = reward.to("cpu") + support = support.to("cpu") + pns_a = pns_a.to("cpu") + Tz = reward + (1 - done.to(reward.dtype)) * discount * support + if Tz.shape != torch.Size([batch_size, atoms]): + raise RuntimeError( + "Tz shape must be torch.Size([batch_size, atoms]), " + f"got Tz.shape={Tz.shape} and batch_size={batch_size}, " + f"atoms={atoms}" + ) + # Clamp between supported values + Tz = Tz.clamp_(min=Vmin, max=Vmax) + if not torch.isfinite(Tz).all(): + raise RuntimeError("Tz has some non-finite elements") + # Compute L2 projection of Tz onto fixed support z + b = (Tz - Vmin) / delta_z # b = (Tz - Vmin) / Δz + low, up = b.floor().to(torch.int64), b.ceil().to(torch.int64) + # Fix disappearing probability mass when l = b = u (b is int) + low[(up > 0) & (low == up)] -= 1 + up[(low < (atoms - 1)) & (low == up)] += 1 + + # Distribute probability of Tz + m = torch.zeros(batch_size, atoms) + offset = torch.linspace( + 0, + ((batch_size - 1) * atoms), + batch_size, + dtype=torch.int64, + # device=device, + ) + offset = offset.unsqueeze(1).expand(batch_size, atoms) + index = (low + offset).view(-1) + tensor = (pns_a * (up.float() - b)).view(-1) + # m_l = m_l + p(s_t+n, a*)(u - b) + m.view(-1).index_add_(0, index, tensor) + index = (up + offset).view(-1) + tensor = (pns_a * (b - low.float())).view(-1) + # m_u = m_u + p(s_t+n, a*)(b - l) + m.view(-1).index_add_(0, index, tensor) + + # Cross-entropy loss (minimises DKL(m||p(s_t, a_t))) + loss = -torch.sum(m.to(device) * log_ps_a, 1) + input_tensor_dict.set( + self.priority_key, + loss.detach().unsqueeze(1).to(input_tensor_dict.device), + inplace=True, + ) + loss_td = TensorDict({"loss": loss.mean()}, []) + return loss_td + + +class DistributionalDoubleDQNLoss(DistributionalDQNLoss): + """ + A distributional, double DQN loss class. + This class mixes distributional and double DQN losses. + + For more details regarding Distributional DQN, refer to "A Distributional + Perspective on Reinforcement Learning", + https://arxiv.org/pdf/1707.06887.pdf + More information on double DQN can be found in "Deep Reinforcement + Learning with Double Q-learning", https://arxiv.org/abs/1509.06461. + + """ + + delay_value: bool = True diff --git a/torchrl/objectives/costs/functional.py b/torchrl/objectives/costs/functional.py new file mode 100644 index 00000000000..99f9a439647 --- /dev/null +++ b/torchrl/objectives/costs/functional.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + + +def cross_entropy_loss( + log_policy: torch.Tensor, action: torch.Tensor, inplace: bool = False +) -> torch.Tensor: + """ + Returns the cross entropy loss defined as the log-softmax value indexed by the action index. + Supports discrete (integer) actions or one-hot encodings. + + Args: + log_policy: Tensor of the log_softmax values of the policy. + action: Integer or one-hot representation of the actions undertaken. Must have a shape log_policy.shape[:-1] + (integer representation) or log_policy.shape (one-hot). + inplace: fills log_policy in-place with 0.0 at non-selected actions before summing along the last dimensions. + This is usually faster but it will change the value of log-policy in place, which may lead to unwanted + behaviours. + + Returns: + + """ + if action.shape == log_policy.shape: + if action.dtype not in (torch.bool, torch.long, torch.uint8): + raise TypeError( + f"Cross-entropy loss with {action.dtype} dtype is not permitted" + ) + if not ((action == 1).sum(-1) == 1).all(): + raise RuntimeError( + "Expected the action tensor to be a one hot encoding of the actions taken, " + "but got more/less than one non-null boolean index on the last dimension" + ) + if inplace: + cross_entropy = log_policy.masked_fill_(action, 0.0).sum(-1) + else: + cross_entropy = (log_policy * action).sum(-1) + elif action.shape == log_policy.shape[:-1]: + cross_entropy = torch.gather(log_policy, dim=-1, index=action[..., None]) + cross_entropy.squeeze_(-1) + else: + raise RuntimeError( + f"unexpected action shape in cross_entropy_loss with log_policy.shape={log_policy.shape} and" + f"action.shape={action.shape}" + ) + return cross_entropy diff --git a/torchrl/objectives/costs/impala.py b/torchrl/objectives/costs/impala.py new file mode 100644 index 00000000000..279bd6b2198 --- /dev/null +++ b/torchrl/objectives/costs/impala.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from torchrl.data.tensordict.tensordict import _TensorDict +from torchrl.modules import ProbabilisticTDModule +from torchrl.objectives.returns.vtrace import vtrace + + +class QValEstimator: + def __init__(self, value_model: ProbabilisticTDModule): + self.value_model = value_model + + @property + def device(self) -> torch.device: + return next(self.value_model.parameters()).device + + def forward(self, tensordict: _TensorDict) -> None: + tensordict_device = tensordict.to(self.device) + self.value_model_device(tensordict_device) # udpates the value key + gamma = tensordict_device.get("gamma") + reward = tensordict_device.get("reward") + next_value = torch.cat( + [ + tensordict_device.get("value")[:, 1:], + torch.ones_like(reward[:, :1]), + ], + 1, + ) + q_value = reward + gamma * next_value + tensordict_device.set("q_value", q_value) + + +class VTraceEstimator: + def forward(self, tensordict: _TensorDict) -> _TensorDict: + tensordict_device = tensordict.to(device) + rewards = tensordict_device.get("reward") + vals = tensordict_device.get("value") + log_mu = tensordict_device.get("log_mu") + log_pi = tensordict_device.get("log_pi") + gamma = tensordict_device.get("gamma") + v_trace, rho = vtrace( + rewards, + vals, + log_pi, + log_mu, + gamma, + rho_bar=self.rho_bar, + c_bar=self.c_bar, + ) + tensordict_device.set("v_trace", v_trace) + tensordict_device.set("rho", rho) + return tensordict_device + + +class ImpalaLoss: + def forward(self, tensordict): + tensordict_device = tensordict.to(device) + self.q_val_estimator(tensordict_device) + self.v_trace_estimator(tensordict_device) diff --git a/torchrl/objectives/costs/ppo.py b/torchrl/objectives/costs/ppo.py new file mode 100644 index 00000000000..aa46bf85abb --- /dev/null +++ b/torchrl/objectives/costs/ppo.py @@ -0,0 +1,399 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Callable, Optional, Tuple + +import torch +from torch import distributions as d + +from torchrl.data.tensordict.tensordict import _TensorDict, TensorDict +from torchrl.envs.utils import step_tensor_dict +from torchrl.modules import ProbabilisticTDModule, TDModule + +__all__ = ["PPOLoss", "ClipPPOLoss", "KLPENPPOLoss"] + +from torchrl.objectives.costs.utils import distance_loss +from .common import _LossModule + + +class PPOLoss(_LossModule): + """ + A parent PPO loss class. + + PPO (Proximal Policy Optimisation) is a model-free, online RL algorithm that makes use of a recorded (batch of) + trajectories to perform several optimization steps, while actively preventing the updated policy to deviate too + much from its original parameter configuration. + + PPO loss can be found in different flavours, depending on the way the constrained optimisation is implemented: + ClipPPOLoss and KLPENPPOLoss. + Unlike its subclasses, this class does not implement any regularisation and should therefore be used cautiously. + + For more details regarding PPO, refer to: "Proximal Policy Optimization Algorithms", + https://arxiv.org/abs/1707.06347 + + Args: + actor (Actor): policy operator. + critic (ProbabilisticTDModule): value operator. + advantage_key (str): the input tensordict key where the advantage is expected to be written. + default: "advantage" + entropy_bonus (bool): if True, an entropy bonus will be added to the loss to favour exploratory policies. + samples_mc_entropy (int): if the distribution retrieved from the policy operator does not have a closed form + formula for the entropy, a Monte-Carlo estimate will be used. samples_mc_entropy will control how many + samples will be used to compute this estimate. + default: 1 + entropy_factor (scalar): entropy multiplier when computing the total loss. + default: 0.01 + critic_factor (scalar): critic loss multiplier when computing the total loss. + default: 1.0 + gamma (scalar): a discount factor for return computation. + loss_function (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1". + + """ + + def __init__( + self, + actor: ProbabilisticTDModule, + critic: TDModule, + advantage_key: str = "advantage", + entropy_bonus: bool = True, + samples_mc_entropy: int = 1, + entropy_factor: float = 0.01, + critic_factor: float = 1.0, + gamma: float = 0.99, + loss_critic_type: str = "smooth_l1", + advantage_module: Optional[Callable[[_TensorDict], _TensorDict]] = None, + ): + super().__init__() + self.actor = actor + self.critic = critic + self.advantage_key = advantage_key + self.samples_mc_entropy = samples_mc_entropy + self.entropy_bonus = entropy_bonus and entropy_factor + self.entropy_factor = entropy_factor + self.critic_factor = critic_factor + self.gamma = gamma + self.loss_critic_type = loss_critic_type + self.advantage_module = advantage_module + + def reset(self) -> None: + pass + + def get_entropy_bonus(self, dist: Optional[d.Distribution] = None) -> torch.Tensor: + try: + entropy = dist.entropy() + except NotImplementedError: + x = dist.rsample((self.samples_mc_entropy,)) + entropy = -dist.log_prob(x) + return entropy.unsqueeze(-1) + + def _log_weight( + self, tensor_dict: _TensorDict + ) -> Tuple[torch.Tensor, d.Distribution]: + # current log_prob of actions + action = tensor_dict.get("action") + if action.requires_grad: + raise RuntimeError("tensor_dict stored action requires grad.") + tensor_dict_clone = tensor_dict.select(*self.actor.in_keys).clone() + + dist, *_ = self.actor.get_dist(tensor_dict_clone) + log_prob = dist.log_prob(action) + log_prob = log_prob.unsqueeze(-1) + + prev_log_prob = tensor_dict.get("action_log_prob") + if prev_log_prob.requires_grad: + raise RuntimeError("tensor_dict prev_log_prob requires grad.") + + log_weight = log_prob - prev_log_prob + return log_weight, dist + + def loss_critic(self, tensor_dict: _TensorDict) -> torch.Tensor: + + if "value_target" in tensor_dict.keys(): + value_target = tensor_dict.get("value_target") + if value_target.requires_grad: + raise RuntimeError( + "value_target retrieved from tensor_dict requires grad." + ) + + else: + with torch.no_grad(): + reward = tensor_dict.get("reward") + next_td = step_tensor_dict(tensor_dict) + next_value = self.critic(next_td).get("state_value") + value_target = reward + next_value * self.gamma + tensor_dict_select = tensor_dict.select(*self.critic.in_keys).clone() + value = self.critic(tensor_dict_select).get("state_value") + loss_value = distance_loss( + value, value_target, loss_function=self.loss_critic_type + ) + return self.critic_factor * loss_value + + def forward(self, tensor_dict: _TensorDict) -> _TensorDict: + if self.advantage_module is not None: + tensor_dict = self.advantage_module(tensor_dict) + tensor_dict = tensor_dict.clone() + advantage = tensor_dict.get(self.advantage_key) + log_weight, dist = self._log_weight(tensor_dict) + neg_loss = (log_weight.exp() * advantage).mean() + print(log_weight) + td_out = TensorDict({"loss_objective": -neg_loss.mean()}, []) + if self.entropy_bonus: + entropy = self.get_entropy_bonus(dist) + td_out.set("entropy", entropy.mean().detach()) # for logging + td_out.set("loss_entropy", -self.entropy_factor * entropy.mean()) + if self.critic_factor: + loss_critic = self.loss_critic(tensor_dict).mean() + td_out.set("loss_critic", loss_critic.mean()) + return td_out + + +class ClipPPOLoss(PPOLoss): + """ + Clipped PPO loss. + + The clipped importance weighted loss is computed as follows: + loss = -min( weight * advantage, min(max(weight, 1-eps), 1+eps) * advantage) + + Args: + actor (Actor): policy operator. + critic (ProbabilisticTDModule): value operator. + advantage_key (str): the input tensordict key where the advantage is expected to be written. + default: "advantage" + clip_epsilon (scalar): weight clipping threshold in the clipped PPO loss equation. + default: 0.2 + entropy_bonus (bool): if True, an entropy bonus will be added to the loss to favour exploratory policies. + samples_mc_entropy (int): if the distribution retrieved from the policy operator does not have a closed form + formula for the entropy, a Monte-Carlo estimate will be used. samples_mc_entropy will control how many + samples will be used to compute this estimate. + default: 1 + entropy_factor (scalar): entropy multiplier when computing the total loss. + default: 0.01 + critic_factor (scalar): critic loss multiplier when computing the total loss. + default: 1.0 + gamma (scalar): a discount factor for return computation. + loss_function (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1". + + """ + + def __init__( + self, + actor: ProbabilisticTDModule, + critic: TDModule, + advantage_key: str = "advantage", + clip_epsilon: float = 0.2, + entropy_bonus: bool = True, + samples_mc_entropy: int = 1, + entropy_factor: float = 0.01, + critic_factor: float = 1.0, + gamma: float = 0.99, + loss_critic_type: str = "l2", + **kwargs, + ): + super(ClipPPOLoss, self).__init__( + actor, + critic, + advantage_key, + entropy_bonus=entropy_bonus, + samples_mc_entropy=samples_mc_entropy, + entropy_factor=entropy_factor, + critic_factor=critic_factor, + gamma=gamma, + loss_critic_type=loss_critic_type, + **kwargs, + ) + self.clip_epsilon = clip_epsilon + self._clip_bounds = ( + math.log1p(-self.clip_epsilon), + math.log1p(self.clip_epsilon), + ) + + def forward(self, tensor_dict: _TensorDict) -> _TensorDict: + if self.advantage_module is not None: + tensor_dict = self.advantage_module(tensor_dict) + tensor_dict = tensor_dict.clone() + for key, value in tensor_dict.items(): + if value.requires_grad: + raise RuntimeError( + f"The key {key} returns a value that requires a gradient, consider detaching." + ) + advantage = tensor_dict.get(self.advantage_key) + log_weight, dist = self._log_weight(tensor_dict) + # ESS for logging + with torch.no_grad(): + # In theory, ESS should be computed on particles sampled from the same source. Here we sample according + # to different, unrelated trajectories, which is not standard. Still it can give a idea of the dispersion + # of the weights. + lw = log_weight.squeeze() + ess = (2 * lw.logsumexp(0) - (2 * lw).logsumexp(0)).exp() + batch = log_weight.shape[0] + + if not advantage.shape == log_weight.shape: + raise RuntimeError( + f"advantage.shape and log_weight.shape do not match (got {advantage.shape} " + f"and {log_weight.shape})" + ) + gain1 = log_weight.exp() * advantage + log_weight_clip = torch.empty_like(log_weight) + # log_weight_clip.data.clamp_(*self._clip_bounds) + idx_pos = advantage >= 0 + log_weight_clip[idx_pos] = log_weight[idx_pos].clamp_max(self._clip_bounds[1]) + log_weight_clip[~idx_pos] = log_weight[~idx_pos].clamp_min(self._clip_bounds[0]) + + gain2 = log_weight_clip.exp() * advantage + gain = torch.stack([gain1, gain2], -1).min(dim=-1)[0] + td_out = TensorDict({"loss_objective": -gain.mean()}, []) + + if self.entropy_bonus: + entropy = self.get_entropy_bonus(dist) + td_out.set("entropy", entropy.mean().detach()) # for logging + td_out.set("loss_entropy", -self.entropy_factor * entropy.mean()) + if self.critic_factor: + loss_critic = self.loss_critic(tensor_dict) + td_out.set("loss_critic", loss_critic.mean()) + td_out.set("ESS", ess.mean() / batch) + return td_out + + +class KLPENPPOLoss(PPOLoss): + """ + KL Penalty PPO loss. + + The KL penalty loss has the following formula: + loss = loss - beta * KL(old_policy, new_policy) + The "beta" parameter is adapted on-the-fly to match a target KL divergence between the new and old policy, thus + favouring a certain level of distancing between the two while still preventing them to be too much apart. + + Args: + actor (Actor): policy operator. + critic (ProbabilisticTDModule): value operator. + advantage_key (str): the input tensordict key where the advantage is expected to be written. + default: "advantage" + dtarg (scalar): target KL divergence. + beta (scalar): initial KL divergence multiplier. + default: 1.0 + increment (scalar): how much beta should be incremented if KL > dtarg. Valid range: increment >= 1.0 + default: 2.0 + decrement (scalar): how much beta should be decremented if KL < dtarg. Valid range: decrement <= 1.0 + default: 0.5 + entropy_bonus (bool): if True, an entropy bonus will be added to the loss to favour exploratory policies. + samples_mc_entropy (int): if the distribution retrieved from the policy operator does not have a closed form + formula for the entropy, a Monte-Carlo estimate will be used. samples_mc_entropy will control how many + samples will be used to compute this estimate. + default: 1 + entropy_factor (scalar): entropy multiplier when computing the total loss. + default: 0.01 + critic_factor (scalar): critic loss multiplier when computing the total loss. + default: 1.0 + gamma (scalar): a discount factor for return computation. + loss_critic_type (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1". + + """ + + def __init__( + self, + actor: ProbabilisticTDModule, + critic: TDModule, + advantage_key="advantage", + dtarg: float = 0.01, + beta: float = 1.0, + increment: float = 2, + decrement: float = 0.5, + samples_mc_kl: int = 1, + entropy_bonus: bool = True, + samples_mc_entropy: int = 1, + entropy_factor: float = 0.01, + critic_factor: float = 1.0, + gamma: float = 0.99, + loss_critic_type: str = "l2", + **kwargs, + ): + super(KLPENPPOLoss, self).__init__( + actor, + critic, + advantage_key, + entropy_bonus=entropy_bonus, + samples_mc_entropy=samples_mc_entropy, + entropy_factor=entropy_factor, + critic_factor=critic_factor, + gamma=gamma, + loss_critic_type=loss_critic_type, + **kwargs, + ) + + self.dtarg = dtarg + self._beta_init = beta + self.beta = beta + + if increment < 1.0: + raise ValueError( + f"increment should be >= 1.0 in KLPENPPOLoss, got {increment:4.4f}" + ) + self.increment = increment + if decrement > 1.0: + raise ValueError( + f"decrement should be <= 1.0 in KLPENPPOLoss, got {decrement:4.4f}" + ) + self.decrement = decrement + self.samples_mc_kl = samples_mc_kl + + def forward(self, tensor_dict: _TensorDict) -> TensorDict: + if self.advantage_module is not None: + tensor_dict = self.advantage_module(tensor_dict) + tensor_dict = tensor_dict.clone() + advantage = tensor_dict.get(self.advantage_key) + log_weight, dist = self._log_weight(tensor_dict) + neg_loss = log_weight.exp() * advantage + + tensor_dict_clone = tensor_dict.select(*self.actor.in_keys).clone() + params = [] + out_key = self.actor.out_keys[0] + i = 0 + while True: + key = f"{out_key}_dist_param_{i}" + if key in tensor_dict.keys(): + params.append(tensor_dict.get(key)) + i += 1 + else: + break + + if i == 0: + raise Exception( + "No parameter was found for the policy distribution. Consider building the policy with save_dist_params=True" + ) + previous_dist, *_ = self.actor.build_dist_from_params(params) + current_dist, *_ = self.actor.get_dist(tensor_dict_clone) + try: + kl = torch.distributions.kl.kl_divergence(previous_dist, current_dist) + except NotImplementedError: + x = previous_dist.sample((self.samples_mc_kl,)) + kl = (previous_dist.log_prob(x) - current_dist.log_prob(x)).mean(0) + kl = kl.unsqueeze(-1) + neg_loss = neg_loss - self.beta * kl + if kl.mean() > self.dtarg * 1.5: + self.beta *= self.increment + elif kl.mean() < self.dtarg / 1.5: + self.beta *= self.decrement + td_out = TensorDict( + { + "loss_objective": -neg_loss.mean(), + "kl": kl.detach().mean(), + }, + [], + ) + + if self.entropy_bonus: + entropy = self.get_entropy_bonus(dist) + td_out.set("entropy", entropy.mean().detach()) # for logging + td_out.set("loss_entropy", -self.entropy_factor * entropy.mean()) + + if self.critic_factor: + loss_critic = self.loss_critic(tensor_dict) + td_out.set("loss_critic", loss_critic.mean()) + + return td_out + + def reset(self) -> None: + self.beta = self._beta_init diff --git a/torchrl/objectives/costs/redq.py b/torchrl/objectives/costs/redq.py new file mode 100644 index 00000000000..db6b7395348 --- /dev/null +++ b/torchrl/objectives/costs/redq.py @@ -0,0 +1,498 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from numbers import Number +from typing import Tuple, Union + +import numpy as np +import torch +from torch import Tensor + +from torchrl.data.tensordict.tensordict import _TensorDict, TensorDict +from torchrl.envs.utils import set_exploration_mode, step_tensor_dict +from torchrl.modules import TDModule +from torchrl.objectives.costs.common import _LossModule +from torchrl.objectives.costs.utils import ( + distance_loss, + hold_out_params, + next_state_value as get_next_state_value, +) + +__all__ = ["REDQLoss", "DoubleREDQLoss"] + + +class REDQLoss_deprecated(_LossModule): + """ + REDQ Loss module. + REDQ (RANDOMIZED ENSEMBLED DOUBLE Q-LEARNING: LEARNING FAST WITHOUT A MODEL + https://openreview.net/pdf?id=AY8zfZm0tDd) generalizes the idea of using an ensemble of Q-value functions to + train a SAC-like algorithm. + + Args: + actor_network (TDModule): the actor to be trained + qvalue_network (TDModule): a single Q-value network that will be multiplicated as many times as needed. + num_qvalue_nets (int, optional): Number of Q-value networks to be trained. Default is 10. + sub_sample_len (int, optional): number of Q-value networks to be subsampled to evaluate the next state value + Default is 2. + gamma (Number, optional): gamma decay factor. Default is 0.99. + priotity_key (str, optional): Key where to write the priority value for prioritized replay buffers. Default is + `"td_error"`. + loss_function (str, optional): loss function to be used for the Q-value. Can be one of `"smooth_l1"`, "l2", + "l1", Default is "smooth_l1". + alpha_init (Number, optional): initial value of the alpha factor. Default is 1.0. + fixed_alpha (bool, optional): whether alpha should be trained to match a target entropy. Default is `False`. + target_entropy (Union[str, Number], optional): Target entropy for the stochastic policy. Default is "auto". + + """ + + delay_actor: bool = False + delay_qvalue: bool = False + + def __init__( + self, + actor_network: TDModule, + qvalue_network: TDModule, + num_qvalue_nets: int = 10, + sub_sample_len: int = 2, + gamma: Number = 0.99, + priotity_key: str = "td_error", + loss_function: str = "smooth_l1", + alpha_init: Number = 1.0, + fixed_alpha: bool = False, + target_entropy: Union[str, Number] = "auto", + ): + super().__init__() + self.convert_to_functional( + actor_network, + "actor_network", + create_target_params=self.delay_actor, + ) + self.convert_to_functional( + qvalue_network, + "qvalue_network", + num_qvalue_nets, + create_target_params=self.delay_qvalue, + ) + self.num_qvalue_nets = num_qvalue_nets + self.sub_sample_len = max(1, min(sub_sample_len, num_qvalue_nets - 1)) + self.gamma = gamma + self.priority_key = priotity_key + self.loss_function = loss_function + + try: + device = next(self.parameters()).device + except AttributeError: + device = torch.device("cpu") + + self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) + self.fixed_alpha = fixed_alpha + if fixed_alpha: + self.register_buffer( + "log_alpha", torch.tensor(math.log(alpha_init), device=device) + ) + else: + self.register_parameter( + "log_alpha", + torch.nn.Parameter(torch.tensor(math.log(alpha_init), device=device)), + ) + + if target_entropy == "auto": + target_entropy = -float(np.prod(actor_network.spec.shape)) + self.register_buffer( + "target_entropy", torch.tensor(target_entropy, device=device) + ) + + @property + def alpha(self): + # keep alpha is a reasonable range + self.log_alpha.data.clamp_(-20, 1.0) + + with torch.no_grad(): + alpha = self.log_alpha.detach().exp() + return alpha + + def forward(self, tensordict: _TensorDict) -> _TensorDict: + loss_actor, action_log_prob = self._actor_loss(tensordict) + + loss_qval = self._qvalue_loss(tensordict) + loss_alpha = self._loss_alpha(action_log_prob) + if not loss_qval.shape == loss_actor.shape: + raise RuntimeError( + f"QVal and actor loss have different shape: {loss_qval.shape} and {loss_actor.shape}" + ) + td_out = TensorDict( + { + "loss_actor": loss_actor.mean(), + "loss_qvalue": loss_qval.mean(), + "loss_alpha": loss_alpha.mean(), + "alpha": self.alpha, + "entropy": -action_log_prob.mean(), + }, + [], + ) + + return td_out + + def _actor_loss(self, tensordict: _TensorDict) -> Tuple[Tensor, Tensor]: + obs_keys = self.actor_network.in_keys + tensordict_clone = tensordict.select(*obs_keys) # to avoid overwriting keys + with set_exploration_mode("random"): + self.actor_network( + tensordict_clone, + params=self.actor_network_params, + buffers=self.actor_network_buffers, + ) + + with hold_out_params(self.qvalue_network_params) as params: + tensordict_expand = self.qvalue_network( + tensordict_clone.select(*self.qvalue_network.in_keys), + tensor_dict_out=TensorDict( + {}, [self.num_qvalue_nets, *tensordict_clone.shape] + ), + params=params, + buffers=self.qvalue_network_buffers, + vmap=True, + ) + state_action_value = tensordict_expand.get("state_action_value").squeeze(-1) + loss_actor = -( + state_action_value + - self.alpha * tensordict_clone.get("action_log_prob").squeeze(-1) + ).mean(0) + return loss_actor, tensordict_clone.get("action_log_prob") + + def _qvalue_loss(self, tensordict: _TensorDict) -> Tensor: + tensordict_save = tensordict + + next_obs_keys = [key for key in tensordict.keys() if key.startswith("next_obs")] + obs_keys = [key for key in tensordict.keys() if key.startswith("obs")] + tensordict = tensordict.select( + "reward", "done", *next_obs_keys, *obs_keys, "action" + ) + + selected_models_idx = torch.randperm(self.num_qvalue_nets)[ + : self.sub_sample_len + ].sort()[0] + with torch.no_grad(): + selected_q_params = [ + p[selected_models_idx] for p in self.target_qvalue_network_params + ] + selected_q_buffers = [ + b[selected_models_idx] for b in self.target_qvalue_network_buffers + ] + + next_td = step_tensor_dict(tensordict).select( + *self.actor_network.in_keys + ) # next_observation -> + # observation + # select pseudo-action + with set_exploration_mode("random"): + self.actor_network( + next_td, + params=list(self.target_actor_network_params), + buffers=self.target_actor_network_buffers, + ) + action_log_prob = next_td.get("action_log_prob") + # get q-values + next_td = self.qvalue_network( + next_td, + tensor_dict_out=TensorDict({}, [self.sub_sample_len, *next_td.shape]), + params=selected_q_params, + buffers=selected_q_buffers, + vmap=True, + ) + state_value = ( + next_td.get("state_action_value") - self.alpha * action_log_prob + ) + state_value = state_value.min(0)[0] + + tensordict.set("next_state_value", state_value) + target_value = get_next_state_value( + tensordict, + gamma=self.gamma, + pred_next_val=state_value, + ) + tensordict_expand = self.qvalue_network( + tensordict.select(*self.qvalue_network.in_keys), + tensor_dict_out=TensorDict({}, [self.num_qvalue_nets, *tensordict.shape]), + params=list(self.qvalue_network_params), + buffers=self.qvalue_network_buffers, + vmap=True, + ) + pred_val = tensordict_expand.get("state_action_value").squeeze(-1) + td_error = abs(pred_val - target_value) + loss_qval = distance_loss( + pred_val, + target_value.expand_as(pred_val), + loss_function=self.loss_function, + ).mean(0) + tensordict_save.set("td_error", td_error.detach().max(0)[0]) + return loss_qval + + def _loss_alpha(self, log_pi: Tensor) -> Tensor: + if torch.is_grad_enabled() and not log_pi.requires_grad: + raise RuntimeError( + "expected log_pi to require gradient for the alpha loss)" + ) + if self.target_entropy is not None: + # we can compute this loss even if log_alpha is not a parameter + alpha_loss = -self.log_alpha.exp() * (log_pi.detach() + self.target_entropy) + else: + # placeholder + alpha_loss = torch.zeros_like(log_pi) + return alpha_loss + + +class DoubleREDQLoss_deprecated(REDQLoss_deprecated): + delay_qvalue: bool = True + + +class REDQLoss(_LossModule): + """ + REDQ Loss module. + + REDQ (RANDOMIZED ENSEMBLED DOUBLE Q-LEARNING: LEARNING FAST WITHOUT A MODEL + https://openreview.net/pdf?id=AY8zfZm0tDd) generalizes the idea of using an ensemble of Q-value functions to + train a SAC-like algorithm. + + Args: + actor_network (TDModule): the actor to be trained + qvalue_network (TDModule): a single Q-value network that will be multiplicated as many times as needed. + num_qvalue_nets (int, optional): Number of Q-value networks to be trained. Default is 10. + sub_sample_len (int, optional): number of Q-value networks to be subsampled to evaluate the next state value + Default is 2. + gamma (Number, optional): gamma decay factor. Default is 0.99. + priotity_key (str, optional): Key where to write the priority value for prioritized replay buffers. Default is + `"td_error"`. + loss_function (str, optional): loss function to be used for the Q-value. Can be one of `"smooth_l1"`, "l2", + "l1", Default is "smooth_l1". + alpha_init (Number, optional): initial value of the alpha factor. Default is 1.0. + fixed_alpha (bool, optional): whether alpha should be trained to match a target entropy. Default is `False`. + target_entropy (Union[str, Number], optional): Target entropy for the stochastic policy. Default is "auto". + + """ + + delay_actor: bool = False + delay_qvalue: bool = False + + def __init__( + self, + actor_network: TDModule, + qvalue_network: TDModule, + num_qvalue_nets: int = 10, + sub_sample_len: int = 2, + gamma: Number = 0.99, + priotity_key: str = "td_error", + loss_function: str = "smooth_l1", + alpha_init: Number = 1.0, + fixed_alpha: bool = False, + target_entropy: Union[str, Number] = "auto", + ): + super().__init__() + self.convert_to_functional( + actor_network, + "actor_network", + create_target_params=self.delay_actor, + ) + self.convert_to_functional( + qvalue_network, + "qvalue_network", + num_qvalue_nets, + create_target_params=self.delay_qvalue, + ) + self.num_qvalue_nets = num_qvalue_nets + self.sub_sample_len = max(1, min(sub_sample_len, num_qvalue_nets - 1)) + self.gamma = gamma + self.priority_key = priotity_key + self.loss_function = loss_function + + try: + device = next(self.parameters()).device + except AttributeError: + device = torch.device("cpu") + + self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) + self.fixed_alpha = fixed_alpha + if fixed_alpha: + self.register_buffer( + "log_alpha", torch.tensor(math.log(alpha_init), device=device) + ) + else: + self.register_parameter( + "log_alpha", + torch.nn.Parameter(torch.tensor(math.log(alpha_init), device=device)), + ) + + if target_entropy == "auto": + target_entropy = -float(np.prod(actor_network.spec.shape)) + self.register_buffer( + "target_entropy", torch.tensor(target_entropy, device=device) + ) + + @property + def alpha(self): + with torch.no_grad(): + alpha = self.log_alpha.detach().exp() + return alpha + + def forward(self, tensordict: _TensorDict) -> _TensorDict: + obs_keys = self.actor_network.in_keys + next_obs_keys = [key for key in tensordict.keys() if key.startswith("next_obs")] + tensordict_select = tensordict.select( + "reward", "done", *next_obs_keys, *obs_keys, "action" + ) + selected_models_idx = torch.randperm(self.num_qvalue_nets)[ + : self.sub_sample_len + ].sort()[0] + selected_q_params = [ + p[selected_models_idx] for p in self.target_qvalue_network_params + ] + selected_q_buffers = [ + b[selected_models_idx] for b in self.target_qvalue_network_buffers + ] + + actor_params = [ + torch.stack([p1, p2], 0) + for p1, p2 in zip( + self.actor_network_params, self.target_actor_network_params + ) + ] + actor_buffers = [ + torch.stack([p1, p2], 0) + for p1, p2 in zip( + self.actor_network_buffers, self.target_actor_network_buffers + ) + ] + + tensordict_actor_grad = tensordict_select.select( + *obs_keys + ) # to avoid overwriting keys + next_td_actor = step_tensor_dict(tensordict_select).select( + *self.actor_network.in_keys + ) # next_observation -> + tensordict_actor = torch.stack([tensordict_actor_grad, next_td_actor], 0) + + with set_exploration_mode("random"): + tensordict_actor = self.actor_network( + tensordict_actor, + params=actor_params, + buffers=actor_buffers, + vmap=(0, 0, 0), + ) + + # repeat tensordict_actor to match the qvalue size + tensordict_qval = torch.cat( + [ + tensordict_actor[0] + .select(*self.qvalue_network.in_keys) + .expand(self.num_qvalue_nets), # for actor loss + tensordict_actor[1] + .select(*self.qvalue_network.in_keys) + .expand(self.sub_sample_len), # for next value estimation + tensordict_select.select(*self.qvalue_network.in_keys).expand( + self.num_qvalue_nets + ), # for qvalue loss + ], + 0, + ) + + # cat params + q_params_detach = hold_out_params(self.qvalue_network_params).params + qvalue_params = [ + torch.cat([p1, p2, p3], 0) + for p1, p2, p3 in zip( + q_params_detach, selected_q_params, self.qvalue_network_params + ) + ] + qvalue_buffers = [ + torch.cat([p1, p2, p3], 0) + for p1, p2, p3 in zip( + self.qvalue_network_buffers, + selected_q_buffers, + self.qvalue_network_buffers, + ) + ] + + tensordict_qval = self.qvalue_network( + tensordict_qval, + tensor_dict_out=TensorDict({}, tensordict_qval.shape), + params=qvalue_params, + buffers=qvalue_buffers, + vmap=(0, 0, 0, 0), + ) + + state_action_value = tensordict_qval.get("state_action_value").squeeze(-1) + ( + state_action_value_actor, + next_state_action_value_qvalue, + state_action_value_qvalue, + ) = state_action_value.split( + [self.num_qvalue_nets, self.sub_sample_len, self.num_qvalue_nets], + dim=0, + ) + action_log_prob = tensordict_actor.get("action_log_prob").squeeze(-1) + ( + action_log_prob_actor, + next_action_log_prob_qvalue, + ) = action_log_prob.unbind(0) + + loss_actor = -( + state_action_value_actor - self.alpha * action_log_prob_actor + ).mean(0) + + next_state_value = ( + next_state_action_value_qvalue - self.alpha * next_action_log_prob_qvalue + ) + next_state_value = next_state_value.min(0)[0] + + target_value = get_next_state_value( + tensordict, + gamma=self.gamma, + pred_next_val=next_state_value, + ) + pred_val = state_action_value_qvalue + td_error = abs(pred_val - target_value) + loss_qval = distance_loss( + pred_val, + target_value.expand_as(pred_val), + loss_function=self.loss_function, + ).mean(0) + + tensordict.set("td_error", td_error.detach().max(0)[0]) + + loss_alpha = self._loss_alpha(action_log_prob) + if not loss_qval.shape == loss_actor.shape: + raise RuntimeError( + f"QVal and actor loss have different shape: {loss_qval.shape} and {loss_actor.shape}" + ) + td_out = TensorDict( + { + "loss_actor": loss_actor.mean(), + "loss_qvalue": loss_qval.mean(), + "loss_alpha": loss_alpha.mean(), + "alpha": self.alpha, + "entropy": -action_log_prob.mean(), + }, + [], + ) + + return td_out + + def _loss_alpha(self, log_pi: Tensor) -> Tensor: + if torch.is_grad_enabled() and not log_pi.requires_grad: + raise RuntimeError( + "expected log_pi to require gradient for the alpha loss)" + ) + if self.target_entropy is not None: + # we can compute this loss even if log_alpha is not a parameter + alpha_loss = -self.log_alpha.exp() * (log_pi.detach() + self.target_entropy) + else: + # placeholder + alpha_loss = torch.zeros_like(log_pi) + return alpha_loss + + +class DoubleREDQLoss(REDQLoss): + delay_qvalue: bool = True diff --git a/torchrl/objectives/costs/sac.py b/torchrl/objectives/costs/sac.py new file mode 100644 index 00000000000..a60dd2032cf --- /dev/null +++ b/torchrl/objectives/costs/sac.py @@ -0,0 +1,323 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from numbers import Number +from typing import Tuple, Union + +import numpy as np +import torch +from torch import Tensor + +from torchrl.data.tensordict.tensordict import _TensorDict, TensorDict +from torchrl.modules import TDModule +from torchrl.modules.td_module.actors import ( + ActorCriticWrapper, + ProbabilisticActor, +) +from torchrl.objectives.costs.utils import distance_loss, next_state_value +from .common import _LossModule + +__all__ = ["SACLoss", "DoubleSACLoss"] + + +class SACLoss(_LossModule): + """ + TorchRL implementation of the SAC loss, as presented in "Soft Actor-Critic: Off-Policy Maximum Entropy Deep + Reinforcement Learning with a Stochastic Actor" https://arxiv.org/pdf/1801.01290.pdf + + Args: + actor_network (ProbabilisticActor): stochastic actor + qvalue_network (TDModule): Q(s, a) parametric model + value_network (TDModule): V(s) parametric model\ + qvalue_network_bis (ProbabilisticTDModule, optional): if required, the + Q-value can be computed twice independently using two separate + networks. The minimum predicted value will then be used for + inference. + gamma (number, optional): discount for return computation + Default is 0.99 + priority_key (str, optional): tensordict key where to write the + priority (for prioritized replay buffer usage). Default is + `"td_error"`. + loss_function (str, optional): loss function to be used with + the value function loss. Default is `"smooth_l1"`. + alpha_init (float, optional): initial entropy multiplier. + Default is 1.0. + fixed_alpha (bool, optional): if True, alpha will be fixed to its + initial value. Otherwise, alpha will be optimized to + match the 'target_entropy' value. + Default is `False`. + target_entropy (float or str, optional): Target entropy for the + stochastic policy. Default is "auto", where target entropy is + computed as `-prod(n_actions)`. + """ + + delay_actor: bool = False + delay_qvalue: bool = False + delay_value: bool = False + + def __init__( + self, + actor_network: ProbabilisticActor, + qvalue_network: TDModule, + value_network: TDModule, + num_qvalue_nets: int = 2, + gamma: Number = 0.99, + priotity_key: str = "td_error", + loss_function: str = "smooth_l1", + alpha_init: float = 1.0, + fixed_alpha: bool = False, + target_entropy: Union[str, float] = "auto", + ) -> None: + super().__init__() + + # Actor + self.convert_to_functional( + actor_network, + "actor_network", + create_target_params=self.delay_actor, + ) + + # Value + self.convert_to_functional( + value_network, + "value_network", + create_target_params=self.delay_value, + ) + + # Q value + self.num_qvalue_nets = num_qvalue_nets + self.convert_to_functional( + qvalue_network, + "qvalue_network", + num_qvalue_nets, + create_target_params=self.delay_qvalue, + ) + + self.gamma = gamma + self.priority_key = priotity_key + self.loss_function = loss_function + self.register_buffer("alpha_init", torch.tensor(alpha_init)) + self.fixed_alpha = fixed_alpha + try: + device = next(self.parameters()).device + except AttributeError: + device = torch.device("cpu") + if fixed_alpha: + self.register_buffer( + "log_alpha", torch.tensor(math.log(alpha_init), device=device) + ) + else: + self.register_parameter( + "log_alpha", + torch.nn.Parameter(torch.tensor(math.log(alpha_init), device=device)), + ) + + if target_entropy == "auto": + target_entropy = -float(np.prod(actor_network.spec.shape)) + self.register_buffer( + "target_entropy", torch.tensor(target_entropy, device=device) + ) + + @property + def device(self) -> torch.device: + for p in self.actor_network_params: + return p.device + for p in self.qvalue_network_params: + return p.device + for p in self.value_network_params: + return p.device + raise RuntimeError( + "At least one of the networks of SACLoss must have trainable " "parameters." + ) + + def forward(self, tensordict: _TensorDict) -> _TensorDict: + if tensordict.ndimension() > 1: + tensordict = tensordict.view(-1) + + device = self.device + td_device = tensordict.to(device) + + loss_actor = self._loss_actor(td_device) + loss_qvalue, priority = self._loss_qvalue(td_device) + loss_value = self._loss_value(td_device) + loss_alpha = self._loss_alpha(td_device) + tensordict.set(self.priority_key, priority) + if (loss_actor.shape != loss_qvalue.shape) or ( + loss_actor.shape != loss_value.shape + ): + raise RuntimeError( + f"Losses shape mismatch: {loss_actor.shape}, {loss_qvalue.shape} and {loss_value.shape}" + ) + return TensorDict( + { + "loss_actor": loss_actor.mean(), + "loss_qvalue": loss_qvalue.mean(), + "loss_value": loss_value.mean(), + "loss_alpha": loss_alpha.mean(), + "alpha": self._alpha, + "entropy": td_device.get("_log_prob").mean().detach(), + }, + [], + ) + + def _loss_actor(self, tensordict: _TensorDict) -> Tensor: + # KL lossa + dist = self.actor_network.get_dist( + tensordict, + params=list(self.actor_network_params), + buffers=list(self.actor_network_buffers), + )[0] + a_reparm = dist.rsample() + log_prob = dist.log_prob(a_reparm) + + td_q = tensordict.select(*self.qvalue_network.in_keys) + td_q.set("action", a_reparm) + td_q = self.qvalue_network( + td_q, + params=list(self.target_qvalue_network_params), + buffers=list(self.qvalue_network_buffers), + vmap=True, + ) + min_q_logprob = td_q.get("state_action_value").min(0)[0].squeeze(-1) + + if log_prob.shape != min_q_logprob.shape: + raise RuntimeError( + f"Losses shape mismatch: {log_prob.shape} and {min_q_logprob.shape}" + ) + + # write log_prob in tensordict for alpha loss + tensordict.set("_log_prob", log_prob.detach()) + return self._alpha * log_prob # - min_q_logprob + + def _loss_qvalue(self, tensordict: _TensorDict) -> Tuple[Tensor, Tensor]: + actor_critic = ActorCriticWrapper(self.actor_network, self.value_network) + params = list(self.target_actor_network_params) + list( + self.target_value_network_params + ) + buffers = list(self.target_actor_network_buffers) + list( + self.target_value_network_buffers + ) + target_value = next_state_value( + tensordict, + actor_critic, + gamma=self.gamma, + next_val_key="state_value", + params=params, + buffers=buffers, + ) + + # value loss + qvalue_network = self.qvalue_network + + # Q-nets must be trained independently: as such, we split the data in 2 if required and train each q-net on + # one half of the data. + shape = tensordict.shape + if shape[0] % self.num_qvalue_nets != 0: + raise RuntimeError( + f"Batch size={tensordict.shape} is incompatible " + f"with num_qvqlue_nets={self.num_qvalue_nets}." + ) + tensordict_chunks = torch.stack( + tensordict.chunk(self.num_qvalue_nets, dim=0), 0 + ) + target_chunks = torch.stack(target_value.chunk(self.num_qvalue_nets, dim=0), 0) + + # if vmap=True, it is assumed that the input tensordict must be cast to the param shape + tensordict_chunks = qvalue_network( + tensordict_chunks, + params=list(self.qvalue_network_params), + buffers=list(self.qvalue_network_buffers), + vmap=( + 0, + 0, + 0, + 0, + ), + ) + pred_val = tensordict_chunks.get("state_action_value").squeeze(-1) + loss_value = distance_loss( + pred_val, target_chunks, loss_function=self.loss_function + ).view(*shape) + priority_value = torch.cat(abs(pred_val - target_chunks).unbind(0), 0) + + return loss_value, priority_value + + def _loss_value(self, tensordict: _TensorDict) -> Tensor: + # value loss + td_copy = tensordict.select(*self.value_network.in_keys).detach() + self.value_network( + td_copy, + params=list(self.value_network_params), + buffers=list(self.value_network_buffers), + ) + pred_val = td_copy.get("state_value").squeeze(-1) + + action_dist = self.actor_network.get_dist( + td_copy, + params=list(self.target_actor_network_params), + buffers=list(self.target_actor_network_buffers), + )[ + 0 + ] # resample an action + action = action_dist.rsample() + td_copy.set("action", action, inplace=False) + + qval_net = self.qvalue_network + td_copy = qval_net( + td_copy, + params=list(self.target_qvalue_network_params), + buffers=list(self.target_qvalue_network_buffers), + vmap=True, + ) + + min_qval = td_copy.get("state_action_value").squeeze(-1).min(0)[0] + + log_p = action_dist.log_prob(action) + if log_p.shape != min_qval.shape: + raise RuntimeError( + f"Losses shape mismatch: {min_qval.shape} and {log_p.shape}" + ) + target_val = min_qval - self._alpha * log_p + + loss_value = distance_loss( + pred_val, target_val, loss_function=self.loss_function + ) + return loss_value + + def _loss_alpha(self, tensordict: _TensorDict) -> Tensor: + log_pi = tensordict.get("_log_prob") + if self.target_entropy is not None: + # we can compute this loss even if log_alpha is not a parameter + alpha_loss = -self.log_alpha.exp() * (log_pi.detach() + self.target_entropy) + else: + # placeholder + alpha_loss = torch.zeros_like(log_pi) + return alpha_loss + + @property + def _alpha(self): + with torch.no_grad(): + alpha = self.log_alpha.detach().exp() + return alpha + + +class DoubleSACLoss(SACLoss): + """ + A Double SAC loss class. + As for Double DDPG/DQN losses, this class separates the target critic/value/actor networks from the + critic/value/actor networks used for data collection. Those target networks should be updated from their original + counterparts with some delay using dedicated classes (SoftUpdate and HardUpdate in objectives.cost.utils). + Note that the original networks will be copied at initialization using the copy.deepcopy method: in some rare cases + this may lead to unexpected behaviours (for instance if the networks change in a way that won't be reflected by their + state_dict). Please report any such bug if encountered. + + """ + + def __init__(self, *args, delay_actor=False, delay_qvalue=False, **kwargs): + self.delay_actor = delay_actor + self.delay_qvalue = delay_qvalue + self.delay_value = True + super().__init__(*args, **kwargs) diff --git a/torchrl/objectives/costs/utils.py b/torchrl/objectives/costs/utils.py new file mode 100644 index 00000000000..23b917a2dc1 --- /dev/null +++ b/torchrl/objectives/costs/utils.py @@ -0,0 +1,334 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import functools +from collections import OrderedDict +from typing import Iterable, Optional, Union + +import torch +from torch import nn, Tensor +from torch.nn import functional as F + +from torchrl.data.tensordict.tensordict import _TensorDict +from torchrl.envs.utils import step_tensor_dict +from torchrl.modules import TDModule + +__all__ = ["SoftUpdate", "HardUpdate", "distance_loss", "hold_out_params"] + + +class _context_manager: + def __init__(self, value=True): + self.value = value + self.prev = [] + + def __call__(self, func): + @functools.wraps(func) + def decorate_context(*args, **kwargs): + with self: + return func(*args, **kwargs) + + return decorate_context + + +def distance_loss( + v1: torch.Tensor, + v2: torch.Tensor, + loss_function: str, + strict_shape: bool = True, +) -> torch.Tensor: + """ + Computes a distance loss between two tensors. + + Args: + v1 (Tensor): a tensor with a shape compatible with v2 + v2 (Tensor): a tensor with a shape compatible with v1 + loss_function (str): One of "l2", "l1" or "smooth_l1" representing which loss function is to be used. + strict_shape (bool): if False, v1 and v2 are allowed to have a different shape. + Default is `True`. + + Returns: + A tensor of the shape v1.view_as(v2) or v2.view_as(v1) with values equal to the distance loss between the + two. + + """ + if v1.shape != v2.shape and strict_shape: + raise RuntimeError( + f"The input tensors have shapes {v1.shape} and {v2.shape} which are incompatible." + ) + + if loss_function == "l2": + value_loss = F.mse_loss( + v1, + v2, + reduction="none", + ) + + elif loss_function == "l1": + value_loss = F.l1_loss( + v1, + v2, + reduction="none", + ) + + elif loss_function == "smooth_l1": + value_loss = F.smooth_l1_loss( + v1, + v2, + reduction="none", + ) + else: + raise NotImplementedError(f"Unknown loss {loss_function}") + return value_loss + + +class ValueLoss: + value_network: nn.Module + target_value_network: nn.Module + + +class _TargetNetUpdate: + """ + An abstract class for target network update in Double DQN/DDPG. + + Args: + loss_module (DQNLoss or DDPGLoss): loss module where the target network should be updated. + + """ + + def __init__( + self, + loss_module: Union["DQNLoss", "DDPGLoss", "SACLoss"], + ): + + _target_names = [] + # for properties + for name in loss_module.__class__.__dict__: + if ( + name.startswith("_target_") + and (name.endswith("params") or name.endswith("buffers")) + and (getattr(loss_module, name) is not None) + ): + _target_names.append(name) + + # for regular lists: raise an exception + for name in loss_module.__dict__: + if ( + name.startswith("_target_") + and (name.endswith("params") or name.endswith("buffers")) + and (getattr(loss_module, name) is not None) + ): + raise RuntimeError( + "Your module seems to have a _target tensor list contained " + "in a non-dynamic structure (such as a list). If the " + "module is cast onto a device, the reference to these " + "tensors will be lost." + ) + + _source_names = ["".join(name.split("_target_")) for name in _target_names] + + for _source in _source_names: + try: + getattr(loss_module, _source) + except AttributeError: + raise RuntimeError( + f"Incongruent target and source parameter lists: " + f"{_source} is not an attribute of the loss_module" + ) + + self._target_names = _target_names + self._source_names = _source_names + self.loss_module = loss_module + self.initialized = False + + @property + def _targets(self): + return OrderedDict( + {name: getattr(self.loss_module, name) for name in self._target_names} + ) + + @property + def _sources(self): + return OrderedDict( + {name: getattr(self.loss_module, name) for name in self._source_names} + ) + + def init_(self) -> None: + for source, target in zip(self._sources.values(), self._targets.values()): + for p_source, p_target in zip(source, target): + if p_target.requires_grad: + raise RuntimeError("the target parameter is part of a graph.") + p_target.data.copy_(p_source.data) + self.initialized = True + + def step(self) -> None: + if not self.initialized: + raise Exception( + f"{self.__class__.__name__} must be " + f"initialized (`{self.__class__.__name__}.init_()`) before calling step()" + ) + + for source, target in zip(self._sources.values(), self._targets.values()): + for p_source, p_target in zip(source, target): + if p_target.requires_grad: + raise RuntimeError("the target parameter is part of a graph.") + self._step(p_source, p_target) + + def _step(self, p_source: Tensor, p_target: Tensor) -> None: + raise NotImplementedError + + def __repr__(self) -> str: + string = ( + f"{self.__class__.__name__}(sources={[name for name in self._sources]}, targets=" + f"{[name for name in self._targets]})" + ) + return string + + +class SoftUpdate(_TargetNetUpdate): + """ + A soft-update class for target network update in Double DQN/DDPG. + This was proposed in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING", https://arxiv.org/pdf/1509.02971.pdf + + Args: + loss_module (DQNLoss or DDPGLoss): loss module where the target network should be updated. + eps (scalar): epsilon in the update equation: + param = prev_param * eps + new_param * (1-eps) + default: 0.999 + """ + + def __init__( + self, + loss_module: Union["DQNLoss", "DDPGLoss", "SACLoss"], + eps: float = 0.999, + ): + if not (eps < 1.0 and eps > 0.0): + raise ValueError( + f"Got eps = {eps} when it was supposed to be between 0 and 1." + ) + super(SoftUpdate, self).__init__(loss_module) + self.eps = eps + + def _step(self, p_source: Tensor, p_target: Tensor) -> None: + p_target.data.copy_(p_target.data * self.eps + p_source.data * (1 - self.eps)) + + +class HardUpdate(_TargetNetUpdate): + """ + A hard-update class for target network update in Double DQN/DDPG (by contrast with soft updates). + This was proposed in the original Double DQN paper: "Deep Reinforcement Learning with Double Q-learning", + https://arxiv.org/abs/1509.06461. + + Args: + loss_module (DQNLoss or DDPGLoss): loss module where the target network should be updated. + value_network_update_interval (scalar): how often the target network should be updated. + default: 1000 + """ + + def __init__( + self, + loss_module: Union["DQNLoss", "DDPGLoss", "SACLoss"], + value_network_update_interval: float = 1000, + ): + super(HardUpdate, self).__init__(loss_module) + self.value_network_update_interval = value_network_update_interval + self.counter = 0 + + def _step(self, p_source: Tensor, p_target: Tensor) -> None: + if self.counter == self.value_network_update_interval: + p_target.data.copy_(p_source.data) + + def step(self) -> None: + super().step() + if self.counter == self.value_network_update_interval: + self.counter = 0 + else: + self.counter += 1 + + +class hold_out_net(_context_manager): + """Context manager to hold a network out of a computational graph.""" + + def __init__(self, network: nn.Module) -> None: + self.network = network + try: + self.p_example = next(network.parameters()) + except StopIteration: + raise RuntimeError( + "hold_out_net requires the network parameter set to be " "non-empty." + ) + self._prev_state = [] + + def __enter__(self) -> None: + self._prev_state.append(self.p_example.requires_grad) + self.network.requires_grad_(False) + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.network.requires_grad_(self._prev_state.pop()) + + +class hold_out_params(_context_manager): + """Context manager to hold a list of parameters out of a computational graph.""" + + def __init__(self, params: Iterable[Tensor]) -> None: + self.params = tuple(p.detach() for p in params) + + def __enter__(self) -> None: + return self.params + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + pass + + +@torch.no_grad() +def next_state_value( + tensor_dict: _TensorDict, + operator: Optional[TDModule] = None, + next_val_key: str = "state_action_value", + gamma: float = 0.99, + pred_next_val: Optional[Tensor] = None, + **kwargs, +) -> torch.Tensor: + """ + Computes the next state value (without gradient) to compute a target for the MSE loss + L = Sum[ (q_value - target_value)^2 ] + The target value is computed as + r + gamma ** n_steps_to_next * value_next_state + If the reward is the immediate reward, n_steps_to_next=1. If N-steps rewards are used, n_steps_to_next is gathered + from the input tensordict. + + Args: + tensor_dict (_TensorDict): Tensordict containing a reward and done key (and a n_steps_to_next key for n-steps + rewards). + operator (ProbabilisticTDModule, optional): the value function operator. Should write a 'next_val_key' + key-value in the input tensordict when called. It does not need to be provided if pred_next_val is given. + next_val_key (str, optional): key where the next value will be written. + Default: 'state_action_value' + gamma (float, optional): return discount rate. + default: 0.99 + pred_next_val (Tensor, optional): the next state value can be provided if it is not computed with the operator. + + Returns: + a Tensor of the size of the input tensordict containing the predicted value state. + """ + try: + steps_to_next_obs = tensor_dict.get("steps_to_next_obs").squeeze(-1) + except KeyError: + steps_to_next_obs = 1 + + rewards = tensor_dict.get("reward").squeeze(-1) + done = tensor_dict.get("done").squeeze(-1) + + if pred_next_val is None: + next_td = step_tensor_dict(tensor_dict) # next_observation -> observation + next_td = next_td.select(*operator.in_keys) + operator(next_td, **kwargs) + pred_next_val_detach = next_td.get(next_val_key).squeeze(-1) + else: + pred_next_val_detach = pred_next_val.squeeze(-1) + done = done.to(torch.float) + target_value = (1 - done) * pred_next_val_detach + rewards = rewards.to(torch.float) + target_value = rewards + (gamma ** steps_to_next_obs) * target_value + return target_value diff --git a/torchrl/objectives/returns/__init__.py b/torchrl/objectives/returns/__init__.py new file mode 100644 index 00000000000..344570b4c76 --- /dev/null +++ b/torchrl/objectives/returns/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .gae import * +from .pg import * +from .returns import * +from .vtrace import * diff --git a/torchrl/objectives/returns/functional.py b/torchrl/objectives/returns/functional.py new file mode 100644 index 00000000000..445f6f92922 --- /dev/null +++ b/torchrl/objectives/returns/functional.py @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch + + +def generalized_advantage_estimate( + gamma: float, + lamda: float, + state_value: torch.Tensor, + next_state_value: torch.Tensor, + reward: torch.Tensor, + done: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get generalized advantage estimate of a trajectory + Refer to "HIGH-DIMENSIONAL CONTINUOUS CONTROL USING GENERALIZED ADVANTAGE ESTIMATION" + https://arxiv.org/pdf/1506.02438.pdf for more context. + + Args: + gamma (scalar): exponential mean discount. + lamda (scalar): trajectory discount. + state_value (Tensor): value function result with old_state input. + must be a [Batch x TimeSteps x 1] or [Batch x TimeSteps] tensor + next_state_value (Tensor): value function result with new_state input. + must be a [Batch x TimeSteps x 1] or [Batch x TimeSteps] tensor + reward (Tensor): agent reward of taking actions in the environment. + must be a [Batch x TimeSteps x 1] or [Batch x TimeSteps] tensor + done (Tensor): boolean flag for end of episode. + """ + not_done = 1 - done.to(next_state_value.dtype) + batch_size, time_steps = not_done.shape[:2] + device = state_value.device + advantage = torch.zeros(batch_size, time_steps + 1, 1, device=device) + + for t in reversed(range(time_steps)): + delta = ( + reward[:, t] + + (gamma * next_state_value[:, t] * not_done[:, t]) + - state_value[:, t] + ) + advantage[:, t] = delta + (gamma * lamda * advantage[:, t + 1] * not_done[:, t]) + + value_target = advantage[:, :time_steps] + state_value + + return advantage[:, :time_steps], value_target diff --git a/torchrl/objectives/returns/gae.py b/torchrl/objectives/returns/gae.py new file mode 100644 index 00000000000..29b9f1c4bef --- /dev/null +++ b/torchrl/objectives/returns/gae.py @@ -0,0 +1,130 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Union + +import torch + +# for value, log_policy, reward, entropy in list(zip(values, log_policies, rewards, entropies))[::-1]: +# gae = gae * opt.gamma * opt.tau +# gae = gae + reward + opt.gamma * next_value.detach() - value.detach() +# next_value = value +# actor_loss = actor_loss + log_policy * gae +# R = R * opt.gamma + reward +# critic_loss = critic_loss + (R - value) ** 2 / 2 +# entropy_loss = entropy_loss + entropy +from torchrl.envs.utils import step_tensor_dict + +# from https://github.com/H-Huang/rpc-rl-experiments/blob/6621f0aadb347d1c4e24bcf46517ac36907401ff/a3c/process.py#L14 +# TODO: create function / object that vectorises that +# actor_loss = 0 +# critic_loss = 0 +# entropy_loss = 0 +# next_value = R +from ...data.tensordict.tensordict import _TensorDict +from ...modules import ProbabilisticTDModule +from .functional import generalized_advantage_estimate + + +# +# def gae(values: torch.Tensor, log_prob_actions: torch.Tensor, rewards: torch.Tensor, entropies: torch.Tensor, +# gamma: Union[Number, torch.Tensor], tau: float) -> torch.Tensor: +# """ +# +# Args: +# values: +# log_prob_actions: +# rewards: +# entropies: +# gamma: +# tau: +# +# Returns: +# +# """ +# gaes = [] +# for value, log_policy, reward, entropy in list( +# zip(values, log_prob_actions, rewards, entropies) +# )[::-1]: +# if next_value is None: +# next_value = torch.zeros_like(value) +# gae = gae * gamma * tau +# gae = gae + reward + gamma * next_value.detach() - value.detach() +# next_value = value +# gaes.append(gae) +# return torch.stack(gae) +# + + +class GAE: + """ + A class wrapper around the generalized advantage estimate functional. + Refer to "HIGH-DIMENSIONAL CONTINUOUS CONTROL USING GENERALIZED ADVANTAGE ESTIMATION" + https://arxiv.org/pdf/1506.02438.pdf for more context. + + Args: + gamma (scalar): exponential mean discount. + lamda (scalar): trajectory discount. + critic (ProbabilisticTDModule): value operator used to retrieve the value estimates. + average_rewards (bool): if True, rewards will be standardized before the GAE is computed. + gradient_mode (bool): if True, gradients are propagated throught the computation of the value function. + Default is `False`. + """ + + def __init__( + self, + gamma: Union[float, torch.Tensor], + lamda: float, + critic: ProbabilisticTDModule, + average_rewards: bool = False, + gradient_mode: bool = False, + ): + self.gamma = gamma + self.lamda = lamda + self.critic = critic + self.average_rewards = average_rewards + self.gradient_mode = gradient_mode + + def __call__(self, tensor_dict: _TensorDict) -> _TensorDict: + """Computes the GAE given the data in tensor_dict. + + Args: + tensor_dict (_TensorDict): A TensorDict containing the data (observation, action, reward, done state) + necessary to compute the value estimates and the GAE. + + Returns: + An updated TensorDict with an "advantage" and a "value_target" keys + + """ + with torch.set_grad_enabled(self.gradient_mode): + if tensor_dict.batch_dims < 2: + raise RuntimeError( + "Expected input tensordict to have at least two dimensions, got" + f"tensor_dict.batch_size = {tensor_dict.batch_size}" + ) + reward = tensor_dict.get("reward") + if self.average_rewards: + reward = reward - reward.mean() + reward = reward / reward.std().clamp_min(1e-4) + tensor_dict.set_( + "reward", reward + ) # we must update the rewards if they are used later in the code + + gamma, lamda = self.gamma, self.lamda + self.critic(tensor_dict) + value = tensor_dict.get("state_value") + + step_td = step_tensor_dict(tensor_dict) + self.critic(step_td) + next_value = step_td.get("state_value") + + done = tensor_dict.get("done") + + adv, value_target = generalized_advantage_estimate( + gamma, lamda, value, next_value, reward, done + ) + tensor_dict.set("advantage", adv) + tensor_dict.set("value_target", value_target) + return tensor_dict diff --git a/torchrl/objectives/returns/pg.py b/torchrl/objectives/returns/pg.py new file mode 100644 index 00000000000..d62fe90a685 --- /dev/null +++ b/torchrl/objectives/returns/pg.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# implements a function that takes a sequence of returns and multiply its by the policy log_prob to get a differentiable objective diff --git a/torchrl/objectives/returns/returns.py b/torchrl/objectives/returns/returns.py new file mode 100644 index 00000000000..af8997f969f --- /dev/null +++ b/torchrl/objectives/returns/returns.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Union + +import torch +from torch import nn + + +def bellman_max( + next_observation: torch.Tensor, + reward: torch.Tensor, + done: torch.Tensor, + gamma: Union[float, torch.Tensor], + value_model: nn.Module, +): + qmax = value_model(next_observation).max(dim=-1)[0] + nonterminal_target = reward + gamma * qmax + terminal_target = reward + target = done * terminal_target + (~done) * nonterminal_target + return target diff --git a/torchrl/objectives/returns/vtrace.py b/torchrl/objectives/returns/vtrace.py new file mode 100644 index 00000000000..11f92ac1b18 --- /dev/null +++ b/torchrl/objectives/returns/vtrace.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Tuple, Union + +import torch + + +def c_val( + log_pi: torch.Tensor, + log_mu: torch.Tensor, + c: Union[float, torch.Tensor] = 1, +) -> torch.Tensor: + return (log_pi - log_mu).clamp_max(math.log(c)).exp() + + +def dv_val( + rewards: torch.Tensor, + vals: torch.Tensor, + gamma: Union[float, torch.Tensor], + rho_bar: Union[float, torch.Tensor], + log_pi: torch.Tensor, + log_mu: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + rho = c_val(log_pi, log_mu, rho_bar) + next_vals = torch.cat([vals[:, 1:], torch.zeros_like(vals[:, :1])], 1) + dv = rho * (rewards + gamma * next_vals - vals) + return dv, rho + + +def vtrace( + rewards: torch.Tensor, + vals: torch.Tensor, + log_pi: torch.Tensor, + log_mu: torch.Tensor, + gamma: Union[torch.Tensor, float], + rho_bar: Union[float, torch.Tensor] = 1.0, + c_bar: Union[float, torch.Tensor] = 1.0, +) -> Tuple[torch.Tensor, torch.Tensor]: + T = vals.shape[1] + if not isinstance(gamma, torch.Tensor): + gamma = torch.full_like(vals, gamma) + + dv, rho = dv_val(rewards, vals, gamma, rho_bar, log_pi, log_mu) + c = c_val(log_pi, log_mu, c_bar) + + v_out = [] + v_out.append(vals[:, -1] + dv[:, -1]) + for t in range(T - 2, -1, -1): + _v_out = ( + vals[:, t] + dv[:, t] + gamma[:, t] * c[:, t] * (v_out[-1] - vals[:, t + 1]) + ) + v_out.append(_v_out) + v_out = torch.stack(list(reversed(v_out)), 1) + return v_out, rho diff --git a/torchrl/record/__init__.py b/torchrl/record/__init__.py new file mode 100644 index 00000000000..551024679b7 --- /dev/null +++ b/torchrl/record/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .recorder import * +from .rendering import * diff --git a/torchrl/record/recorder.py b/torchrl/record/recorder.py new file mode 100644 index 00000000000..69b947ebed2 --- /dev/null +++ b/torchrl/record/recorder.py @@ -0,0 +1,148 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Sequence + +import torch + +from torchrl.data.tensordict.tensordict import _TensorDict +from torchrl.envs.transforms import ObservationTransform, Transform + +__all__ = ["VideoRecorder", "TensorDictRecorder"] + + +class VideoRecorder(ObservationTransform): + """ + Video Recorder transform. + Will record a series of observations from an environment and write them + to a TensorBoard SummaryWriter object when needed. + + Args: + writer (SummaryWriter): a tb.SummaryWriter instance where the video + should be written. + tag (str): the video tag in the writer. + keys (Sequence[str], optional): keys to be read to produce the video. + Default is `"next_observation_pixels"`. + skip (int): frame interval in the output video. + Default is 2. + """ + + def __init__( + self, + writer: "SummaryWriter", + tag: str, + keys: Optional[Sequence[str]] = None, + skip: int = 2, + **kwargs, + ) -> None: + if keys is None: + keys = ["next_observation_pixels"] + + super().__init__(keys=keys) + video_kwargs = {"fps": 6} + video_kwargs.update(kwargs) + self.video_kwargs = video_kwargs + self.iter = 0 + self.skip = skip + self.writer = writer + self.tag = tag + self.count = 0 + self.obs = [] + try: + import moviepy # noqa + except ImportError: + raise Exception("moviepy not found, VideoRecorder cannot be created") + + def _apply(self, observation: torch.Tensor) -> torch.Tensor: + if not (observation.shape[-1] == 3 or observation.ndimension() == 2): + raise RuntimeError(f"Invalid observation shape, got: {observation.shape}") + observation_trsf = observation + self.count += 1 + if self.count % self.skip == 0: + if observation.ndimension() == 2: + observation_trsf = observation.unsqueeze(-3) + else: + if observation.ndimension() != 3: + raise RuntimeError( + "observation is expected to have 3 dimensions, " + f"got {observation.ndimension()} instead" + ) + if observation_trsf.shape[-1] != 3: + raise RuntimeError( + "observation_trsf is expected to have 3 dimensions, " + f"got {observation_trsf.ndimension()} instead" + ) + observation_trsf = observation_trsf.permute(2, 0, 1) + self.obs.append(observation_trsf.cpu().to(torch.uint8)) + return observation + + def dump(self) -> None: + """Writes the video to the self.writer attribute.""" + self.writer.add_video( + tag=f"{self.tag}", + vid_tensor=torch.stack(self.obs, 0).unsqueeze(0), + global_step=self.iter, + **self.video_kwargs, + ) + self.iter += 1 + self.count = 0 + self.obs = [] + + +class TensorDictRecorder(Transform): + """ + TensorDict recorder. + When the 'dump' method is called, this class will save a stack of the tensordict resulting from `env.step(td)` in a + file with a prefix defined by the out_file_base argument. + + Args: + out_file_base (str): a string defining the prefix of the file where the tensordict will be written. + skip_reset (bool): if True, the first TensorDict of the list will be discarded (usually the tensordict + resulting from the call to `env.reset()`) + default: True + skip (int): frame interval for the saved tensordict. + default: 4 + + """ + + def __init__( + self, + out_file_base: str, + skip_reset: bool = True, + skip: int = 4, + keys: Optional[Sequence[str]] = None, + ) -> None: + if keys is None: + keys = [] + + super().__init__(keys=keys) + self.iter = 0 + self.out_file_base = out_file_base + self.td = [] + self.skip_reset = skip_reset + self.skip = skip + self.count = 0 + + def _call(self, td: _TensorDict) -> _TensorDict: + self.count += 1 + if self.count % self.skip == 0: + _td = td + if self.keys: + _td = td.select(*self.keys).clone() + self.td.append(_td) + return td + + def dump(self) -> None: + td = self.td + if self.skip_reset: + td = td[1:] + torch.save( + torch.stack(td, 0).contiguous(), + f"{self.out_file_base}_tensor_dict.t", + ) + self.iter += 1 + self.count = 0 + del self.td + self.td = [] diff --git a/torchrl/record/rendering.py b/torchrl/record/rendering.py new file mode 100644 index 00000000000..7bec24cb17b --- /dev/null +++ b/torchrl/record/rendering.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree.