diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 00000000..ae319c70
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,23 @@
+# How to Contribute
+
+We'd love to accept your patches and contributions to this project. There are
+just a few small guidelines you need to follow.
+
+## Contributor License Agreement
+
+Contributions to this project must be accompanied by a Contributor License
+Agreement. You (or your employer) retain the copyright to your contribution,
+this simply gives us permission to use and redistribute your contributions as
+part of the project. Head over to to see
+your current agreements on file or to sign a new one.
+
+You generally only need to submit a CLA once, so if you've already submitted one
+(even if it was for a different project), you probably don't need to do it
+again.
+
+## Code reviews
+
+All submissions, including submissions by project members, require review. We
+use GitHub pull requests for this purpose. Consult
+[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
+information on using pull requests.
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 00000000..7a4a3ea2
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 00000000..fc03786d
--- /dev/null
+++ b/README.md
@@ -0,0 +1,85 @@
+# `dm_control`: The DeepMind Control Suite and Control Package
+
+# 
+
+This package contains:
+
+- A set of Python Reinforcement Learning environments powered by the MuJoCo
+ physics engine. See the `suite` subdirectory.
+
+- Libraries that provide Python bindings to the MuJoCo physics engine.
+
+If you use this package, please cite our accompanying accompanying [tech report](tech_report.pdf).
+
+## Installation and requirements
+
+Follow these steps to install `dm_control`:
+
+1. Download MuJoCo Pro 1.50 from the download page on the [MuJoCo website](http://www.mujoco.org/).
+ MuJoCo Pro must be installed before `dm_control`, since `dm_control`'s
+ install script generates Python [`ctypes`](https://docs.python.org/2/library/ctypes.html)
+ bindings based on MuJoCo's header files. By default, `dm_control` assumes
+ that the MuJoCo Zip archive is extracted as `~/.mujoco/mjpro150`.
+
+2. Install the `dm_control` Python package by running
+ `pip install git+git://github.com/deepmind/dm_control.git`
+ (PyPI package coming soon) or by cloning the repository and running
+ `pip install /path/to/dm_control/`
+ At installation time, `dm_control` looks for the MuJoCo headers from Step 1
+ in `~/.mujoco/mjpro150/include`, however this path can be configured with the
+ `headers-dir` command line argument.
+
+3. Install a license key for MuJoCo, required by `dm_control` at runtime. See
+ the [MuJoCo license key page](https://www.roboti.us/license.html) for further
+ details. By default, `dm_control` looks for the MuJoCo license key file at
+ `~/.mujoco/mjkey.txt`.
+
+4. If the license key (e.g. `mjkey.txt`) or the shared library provided by
+ MuJoCo Pro (e.g. `libmujoco150.so` or `libmujoco150.dylib`) are installed at
+ non-default paths, specify their locations using the `MJKEY_PATH` and
+ `MJLIB_PATH` environment variables respectively.
+
+## Additional instructions for Homebrew users on macOS
+
+1. The above instructions using `pip` should work, provided that you
+ use a Python interpreter that is installed by Homebrew (rather than the
+ system-default one).
+
+2. To get OpenGL working, install the `glfw` package from Homebrew by running
+ `brew install glfw`.
+
+3. Before running, the `DYLD_LIBRARY_PATH` environment variable needs to be
+ updated with the path to the GLFW library. This can be done by running
+ `export DYLD_LIBRARY_PATH=$(brew --prefix)/lib:$DYLD_LIBRARY_PATH`.
+
+## Control Suite quickstart
+
+```python
+from dm_control import suite
+
+# Load one task:
+env = suite.load(domain_name="cartpole", task_name="swingup")
+
+# Iterate over a task set:
+for domain_name, task_name in suite.BENCHMARKING:
+ env = suite.load(domain_name, task_name)
+
+# Step through an episode and print out reward, discount and observation.
+action_spec = env.action_spec()
+time_step = env.reset()
+while not time_step.last():
+ action = np.random.uniform(action_spec.minimum,
+ action_spec.maximum,
+ size=action_spec.shape)
+ time_step = env.step(action)
+ print(time_step.reward, time_step.discount, time_step.observation)
+```
+
+See our [tech report](tech_report.pdf) for further details.
+
+## Illustration video
+
+Below is a video montage of solved Control Suite tasks, with reward
+visualisation enabled.
+
+[](https://www.youtube.com/watch?v=rAai4QzcYbs)
diff --git a/all_domains.png b/all_domains.png
new file mode 100644
index 00000000..10c22fae
Binary files /dev/null and b/all_domains.png differ
diff --git a/dm_control/__init__.py b/dm_control/__init__.py
new file mode 100644
index 00000000..1ebb270f
--- /dev/null
+++ b/dm_control/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
diff --git a/dm_control/autowrap/__init__.py b/dm_control/autowrap/__init__.py
new file mode 100644
index 00000000..f9817d49
--- /dev/null
+++ b/dm_control/autowrap/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
diff --git a/dm_control/autowrap/autowrap.py b/dm_control/autowrap/autowrap.py
new file mode 100644
index 00000000..7da4378e
--- /dev/null
+++ b/dm_control/autowrap/autowrap.py
@@ -0,0 +1,138 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+r"""Automatically generates ctypes Python bindings for MuJoCo.
+
+Parses mjdata.h, mjmodel.h, mjrender.h, mjvisualize.h, mjxmacro.h and mujoco.h;
+generates the following Python source files:
+
+ constants.py: constants
+ enums.py: enums
+ sizes.py: size information for dynamically-shaped arrays
+ types.py: ctypes declarations for structs
+ wrappers.py: low-level Python wrapper classes for structs (these implement
+ getter/setter methods for struct members where applicable)
+ functions.py: ctypes function declarations for MuJoCo API functions
+
+Example usage:
+
+ autowrap --header_paths='/path/to/mjmodel.h /path/to/mjdata.h ...' \
+ --output_dir=/path/to/mjbindings
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import os
+
+# Internal dependencies.
+
+from absl import app
+from absl import flags
+from absl import logging
+
+from dm_control.autowrap import binding_generator
+from dm_control.autowrap import codegen_util
+
+import six
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_spaceseplist(
+ "header_paths", None,
+ "Space-separated list of paths to MuJoCo header files.")
+
+flags.DEFINE_string("output_dir", None,
+ "Path to output directory for wrapper source files.")
+
+
+def main(unused_argv):
+ # Get the path to the xmacro header file.
+ xmacro_hdr_path = None
+ for path in FLAGS.header_paths:
+ if path.endswith("mjxmacro.h"):
+ xmacro_hdr_path = path
+ break
+ if xmacro_hdr_path is None:
+ logging.fatal("List of inputs must contain a path to mjxmacro.h")
+
+ srcs = codegen_util.UniqueOrderedDict()
+ for p in sorted(FLAGS.header_paths):
+ with open(p, "r") as f:
+ srcs[p] = f.read()
+
+ # consts_dict should be a codegen_util.UniqueOrderedDict.
+ # This is a temporary workaround due to the fact that the parser does not yet
+ # handle nested `#if define(predicate)` blocks, which results in some
+ # constants being parsed twice. We therefore can't enforce the uniqueness of
+ # the keys in `consts_dict`. As of MuJoCo v1.30 there is only a single problem
+ # block beginning on line 10 in mujoco.h, and a single constant is affected
+ # (MJAPI).
+ consts_dict = collections.OrderedDict()
+
+ # These are commented in `mjdata.h` but have no macros in `mjxmacro.h`.
+ hints_dict = codegen_util.UniqueOrderedDict({"buffer": ("nbuffer",),
+ "stack": ("nstack",)})
+
+ parser = binding_generator.BindingGenerator(
+ consts_dict=consts_dict, hints_dict=hints_dict)
+
+ # Parse enums.
+ for pth, src in six.iteritems(srcs):
+ if pth is not xmacro_hdr_path:
+ parser.parse_enums(src)
+
+ # Parse constants and type declarations.
+ for pth, src in six.iteritems(srcs):
+ if pth is not xmacro_hdr_path:
+ parser.parse_consts_typedefs(src)
+
+ # Get shape hints from mjxmacro.h.
+ parser.parse_hints(srcs[xmacro_hdr_path])
+
+ # Parse structs.
+ for pth, src in six.iteritems(srcs):
+ if pth is not xmacro_hdr_path:
+ parser.parse_structs(src)
+
+ # Parse functions.
+ for pth, src in six.iteritems(srcs):
+ if pth is not xmacro_hdr_path:
+ parser.parse_functions(src)
+
+ # Parse global strings and function pointers.
+ for pth, src in six.iteritems(srcs):
+ if pth is not xmacro_hdr_path:
+ parser.parse_global_strings(src)
+ parser.parse_function_pointers(src)
+
+ # Create the output directory if it doesn't already exist.
+ if not os.path.exists(FLAGS.output_dir):
+ os.makedirs(FLAGS.output_dir)
+
+ # Generate Python source files and write them to the output directory.
+ parser.write_consts(os.path.join(FLAGS.output_dir, "constants.py"))
+ parser.write_enums(os.path.join(FLAGS.output_dir, "enums.py"))
+ parser.write_types(os.path.join(FLAGS.output_dir, "types.py"))
+ parser.write_wrappers(os.path.join(FLAGS.output_dir, "wrappers.py"))
+ parser.write_funcs_and_globals(os.path.join(FLAGS.output_dir, "functions.py"))
+ parser.write_index_dict(os.path.join(FLAGS.output_dir, "sizes.py"))
+
+if __name__ == "__main__":
+ flags.mark_flag_as_required("header_paths")
+ flags.mark_flag_as_required("output_dir")
+ app.run(main)
diff --git a/dm_control/autowrap/binding_generator.py b/dm_control/autowrap/binding_generator.py
new file mode 100644
index 00000000..023f71c8
--- /dev/null
+++ b/dm_control/autowrap/binding_generator.py
@@ -0,0 +1,526 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Parses MuJoCo header files and generates Python bindings."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import pprint
+import textwrap
+
+# Internal dependencies.
+
+from absl import logging
+
+from dm_control.autowrap import c_declarations
+from dm_control.autowrap import codegen_util
+from dm_control.autowrap import header_parsing
+
+import pyparsing
+import six
+
+# Absolute path to the top-level module.
+_MODULE = "dm_control.mujoco.wrapper"
+
+# Imports used in all generated source files.
+_BOILERPLATE_IMPORTS = [
+ "from __future__ import absolute_import",
+ "from __future__ import division",
+ "from __future__ import print_function\n",
+]
+
+
+class Error(Exception):
+ pass
+
+
+class BindingGenerator(object):
+ """Parses declarations from MuJoCo headers and generates Python bindings."""
+
+ def __init__(self, enums_dict=None, consts_dict=None, typedefs_dict=None,
+ hints_dict=None, structs_dict=None, funcs_dict=None,
+ strings_dict=None, func_ptrs_dict=None, index_dict=None):
+ """Constructs a new HeaderParser instance.
+
+ The optional arguments listed below can be used to passing in dict-like
+ objects specifying pre-defined declarations. By default empty
+ UniqueOrderedDicts will be instantiated and then populated according to the
+ contents of the headers.
+
+ Args:
+ enums_dict: nested mappings from {enum_name: {member_name: value}}
+ consts_dict: mapping from {const_name: value}
+ typedefs_dict: mapping from {type_name: ctypes_typename}
+ hints_dict: mapping from {var_name: shape_tuple}
+ structs_dict: mapping from {struct_name: Struct_instance}
+ funcs_dict: mapping from {func_name: Function_instance}
+ strings_dict: mapping from {var_name: StaticStringArray_instance}
+ func_ptrs_dict: mapping from {var_name: FunctionPtr_instance}
+ index_dict: mapping from {lowercase_struct_name: {var_name: shape_tuple}}
+ """
+ self.enums_dict = (enums_dict if enums_dict is not None
+ else codegen_util.UniqueOrderedDict())
+ self.consts_dict = (consts_dict if consts_dict is not None
+ else codegen_util.UniqueOrderedDict())
+ self.typedefs_dict = (typedefs_dict if typedefs_dict is not None
+ else codegen_util.UniqueOrderedDict())
+ self.hints_dict = (hints_dict if hints_dict is not None
+ else codegen_util.UniqueOrderedDict())
+ self.structs_dict = (structs_dict if structs_dict is not None
+ else codegen_util.UniqueOrderedDict())
+ self.funcs_dict = (funcs_dict if funcs_dict is not None
+ else codegen_util.UniqueOrderedDict())
+ self.strings_dict = (strings_dict if strings_dict is not None
+ else codegen_util.UniqueOrderedDict())
+ self.func_ptrs_dict = (func_ptrs_dict if func_ptrs_dict is not None
+ else codegen_util.UniqueOrderedDict())
+ self.index_dict = (index_dict if index_dict is not None
+ else codegen_util.UniqueOrderedDict())
+
+ def get_consts_and_enums(self):
+ consts_and_enums = self.consts_dict.copy()
+ for enum in six.itervalues(self.enums_dict):
+ consts_and_enums.update(enum)
+ return consts_and_enums
+
+ def resolve_size(self, old_size):
+ """Resolves an array size identifier.
+
+ The following conversions will be attempted:
+
+ * If `old_size` is an integer it will be returned as-is.
+ * If `old_size` is a string of the form `"3"` it will be cast to an int.
+ * If `old_size` is a string in `self.consts_dict` then the value of the
+ constant will be returned.
+ * If `old_size` is a string of the form `"3*constant_name"` then the
+ result of `3*constant_value` will be returned.
+ * If `old_size` is a string that does not specify an int constant and
+ cannot be cast to an int (e.g. an identifier for a dynamic dimension,
+ such as `"ncontact"`) then it will be returned as-is.
+
+ Args:
+ old_size: An int or string.
+
+ Returns:
+ An int or string.
+ """
+ if isinstance(old_size, int):
+ return old_size # If it's already an int then there's nothing left to do
+ elif "*" in old_size:
+ # If it's a string specifying a product (such as "2*mjMAXLINEPNT"),
+ # recursively resolve the components to ints and calculate the result.
+ size = 1
+ for part in old_size.split("*"):
+ dim = self.resolve_size(part)
+ assert isinstance(dim, int)
+ size *= dim
+ return size
+ else:
+ # Recursively dereference any sizes declared in header macros
+ size = codegen_util.recursive_dict_lookup(old_size,
+ self.get_consts_and_enums())
+ # Try to coerce the result to an int, return a string if this fails
+ return codegen_util.try_coerce_to_num(size, try_types=(int,))
+
+ def get_shape_tuple(self, old_size, squeeze=False):
+ """Generates a shape tuple from parser results.
+
+ Args:
+ old_size: Either a `pyparsing.ParseResults`, or a valid int or string
+ input to `self.resolve_size` (see method docstring for further details).
+ squeeze: If True, any dimensions that are statically defined as 1 will be
+ removed from the shape tuple.
+
+ Returns:
+ A shape tuple containing ints for dimensions that are statically defined,
+ and string size identifiers for dimensions that can only be determined at
+ runtime.
+ """
+ if isinstance(old_size, pyparsing.ParseResults):
+ # For multi-dimensional arrays, convert each dimension separately
+ shape = tuple(self.resolve_size(dim) for dim in old_size)
+ else:
+ shape = (self.resolve_size(old_size),)
+ if squeeze:
+ shape = tuple(d for d in shape if d != 1) # Remove singleton dimensions
+ return shape
+
+ def resolve_typename(self, old_ctypes_typename):
+ """Gets a qualified ctypes typename from typedefs_dict and C_TO_CTYPES."""
+
+ # Recursively dereference any typenames declared in self.typedefs_dict
+ new_ctypes_typename = codegen_util.recursive_dict_lookup(
+ old_ctypes_typename, self.typedefs_dict)
+
+ # Try to convert to a ctypes native typename
+ new_ctypes_typename = header_parsing.C_TO_CTYPES.get(
+ new_ctypes_typename, new_ctypes_typename)
+
+ if new_ctypes_typename == old_ctypes_typename:
+ logging.warn("Could not resolve typename '%s'", old_ctypes_typename)
+
+ return new_ctypes_typename
+
+ def get_type_from_token(self, token, parent=None):
+ """Accepts a token returned by a parser, returns a subclass of CDeclBase."""
+
+ comment = codegen_util.mangle_comment(token.comment)
+ is_const = token.is_const == "const"
+
+ # A new struct declaration
+ if token.members:
+
+ name = token.name
+
+ # If the name is empty, see if there is a type declaration that matches
+ # this struct's typename
+ if not name:
+ for k, v in six.iteritems(self.typedefs_dict):
+ if v == token.typename:
+ name = k
+
+ # Anonymous structs need a dummy typename
+ typename = token.typename
+ if not typename:
+ if parent:
+ typename = token.name
+ else:
+ raise Error(
+ "Anonymous structs that aren't members of a named struct are not "
+ "supported (name = '{token.name}').".format(token=token))
+
+ # Mangle the name if it contains any protected keywords
+ name = codegen_util.mangle_varname(name)
+
+ members = codegen_util.UniqueOrderedDict()
+ sub_structs = codegen_util.UniqueOrderedDict()
+ out = c_declarations.Struct(name, typename, members, sub_structs, comment,
+ parent, is_const)
+
+ # Map the old typename to the mangled typename in typedefs_dict
+ self.typedefs_dict[typename] = out.ctypes_typename
+
+ # Add members
+ for sub_token in token.members:
+
+ # Recurse into nested structs
+ member = self.get_type_from_token(sub_token, parent=out)
+ out.members[member.name] = member
+
+ # Nested sub-structures need special treatment
+ if isinstance(member, c_declarations.Struct):
+ out.sub_structs[member.name] = member
+
+ # Add to dict of structs
+ self.structs_dict[out.ctypes_typename] = out
+
+ else:
+
+ name = codegen_util.mangle_varname(token.name)
+ typename = self.resolve_typename(token.typename)
+
+ # 1D array with size defined at compile time
+ if token.size:
+ shape = self.get_shape_tuple(token.size)
+ if typename in header_parsing.CTYPES_TO_NUMPY:
+ out = c_declarations.StaticNDArray(name, typename, shape, comment,
+ parent, is_const)
+ else:
+ out = c_declarations.StaticPtrArray(name, typename, shape, comment,
+ parent, is_const)
+ elif token.ptr:
+
+ # Pointer to a numpy-compatible type, could be an array or a scalar
+ if typename in header_parsing.CTYPES_TO_NUMPY:
+
+ # Multidimensional array (one or more dimensions might be undefined)
+ if name in self.hints_dict:
+
+ # Dynamically-sized dimensions have string identifiers
+ shape = self.hints_dict[name]
+ if any(isinstance(d, str) for d in shape):
+ out = c_declarations.DynamicNDArray(name, typename, shape,
+ comment, parent, is_const)
+ else:
+ out = c_declarations.StaticNDArray(name, typename, shape, comment,
+ parent, is_const)
+
+ # This must be a pointer to a scalar primitive
+ else:
+ out = c_declarations.ScalarPrimitivePtr(name, typename, comment,
+ parent, is_const)
+
+ # Pointer to struct or other arbitrary type
+ else:
+ out = c_declarations.ScalarPrimitivePtr(name, typename, comment,
+ parent, is_const)
+
+ # A struct we've already encountered
+ elif typename in self.structs_dict:
+ s = self.structs_dict[typename]
+ out = c_declarations.Struct(name, s.typename, s.members, s.sub_structs,
+ comment, parent)
+
+ # Presumably this is a scalar primitive
+ else:
+ out = c_declarations.ScalarPrimitive(name, typename, comment, parent,
+ is_const)
+
+ return out
+
+ # Parsing functions.
+ # ----------------------------------------------------------------------------
+
+ def parse_hints(self, xmacro_src):
+ """Parses mjxmacro.h, update self.hints_dict."""
+ parser = header_parsing.XMACRO
+ for tokens, _, _ in parser.scanString(xmacro_src):
+ for xmacro in tokens:
+ for member in xmacro.members:
+ # "Squeeze out" singleton dimensions.
+ shape = self.get_shape_tuple(member.dims, squeeze=True)
+ self.hints_dict.update({member.name: shape})
+
+ if codegen_util.is_macro_pointer(xmacro.name):
+ struct_name = codegen_util.macro_struct_name(xmacro.name)
+ if struct_name not in self.index_dict:
+ self.index_dict[struct_name] = {}
+
+ self.index_dict[struct_name].update({member.name: shape})
+
+ def parse_enums(self, src):
+ """Parses mj*.h, update self.enums_dict."""
+ parser = header_parsing.ENUM_DECL
+ for tokens, _, _ in parser.scanString(src):
+ for enum in tokens:
+ members = codegen_util.UniqueOrderedDict()
+ value = 0
+ for member in enum.members:
+ # Leftward bitshift
+ if member.bit_lshift_a:
+ value = int(member.bit_lshift_a) << int(member.bit_lshift_b)
+ # Assignment
+ elif member.value:
+ value = int(member.value)
+ # Implicit count
+ else:
+ value += 1
+ members.update({member.name: value})
+ self.enums_dict.update({enum.name: members})
+
+ def parse_consts_typedefs(self, src):
+ """Updates self.consts_dict, self.typedefs_dict."""
+ parser = (header_parsing.COND_DECL | header_parsing.UNCOND_DECL)
+ for tokens, _, _ in parser.scanString(src):
+ self.recurse_into_conditionals(tokens)
+
+ def recurse_into_conditionals(self, tokens):
+ """Called recursively within nested #if(n)def... #else... #endif blocks."""
+ for token in tokens:
+ # Another nested conditional block
+ if token.predicate:
+ if (token.predicate in self.get_consts_and_enums()
+ and self.get_consts_and_enums()[token.predicate]):
+ self.recurse_into_conditionals(token.if_true)
+ else:
+ self.recurse_into_conditionals(token.if_false)
+ # One or more declarations
+ else:
+ if token.typename:
+ self.typedefs_dict.update({token.name: token.typename})
+ elif token.value:
+ value = codegen_util.try_coerce_to_num(token.value)
+ # Avoid adding function aliases.
+ if isinstance(value, str):
+ continue
+ else:
+ self.consts_dict.update({token.name: value})
+ else:
+ self.consts_dict.update({token.name: True})
+
+ def parse_structs(self, src):
+ """Updates self.structs_dict."""
+ parser = header_parsing.NESTED_STRUCTS
+ for tokens, _, _ in parser.scanString(src):
+ for token in tokens:
+ self.get_type_from_token(token)
+
+ def parse_functions(self, src):
+ """Updates self.funcs_dict."""
+ parser = header_parsing.MJAPI_FUNCTION_DECL
+ for tokens, _, _ in parser.scanString(src):
+ for token in tokens:
+ name = codegen_util.mangle_varname(token.name)
+ comment = codegen_util.mangle_comment(token.comment)
+ args = codegen_util.UniqueOrderedDict()
+ for arg in token.arguments:
+ a = self.get_type_from_token(arg)
+ args[a.name] = a
+ r = self.get_type_from_token(token.return_value)
+ f = c_declarations.Function(name, args, r, comment)
+ self.funcs_dict[f.name] = f
+
+ def parse_global_strings(self, src):
+ """Updates self.strings_dict."""
+ parser = header_parsing.MJAPI_STRING_ARRAY
+ for token, _, _ in parser.scanString(src):
+ name = codegen_util.mangle_varname(token.name)
+ shape = self.get_shape_tuple(token.dims)
+ self.strings_dict[name] = c_declarations.StaticStringArray(
+ name, shape, symbol_name=token.name)
+
+ def parse_function_pointers(self, src):
+ """Updates self.func_ptrs_dict."""
+ parser = header_parsing.MJAPI_FUNCTION_PTR
+ for token, _, _ in parser.scanString(src):
+ name = codegen_util.mangle_varname(token.name)
+ self.func_ptrs_dict[name] = c_declarations.FunctionPtr(
+ name, symbol_name=token.name)
+
+ # Code generation methods
+ # ----------------------------------------------------------------------------
+
+ def make_header(self, imports=()):
+ """Returns a header string for an auto-generated Python source file."""
+ docstring = textwrap.dedent("""
+ \"\"\"Automatically generated by {scriptname:}.
+
+ MuJoCo header version: {mujoco_version:}
+ \"\"\"
+ """.format(scriptname=os.path.split(__file__)[-1],
+ mujoco_version=self.consts_dict["mjVERSION_HEADER"]))
+ docstring = docstring[1:] # Strip the leading line break.
+ return "\n".join(
+ [docstring] + _BOILERPLATE_IMPORTS + list(imports) + ["\n"])
+
+ def write_consts(self, fname):
+ imports = [
+ "# pylint: disable=invalid-name",
+ ]
+ with open(fname, "w") as f:
+ f.write(self.make_header(imports))
+ f.write(codegen_util.comment_line("Constants") + "\n")
+ for name, value in six.iteritems(self.consts_dict):
+ f.write("{0} = {1}\n".format(name, value))
+ f.write("\n" + codegen_util.comment_line("End of generated code"))
+
+ def write_enums(self, fname):
+ with open(fname, "w") as f:
+ imports = [
+ "import collections",
+ "# pylint: disable=invalid-name",
+ "# pylint: disable=line-too-long",
+ ]
+ f.write(self.make_header(imports))
+ f.write(codegen_util.comment_line("Enums"))
+ for enum_name, members in six.iteritems(self.enums_dict):
+ fields = ["\"{}\"".format(name) for name in six.iterkeys(members)]
+ values = [str(value) for value in six.itervalues(members)]
+ s = textwrap.dedent("""
+ {0} = collections.namedtuple(
+ "{0}",
+ [{1}]
+ )({2})
+ """).format(enum_name, ",\n ".join(fields), ", ".join(values))
+ f.write(s)
+ f.write("\n" + codegen_util.comment_line("End of generated code"))
+
+ def write_types(self, fname):
+ imports = [
+ "import ctypes",
+ ]
+ with open(fname, "w") as f:
+ f.write(self.make_header(imports))
+ f.write(codegen_util.comment_line("ctypes struct declarations"))
+ for struct in six.itervalues(self.structs_dict):
+ f.write("\n" + struct.ctypes_struct_decl)
+ f.write("\n" + codegen_util.comment_line("End of generated code"))
+
+ def write_wrappers(self, fname):
+ with open(fname, "w") as f:
+ imports = [
+ "import ctypes",
+ "# Internal dependencies.",
+ "# pylint: disable=undefined-variable",
+ "# pylint: disable=wildcard-import",
+ "from {} import util".format(_MODULE),
+ "from {}.mjbindings.types import *".format(_MODULE),
+ ]
+ f.write(self.make_header(imports))
+ f.write(codegen_util.comment_line("Low-level wrapper classes"))
+ for struct in six.itervalues(self.structs_dict):
+ f.write("\n" + struct.wrapper_class)
+ f.write("\n" + codegen_util.comment_line("End of generated code"))
+
+ def write_funcs_and_globals(self, fname):
+ """Write ctypes declarations for functions and global data."""
+ imports = [
+ "import collections",
+ "import ctypes",
+ "# Internal dependencies.",
+ "# pylint: disable=undefined-variable",
+ "# pylint: disable=wildcard-import",
+ "from {} import util".format(_MODULE),
+ "from {}.mjbindings.types import *".format(_MODULE),
+ "import numpy as np",
+ "# pylint: disable=line-too-long",
+ "# pylint: disable=invalid-name",
+ "# common_typos_disable",
+ ]
+ with open(fname, "w") as f:
+ f.write(self.make_header(imports))
+ f.write("mjlib = util.get_mjlib()\n")
+
+ f.write("\n" + codegen_util.comment_line("ctypes function declarations"))
+ for function in six.itervalues(self.funcs_dict):
+ f.write("\n" + function.ctypes_func_decl(cdll_name="mjlib"))
+
+ # Only require strings for UI purposes.
+ f.write("\n" + codegen_util.comment_line("String arrays") + "\n")
+ for string_arr in six.itervalues(self.strings_dict):
+ f.write(string_arr.ctypes_var_decl(cdll_name="mjlib"))
+
+ f.write("\n" + codegen_util.comment_line("Function pointers"))
+
+ fields = [repr(name) for name in self.func_ptrs_dict.keys()]
+ values = [func_ptr.ctypes_var_decl(cdll_name="mjlib")
+ for func_ptr in self.func_ptrs_dict.values()]
+ f.write(textwrap.dedent("""
+ function_pointers = collections.namedtuple(
+ 'FunctionPointers',
+ [{0}]
+ )({1})
+ """).format(",\n ".join(fields), ",\n ".join(values)))
+
+ f.write("\n" + codegen_util.comment_line("End of generated code"))
+
+ def write_index_dict(self, fname):
+ pp = pprint.PrettyPrinter()
+ output_string = pp.pformat(dict(self.index_dict))
+ indent = codegen_util.Indenter()
+ imports = [
+ "# pylint: disable=bad-continuation",
+ "# pylint: disable=line-too-long",
+ ]
+ with open(fname, "w") as f:
+ f.write(self.make_header(imports))
+ f.write("array_sizes = (\n")
+ with indent:
+ f.write(output_string)
+ f.write("\n)")
+ f.write("\n" + codegen_util.comment_line("End of generated code"))
diff --git a/dm_control/autowrap/c_declarations.py b/dm_control/autowrap/c_declarations.py
new file mode 100644
index 00000000..9bdd3eb0
--- /dev/null
+++ b/dm_control/autowrap/c_declarations.py
@@ -0,0 +1,425 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Python representations of C declarations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import textwrap
+
+# Internal dependencies.
+
+from dm_control.autowrap import codegen_util
+from dm_control.autowrap import header_parsing
+
+import six
+
+
+class CDeclBase(object):
+ """Base class for Python representations of C declarations."""
+
+ def __init__(self, **attrs):
+ self._attrs = attrs
+ for k, v in six.iteritems(attrs):
+ setattr(self, k, v)
+
+ def __repr__(self):
+ """Pretty string representation."""
+ attr_str = ", ".join("{0}={1!r}".format(k, v)
+ for k, v in six.iteritems(self._attrs))
+ return "{0}({1})".format(type(self).__name__, attr_str)
+
+ @property
+ def docstring(self):
+ """Auto-generate a docstring for self."""
+ return "\n".join(textwrap.wrap(self.comment, 74))
+
+ @property
+ def ctypes_typename(self):
+ """ctypes typename."""
+ return self.typename
+
+ @property
+ def ctypes_ptr(self):
+ """String representation of self as a ctypes pointer."""
+ return header_parsing.CTYPES_PTRS.get(
+ self.ctypes_typename, "ctypes.POINTER({})".format(self.ctypes_typename))
+
+ @property
+ def np_dtype(self):
+ """Get a numpy dtype name for self, fall back on self.ctypes_typename."""
+ return header_parsing.CTYPES_TO_NUMPY.get(self.ctypes_typename,
+ self.ctypes_typename)
+
+ @property
+ def np_flags(self):
+ """Tuple of strings specifying numpy.ndarray flags."""
+ return ("C", "W")
+
+
+class Struct(CDeclBase):
+ """C struct declaration."""
+
+ def __init__(self, name, typename, members, sub_structs, comment="",
+ parent=None, is_const=None):
+ super(Struct, self).__init__(name=name,
+ typename=typename,
+ members=members,
+ sub_structs=sub_structs,
+ comment=comment,
+ parent=parent,
+ is_const=is_const)
+
+ @property
+ def ctypes_struct_decl(self):
+ """Generates a ctypes.Structure declaration for self."""
+ indent = codegen_util.Indenter()
+ s = textwrap.dedent("""
+ class {0.ctypes_typename:}(ctypes.Structure):
+ \"\"\"{0.docstring:}\"\"\"
+ """.format(self))
+ with indent:
+ if self.members:
+ s += indent("\n_fields_ = [\n")
+ with indent:
+ with indent:
+ s += ",\n".join(indent(m.ctypes_field_decl)
+ for m in six.itervalues(self.members))
+ s += indent("\n]\n")
+ return s
+
+ @property
+ def ctypes_typename(self):
+ """Mangles ctypes.Structure typenames to distinguish them from wrappers."""
+ return codegen_util.mangle_struct_typename(self.typename)
+
+ @property
+ def ctypes_field_decl(self):
+ """Generates a declaration for self as a field of a ctypes.Structure."""
+ return "('{0.name:}', {0.ctypes_typename:})".format(self) # pylint: disable=missing-format-attribute
+
+ @property
+ def wrapper_name(self):
+ return codegen_util.camel_case(self.typename) + "Wrapper"
+
+ @property
+ def wrapper_class(self):
+ """Generates a Python class containing getter/setter methods for members."""
+ indent = codegen_util.Indenter()
+ s = textwrap.dedent("""
+ class {0.wrapper_name}(util.WrapperBase):
+ \"\"\"{0.docstring:}\"\"\"
+ """.format(self))
+ with indent:
+ s += "".join(indent(m.getters_setters)
+ for m in six.itervalues(self.members))
+ return s
+
+ @property
+ def getters_setters(self):
+ """Populates a Python class with getter & setter methods for self."""
+ return textwrap.dedent("""
+ @util.CachedProperty
+ def {0.name:}(self):
+ \"\"\"{0.docstring:}\"\"\"
+ return {0.wrapper_name}(ctypes.pointer(self._ptr.contents.{0.name}))
+ """.format(self)) # pylint: disable=missing-format-attribute
+
+ @property
+ def arg(self):
+ """String representation of self as a ctypes function argument."""
+ return self.ctypes_typename
+
+
+class ScalarPrimitive(CDeclBase):
+ """A scalar value corresponding to a C primitive type."""
+
+ def __init__(self, name, typename, comment="", parent=None, is_const=None):
+ super(ScalarPrimitive, self).__init__(name=name,
+ typename=typename,
+ comment=comment,
+ parent=parent,
+ is_const=is_const)
+
+ @property
+ def ctypes_field_decl(self):
+ """Generates a declaration for self as a field of a ctypes.Structure."""
+ return "('{0.name:}', {0.ctypes_typename:})".format(self) # pylint: disable=missing-format-attribute
+
+ @property
+ def getters_setters(self):
+ """Populates a Python class with getter & setter methods for self."""
+ return textwrap.dedent("""
+ @property
+ def {0.name:}(self):
+ \"\"\"{0.docstring:}\"\"\"
+ return self._ptr.contents.{0.name:}
+
+ @{0.name:}.setter
+ def {0.name:}(self, value):
+ self._ptr.contents.{0.name:} = value
+ """.format(self)) # pylint: disable=missing-format-attribute
+
+ @property
+ def arg(self):
+ """String representation of self as a ctypes function argument."""
+ return self.ctypes_typename
+
+
+class ScalarPrimitivePtr(CDeclBase):
+ """Pointer to a ScalarPrimitive."""
+
+ def __init__(self, name, typename, comment="", parent=None, is_const=None):
+ super(ScalarPrimitivePtr, self).__init__(name=name,
+ typename=typename,
+ comment=comment,
+ parent=parent,
+ is_const=is_const)
+
+ @property
+ def ctypes_field_decl(self):
+ """Generates a declaration for self as a field of a ctypes.Structure."""
+ return "('{0.name:}', {0.ctypes_ptr:})".format(self) # pylint: disable=missing-format-attribute
+
+ @property
+ def getters_setters(self):
+ """Populates a Python class with getter & setter methods for self."""
+ return textwrap.dedent("""
+ @property
+ def {0.name:}(self):
+ \"\"\"{0.docstring:}\"\"\"
+ return self._ptr.contents.{0.name:}
+
+ @{0.name:}.setter
+ def {0.name:}(self, value):
+ self._ptr.contents.{0.name:} = value
+ """.format(self)) # pylint: disable=missing-format-attribute
+
+ @property
+ def arg(self):
+ """Generates string representation of self as a ctypes function argument."""
+ # we assume that every pointer that maps to a numpy dtype corresponds to an
+ # array argument/return value
+ if self.ctypes_typename in header_parsing.CTYPES_TO_NUMPY:
+ return ("util.ndptr(dtype={0.np_dtype}, flags={0.np_flags!s:})"
+ "".format(self)) # pylint: disable=missing-format-attribute
+ else:
+ return self.ctypes_ptr
+
+
+class StaticPtrArray(CDeclBase):
+ """Array of arbitrary pointers whose size can be inferred from the headers."""
+
+ def __init__(self, name, typename, shape, comment="", parent=None,
+ is_const=None):
+ super(StaticPtrArray, self).__init__(name=name,
+ typename=typename,
+ shape=shape,
+ comment=comment,
+ parent=parent,
+ is_const=is_const)
+
+ @property
+ def ctypes_field_decl(self):
+ """Generates a declaration for self as a field of a ctypes.Structure."""
+ if self.typename in header_parsing.CTYPES_PTRS:
+ return "('{0.name:}', {0.ctypes_ptr:} * {1:})".format( # pylint: disable=missing-format-attribute
+ self, " * ".join(str(d) for d in self.shape))
+ else:
+ return "('{0.name:}', {0.ctypes_typename:} * {1:})".format( # pylint: disable=missing-format-attribute
+ self, " * ".join(str(d) for d in self.shape))
+
+ @property
+ def getters_setters(self):
+ """Populates a Python class with getter & setter methods for self."""
+ return textwrap.dedent("""
+ @property
+ def {0.name:}(self):
+ \"\"\"{0.docstring:}\"\"\"
+ return self._ptr.contents.{0.name:}
+ """.format(self)) # pylint: disable=missing-format-attribute
+
+ @property
+ def arg(self):
+ """Generates string representation of self as a ctypes function argument."""
+ return "{0.ctypes_typename:}".format(self)
+
+
+class StaticNDArray(CDeclBase):
+ """Numeric array whose dimensions can all be inferred from the headers."""
+
+ def __init__(self, name, typename, shape, comment="", parent=None,
+ is_const=None):
+ super(StaticNDArray, self).__init__(name=name,
+ typename=typename,
+ shape=shape,
+ comment=comment,
+ parent=parent,
+ is_const=is_const)
+
+ @property
+ def ctypes_field_decl(self):
+ """Generates a declaration for self as a field of a ctypes.Structure."""
+ return "('{0.name:}', {0.ctypes_typename:} * ({1:}))".format( # pylint: disable=missing-format-attribute
+ self, " * ".join(str(d) for d in self.shape))
+
+ @property
+ def getters_setters(self):
+ """Populates a Python class with a getter method for self (no setter)."""
+ return textwrap.dedent("""
+ @util.CachedProperty
+ def {0.name:}(self):
+ \"\"\"{0.docstring:}\"\"\"
+ return util.buf_to_npy(self._ptr.contents.{0.name:}, {0.shape!s:})
+ """.format(self)) # pylint: disable=missing-format-attribute
+
+ @property
+ def arg(self):
+ """Generates string representation of self as a ctypes function argument."""
+ return ("util.ndptr(shape={0.shape}, dtype={0.np_dtype}, " # pylint: disable=missing-format-attribute
+ "flags={0.np_flags!s})".format(self))
+
+
+class DynamicNDArray(CDeclBase):
+ """Numeric array where one or more dimensions are determined at runtime."""
+
+ def __init__(self, name, typename, shape, comment="", parent=None,
+ is_const=None):
+ super(DynamicNDArray, self).__init__(name=name,
+ typename=typename,
+ shape=shape,
+ comment=comment,
+ parent=parent,
+ is_const=is_const)
+
+ @property
+ def runtime_shape_str(self):
+ """String representation of shape tuple at runtime."""
+ rs = []
+ for d in self.shape:
+ # dynamically-sized dimension
+ if isinstance(d, str):
+ if self.parent and d in self.parent.members:
+ rs.append("self.{}".format(d))
+ else:
+ rs.append("self._model.{}".format(d))
+ # static dimension
+ else:
+ rs.append(str(d))
+ return str(tuple(rs)).replace("'", "") # strip quotes from string rep
+
+ @property
+ def ctypes_field_decl(self):
+ """Generates a declaration for self as a field of a ctypes.Structure."""
+ return "('{0.name:}', {0.ctypes_ptr})".format(self) # pylint: disable=missing-format-attribute
+
+ @property
+ def getters_setters(self):
+ """Populates a Python class with a getter method for self (no setter)."""
+ return textwrap.dedent("""
+ @util.CachedProperty
+ def {0.name:}(self):
+ \"\"\"{0.docstring:}\"\"\"
+ return util.buf_to_npy(self._ptr.contents.{0.name:},
+ {0.runtime_shape_str:})
+ """.format(self)) # pylint: disable=missing-format-attribute
+
+ @property
+ def arg(self):
+ """Generates string representation of self as a ctypes function argument."""
+ return ("util.ndptr(dtype={0.np_dtype}, flags={0.np_flags!s:})"
+ "".format(self)) # pylint: disable=missing-format-attribute
+
+
+class Function(CDeclBase):
+ """A function declaration including input type(s) and return type."""
+
+ def __init__(self, name, arguments, return_value, comment=""):
+ super(Function, self).__init__(name=name,
+ arguments=arguments,
+ return_value=return_value,
+ comment=comment)
+
+ def ctypes_func_decl(self, cdll_name):
+ """Generates a ctypes function declaration."""
+ indent = codegen_util.Indenter()
+ # triple-quoted docstring
+ s = ("{0:}.{1.name:}.__doc__ = \"\"\"\n{1.docstring:}\"\"\"\n" # pylint: disable=missing-format-attribute
+ ).format(cdll_name, self)
+ # arguments
+ s += "{0:}.{1.name:}.argtypes = [".format(cdll_name, self) # pylint: disable=missing-format-attribute
+ if len(self.arguments) > 1:
+ s += "\n"
+ with indent:
+ with indent:
+ s += ",\n".join(indent(a.arg) for a in six.itervalues(self.arguments))
+ s += "\n"
+ else:
+ s += ", ".join(indent(a.arg) for a in six.itervalues(self.arguments))
+ s += "]\n"
+ # return value
+ s += "{0:}.{1.name:}.restype = {2:}\n".format( # pylint: disable=missing-format-attribute
+ cdll_name, self, self.return_value.arg)
+ return s
+
+ @property
+ def docstring(self):
+ """Generates a docstring."""
+ indent = codegen_util.Indenter()
+ s = "\n".join(textwrap.wrap(self.comment, 80)) + "\n\nArgs:\n"
+ with indent:
+ for a in six.itervalues(self.arguments):
+ s += indent("{a.name:}: {a.arg:}{const:}\n".format(
+ a=a, const=(" " if a.is_const else "")))
+ s += "Returns:\n"
+ with indent:
+ s += indent("{0.return_value.arg}{1:}\n".format( # pylint: disable=missing-format-attribute
+ self, (" " if self.return_value.is_const else "")))
+ return s
+
+
+class StaticStringArray(CDeclBase):
+ """A string array of fixed dimensions exported by MuJoCo."""
+
+ def __init__(self, name, shape, symbol_name):
+ super(StaticStringArray, self).__init__(name=name,
+ shape=shape,
+ symbol_name=symbol_name)
+
+ def ctypes_var_decl(self, cdll_name=""):
+ """Generates a ctypes export statement."""
+
+ ptr_str = "ctypes.c_char_p"
+ for dim in self.shape[::-1]:
+ ptr_str = "({0} * {1!s})".format(ptr_str, dim)
+
+ return "{0} = {1}.in_dll({2}, {3!r})\n".format(
+ self.name, ptr_str, cdll_name, self.symbol_name)
+
+
+class FunctionPtr(CDeclBase):
+ """A pointer to an externally defined C function."""
+
+ def __init__(self, name, symbol_name, type_name=None):
+ super(FunctionPtr, self).__init__(
+ name=name, symbol_name=symbol_name, type_name=type_name)
+
+ def ctypes_var_decl(self, cdll_name=""):
+ """Generates a ctypes export statement."""
+
+ return "ctypes.c_void_p.in_dll({0}, {1!r})".format(
+ cdll_name, self.symbol_name)
diff --git a/dm_control/autowrap/codegen_util.py b/dm_control/autowrap/codegen_util.py
new file mode 100644
index 00000000..5328319c
--- /dev/null
+++ b/dm_control/autowrap/codegen_util.py
@@ -0,0 +1,152 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Misc helper functions needed by autowrap.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import keyword
+import re
+
+# Internal dependencies.
+import six
+from six.moves import builtins
+
+_MJXMACRO_SUFFIX = "_POINTERS"
+_PYTHON_RESERVED_KEYWORDS = set(keyword.kwlist + dir(builtins))
+if not six.PY2:
+ _PYTHON_RESERVED_KEYWORDS.add("buffer")
+
+
+class Indenter(object):
+ r"""Callable context manager for tracking string indentation levels.
+
+ Args:
+ level: The initial indentation level.
+ indent_str: The string used to indent each line.
+
+ Example usage:
+
+ ```python
+ idt = Indenter()
+ s = idt("level 0\n")
+ with idt:
+ s += idt("level 1\n")
+ with idt:
+ s += idt("level 2\n")
+ s += idt("level 1 again\n")
+ s += idt("back to level 0\n")
+ print(s)
+ ```
+ """
+
+ def __init__(self, level=0, indent_str=" "):
+ self.indent_str = indent_str
+ self.level = level
+
+ def __enter__(self):
+ self.level += 1
+ return self
+
+ def __exit__(self, type_, value, traceback):
+ self.level -= 1
+
+ def __call__(self, string):
+ return indent(string, self.level, self.indent_str)
+
+
+def indent(s, n=1, indent_str=" "):
+ """Inserts `n * indent_str` at the start of each non-empty line in `s`."""
+ p = n * indent_str
+ return "".join((p + l) if l.lstrip() else l for l in s.splitlines(True))
+
+
+class UniqueOrderedDict(collections.OrderedDict):
+ """Subclass of `OrderedDict` that enforces the uniqueness of keys."""
+
+ def __setitem__(self, k, v):
+ if k in self:
+ raise ValueError("Key '{}' already exists.".format(k))
+ super(UniqueOrderedDict, self).__setitem__(k, v)
+
+
+def macro_struct_name(name, suffix=None):
+ """Converts mjxmacro struct names, e.g. "MJDATA_POINTERS" to "mjdata"."""
+ if suffix is None:
+ suffix = _MJXMACRO_SUFFIX
+ return name[:-len(suffix)].lower()
+
+
+def is_macro_pointer(name):
+ """Returns True if the mjxmacro struct name contains pointer sizes."""
+ return name.endswith(_MJXMACRO_SUFFIX)
+
+
+def mangle_varname(s):
+ """Append underscores to ensure that `s` is not a reserved Python keyword."""
+ while s in _PYTHON_RESERVED_KEYWORDS:
+ s += "_"
+ return s
+
+
+def mangle_struct_typename(s):
+ """Strip leading underscores and make uppercase."""
+ return s.lstrip("_").upper()
+
+
+def mangle_comment(s):
+ """Strip extraneous whitespace, add full-stops at end of each line."""
+ if not isinstance(s, str):
+ return "\n".join(mangle_comment(line) for line in s)
+ elif not s:
+ return "."
+ else:
+ return ".\n".join(" ".join(line.split()) for line in s.splitlines()) + "."
+
+
+def camel_case(s):
+ """Convert a snake_case string (maybe with lowerCaseFirst) to CamelCase."""
+ tokens = re.sub(r"([A-Z])", r" \1", s.replace("_", " ")).split()
+ return "".join(w.title() for w in tokens)
+
+
+def try_coerce_to_num(s, try_types=(int, float)):
+ """Try to coerce string to Python numeric type, return None if empty."""
+ if not s:
+ return None
+ for try_type in try_types:
+ try:
+ return try_type(s.rstrip("UuFf"))
+ except (ValueError, AttributeError):
+ continue
+ return s
+
+
+def recursive_dict_lookup(key, try_dict, max_depth=10):
+ """Recursively map dictionary keys to values."""
+ if max_depth < 0:
+ raise KeyError("Maximum recursion depth exceeded")
+ while key in try_dict:
+ key = try_dict[key]
+ return recursive_dict_lookup(key, try_dict, max_depth - 1)
+ return key
+
+
+def comment_line(string, width=79, fill_char="-"):
+ """Wraps `string` in a padded comment line."""
+ return "# {0:{2}^{1}}\n".format(string, width - 2, fill_char)
diff --git a/dm_control/autowrap/header_parsing.py b/dm_control/autowrap/header_parsing.py
new file mode 100644
index 00000000..4851fbc1
--- /dev/null
+++ b/dm_control/autowrap/header_parsing.py
@@ -0,0 +1,280 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""pyparsing definitions and helper functions for parsing MuJoCo headers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+import pyparsing as pp
+import six
+
+# NB: Don't enable parser memoization (`pp.ParserElement.enablePackrat()`),
+# since this results in a ~6x slowdown.
+
+C_TO_CTYPES = {
+ # integers
+ "int": "ctypes.c_int",
+ "unsigned int": "ctypes.c_uint",
+ "char": "ctypes.c_char",
+ "unsigned char": "ctypes.c_ubyte",
+ "size_t": "ctypes.c_size_t",
+ # floats
+ "float": "ctypes.c_float",
+ "double": "ctypes.c_double",
+ # pointers
+ "void": "None",
+}
+
+CTYPES_PTRS = {"None": "ctypes.c_void_p",}
+
+CTYPES_TO_NUMPY = {
+ # integers
+ "ctypes.c_int": "np.intc",
+ "ctypes.c_uint": "np.uintc",
+ "ctypes.c_ubyte": "np.ubyte",
+ # floats
+ "ctypes.c_float": "np.float32",
+ "ctypes.c_double": "np.float64",
+}
+
+# Helper functions for constructing recursive parsers.
+# ------------------------------------------------------------------------------
+
+
+def _nested_scopes(opening, closing, body):
+ """Constructs a parser for (possibly nested) scopes."""
+ scope = pp.Forward()
+ scope << pp.Group( # pylint: disable=expression-not-assigned
+ opening +
+ pp.ZeroOrMore(body | scope)("members") +
+ closing)
+ return scope
+
+
+def _nested_if_else(if_, pred, else_, endif, match_if_true, match_if_false):
+ """Constructs a parser for (possibly nested) if...(else)...endif blocks."""
+ ifelse = pp.Forward()
+ ifelse << pp.Group( # pylint: disable=expression-not-assigned
+ if_ +
+ pred("predicate") +
+ pp.ZeroOrMore(match_if_true | ifelse)("if_true") +
+ pp.Optional(else_ +
+ pp.ZeroOrMore(match_if_false | ifelse)("if_false")) +
+ endif)
+ return ifelse
+
+
+# Some common string patterns to suppress.
+# ------------------------------------------------------------------------------
+(X, LPAREN, RPAREN, LBRACK, RBRACK, LBRACE, RBRACE, SEMI, COMMA, EQUAL, FSLASH,
+ BSLASH) = map(pp.Suppress, "X()[]{};,=/\\")
+EOL = pp.LineEnd().suppress()
+
+# Comments, continuation.
+# ------------------------------------------------------------------------------
+COMMENT = pp.Combine(
+ pp.Suppress("//") +
+ pp.Optional(pp.White()).suppress() +
+ pp.SkipTo(pp.LineEnd()))
+
+MULTILINE_COMMENT = pp.delimitedList(
+ COMMENT.copy().setWhitespaceChars(" \t"), delim=EOL)
+
+CONTINUATION = (BSLASH + pp.LineEnd()).suppress()
+
+# Preprocessor directives.
+# ------------------------------------------------------------------------------
+DEFINE = pp.Keyword("#define").suppress()
+IFDEF = pp.Keyword("#ifdef").suppress()
+IFNDEF = pp.Keyword("#ifndef").suppress()
+ELSE = pp.Keyword("#else").suppress()
+ENDIF = pp.Keyword("#endif").suppress()
+
+# Variable names, types, literals etc.
+# ------------------------------------------------------------------------------
+NAME = pp.Word(pp.alphanums + "_")
+INT = pp.Word(pp.nums + "UuLl")
+FLOAT = pp.Word(pp.nums + ".+-EeFf")
+NUMBER = FLOAT | INT
+
+# Dimensions can be of the form `[3]`, `[constant_name]` or `[2*constant_name]`
+ARRAY_DIM = pp.Combine(
+ LBRACK +
+ (INT | NAME) +
+ pp.Optional(pp.Literal("*")) +
+ pp.Optional(INT | NAME) +
+ RBRACK)
+
+PTR = pp.Literal("*")
+EXTERN = pp.Keyword("extern")
+NATIVE_TYPENAME = pp.MatchFirst(
+ [pp.Keyword(n) for n in six.iterkeys(C_TO_CTYPES)])
+
+# Macros.
+# ------------------------------------------------------------------------------
+
+HDR_GUARD = DEFINE + "THIRD_PARTY_MUJOCO_HDRS_"
+
+# e.g. "#define mjUSEDOUBLE"
+DEF_FLAG = pp.Group(
+ DEFINE +
+ NAME("name") +
+ (COMMENT("comment") | EOL)).ignore(HDR_GUARD)
+
+# e.g. "#define mjMINVAL 1E-14 // minimum value in any denominator"
+DEF_CONST = pp.Group(
+ DEFINE +
+ NAME("name") +
+ (NUMBER | NAME)("value") +
+ (COMMENT("comment") | EOL))
+
+# e.g. "X( mjtNum*, name_textadr, ntext, 1 )"
+XMEMBER = pp.Group(
+ X +
+ LPAREN +
+ (NATIVE_TYPENAME | NAME)("typename") +
+ pp.Optional(PTR("ptr")) +
+ COMMA +
+ NAME("name") +
+ COMMA +
+ pp.delimitedList((INT | NAME), delim=COMMA)("dims") +
+ RPAREN)
+
+XMACRO = pp.Group(
+ pp.Optional(COMMENT("comment")) +
+ DEFINE +
+ NAME("name") +
+ CONTINUATION +
+ pp.delimitedList(XMEMBER, delim=CONTINUATION)("members"))
+
+
+# Type/variable declarations.
+# ------------------------------------------------------------------------------
+TYPEDEF = pp.Keyword("typedef").suppress()
+STRUCT = pp.Keyword("struct").suppress()
+ENUM = pp.Keyword("enum").suppress()
+
+# e.g. "typedef unsigned char mjtByte; // used for true/false"
+TYPE_DECL = pp.Group(
+ TYPEDEF +
+ pp.Optional(STRUCT) +
+ (NATIVE_TYPENAME | NAME)("typename") +
+ pp.Optional(PTR("ptr")) +
+ NAME("name") +
+ SEMI +
+ pp.Optional(COMMENT("comment")))
+
+# Declarations of flags/constants/types.
+UNCOND_DECL = DEF_FLAG | DEF_CONST | TYPE_DECL
+
+# Declarations inside (possibly nested) #if(n)def... #else... #endif... blocks.
+COND_DECL = _nested_if_else(IFDEF, NAME, ELSE, ENDIF, UNCOND_DECL, UNCOND_DECL)
+# Note: this doesn't work for '#if defined(FLAG)' blocks
+
+# e.g. "mjtNum gravity[3]; // gravitational acceleration"
+STRUCT_MEMBER = pp.Group(
+ (NATIVE_TYPENAME | NAME)("typename") +
+ pp.Optional(PTR("ptr")) +
+ NAME("name") +
+ pp.ZeroOrMore(ARRAY_DIM)("size") +
+ SEMI +
+ pp.Optional(COMMENT("comment")))
+
+STRUCT_DECL = pp.Group(
+ STRUCT +
+ pp.Optional(NAME("typename")) +
+ pp.Optional(COMMENT("comment")) +
+ LBRACE +
+ pp.OneOrMore(STRUCT_MEMBER)("members") +
+ RBRACE +
+ pp.Optional(NAME("name")) +
+ SEMI)
+
+# Multiple (possibly nested) struct declarations.
+NESTED_STRUCTS = _nested_scopes(
+ opening=(STRUCT +
+ pp.Optional(NAME("typename")) +
+ pp.Optional(COMMENT("comment")) +
+ LBRACE),
+ closing=(RBRACE + pp.Optional(NAME("name")) + SEMI),
+ body=pp.OneOrMore(
+ STRUCT_MEMBER | STRUCT_DECL | COMMENT.suppress())("members"))
+
+BIT_LSHIFT = INT("bit_lshift_a") + pp.Suppress("<<") + INT("bit_lshift_b")
+
+ENUM_LINE = pp.Group(
+ NAME("name") +
+ pp.Optional(EQUAL + (INT("value") ^ BIT_LSHIFT)) +
+ pp.Optional(COMMA) +
+ pp.Optional(COMMENT("comment")))
+
+ENUM_DECL = pp.Group(
+ TYPEDEF +
+ ENUM +
+ NAME("typename") +
+ pp.Optional(COMMENT("comment")) +
+ LBRACE +
+ pp.OneOrMore(ENUM_LINE | COMMENT.suppress())("members") +
+ RBRACE +
+ pp.Optional(NAME("name")) +
+ SEMI)
+
+# Function declarations.
+# ------------------------------------------------------------------------------
+MJAPI = pp.Keyword("MJAPI")
+CONST = pp.Keyword("const")
+
+ARG = pp.Group(
+ pp.Optional(CONST("is_const")) +
+ (NATIVE_TYPENAME | NAME)("typename") +
+ pp.Optional(PTR("ptr")) +
+ NAME("name") +
+ pp.Optional(ARRAY_DIM("size")))
+
+RET = pp.Group(
+ pp.Optional(CONST("is_const")) +
+ (NATIVE_TYPENAME | NAME)("typename") +
+ pp.Optional(PTR("ptr")))
+
+FUNCTION_DECL = (
+ RET("return_value") +
+ NAME("name") +
+ LPAREN +
+ pp.delimitedList(ARG, delim=COMMA)("arguments") +
+ RPAREN +
+ SEMI)
+
+MJAPI_FUNCTION_DECL = pp.Group(
+ pp.Optional(MULTILINE_COMMENT("comment")) +
+ MJAPI +
+ FUNCTION_DECL)
+
+# Global variables.
+# ------------------------------------------------------------------------------
+
+MJAPI_STRING_ARRAY = (
+ MJAPI +
+ EXTERN +
+ CONST +
+ pp.Keyword("char") +
+ PTR +
+ NAME("name") +
+ pp.OneOrMore(ARRAY_DIM)("dims") +
+ SEMI)
+
+MJAPI_FUNCTION_PTR = MJAPI + EXTERN + NAME("typename") + NAME("name") + SEMI
diff --git a/dm_control/mujoco/__init__.py b/dm_control/mujoco/__init__.py
new file mode 100644
index 00000000..442ba6fd
--- /dev/null
+++ b/dm_control/mujoco/__init__.py
@@ -0,0 +1,23 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Mujoco implementations of base classes."""
+
+from dm_control.mujoco.engine import action_spec
+
+from dm_control.mujoco.engine import Camera
+from dm_control.mujoco.engine import MovableCamera
+from dm_control.mujoco.engine import Physics
+from dm_control.mujoco.engine import TextOverlay
diff --git a/dm_control/mujoco/engine.py b/dm_control/mujoco/engine.py
new file mode 100644
index 00000000..946fe167
--- /dev/null
+++ b/dm_control/mujoco/engine.py
@@ -0,0 +1,757 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Mujoco `Physics` implementation and helper classes.
+
+The `Physics` class provides the main Python interface to MuJoCo.
+
+MuJoCo models are defined using the MJCF XML format. The `Physics` class
+can load a model from a path to an XML file, an XML string, or from a serialized
+MJB binary format. See the named constructors for each of these cases.
+
+Each `Physics` instance defines a simulated world. To step forward the
+simulation, use the `step` method. To set a control or actuation signal, use the
+`set_control` method, which will apply the provided signal to the actuators in
+subsequent calls to `step`.
+
+Use the `Camera` class to create RGB or depth images. A `Camera` can render its
+viewport to an array using the `render` method, and can query for objects
+visible at specific positions using the `select` method. The `Physics` class
+also provides a `render` method that returns a pixel array directly.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+# Internal dependencies.
+
+from dm_control import render
+from dm_control.mujoco import index
+
+from dm_control.mujoco import wrapper
+from dm_control.mujoco.wrapper import util
+from dm_control.mujoco.wrapper.mjbindings import enums
+from dm_control.mujoco.wrapper.mjbindings import mjlib
+from dm_control.mujoco.wrapper.mjbindings import types
+
+from dm_control.rl import control as _control
+
+import numpy as np
+import six
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+from dm_control.rl import specs
+
+_FONT_SCALE = 150
+_MAX_WIDTH = 1024
+_MAX_HEIGHT = 1024
+
+_FONT_STYLES = {
+ 'normal': enums.mjtFont.mjFONT_NORMAL,
+ 'shadow': enums.mjtFont.mjFONT_SHADOW,
+ 'big': enums.mjtFont.mjFONT_BIG,
+}
+_GRID_POSITIONS = {
+ 'top left': enums.mjtGridPos.mjGRID_TOPLEFT,
+ 'top right': enums.mjtGridPos.mjGRID_TOPRIGHT,
+ 'bottom left': enums.mjtGridPos.mjGRID_BOTTOMLEFT,
+ 'bottom right': enums.mjtGridPos.mjGRID_BOTTOMRIGHT,
+}
+
+_DIVERGENCE_WARNINGS = [
+ enums.mjtWarning.mjWARN_INERTIA,
+ enums.mjtWarning.mjWARN_BADQPOS,
+ enums.mjtWarning.mjWARN_BADQVEL,
+ enums.mjtWarning.mjWARN_BADQACC,
+ enums.mjtWarning.mjWARN_BADCTRL,
+]
+
+Contexts = collections.namedtuple('Contexts', ['gl', 'mujoco'])
+Selected = collections.namedtuple(
+ 'Selected', ['body', 'geom', 'world_position'])
+NamedIndexStructs = collections.namedtuple(
+ 'NamedIndexStructs', ['model', 'data'])
+Pose = collections.namedtuple(
+ 'Pose', ['lookat', 'distance', 'azimuth', 'elevation'])
+
+
+class Physics(_control.Physics):
+ """Encapsulates a MuJoCo model.
+
+ A MuJoCo model is typically defined by an MJCF XML file [0]
+
+ ```python
+ physics = Physics.from_xml_path('/path/to/model.xml')
+
+ with physics.reset_context():
+ physics.named.data.qpos['hinge'] = np.random.rand()
+
+ # Apply controls and advance the simulation state.
+ physics.set_control(np.random.random_sample(size=N_ACTUATORS))
+ physics.step()
+
+ # Render a camera defined in the XML file to a NumPy array.
+ rgb = physics.render(height=240, width=320, id=0)
+ ```
+
+ [0] http://www.mujoco.org/book/modeling.html
+ """
+
+ def __init__(self, data):
+ """Initializes a new `Physics` instance.
+
+ Args:
+ data: Instance of `wrapper.MjData`.
+ """
+ self._reload_from_data(data)
+
+ def set_control(self, control):
+ """Sets the control signal for the actuators.
+
+ Args:
+ control: NumPy array or array-like actuation values.
+ """
+ self.data.ctrl[:] = np.asarray(control)
+
+ def step(self, n_sub_steps=1):
+ """Advances physics with up-to-date position and velocity dependent fields.
+
+ The actuation can be updated by calling the `set_control` function first.
+
+ Args:
+ n_sub_steps: Optional number of times to advance the physics. Default 1.
+ """
+ # In the case of Euler integration we assume mj_step1 has already been
+ # called for this state, finish the step with mj_step2 and then update all
+ # position and velocity related fields with mj_step1. This ensures that
+ # (most of) mjData is in sync with qpos and qvel. In the case of non-Euler
+ # integrators (e.g. RK4) an additional mj_step1 must be called after the
+ # last mj_step to ensure mjData syncing.
+ for _ in xrange(n_sub_steps):
+ if self.model.opt.integrator == enums.mjtIntegrator.mjINT_EULER:
+ mjlib.mj_step2(self.model.ptr, self.data.ptr)
+ mjlib.mj_step1(self.model.ptr, self.data.ptr)
+ else:
+ mjlib.mj_step(self.model.ptr, self.data.ptr)
+
+ if self.model.opt.integrator != enums.mjtIntegrator.mjINT_EULER:
+ mjlib.mj_step1(self.model.ptr, self.data.ptr)
+
+ self.check_divergence()
+
+ def render(self, height=240, width=320, camera_id=-1, overlays=(),
+ depth=False, scene_option=None):
+ """Returns a camera view as a NumPy array of pixel values.
+
+ Args:
+ height: Viewport height (number of pixels). Optional, defaults to 240.
+ width: Viewport width (number of pixels). Optional, defaults to 320.
+ camera_id: Optional camera name or index. Defaults to -1, the free
+ camera, which is always defined. A nonnegative integer or string
+ corresponds to a fixed camera, which must be defined in the model XML.
+ If `camera_id` is a string then the camera must also be named.
+ overlays: An optional sequence of `TextOverlay` instances to draw. Only
+ supported if `depth` is False.
+ depth: If `True`, this method returns a NumPy float array of depth values
+ (in meters). Defaults to `False`, which results in an RGB image.
+ scene_option: An optional `wrapper.MjvOption` instance that can be used to
+ render the scene with custom visualization options. If None then the
+ default options will be used.
+
+ Returns:
+ The rendered RGB or depth image.
+ """
+ camera = Camera(
+ physics=self, height=height, width=width, camera_id=camera_id)
+ return camera.render(
+ overlays=overlays, depth=depth, scene_option=scene_option)
+
+ def get_state(self):
+ """Returns the physics state.
+
+ Returns:
+ NumPy array containing full physics simulation state.
+ """
+ return np.concatenate(self._physics_state_items())
+
+ def set_state(self, physics_state):
+ """Sets the physics state.
+
+ Args:
+ physics_state: NumPy array containing the full physics simulation state.
+
+ Raises:
+ ValueError: If `physics_state` has invalid size.
+ """
+ state_items = self._physics_state_items()
+
+ expected_shape = (sum(item.size for item in state_items),)
+ if expected_shape != physics_state.shape:
+ raise ValueError('Input physics state has shape {}. Expected {}.'.format(
+ physics_state.shape, expected_shape))
+
+ start = 0
+ for state_item in state_items:
+ size = state_item.size
+ state_item[:] = physics_state[start:start + size]
+ start += size
+
+ def copy(self, share_model=False):
+ """Creates a copy of this `Physics` instance.
+
+ Args:
+ share_model: If True, the copy and the original will share a common
+ MjModel instance. By default, both model and data will both be copied.
+
+ Returns:
+ A `Physics` instance.
+ """
+ if not share_model:
+ new_model = self.model.copy()
+ else:
+ new_model = self.model
+ new_data = wrapper.MjData(new_model)
+ mjlib.mj_copyData(new_data.ptr, new_data.model.ptr, self.data.ptr)
+ cls = self.__class__
+ new_obj = cls.__new__(cls)
+ new_obj._reload_from_data(new_data) # pylint: disable=protected-access
+ return new_obj
+
+ def reset(self):
+ """Resets internal variables of the physics simulation."""
+ mjlib.mj_resetData(self.model.ptr, self.data.ptr)
+ # Disable actuation since we don't yet have meaningful control inputs.
+ with self.model.disable('actuation'):
+ self.forward()
+
+ def after_reset(self):
+ """Runs after resetting internal variables of the physics simulation."""
+ # Disable actuation since we don't yet have meaningful control inputs.
+ with self.model.disable('actuation'):
+ self.forward()
+
+ def forward(self):
+ """Recomputes the forward dynamics without advancing the simulation."""
+ # Note: `mj_forward` differs from `mj_step1` in that it also recomputes
+ # quantities that depend on acceleration (and therefore on the state of the
+ # controls). For example `mj_forward` updates accelerometer and gyro
+ # readings, whereas `mj_step1` does not.
+ # http://www.mujoco.org/book/programming.html#siForward
+ mjlib.mj_forward(self.model.ptr, self.data.ptr)
+
+ def check_divergence(self):
+ """Raises a `base.PhysicsError` if the simulation state is divergent."""
+ warning_counts = [self.data.warning[i].number for i in _DIVERGENCE_WARNINGS]
+ if any(warning_counts):
+ warning_names = []
+ for i in np.where(warning_counts)[0]:
+ field_idx = _DIVERGENCE_WARNINGS[i]
+ warning_names.append(enums.mjtWarning._fields[field_idx])
+ raise _control.PhysicsError(
+ 'Physics state has diverged. Warning(s) raised: {}'.format(
+ ', '.join(warning_names)))
+
+ def __getstate__(self):
+ return self.data # All state is assumed to reside within `self.data`.
+
+ def __setstate__(self, data):
+ self._reload_from_data(data)
+
+ def _reload_from_model(self, model):
+ """Initializes a new or existing `Physics` from a `wrapper.MjModel`.
+
+ Creates a new `wrapper.MjData` instance, then delegates to
+ `_reload_from_data`.
+
+ Args:
+ model: Instance of `wrapper.MjModel`.
+ """
+ data = wrapper.MjData(model)
+ self._reload_from_data(data)
+
+ def _reload_from_data(self, data):
+ """Initializes a new or existing `Physics` instance from a `wrapper.MjData`.
+
+ Assigns all attributes and sets up rendering contexts and named indexing.
+
+ The default constructor as well as the other `reload_from` methods should
+ delegate to this method.
+
+ Args:
+ data: Instance of `wrapper.MjData`.
+ """
+ self._data = data
+
+ # Forcibly clear the previous context to avoid problems with GL
+ # implementations which do not support multiple contexts on a given device.
+ if hasattr(self, '_contexts'):
+ self._contexts.gl.free_context()
+
+ # Set up rendering context. Need to provide at least one rendering api in
+ # the BUILD target.
+ render_context = render.Renderer(_MAX_WIDTH, _MAX_HEIGHT)
+ mujoco_context = wrapper.MjrContext()
+ with render_context.make_current(_MAX_WIDTH, _MAX_HEIGHT):
+ mjlib.mjr_makeContext(self.model.ptr, mujoco_context.ptr, _FONT_SCALE)
+ mjlib.mjr_setBuffer(
+ enums.mjtFramebuffer.mjFB_OFFSCREEN, mujoco_context.ptr)
+ self._contexts = Contexts(gl=render_context, mujoco=mujoco_context)
+
+ # Call kinematics update to enable rendering.
+ self.after_reset()
+
+ # Set up named indexing.
+ axis_indexers = index.make_axis_indexers(self.model)
+ self._named = NamedIndexStructs(
+ model=index.struct_indexer(self.model, 'mjmodel', axis_indexers),
+ data=index.struct_indexer(self.data, 'mjdata', axis_indexers),)
+
+ @classmethod
+ def from_model(cls, model):
+ """A named constructor from a `wrapper.MjModel` instance."""
+ data = wrapper.MjData(model)
+ return cls(data)
+
+ @classmethod
+ def from_xml_string(cls, xml_string, assets=None):
+ """A named constructor from a string containing an MJCF XML file.
+
+ Args:
+ xml_string: XML string containing an MJCF model description.
+ assets: Optional dict containing external assets referenced by the model
+ (such as additional XML files, textures, meshes etc.), in the form of
+ `{filename: contents_string}` pairs. The keys should correspond to the
+ filenames specified in the model XML.
+
+ Returns:
+ A new `Physics` instance.
+ """
+ model = wrapper.MjModel.from_xml_string(xml_string, assets=assets)
+ return cls.from_model(model)
+
+ @classmethod
+ def from_byte_string(cls, byte_string):
+ """A named constructor from a model binary as a byte string."""
+ model = wrapper.MjModel.from_byte_string(byte_string)
+ return cls.from_model(model)
+
+ @classmethod
+ def from_xml_path(cls, file_path):
+ """A named constructor from a path to an MJCF XML file.
+
+ Args:
+ file_path: String containing path to model definition file.
+
+ Returns:
+ A new `Physics` instance.
+ """
+ model = wrapper.MjModel.from_xml_path(file_path)
+ return cls.from_model(model)
+
+ @classmethod
+ def from_binary_path(cls, file_path):
+ """A named constructor from a path to an MJB model binary file.
+
+ Args:
+ file_path: String containing path to model definition file.
+
+ Returns:
+ A new `Physics` instance.
+ """
+ model = wrapper.MjModel.from_binary_path(file_path)
+ return cls.from_model(model)
+
+ def reload_from_xml_string(self, xml_string, assets=None):
+ """Reloads the `Physics` instance from a string containing an MJCF XML file.
+
+ After calling this method, the state of the `Physics` instance is the same
+ as a new `Physics` instance created with the `from_xml_string` named
+ constructor.
+
+ Args:
+ xml_string: XML string containing an MJCF model description.
+ assets: Optional dict containing external assets referenced by the model
+ (such as additional XML files, textures, meshes etc.), in the form of
+ `{filename: contents_string}` pairs. The keys should correspond to the
+ filenames specified in the model XML.
+ """
+ new_model = wrapper.MjModel.from_xml_string(xml_string, assets=assets)
+ self._reload_from_model(new_model)
+
+ def reload_from_xml_path(self, file_path):
+ """Reloads the `Physics` instance from a path to an MJCF XML file.
+
+ After calling this method, the state of the `Physics` instance is the same
+ as a new `Physics` instance created with the `from_xml_path`
+ named constructor.
+
+ Args:
+ file_path: String containing path to model definition file.
+ """
+ self._reload_from_model(wrapper.MjModel.from_xml_path(file_path))
+
+ @property
+ def named(self):
+ return self._named
+
+ @property
+ def contexts(self):
+ """Returns a `Contexts` namedtuple, used in `Camera`s and rendering code."""
+ return self._contexts
+
+ @property
+ def model(self):
+ return self._data.model
+
+ @property
+ def data(self):
+ return self._data
+
+ def _physics_state_items(self):
+ """Returns list of arrays making up internal physics simulation state.
+
+ The physics state consists of the state variables, their derivatives and
+ actuation activations.
+
+ Returns:
+ List of NumPy arrays containing full physics simulation state.
+ """
+ return [self.data.qpos[:], self.data.qvel[:], self.data.act[:]]
+
+ # Named views of simulation data.
+
+ def control(self):
+ """Returns MuJoCo actuation vector as defined in the model."""
+ return self.data.ctrl[:]
+
+ def activation(self):
+ """Returns the internal states of 'filter' or 'integrator' actuators.
+
+ For details, please refer to
+ http://www.mujoco.org/book/computation.html#geActuation
+
+ Returns:
+ Activations in a numpy array.
+ """
+ return self.data.act[:]
+
+ def state(self):
+ """Returns the full physics state. Alias for `get_physics_state`."""
+ return np.concatenate(self._physics_state_items())
+
+ def position(self):
+ """Returns generalized positions (system configuration)."""
+ return self.data.qpos[:]
+
+ def velocity(self):
+ """Returns generalized velocities."""
+ return self.data.qvel[:]
+
+ def timestep(self):
+ """Returns the simulation timestep."""
+ return self.model.opt.timestep
+
+ def time(self):
+ """Returns episode time in seconds."""
+ return self.data.time
+
+
+class Camera(object):
+ """Mujoco scene camera.
+
+ Holds rendering properties such as the width and height of the viewport. The
+ camera position and rotation is defined by the Mujoco camera corresponding to
+ the `camera_id`. Multiple `Camera` instances may exist for a single
+ `camera_id`, for example to render the same view at different resolutions.
+ """
+
+ def __init__(self, physics, height=240, width=320, camera_id=-1):
+ """Initializes a new `Camera`.
+
+ Args:
+ physics: Instance of `Physics`.
+ height: Optional image height. Defaults to 240.
+ width: Optional image width. Defaults to 320.
+ camera_id: Optional camera name or index. Defaults to -1, the free
+ camera, which is always defined. A nonnegative integer or string
+ corresponds to a fixed camera, which must be defined in the model XML.
+ If `camera_id` is a string then the camera must also be named.
+
+ Raises:
+ ValueError: If `camera_id` is outside the valid range, or if `width` or
+ `height` exceed the dimensions of MuJoCo's offscreen framebuffer.
+ """
+ buffer_width = physics.model.vis.global_.offwidth
+ buffer_height = physics.model.vis.global_.offheight
+ if width > buffer_width:
+ raise ValueError('Image width {} > framebuffer width {}. Either reduce '
+ 'the image width or specify a larger offscreen '
+ 'framebuffer in the model XML using the clause\n'
+ '\n'
+ ' \n'
+ ''.format(width, buffer_width))
+ if height > buffer_height:
+ raise ValueError('Image height {} > framebuffer height {}. Either reduce '
+ 'the image height or specify a larger offscreen '
+ 'framebuffer in the model XML using the clause\n'
+ '\n'
+ ' \n'
+ ''.format(height, buffer_height))
+ if isinstance(camera_id, six.string_types):
+ camera_id = physics.model.name2id(camera_id, 'camera')
+ if camera_id < -1:
+ raise ValueError('camera_id cannot be smaller than -1.')
+ if camera_id >= physics.model.ncam:
+ raise ValueError('model has {} fixed cameras. camera_id={} is invalid.'.
+ format(physics.model.ncam, camera_id))
+
+ self._width = width
+ self._height = height
+ self._physics = physics
+
+ # Variables corresponding to structs needed by Mujoco's rendering functions.
+ self._scene = wrapper.MjvScene()
+ self._scene_option = wrapper.MjvOption()
+
+ self._perturb = wrapper.MjvPerturb()
+ self._perturb.active = 0
+ self._perturb.select = 0
+
+ self._rect = types.MJRRECT(0, 0, self._width, self._height)
+
+ self._render_camera = wrapper.MjvCamera()
+ self._render_camera.fixedcamid = camera_id
+
+ if camera_id == -1:
+ self._render_camera.type_ = enums.mjtCamera.mjCAMERA_FREE
+ else:
+ # As defined in the Mujoco documentation, mjCAMERA_FIXED refers to a
+ # camera explicitly defined in the model.
+ self._render_camera.type_ = enums.mjtCamera.mjCAMERA_FIXED
+
+ # Internal buffers.
+ self._rgb_buffer = np.empty((self._height, self._width, 3), dtype=np.uint8)
+ self._depth_buffer = np.empty((self._height, self._width), dtype=np.float32)
+
+ if self._physics.contexts.mujoco is not None:
+ with self._physics.contexts.gl.make_current(self._width, self._height):
+ mjlib.mjr_setBuffer(enums.mjtFramebuffer.mjFB_OFFSCREEN,
+ self._physics.contexts.mujoco.ptr)
+
+ @property
+ def width(self):
+ """Returns the image width (number of pixels)."""
+ return self._width
+
+ @property
+ def height(self):
+ """Returns the image height (number of pixels)."""
+ return self._height
+
+ @property
+ def option(self):
+ """Returns the camera's visualization options."""
+ return self._scene_option
+
+ def update(self, scene_option=None):
+ """Updates geometry used for rendering.
+
+ Args:
+ scene_option: A custom `wrapper.MjvOption` instance to use to render
+ the scene instead of the default. If None, will use the default.
+ """
+ scene_option = scene_option or self._scene_option
+ mjlib.mjv_updateScene(self._physics.model.ptr, self._physics.data.ptr,
+ scene_option.ptr, self._perturb.ptr,
+ self._render_camera.ptr, enums.mjtCatBit.mjCAT_ALL,
+ self._scene.ptr)
+
+ def render(self, overlays=(), depth=False, scene_option=None):
+ """Renders the camera view as a numpy array of pixel values.
+
+ Args:
+ overlays: An optional sequence of `TextOverlay` instances to draw. Only
+ supported if `depth` is False.
+ depth: An optional boolean. If True make the camera measure depth
+ scene_option: A custom `wrapper.MjvOption` instance to use to render
+ the scene instead of the default. If None, will use the default.
+
+ Returns:
+ The rendered scene. If `depth` is False this is a NumPy uint8 array of RGB
+ values, otherwise it is a float NumPy array of depth values (in meters).
+
+ Raises:
+ ValueError: If overlays are requested with depth rendering.
+ """
+
+ if depth and overlays:
+ raise ValueError('Overlays are not supported with depth rendering.')
+
+ self.update(scene_option=scene_option)
+
+ with self._physics.contexts.gl.make_current(self._width, self._height):
+ mjlib.mjr_render(self._rect, self._scene.ptr,
+ self._physics.contexts.mujoco.ptr)
+
+ if depth:
+ mjlib.mjr_readPixels(None, self._depth_buffer, self._rect,
+ self._physics.contexts.mujoco.ptr)
+
+ # Get distance of near and far clipping planes.
+ extent = self._physics.model.stat.extent
+ near = self._physics.model.vis.map_.znear * extent
+ far = self._physics.model.vis.map_.zfar * extent
+
+ # Convert from [0 1] to depth in meters, see links below.
+ # http://stackoverflow.com/a/6657284/1461210
+ # https://www.khronos.org/opengl/wiki/Depth_Buffer_Precision
+ self._depth_buffer = near / (1 - self._depth_buffer * (1 - near / far))
+
+ else:
+ for overlay in overlays:
+ overlay.draw(self._physics.contexts.mujoco.ptr, self._rect)
+
+ mjlib.mjr_readPixels(self._rgb_buffer, None, self._rect,
+ self._physics.contexts.mujoco.ptr)
+
+ return np.flipud(self._depth_buffer if depth else self._rgb_buffer)
+
+ def select(self, cursor_position):
+ """Returns bodies and geoms visible at given coordinates in the frame.
+
+ Args:
+ cursor_position: A `tuple` containing x and y coordinates, normalized to
+ between 0 and 1, and where (0, 0) is bottom-left.
+
+ Returns:
+ A `Selected` namedtuple. Fields are None if nothing is selected.
+ """
+ self.update()
+ aspect_ratio = self._width / self._height
+ cursor_x, cursor_y = cursor_position
+ pos = np.empty(3, np.double)
+ selected_geom = mjlib.mjv_select(
+ self._physics.model.ptr,
+ self._physics.data.ptr,
+ self._scene_option.ptr,
+ aspect_ratio,
+ cursor_x,
+ cursor_y,
+ self._scene.ptr,
+ pos)
+
+ if selected_geom == -1: # Nothing was selected.
+ return Selected(body=None, geom=None, world_position=None)
+ else:
+ assert 0 <= selected_geom < self._physics.model.ngeom
+ selected_body = self._physics.model.geom_bodyid[selected_geom]
+ assert 0 <= selected_body < self._physics.model.nbody
+ return Selected(
+ body=selected_body, geom=selected_geom, world_position=pos)
+
+
+class MovableCamera(Camera):
+ """Subclass of `Camera` that can be moved by changing its pose.
+
+ A `MovableCamera` always corresponds to a MuJoCo free camera with id -1.
+ """
+
+ def __init__(self, physics, height=240, width=320):
+ """Initializes a new `MovableCamera`.
+
+ Args:
+ physics: Instance of `Physics`.
+ height: Optional image height. Defaults to 240.
+ width: Optional image width. Defaults to 320.
+ """
+ super(MovableCamera, self).__init__(
+ physics=physics, height=height, width=width, camera_id=-1)
+
+ def get_pose(self):
+ """Returns the pose of the camera.
+
+ Returns:
+ A `Pose` named tuple with fields:
+ lookat: NumPy array specifying lookat point.
+ distance: Float specifying distance to `lookat`.
+ azimuth: Azimuth in degrees.
+ elevation: Elevation in degrees.
+ """
+ return Pose(self._render_camera.lookat, self._render_camera.distance,
+ self._render_camera.azimuth, self._render_camera.elevation)
+
+ def set_pose(self, lookat, distance, azimuth, elevation):
+ """Sets the pose of the camera.
+
+ Args:
+ lookat: NumPy array or list specifying lookat point.
+ distance: Float specifying distance to `lookat`.
+ azimuth: Azimuth in degrees.
+ elevation: Elevation in degrees.
+ """
+ self._render_camera.lookat[:] = lookat
+ self._render_camera.distance = distance
+ self._render_camera.azimuth = azimuth
+ self._render_camera.elevation = elevation
+
+
+class TextOverlay(object):
+ """A text overlay that can be drawn on top of a camera view."""
+
+ __slots__ = ('title', 'body', 'style', 'position')
+
+ def __init__(self, title='', body='', style='normal', position='top left'):
+ """Initializes a new TextOverlay instance.
+
+ Args:
+ title: Title text.
+ body: Body text.
+ style: The font style. Can be either "normal", "shadow", or "big".
+ position: The grid position of the overlay. Can be either "top left",
+ "top right", "bottom left", or "bottom right".
+ """
+ self.title = title
+ self.body = body
+ self.style = _FONT_STYLES[style]
+ self.position = _GRID_POSITIONS[position]
+
+ def draw(self, context, rect):
+ """Draws the overlay.
+
+ Args:
+ context: A `types.MJRCONTEXT` pointer.
+ rect: A `types.MJRRECT`.
+ """
+ mjlib.mjr_overlay(self.style,
+ self.position,
+ rect,
+ util.to_binary_string(self.title),
+ util.to_binary_string(self.body),
+ context)
+
+
+def action_spec(physics):
+ """Returns an `BoundedArraySpec` matching the `Physics` actuators."""
+ num_actions = physics.model.nu
+ is_limited = physics.model.actuator_ctrllimited.ravel().astype(np.bool)
+ control_range = physics.model.actuator_ctrlrange
+ minima = np.full(num_actions, fill_value=-np.inf, dtype=np.float)
+ maxima = np.full(num_actions, fill_value=np.inf, dtype=np.float)
+ minima[is_limited], maxima[is_limited] = control_range[is_limited].T
+
+ return specs.BoundedArraySpec(
+ shape=(num_actions,), dtype=np.float, minimum=minima, maximum=maxima)
diff --git a/dm_control/mujoco/engine_test.py b/dm_control/mujoco/engine_test.py
new file mode 100644
index 00000000..c37dcb50
--- /dev/null
+++ b/dm_control/mujoco/engine_test.py
@@ -0,0 +1,316 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for `engine`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+
+from absl.testing import absltest
+from absl.testing import parameterized
+
+from dm_control.mujoco import engine
+from dm_control.mujoco import wrapper
+from dm_control.mujoco.testing import assets
+from dm_control.mujoco.wrapper.mjbindings import enums
+from dm_control.mujoco.wrapper.mjbindings import mjlib
+
+from dm_control.rl import control
+
+import numpy as np
+from six.moves import cPickle
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+MODEL_PATH = assets.get_path('cartpole.xml')
+MODEL_WITH_ASSETS = assets.get_contents('model_with_assets.xml')
+ASSETS = {
+ 'texture.png': assets.get_contents('deepmind.png'),
+ 'mesh.stl': assets.get_contents('cube.stl'),
+ 'included.xml': assets.get_contents('sphere.xml')
+}
+
+
+class MujocoEngineTest(parameterized.TestCase):
+
+ def setUp(self):
+ self._physics = engine.Physics.from_xml_path(MODEL_PATH)
+
+ def _assert_attributes_equal(self, actual_obj, expected_obj, attr_to_compare):
+ for name in attr_to_compare:
+ actual_value = getattr(actual_obj, name)
+ expected_value = getattr(expected_obj, name)
+ try:
+ if isinstance(expected_value, np.ndarray):
+ np.testing.assert_array_equal(actual_value, expected_value)
+ else:
+ self.assertEqual(actual_value, expected_value)
+ except AssertionError as e:
+ raise AssertionError("Attribute '{}' differs from expected value. {}"
+ "".format(name, e.message))
+
+ @parameterized.parameters(0, 'cart', u'cart')
+ def testCameraIndexing(self, camera_id):
+ height, width = 480, 640
+ _ = engine.Camera(
+ self._physics, height, width, camera_id=camera_id)
+
+ def testDepthRender(self):
+ plane_and_box = """
+
+
+
+
+
+
+
+ """
+ physics = engine.Physics.from_xml_string(plane_and_box)
+ pixels = physics.render(height=200, width=200, camera_id='top', depth=True)
+ # Nearest pixels should be 2.8m away
+ np.testing.assert_approx_equal(pixels.min(), 2.8, 3)
+ # Furthest pixels should be 3m away (depth is orthographic)
+ np.testing.assert_approx_equal(pixels.max(), 3.0, 3)
+
+ def testTextOverlay(self):
+ height, width = 480, 640
+ overlay = engine.TextOverlay(title='Title', body='Body', style='big',
+ position='bottom right')
+
+ no_overlay = self._physics.render(height, width, camera_id=0)
+ with_overlay = self._physics.render(height, width, camera_id=0,
+ overlays=[overlay])
+ self.assertFalse(np.all(no_overlay == with_overlay),
+ msg='Images are identical with and without text overlay.')
+
+ def testSceneOption(self):
+ height, width = 480, 640
+ scene_option = wrapper.MjvOption()
+ mjlib.mjv_defaultOption(scene_option.ptr)
+
+ # Render geoms as semi-transparent.
+ scene_option.flags[enums.mjtVisFlag.mjVIS_TRANSPARENT] = 1
+
+ no_scene_option = self._physics.render(height, width, camera_id=0)
+ with_scene_option = self._physics.render(height, width, camera_id=0,
+ scene_option=scene_option)
+ self.assertFalse(np.all(no_scene_option == with_scene_option),
+ msg='Images are identical with and without scene option.')
+
+ @parameterized.parameters(((0.5, 0.5), (1, 3)), # pole
+ ((0.5, 0.1), (0, 0)), # ground
+ ((0.9, 0.9), (None, None)), # sky
+ )
+ def testCameraSelection(self, coordinates, expected_selection):
+ height, width = 480, 640
+ camera = engine.Camera(self._physics, height, width, camera_id=0)
+
+ # Test for b/63380170: Enabling visualization of body frames adds
+ # "non-model" geoms to the scene. This means that the indices of geoms
+ # within `camera._scene.geoms` don't match the rows of `model.geom_bodyid`.
+ camera.option.frame = enums.mjtFrame.mjFRAME_BODY
+
+ selected = camera.select(coordinates)
+ self.assertEqual(expected_selection, selected[:2])
+
+ def testMovableCameraSetGetPose(self):
+ height, width = 240, 320
+
+ camera = engine.MovableCamera(self._physics, height, width)
+ image = camera.render().copy()
+
+ pose = camera.get_pose()
+
+ lookat_offset = np.array([0.01, 0.02, -0.03])
+
+ # Would normally pass the new values directly to camera.set_pose instead of
+ # using the namedtuple _replace method, but this makes the asserts at the
+ # end of the test a little cleaner.
+ new_pose = pose._replace(distance=pose.distance * 1.5,
+ lookat=pose.lookat + lookat_offset,
+ azimuth=pose.azimuth + -15,
+ elevation=pose.elevation - 10)
+
+ camera.set_pose(*new_pose)
+
+ self.assertEqual(new_pose.distance, camera.get_pose().distance)
+ self.assertEqual(new_pose.azimuth, camera.get_pose().azimuth)
+ self.assertEqual(new_pose.elevation, camera.get_pose().elevation)
+ np.testing.assert_allclose(new_pose.lookat, camera.get_pose().lookat)
+
+ self.assertFalse(np.all(image == camera.render()))
+
+ def testRenderExceptions(self):
+ max_width = self._physics.model.vis.global_.offwidth
+ max_height = self._physics.model.vis.global_.offheight
+ max_camid = self._physics.model.ncam - 1
+ with self.assertRaisesRegexp(ValueError, 'width'):
+ self._physics.render(max_height, max_width + 1, camera_id=max_camid)
+ with self.assertRaisesRegexp(ValueError, 'height'):
+ self._physics.render(max_height + 1, max_width, camera_id=max_camid)
+ with self.assertRaisesRegexp(ValueError, 'camera_id'):
+ self._physics.render(max_height, max_width, camera_id=max_camid + 1)
+ with self.assertRaisesRegexp(ValueError, 'camera_id'):
+ self._physics.render(max_height, max_width, camera_id=-2)
+
+ def testPhysicsRenderMethod(self):
+ height, width = 240, 320
+ image = self._physics.render(height=height, width=width)
+ self.assertEqual(image.shape, (height, width, 3))
+ depth = self._physics.render(height=height, width=width, depth=True)
+ self.assertEqual(depth.shape, (height, width))
+
+ def testNamedViews(self):
+ self.assertEqual((1,), self._physics.control().shape)
+ self.assertEqual((2,), self._physics.position().shape)
+ self.assertEqual((2,), self._physics.velocity().shape)
+ self.assertEqual((0,), self._physics.activation().shape)
+ self.assertEqual((4,), self._physics.state().shape)
+ self.assertEqual(0., self._physics.time())
+ self.assertEqual(0.01, self._physics.timestep())
+
+ def testSetGetPhysicsState(self):
+ physics_state = self._physics.get_state()
+ self._physics.set_state(physics_state)
+
+ new_physics_state = np.random.random_sample(physics_state.shape)
+ self._physics.set_state(new_physics_state)
+
+ np.testing.assert_allclose(new_physics_state,
+ self._physics.get_state())
+
+ def testSetInvalidPhysicsState(self):
+ badly_shaped_state = np.repeat(self._physics.get_state(), repeats=2)
+
+ with self.assertRaises(ValueError):
+ self._physics.set_state(badly_shaped_state)
+
+ def testNamedIndexing(self):
+ self.assertEqual((3,), self._physics.named.data.xpos['cart'].shape)
+ self.assertEqual((2, 3),
+ self._physics.named.data.xpos[['cart', 'pole']].shape)
+
+ def testReload(self):
+ self._physics.reload_from_xml_path(MODEL_PATH)
+
+ def testLoadAndReloadFromStringWithAssets(self):
+ physics = engine.Physics.from_xml_string(
+ MODEL_WITH_ASSETS, assets=ASSETS)
+ physics.reload_from_xml_string(MODEL_WITH_ASSETS, assets=ASSETS)
+
+ @parameterized.parameters(
+ 'mjWARN_INERTIA',
+ 'mjWARN_BADQPOS',
+ 'mjWARN_BADQVEL',
+ 'mjWARN_BADQACC',
+ )
+ def testDivergenceException(self, warning_name):
+ warning_enum = getattr(enums.mjtWarning, warning_name)
+ with self._physics.reset_context():
+ self._physics.data.warning[warning_enum].number = 1
+ with self.assertRaisesRegexp(control.PhysicsError, warning_name):
+ self._physics.check_divergence()
+ self._physics.reset()
+ self._physics.check_divergence()
+
+ @parameterized.parameters(float('inf'), float('nan'), 1e15)
+ def testBadQpos(self, bad_value):
+ with self._physics.reset_context():
+ self._physics.data.qpos[0] = bad_value
+ mjlib.mj_checkPos(self._physics.model.ptr, self._physics.data.ptr)
+ with self.assertRaises(control.PhysicsError):
+ self._physics.check_divergence()
+ self._physics.reset()
+ mjlib.mj_checkPos(self._physics.model.ptr, self._physics.data.ptr)
+ self._physics.check_divergence()
+
+ def testNanControl(self):
+ with self._physics.reset_context():
+ self._physics.data.ctrl[0] = float('nan')
+
+ # Apply the controls.
+ mjlib.mj_step(self._physics.model.ptr, self._physics.data.ptr)
+ with self.assertRaisesRegexp(control.PhysicsError, 'mjWARN_BADCTRL'):
+ self._physics.check_divergence()
+
+ @parameterized.named_parameters(
+ ('_copy', lambda x: x.copy()),
+ ('_pickle_and_unpickle', lambda x: cPickle.loads(cPickle.dumps(x))),
+ )
+ def testCopyOrPicklePhysics(self, func):
+ for _ in xrange(10):
+ self._physics.step()
+ physics2 = func(self._physics)
+ self.assertNotEqual(physics2.model.ptr, self._physics.model.ptr)
+ self.assertNotEqual(physics2.data.ptr, self._physics.data.ptr)
+ model_attr_to_compare = ('nnames', 'njmax', 'body_pos', 'geom_quat')
+ self._assert_attributes_equal(
+ physics2.model, self._physics.model, model_attr_to_compare)
+ data_attr_to_compare = ('time', 'energy', 'qpos', 'xpos')
+ self._assert_attributes_equal(
+ physics2.data, self._physics.data, data_attr_to_compare)
+ for _ in xrange(10):
+ self._physics.step()
+ physics2.step()
+ self._assert_attributes_equal(
+ physics2.model, self._physics.model, model_attr_to_compare)
+ self._assert_attributes_equal(
+ physics2.data, self._physics.data, data_attr_to_compare)
+
+ def testCopyDataOnly(self):
+ physics2 = self._physics.copy(share_model=True)
+ self.assertEqual(physics2.model.ptr, self._physics.model.ptr)
+ self.assertNotEqual(physics2.data.ptr, self._physics.data.ptr)
+
+ def testForwardDynamicsUpdatedAfterReset(self):
+ gravity = -9.81
+ self._physics.model.opt.gravity[2] = gravity
+ with self._physics.reset_context():
+ pass
+ self.assertAlmostEqual(
+ self._physics.named.data.sensordata['accelerometer'][2], -gravity)
+
+ def testActuationNotAppliedInAfterReset(self):
+ self._physics.data.ctrl[0] = 1.
+ self._physics.after_reset() # Calls `forward()` with actuation disabled.
+ self.assertEqual(self._physics.data.actuator_force[0], 0.)
+ self._physics.forward() # Call `forward` directly with actuation enabled.
+ self.assertEqual(self._physics.data.actuator_force[0], 1.)
+
+ def testActionSpec(self):
+ xml = """
+
+
+
+
+
+
+
+
+
+
+
+
+ """
+ physics = engine.Physics.from_xml_string(xml)
+ spec = engine.action_spec(physics)
+ self.assertEqual(np.float, spec.dtype)
+ np.testing.assert_array_equal(spec.minimum, [-np.inf, -1.0])
+ np.testing.assert_array_equal(spec.maximum, [np.inf, 2.0])
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/mujoco/index.py b/dm_control/mujoco/index.py
new file mode 100644
index 00000000..3a3d3a27
--- /dev/null
+++ b/dm_control/mujoco/index.py
@@ -0,0 +1,641 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Mujoco functions to support named indexing.
+
+The Mujoco name structure works as follows:
+
+In mjxmacro.h, each "X" entry denotes a type (a), a field name (b) and a list
+of dimension size metadata (c) which may contain both numbers and names, for
+example
+
+ X(int, name_bodyadr, nbody, 1) // or
+ X(mjtNum, body_pos, nbody, 3)
+ a b c ----->
+
+The second declaration states that the field `body_pos` has type `mjtNum` and
+dimension sizes `(nbody, 3)`, i.e. the first axis is indexed by body number.
+These and other named dimensions are sized based on the loaded model. This
+information is parsed and stored in `mjbindings.sizes`.
+
+In mjmodel.h, the struct mjModel contains an array of element name addresses
+for each size name.
+
+ int* name_bodyadr; // body name pointers (nbody x 1)
+
+By iterating over each of these element name address arrays, we first obtain a
+mapping from size names to a list of element names.
+
+ {'nbody': ['cart', 'pole'], 'njnt': ['free', 'ball', 'hinge'], ...}
+
+In addition to the element names that are derived from the mjModel struct at
+runtime, we also assign hard-coded names to certain dimensions where there is an
+established naming convention (e.g. 'x', 'y', 'z' for dimensions that correspond
+to Cartesian positions).
+
+For some dimensions, a single element name maps to multiple indices within the
+underlying field. For example, a single joint name corresponds to a variable
+number of indices within `qpos` that depends on the number of degrees of freedom
+associated with that joint type. These are referred to as "ragged" dimensions.
+
+In such cases we determine the size of each named element by examining the
+address arrays (e.g. `jnt_qposadr`), and construct a mapping from size name to
+element sizes:
+
+ {'nq': [7, 3, 1], 'nv': [6, 3, 1], ...}
+
+Given these two dictionaries, we then create an `Axis` instance for each size
+name. These objects have a `convert_key_item` method that handles the conversion
+from indexing expressions containing element names to valid numpy indices.
+Different implementations of `Axis` are used to handle "ragged" and "non-ragged"
+dimensions.
+
+ {'nbody': RegularNamedAxis(names=['cart', 'pole']),
+ 'nq': RaggedNamedAxis(names=['free', 'ball', 'hinge'], sizes=[7, 4, 1])}
+
+We construct this dictionary once using `make_axis_indexers`.
+
+Finally, for each field we construct a `FieldIndexer` class. A `FieldIndexer`
+instance encapsulates a field together with a list of `Axis` instances (one per
+dimension), and implements the named indexing logic by calling their respective
+`convert_key_item` methods.
+
+Summary of terminology:
+
+* _size name_ or _size_ A dimension size name, e.g. `nbody` or `ngeom`.
+* _element name_ or _name_ A named element in a Mujoco model, e.g. 'cart' or
+ 'pole'.
+* _element index_ or _index_ The index of an element name, for a specific size
+ name.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import collections
+import weakref
+
+# Internal dependencies.
+
+from dm_control.mujoco.wrapper import util
+from dm_control.mujoco.wrapper.mjbindings import sizes
+import numpy as np
+import six
+
+
+# Mapping from {size_name: address_field_name} for ragged dimensions.
+_RAGGED_ADDRS = {
+ 'nq': 'jnt_qposadr',
+ 'nv': 'jnt_dofadr',
+ 'nsensordata': 'sensor_adr',
+ 'nnumericdata': 'numeric_adr',
+}
+
+# Names of columns.
+_COLUMN_NAMES = {
+ 'xyz': ['x', 'y', 'z'],
+ 'quat': ['qw', 'qx', 'qy', 'qz'],
+ 'mat': ['xx', 'xy', 'xz',
+ 'yx', 'yy', 'yz',
+ 'zx', 'zy', 'zz'],
+}
+
+# Mapping from keys of _COLUMN_NAMES to sets of field names whose columns are
+# addressable using those names.
+_COLUMN_ID_TO_FIELDS = {
+ 'xyz': set([
+ 'body_pos',
+ 'body_ipos',
+ 'body_inertia',
+ 'jnt_pos',
+ 'jnt_axis',
+ 'geom_size',
+ 'geom_pos',
+ 'site_size',
+ 'site_pos',
+ 'cam_pos',
+ 'cam_poscom0',
+ 'cam_pos0',
+ 'light_pos',
+ 'light_dir',
+ 'light_poscom0',
+ 'light_pos0',
+ 'light_dir0',
+ 'mesh_vert',
+ 'mesh_normal',
+ 'mocap_pos',
+ 'xpos',
+ 'xipos',
+ 'xanchor',
+ 'xaxis',
+ 'geom_xpos',
+ 'site_xpos',
+ 'cam_xpos',
+ 'light_xpos',
+ 'light_xdir',
+ 'subtree_com',
+ 'wrap_xpos',
+ 'subtree_linvel',
+ 'subtree_angmom',
+ ]),
+ 'quat': set([
+ 'body_quat',
+ 'body_iquat',
+ 'geom_quat',
+ 'site_quat',
+ 'cam_quat',
+ 'mocap_quat',
+ 'xquat',
+ ]),
+ 'mat': set([
+ 'cam_mat0',
+ 'xmat',
+ 'ximat',
+ 'geom_xmat',
+ 'site_xmat',
+ 'cam_xmat',
+ ])
+}
+
+
+def _get_size_name_to_element_names(model):
+ """Returns a dict that maps size names to element names.
+
+ Args:
+ model: An instance of `mjbindings.mjModelWrapper`.
+
+ Returns:
+ A `dict` mapping from a size name (e.g. `'nbody'`) to a list of element
+ names.
+ """
+
+ names = model.names[:model.nnames]
+ size_name_to_element_names = {}
+
+ for field_name in dir(model):
+ if not _is_name_pointer(field_name):
+ continue
+
+ # Get addresses of element names in `model.names` array, e.g.
+ # field name: `name_nbodyadr` and name_addresses: `[86, 92, 101]`, and skip
+ # when there are no elements for this type in the model.
+ name_addresses = getattr(model, field_name).ravel()
+ if not name_addresses.size:
+ continue
+
+ # Get the element names.
+ element_names = []
+ for start_index in name_addresses:
+ name = names[start_index:names.find(b'\0', start_index)]
+ element_names.append(util.to_native_string(name))
+
+ # String identifier for the size of the first dimension, e.g. 'nbody'.
+ size_name = _get_size_name(field_name)
+
+ size_name_to_element_names[size_name] = element_names
+
+ # Add custom element names for certain columns.
+ for size_name, element_names in six.iteritems(_COLUMN_NAMES):
+ size_name_to_element_names[size_name] = element_names
+
+ # "Ragged" axes inherit their element names from other "non-ragged" axes.
+ # For example, the element names for "nv" axis come from "njnt".
+ for size_name, address_field_name in six.iteritems(_RAGGED_ADDRS):
+ donor = 'n' + address_field_name.split('_')[0]
+ if donor in size_name_to_element_names:
+ size_name_to_element_names[size_name] = size_name_to_element_names[donor]
+
+ # Mocap bodies are a special subset of bodies.
+ mocap_body_names = [None] * model.nmocap
+ for body_id, body_name in enumerate(size_name_to_element_names['nbody']):
+ body_mocapid = model.body_mocapid[body_id]
+ if body_mocapid != -1:
+ mocap_body_names[body_mocapid] = body_name
+ assert None not in mocap_body_names
+ size_name_to_element_names['nmocap'] = mocap_body_names
+
+ return size_name_to_element_names
+
+
+def _get_size_name_to_element_sizes(model):
+ """Returns a dict that maps size names to element sizes for ragged axes.
+
+ Args:
+ model: An instance of `mjbindings.mjModelWrapper`.
+
+ Returns:
+ A `dict` mapping from a size name (e.g. `'nv'`) to a numpy array of element
+ sizes. Size names corresponding to non-ragged axes are omitted.
+ """
+
+ size_name_to_element_sizes = {}
+
+ for size_name, address_field_name in six.iteritems(_RAGGED_ADDRS):
+ addresses = getattr(model, address_field_name).ravel()
+ total_length = getattr(model, size_name)
+ element_sizes = np.diff(np.r_[addresses, total_length])
+ size_name_to_element_sizes[size_name] = element_sizes
+
+ return size_name_to_element_sizes
+
+
+def make_axis_indexers(model):
+ """Returns a dict that maps size names to `Axis` indexers.
+
+ Args:
+ model: An instance of `mjbindings.mjModelWrapper`.
+
+ Returns:
+ A `dict` mapping from a size name (e.g. `'nbody'`) to a `Axis`
+ instance.
+ """
+
+ size_name_to_element_names = _get_size_name_to_element_names(model)
+ size_name_to_element_sizes = _get_size_name_to_element_sizes(model)
+
+ # Unrecognized size names are treated as unnamed axes.
+ axis_indexers = collections.defaultdict(UnnamedAxis)
+
+ for size_name in size_name_to_element_names:
+ element_names = size_name_to_element_names[size_name]
+ if size_name in _RAGGED_ADDRS:
+ element_sizes = size_name_to_element_sizes[size_name]
+ indexer = RaggedNamedAxis(element_names, element_sizes)
+ else:
+ indexer = RegularNamedAxis(element_names)
+ axis_indexers[size_name] = indexer
+
+ return axis_indexers
+
+
+def _is_name_pointer(field_name):
+ """Returns True for name pointer field names such as `name_bodyadr`."""
+ # Denotes name pointer fields in mjModel.
+ prefix, suffix = 'name_', 'adr'
+ return field_name.startswith(prefix) and field_name.endswith(suffix)
+
+
+def _get_size_name(field_name, struct_name='mjmodel'):
+ # Look up size name in metadata.
+ return sizes.array_sizes[struct_name][field_name][0]
+
+
+def _validate_key_item(key_item):
+ if isinstance(key_item, (list, np.ndarray)):
+ for sub in key_item:
+ _validate_key_item(sub) # Recurse into nested arrays and lists.
+ elif key_item is Ellipsis:
+ raise IndexError('Ellipsis indexing not supported.')
+ elif key_item is None:
+ raise IndexError('None indexing not supported.')
+ elif key_item in (b'', u''):
+ raise IndexError('Empty strings are not allowed.')
+
+
+@six.add_metaclass(abc.ABCMeta)
+class Axis(object):
+ """Handles the conversion of named indexing expressions into numpy indices."""
+
+ @abc.abstractmethod
+ def convert_key_item(self, key_item):
+ """Converts a (possibly named) indexing expression to a numpy index."""
+
+
+class UnnamedAxis(Axis):
+ """An object representing an axis where the elements are not named."""
+
+ def convert_key_item(self, key_item):
+ """Validate the indexing expression and return it unmodified."""
+ _validate_key_item(key_item)
+ return key_item
+
+
+class RegularNamedAxis(Axis):
+ """Represents an axis where each named element has a fixed size of 1."""
+
+ def __init__(self, names):
+ """Initializes a new `RegularNamedAxis` instance.
+
+ Args:
+ names: A list or array of element names.
+ """
+ self._names = names
+ self._names_to_offsets = {name: offset
+ for offset, name in enumerate(names) if name}
+
+ def convert_key_item(self, key_item):
+ """Converts a named indexing expression to a numpy-friendly index."""
+
+ _validate_key_item(key_item)
+
+ if isinstance(key_item, six.string_types):
+ key_item = self._names_to_offsets[util.to_native_string(key_item)]
+
+ elif isinstance(key_item, (list, np.ndarray)):
+ # Cast lists to numpy arrays.
+ key_item = np.array(key_item, copy=False)
+ original_shape = key_item.shape
+
+ # We assume that either all or none of the items in the array are strings
+ # representing names. If there is a mix, we will let NumPy throw an error
+ # when trying to index with the returned item.
+ if isinstance(key_item.flat[0], six.string_types):
+ key_item = np.array([self._names_to_offsets[util.to_native_string(k)]
+ for k in key_item.flat])
+ # Ensure the output shape is the same as that of the input.
+ key_item.shape = original_shape
+
+ return key_item
+
+ @property
+ def names(self):
+ """Returns a list of element names."""
+ return self._names
+
+
+class RaggedNamedAxis(Axis):
+ """Represents an axis where the named elements may vary in size."""
+
+ def __init__(self, element_names, element_sizes):
+ """Initializes a new `RaggedNamedAxis` instance.
+
+ Args:
+ element_names: A list or array containing the element names.
+ element_sizes: A list or array containing the size of each element.
+ """
+ names_to_slices = {}
+ names_to_indices = {}
+
+ offset = 0
+ for name, size in zip(element_names, element_sizes):
+ # Don't add unnamed elements to the dicts.
+ if name:
+ names_to_slices[name] = slice(offset, offset + size)
+ names_to_indices[name] = range(offset, offset + size)
+ offset += size
+
+ self._names = element_names
+ self._sizes = element_sizes
+ self._names_to_slices = names_to_slices
+ self._names_to_indices = names_to_indices
+
+ def convert_key_item(self, key):
+ """Converts a named indexing expression to a numpy-friendly index."""
+
+ _validate_key_item(key)
+
+ if isinstance(key, six.string_types):
+ key = self._names_to_slices[util.to_native_string(key)]
+
+ elif isinstance(key, (list, np.ndarray)):
+ # We assume that either all or none of the items in the sequence are
+ # strings representing names. If there is a mix, we will let NumPy throw
+ # an error when trying to index with the returned key.
+ if isinstance(key[0], six.string_types):
+ new_key = []
+ for k in key:
+ idx = self._names_to_indices[util.to_native_string(k)]
+ if isinstance(idx, int):
+ new_key.append(idx)
+ else:
+ new_key.extend(idx)
+ key = new_key
+
+ return key
+
+ @property
+ def names(self):
+ """Returns a list of element names."""
+ return self._names
+
+
+Axes = collections.namedtuple('Axes', ['row', 'col'])
+Axes.__new__.__defaults__ = (None,) # Default value for optional 'col' field
+
+
+class FieldIndexer(object):
+ """An array-like object providing named access to a field in a MuJoCo struct.
+
+ FieldIndexers expose the same attributes and methods as an `np.ndarray`.
+
+ They may be indexed with strings or lists of strings corresponding to element
+ names. They also support standard numpy indexing expressions, with the
+ exception of indices containing `Ellipsis` or `None`.
+ """
+
+ __slots__ = ('_field_name', '_field', '_axes')
+
+ def __init__(self,
+ parent_struct,
+ field_name,
+ axis_indexers):
+ """Initializes a new `FieldIndexer`.
+
+ Args:
+ parent_struct: Wrapped ctypes structure, as generated by `mjbindings`.
+ field_name: String containing field name in `parent_struct`.
+ axis_indexers: A list of `Axis` instances, one per dimension.
+ """
+ self._field_name = field_name
+ self._field = weakref.proxy(getattr(parent_struct, field_name))
+ self._axes = Axes(*axis_indexers)
+
+ def __dir__(self):
+ # Enables IPython tab completion
+ return sorted(set(dir(type(self)) + dir(self._field)))
+
+ def __getattr__(self, name):
+ return getattr(self._field, name)
+
+ def _convert_key(self, key):
+ """Convert a (possibly named) indexing expression to a valid numpy index."""
+ return_tuple = isinstance(key, tuple)
+ if not return_tuple:
+ key = (key,)
+ if len(key) > self._field.ndim:
+ raise IndexError('Index tuple has {} elements, but array has only {} '
+ 'dimensions.'.format(len(key), self._field.ndim))
+ new_key = tuple(axis.convert_key_item(key_item)
+ for axis, key_item in zip(self._axes, key))
+ if not return_tuple:
+ new_key = new_key[0]
+ return new_key
+
+ def __getitem__(self, key):
+ """Converts the key to a numeric index and returns the indexed array.
+
+ Args:
+ key: Indexing expression.
+
+ Raises:
+ IndexError: If an indexing tuple has too many elements, or if it contains
+ `Ellipsis`, `None`, or an empty string.
+
+ Returns:
+ The indexed array.
+ """
+ return self._field[self._convert_key(key)]
+
+ def __setitem__(self, key, value):
+ """Converts the key and assigns to the indexed array.
+
+ Args:
+ key: Indexing expression.
+ value: Value to assign.
+
+ Raises:
+ IndexError: If an indexing tuple has too many elements, or if it contains
+ `Ellipsis`, `None`, or an empty string.
+ """
+ self._field[self._convert_key(key)] = value
+
+ @property
+ def axes(self):
+ """A namedtuple containing the row and column indexers for this field."""
+ return self._axes
+
+ def __repr__(self):
+ """Returns a pretty string representation of the `FieldIndexer`."""
+
+ def get_name_arr_and_len(dim_idx):
+ """Returns a string array of element names and the max name length."""
+ axis = self._axes[dim_idx]
+ size = self._field.shape[dim_idx]
+ try:
+ name_len = max(len(name) for name in axis.names)
+ name_arr = np.zeros(size, dtype='S{}'.format(name_len))
+ for name in axis.names:
+ if name:
+ # Use the `Axis` object to convert the name into a numpy index, then
+ # use this index to write into name_arr.
+ name_arr[axis.convert_key_item(name)] = name
+ except AttributeError:
+ name_arr = np.zeros(size, dtype='S0') # An array of zero-length strings
+ name_len = 0
+ return name_arr, name_len
+
+ row_name_arr, row_name_len = get_name_arr_and_len(0)
+ if self._field.ndim > 1:
+ col_name_arr, col_name_len = get_name_arr_and_len(1)
+ else:
+ col_name_arr, col_name_len = np.zeros(1, dtype='S0'), 0
+
+ idx_len = int(np.log10(max(self._field.shape[0], 1))) + 1
+
+ cls_template = '{class_name:}({field_name:}):'
+ col_template = '{padding:}{col_names:}'
+ row_template = '{idx:{idx_len:}} {row_name:>{row_name_len:}} {row_vals:}'
+
+ lines = []
+
+ # Write the class name and field name.
+ lines.append(cls_template.format(class_name=self.__class__.__name__,
+ field_name=self._field_name))
+
+ # Write a header line containing the column names (if there are any).
+ if col_name_len:
+ col_width = max(col_name_len, 9) + 1
+ extra_indent = 4
+ padding = ' ' * (idx_len + row_name_len + extra_indent)
+ col_names = ''.join(
+ '{name:<{width:}}'
+ .format(name=util.to_native_string(name), width=col_width)
+ for name in col_name_arr)
+ lines.append(col_template.format(padding=padding, col_names=col_names))
+
+ # Write the row names (if there are any) and the formatted array values.
+ if not self._field.shape[0]:
+ lines.append('(empty)')
+ else:
+ for idx, row in enumerate(self._field):
+ row_vals = np.array2string(
+ np.atleast_1d(row),
+ suppress_small=True,
+ formatter={'float_kind': '{: < 9.3g}'.format})
+ lines.append(row_template.format(
+ idx=idx,
+ idx_len=idx_len,
+ row_name=util.to_native_string(row_name_arr[idx]),
+ row_name_len=row_name_len,
+ row_vals=row_vals))
+ return '\n'.join(lines)
+
+
+def struct_indexer(struct, struct_name, size_to_axis_indexer):
+ """Returns a namedtuple with a `FieldIndexer` for each dynamic array field.
+
+ Usage example
+
+ ```python
+ named_data = struct_indexer(mjdata, 'mjdata', size_to_axis_indexer)
+ fingertip_xpos = named_data.xpos['fingertip']
+ elbow_qvel = named_data.qvel['elbow']
+ ```
+
+ Args:
+ struct: Wrapped ctypes structure as generated by `mjbindings`.
+ struct_name: String containing corresponding Mujoco name of struct.
+ size_to_axis_indexer: dict that maps size names to `Axis` instances.
+
+ Returns:
+ A `namedtuple` with a field for every dynamically sized array field mapping
+ to a `FieldIndexer`.
+
+ Raises:
+ ValueError: If `struct_name` is not recognized.
+ """
+ struct_name = struct_name.lower()
+ if struct_name not in sizes.array_sizes:
+ raise ValueError('Unrecognized struct name ' + struct_name)
+
+ array_sizes = sizes.array_sizes[struct_name]
+
+ # Used to create the namedtuple.
+ field_names = []
+ field_indexers = {}
+
+ for field_name in array_sizes:
+
+ # Skip over fields that have sizes but aren't numpy arrays, such as text
+ # fields and contacts (b/34805932).
+ if not isinstance(getattr(struct, field_name), np.ndarray):
+ continue
+
+ size_names = sizes.array_sizes[struct_name][field_name]
+
+ # Here we override the size name in order to enable named column indexing
+ # for certain fields, e.g. 3 becomes "xyz" for field name "xpos".
+ for new_col_size, field_set in six.iteritems(_COLUMN_ID_TO_FIELDS):
+ if field_name in field_set:
+ size_names = (size_names[0], new_col_size)
+ break
+
+ axis_indexers = []
+ for size_name in size_names:
+ axis_indexers.append(size_to_axis_indexer[size_name])
+
+ field_indexers[field_name] = FieldIndexer(
+ parent_struct=struct,
+ field_name=field_name,
+ axis_indexers=axis_indexers)
+
+ field_names.append(field_name)
+
+ struct_indexer_ = collections.namedtuple(struct_name + '_indexer',
+ field_names)
+
+ return struct_indexer_(**field_indexers)
diff --git a/dm_control/mujoco/index_test.py b/dm_control/mujoco/index_test.py
new file mode 100644
index 00000000..bd956d47
--- /dev/null
+++ b/dm_control/mujoco/index_test.py
@@ -0,0 +1,319 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for index."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+
+from absl.testing import absltest
+from absl.testing import parameterized
+
+from dm_control.mujoco import index
+from dm_control.mujoco import wrapper
+from dm_control.mujoco.testing import assets
+from dm_control.mujoco.wrapper.mjbindings import sizes
+
+import numpy as np
+import six
+
+MODEL = assets.get_contents('cartpole.xml')
+MODEL_NO_NAMES = assets.get_contents('cartpole_no_names.xml')
+
+FIELD_REPR = {
+ 'act': ('FieldIndexer(act):\n'
+ '(empty)'),
+ 'qM': ('FieldIndexer(qM):\n'
+ '0 [ 0 ]\n'
+ '1 [ 1 ]\n'
+ '2 [ 2 ]'),
+ 'sensordata': ('FieldIndexer(sensordata):\n'
+ '0 accelerometer [ 0 ]\n'
+ '1 accelerometer [ 1 ]\n'
+ '2 accelerometer [ 2 ]\n'
+ '3 collision [ 3 ]'),
+ 'xpos': ('FieldIndexer(xpos):\n'
+ ' x y z \n'
+ '0 world [ 0 1 2 ]\n'
+ '1 cart [ 3 4 5 ]\n'
+ '2 pole [ 6 7 8 ]\n'
+ '3 mocap1 [ 9 10 11 ]\n'
+ '4 mocap2 [ 12 13 14 ]'),
+}
+
+
+class MujocoIndexTest(parameterized.TestCase):
+
+ def setUp(self):
+ self._model = wrapper.MjModel.from_xml_string(MODEL)
+ self._data = wrapper.MjData(self._model)
+
+ self._size_to_axis_indexer = index.make_axis_indexers(self._model)
+
+ self._model_indexers = index.struct_indexer(self._model, 'mjmodel',
+ self._size_to_axis_indexer)
+
+ self._data_indexers = index.struct_indexer(self._data, 'mjdata',
+ self._size_to_axis_indexer)
+
+ def assertIndexExpressionEqual(self, expected, actual):
+ try:
+ if isinstance(expected, tuple):
+ self.assertEqual(len(expected), len(actual))
+ for expected_item, actual_item in zip(expected, actual):
+ self.assertIndexExpressionEqual(expected_item, actual_item)
+ elif isinstance(expected, (list, np.ndarray)):
+ np.testing.assert_array_equal(expected, actual)
+ else:
+ self.assertEqual(expected, actual)
+ except AssertionError:
+ self.fail('Indexing expressions are not equal.\n'
+ 'expected: {!r}\nactual: {!r}'.format(expected, actual))
+
+ @parameterized.parameters(
+ # (field name, named index key, expected integer index key)
+ ('actuator_gear', 'slide', 0),
+ ('dof_armature', 'slider', slice(0, 1, None)),
+ ('dof_armature', ['slider', 'hinge'], [0, 1]),
+ ('numeric_data', 'three_numbers', slice(1, 4, None)),
+ ('numeric_data', ['three_numbers', 'control_timestep'], [1, 2, 3, 0]))
+ def testModelNamedIndexing(self, field_name, key, numeric_key):
+
+ indexer = getattr(self._model_indexers, field_name)
+ field = getattr(self._model, field_name)
+
+ converted_key = indexer._convert_key(key)
+
+ # Explicit check that the converted key matches the numeric key.
+ converted_key = indexer._convert_key(key)
+ self.assertIndexExpressionEqual(numeric_key, converted_key)
+
+ # This writes unique values to the underlying buffer to prevent false
+ # negatives.
+ field.flat[:] = np.arange(field.size)
+
+ # Check that the result of named indexing matches the result of numeric
+ # indexing.
+ np.testing.assert_array_equal(field[numeric_key], indexer[key])
+
+ @parameterized.parameters(
+ # (field name, named index key, expected integer index key)
+ ('xpos', 'pole', 2),
+ ('xpos', ['pole', 'cart'], [2, 1]),
+ ('sensordata', 'accelerometer', slice(0, 3, None)),
+ ('sensordata', 'collision', slice(3, 4, None)),
+ ('sensordata', ['accelerometer', 'collision'], [0, 1, 2, 3]),
+ # Slices.
+ ('xpos', (slice(None), 0), (slice(None), 0)),
+ # Custom fixed-size columns.
+ ('xpos', ('pole', 'y'), (2, 1)),
+ ('xmat', ('cart', ['yy', 'zz']), (1, [4, 8])),
+ # Custom indexers for mocap bodies.
+ ('mocap_quat', 'mocap1', 0),
+ ('mocap_pos', (['mocap2', 'mocap1'], 'z'), ([1, 0], 2)),
+ # Two-dimensional named indexing.
+ ('xpos', (['pole', 'cart'], ['x', 'z']), ([2, 1], [0, 2])),
+ ('xpos', ([['pole'], ['cart']], ['x', 'z']), ([[2], [1]], [0, 2])))
+ def testDataNamedIndexing(self, field_name, key, numeric_key):
+
+ indexer = getattr(self._data_indexers, field_name)
+ field = getattr(self._data, field_name)
+
+ # Explicit check that the converted key matches the numeric key.
+ converted_key = indexer._convert_key(key)
+ self.assertIndexExpressionEqual(numeric_key, converted_key)
+
+ # This writes unique values to the underlying buffer to prevent false
+ # negatives.
+ field.flat[:] = np.arange(field.size)
+
+ # Check that the result of named indexing matches the result of numeric
+ # indexing.
+ np.testing.assert_array_equal(field[numeric_key], indexer[key])
+
+ @parameterized.parameters(
+ # (field name, named index key)
+ ('xpos', 'pole'),
+ ('xpos', ['pole', 'cart']),
+ ('xpos', (['pole', 'cart'], 'y')),
+ ('xpos', (['pole', 'cart'], ['x', 'z'])),
+ ('qpos', 'slider'),
+ ('qvel', ['slider', 'hinge']),)
+ def testDataAssignment(self, field_name, key):
+
+ indexer = getattr(self._data_indexers, field_name)
+ field = getattr(self._data, field_name)
+
+ # The result of the indexing expression is either an array or a scalar.
+ index_result = indexer[key]
+ try:
+ # Write a sequence of unique values to prevent false negatives.
+ new_values = np.arange(index_result.size).reshape(index_result.shape)
+ except AttributeError:
+ new_values = 99
+ indexer[key] = new_values
+
+ # Check that the new value(s) can be read back from the underlying buffer.
+ converted_key = indexer._convert_key(key)
+ np.testing.assert_array_equal(new_values, field[converted_key])
+
+ @parameterized.parameters(
+ # (field name, first index key, second index key)
+ ('sensordata', 'accelerometer', 0),
+ ('sensordata', 'accelerometer', [0, 2]),
+ ('sensordata', 'accelerometer', slice(None)),)
+ def testChainedAssignment(self, field_name, first_key, second_key):
+
+ indexer = getattr(self._data_indexers, field_name)
+ field = getattr(self._data, field_name)
+
+ # The result of the indexing expression is either an array or a scalar.
+ index_result = indexer[first_key][second_key]
+ try:
+ # Write a sequence of unique values to prevent false negatives.
+ new_values = np.arange(index_result.size).reshape(index_result.shape)
+ except AttributeError:
+ new_values = 99
+ indexer[first_key][second_key] = new_values
+
+ # Check that the new value(s) can be read back from the underlying buffer.
+ converted_key = indexer._convert_key(first_key)
+ np.testing.assert_array_equal(new_values, field[converted_key][second_key])
+
+ def testNamedColumnFieldNames(self):
+
+ all_fields = set()
+ for struct in six.itervalues(sizes.array_sizes):
+ all_fields.update(struct.keys())
+
+ named_col_fields = set()
+ for field_set in six.itervalues(index._COLUMN_ID_TO_FIELDS):
+ named_col_fields.update(field_set)
+
+ # Check that all of the "named column" fields specified in index are
+ # also found in mjbindings.sizes.
+ self.assertContainsSubset(named_col_fields, all_fields)
+
+ @parameterized.parameters('xpos', 'xmat') # field name
+ def testTooManyIndices(self, field_name):
+ indexer = getattr(self._data_indexers, field_name)
+ with self.assertRaisesRegexp(IndexError, 'Index tuple'):
+ _ = indexer[:, :, :, 'too', 'many', 'elements']
+
+ @parameterized.parameters(
+ # bad item, exception regexp
+ (Ellipsis, 'Ellipsis'),
+ (None, 'None'),
+ (np.newaxis, 'None'),
+ (b'', 'Empty string'),
+ (u'', 'Empty string'))
+ def testBadIndexItems(self, bad_index_item, exception_regexp):
+ indexer = getattr(self._data_indexers, 'xpos')
+ expressions = [
+ bad_index_item,
+ (0, bad_index_item),
+ [bad_index_item],
+ [[bad_index_item]],
+ (0, [bad_index_item]),
+ (0, [[bad_index_item]]),
+ np.array([bad_index_item]),
+ (0, np.array([bad_index_item])),
+ (0, np.array([[bad_index_item]]))
+ ]
+ for expression in expressions:
+ with self.assertRaisesRegexp(IndexError, exception_regexp):
+ _ = indexer[expression]
+
+ @parameterized.parameters('act', 'qM', 'sensordata', 'xpos') # field name
+ def testFieldIndexerRepr(self, field_name):
+
+ indexer = getattr(self._data_indexers, field_name)
+ field = getattr(self._data, field_name)
+
+ # Write a sequence of unique values to prevent false negatives.
+ field.flat[:] = np.arange(field.size)
+
+ # Check that the string representation is as expected.
+ self.assertEqual(FIELD_REPR[field_name], repr(indexer))
+
+ @parameterized.parameters(MODEL, MODEL_NO_NAMES)
+ def testBuildIndexersForEdgeCases(self, xml_string):
+ model = wrapper.MjModel.from_xml_string(xml_string)
+ data = wrapper.MjData(model)
+
+ size_to_axis_indexer = index.make_axis_indexers(model)
+
+ index.struct_indexer(model, 'mjmodel', size_to_axis_indexer)
+ index.struct_indexer(data, 'mjdata', size_to_axis_indexer)
+
+ @parameterized.parameters(
+ name for name in dir(np.ndarray)
+ if not name.startswith('_') # Exclude 'private' attributes
+ and name not in ('ctypes', 'flat') # Can't compare via identity/equality
+ )
+ def testFieldIndexerDelegatesNDArrayAttributes(self, name):
+ field = self._data.xpos
+ field_indexer = self._data_indexers.xpos
+ actual = getattr(field_indexer, name)
+ expected = getattr(field, name)
+ if isinstance(expected, np.ndarray):
+ np.testing.assert_array_equal(actual, expected)
+ else:
+ self.assertEqual(actual, expected)
+
+ # FieldIndexer attributes should be read-only
+ with self.assertRaisesRegexp(AttributeError, name):
+ setattr(field_indexer, name, expected)
+
+ def testFieldIndexerDir(self):
+ expected_subset = dir(self._data.xpos)
+ actual_set = dir(self._data_indexers.xpos)
+ self.assertContainsSubset(expected_subset, actual_set)
+
+
+def _iter_indexers(model, data):
+ size_to_axis_indexer = index.make_axis_indexers(model)
+ for struct, struct_name in ((model, 'mjmodel'), (data, 'mjdata')):
+ indexer = index.struct_indexer(struct, struct_name, size_to_axis_indexer)
+ for field_name, field_indexer in six.iteritems(indexer._asdict()):
+ yield field_name, field_indexer
+
+
+class AllFieldsTest(parameterized.TestCase):
+ """Generic tests covering each FieldIndexer in model and data."""
+
+ # NB: the class must hold references to the model and data instances or they
+ # may be garbage-collected before any indexing is attempted.
+ model = wrapper.MjModel.from_xml_string(MODEL)
+ data = wrapper.MjData(model)
+
+ # Iterates over ('field_name', FieldIndexer) pairs
+ @parameterized.named_parameters(_iter_indexers(model, data))
+ def testReadWrite_(self, field):
+ # Read the contents of the FieldIndexer as a numpy array.
+ old_contents = field[:]
+ # Write unique values to the FieldIndexer and read them back again.
+ # Don't write to non-float fields since these might contain pointers.
+ if np.issubdtype(old_contents.dtype, float):
+ new_contents = np.arange(old_contents.size, dtype=old_contents.dtype)
+ new_contents.shape = old_contents.shape
+ field[:] = new_contents
+ np.testing.assert_array_equal(new_contents, field[:])
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/mujoco/testing/__init__.py b/dm_control/mujoco/testing/__init__.py
new file mode 100644
index 00000000..1ebb270f
--- /dev/null
+++ b/dm_control/mujoco/testing/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
diff --git a/dm_control/mujoco/testing/assets/__init__.py b/dm_control/mujoco/testing/assets/__init__.py
new file mode 100644
index 00000000..bafa2344
--- /dev/null
+++ b/dm_control/mujoco/testing/assets/__init__.py
@@ -0,0 +1,32 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Assets used for testing the MuJoCo bindings."""
+
+import os
+
+from dm_control.utils import resources
+
+_ASSETS_DIR = os.path.dirname(__file__)
+
+
+def get_contents(filename):
+ """Returns the contents of an asset as a string."""
+ return resources.GetResource(os.path.join(_ASSETS_DIR, filename))
+
+
+def get_path(filename):
+ """Returns the path to an asset."""
+ return resources.GetResourceFilename(os.path.join(_ASSETS_DIR, filename))
diff --git a/dm_control/mujoco/testing/assets/cartpole.xml b/dm_control/mujoco/testing/assets/cartpole.xml
new file mode 100644
index 00000000..b15370be
--- /dev/null
+++ b/dm_control/mujoco/testing/assets/cartpole.xml
@@ -0,0 +1,69 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/mujoco/testing/assets/cartpole_no_names.xml b/dm_control/mujoco/testing/assets/cartpole_no_names.xml
new file mode 100644
index 00000000..73d78eb3
--- /dev/null
+++ b/dm_control/mujoco/testing/assets/cartpole_no_names.xml
@@ -0,0 +1,62 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/mujoco/testing/assets/cube.stl b/dm_control/mujoco/testing/assets/cube.stl
new file mode 100644
index 00000000..a5bc8256
Binary files /dev/null and b/dm_control/mujoco/testing/assets/cube.stl differ
diff --git a/dm_control/mujoco/testing/assets/deepmind.png b/dm_control/mujoco/testing/assets/deepmind.png
new file mode 100644
index 00000000..1586759c
Binary files /dev/null and b/dm_control/mujoco/testing/assets/deepmind.png differ
diff --git a/dm_control/mujoco/testing/assets/humanoid.xml b/dm_control/mujoco/testing/assets/humanoid.xml
new file mode 100644
index 00000000..93590881
--- /dev/null
+++ b/dm_control/mujoco/testing/assets/humanoid.xml
@@ -0,0 +1,121 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/mujoco/testing/assets/model_with_assets.xml b/dm_control/mujoco/testing/assets/model_with_assets.xml
new file mode 100644
index 00000000..adfd45b0
--- /dev/null
+++ b/dm_control/mujoco/testing/assets/model_with_assets.xml
@@ -0,0 +1,13 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/mujoco/testing/assets/sphere.xml b/dm_control/mujoco/testing/assets/sphere.xml
new file mode 100644
index 00000000..09991911
--- /dev/null
+++ b/dm_control/mujoco/testing/assets/sphere.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
diff --git a/dm_control/mujoco/testing/decorators.py b/dm_control/mujoco/testing/decorators.py
new file mode 100644
index 00000000..78086c2c
--- /dev/null
+++ b/dm_control/mujoco/testing/decorators.py
@@ -0,0 +1,65 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Decorators used in MuJoCo tests."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+import threading
+
+# Internal dependencies.
+
+import six
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+
+def run_threaded(num_threads=4, calls_per_thread=10):
+ """A decorator that executes the same test repeatedly in multiple threads.
+
+ Note: `setUp` and `tearDown` methods will only be called once from the main
+ thread, so all thread-local setup must be done within the test method.
+
+ Args:
+ num_threads: Number of concurrent threads to spawn.
+ calls_per_thread: Number of times each thread should call the test method.
+ Returns:
+ Decorated test method.
+ """
+ def decorator(test_method):
+ """Decorator around the test method."""
+ def decorated_method(self, *args, **kwargs):
+ """Actual method this factory will return."""
+ exceptions = []
+ def worker():
+ try:
+ for _ in xrange(calls_per_thread):
+ test_method(self, *args, **kwargs)
+ except: # pylint: disable=bare-except
+ # Appending to Python list is thread-safe:
+ # http://effbot.org/pyfaq/what-kinds-of-global-value-mutation-are-thread-safe.htm
+ exceptions.append(sys.exc_info())
+ threads = [threading.Thread(target=worker, name='thread_{}'.format(i))
+ for i in xrange(num_threads)]
+ for thread in threads:
+ thread.start()
+ for thread in threads:
+ thread.join()
+ for exc_class, old_exc, tb in exceptions:
+ six.reraise(exc_class, old_exc, tb)
+ return decorated_method
+ return decorator
diff --git a/dm_control/mujoco/testing/decorators_test.py b/dm_control/mujoco/testing/decorators_test.py
new file mode 100644
index 00000000..bfd86280
--- /dev/null
+++ b/dm_control/mujoco/testing/decorators_test.py
@@ -0,0 +1,65 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests of the decorators module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+
+from absl.testing import absltest
+
+from dm_control.mujoco.testing import decorators
+import mock
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+
+class RunThreadedTest(absltest.TestCase):
+
+ @mock.patch(decorators.__name__ + ".threading")
+ def test_number_of_threads(self, mock_threading):
+ num_threads = 5
+
+ mock_threads = [mock.MagicMock() for _ in xrange(num_threads)]
+ for thread in mock_threads:
+ thread.start = mock.MagicMock()
+ thread.join = mock.MagicMock()
+
+ mock_threading.Thread = mock.MagicMock(side_effect=mock_threads)
+
+ test_decorator = decorators.run_threaded(num_threads=num_threads)
+ test_runner = test_decorator(mock.MagicMock())
+ test_runner(self)
+
+ for thread in mock_threads:
+ thread.start.assert_called_once()
+ thread.join.assert_called_once()
+
+ def test_number_of_iterations(self):
+ calls_per_thread = 5
+
+ tested_method = mock.MagicMock()
+ test_decorator = decorators.run_threaded(
+ num_threads=1, calls_per_thread=calls_per_thread)
+ test_runner = test_decorator(tested_method)
+ test_runner(self)
+
+ self.assertEqual(calls_per_thread, tested_method.call_count)
+
+
+if __name__ == "__main__":
+ absltest.main()
diff --git a/dm_control/mujoco/testing/memory_checker.py b/dm_control/mujoco/testing/memory_checker.py
new file mode 100644
index 00000000..84388eb9
--- /dev/null
+++ b/dm_control/mujoco/testing/memory_checker.py
@@ -0,0 +1,102 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""A context manager for checking that MuJoCo memory is freed."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+import ctypes
+
+# Internal dependencies.
+
+from dm_control.mujoco.wrapper.mjbindings import mjlib
+import six
+
+
+# Used for overriding MuJoCo's memory handlers
+_LIBC = ctypes.cdll.LoadLibrary("libc.so.6")
+_LIBC.aligned_alloc.argtypes = [ctypes.c_size_t, ctypes.c_size_t]
+_LIBC.aligned_alloc.restype = ctypes.c_void_p
+_LIBC.free.argtypes = [ctypes.c_void_p]
+_LIBC.free.restype = None
+
+# MuJoCo normally pads and aligns memory to multiples of 8 bytes
+_BYTE_ALIGNMENT = 8
+
+# Expose pointers to custom memory handlers.
+mjlib.mju_user_malloc = ctypes.c_void_p.in_dll(mjlib, "mju_user_malloc")
+mjlib.mju_user_free = ctypes.c_void_p.in_dll(mjlib, "mju_user_free")
+
+
+@contextlib.contextmanager
+def assert_mujoco_memory_freed():
+ """Context manager for debugging memory leaks in MuJoCo.
+
+ Yields:
+ None
+
+ Raises:
+ AssertionError: If MuJoCo heap-allocated any memory inside the context
+ manager without freeing it.
+ """
+
+ # NB: The custom memory handlers need to use libc's `aligned_alloc` and `free`
+ # rather than `mju_malloc` and `mju_free`, since these will delegate to
+ # `mju_user_malloc` and `mju_user_free` if they are not NULL.
+
+ remaining_pointers = {}
+
+ @ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.c_size_t)
+ def debug_malloc(size):
+ if size % _BYTE_ALIGNMENT:
+ size += _BYTE_ALIGNMENT - (size % _BYTE_ALIGNMENT)
+ address = _LIBC.aligned_alloc(_BYTE_ALIGNMENT, size)
+ remaining_pointers[address] = size
+ return address
+
+ @ctypes.CFUNCTYPE(None, ctypes.c_void_p)
+ def debug_free(address):
+ _LIBC.free(address)
+ # Allow freeing of arrays that were allocated outside of the context.
+ remaining_pointers.pop(address, None)
+
+ # Keep the old pointer addresses in case there were already custom memory
+ # handling callbacks defined.
+ old_user_malloc_ptr_value = mjlib.mju_user_malloc.value
+ old_user_free_ptr_value = mjlib.mju_user_free.value
+
+ # Set the new callbacks.
+ mjlib.mju_user_malloc.value = ctypes.cast(debug_malloc, ctypes.c_void_p).value
+ mjlib.mju_user_free.value = ctypes.cast(debug_free, ctypes.c_void_p).value
+
+ try:
+ yield
+ finally:
+ # Make sure we reset the memory handlers, even if an exception is raised.
+ mjlib.mju_user_malloc.value = old_user_malloc_ptr_value
+ mjlib.mju_user_free.value = old_user_free_ptr_value
+
+ if remaining_pointers:
+ n_not_freed = len(remaining_pointers)
+ n_bytes_leaked = sum(six.itervalues(remaining_pointers))
+ details_str = "\n".join(
+ "address: {} size: {} B".format(address, size)
+ for address, size in six.iteritems(remaining_pointers))
+ raise AssertionError(
+ "MuJoCo failed to free {} arrays with a total size of {} B:\n{}"
+ .format(n_not_freed, n_bytes_leaked, details_str))
diff --git a/dm_control/mujoco/testing/memory_checker_test.py b/dm_control/mujoco/testing/memory_checker_test.py
new file mode 100644
index 00000000..6e9de8c7
--- /dev/null
+++ b/dm_control/mujoco/testing/memory_checker_test.py
@@ -0,0 +1,104 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+
+"""Tests for memory_checker."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import ctypes
+
+# Internal dependencies.
+
+from absl.testing import absltest
+
+from dm_control.mujoco.testing import memory_checker
+from dm_control.mujoco.wrapper.mjbindings import mjlib
+
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+
+class MemoryCheckingContextTest(absltest.TestCase):
+
+ n_arrays = 3
+ bytes_per_array = 13
+ bytes_per_array_padded = 16
+
+ def test_enter_and_exit(self):
+ """Check that pointers are set and reset correctly on entry and exit."""
+ mju_user_malloc = mjlib.mju_user_malloc
+ mju_user_free = mjlib.mju_user_free
+ ptr_value = lambda func: ctypes.cast(func, ctypes.c_void_p).value
+
+ old_mju_user_malloc_ptr_value = ptr_value(mju_user_malloc)
+ old_mju_user_free_ptr_value = ptr_value(mju_user_free)
+
+ class DummyError(RuntimeError):
+ pass
+
+ try:
+ with memory_checker.assert_mujoco_memory_freed():
+ self.assertNotEqual(
+ ptr_value(mju_user_malloc), old_mju_user_malloc_ptr_value)
+ self.assertNotEqual(
+ ptr_value(mju_user_free), old_mju_user_free_ptr_value)
+ raise DummyError("Simulating an exception inside the context manager")
+ except DummyError:
+ pass
+
+ # Check that the pointers to the custom memory handlers were reset as we
+ # exited, even though an exception occurred inside the context.
+ self.assertEqual(ptr_value(mju_user_malloc), old_mju_user_malloc_ptr_value)
+ self.assertEqual(ptr_value(mju_user_free), old_mju_user_free_ptr_value)
+
+ def test_allocate_and_free_inside(self):
+ """Allocating and freeing inside shouldn't raise any exceptions."""
+ with memory_checker.assert_mujoco_memory_freed():
+ allocated = [
+ mjlib.mju_malloc(self.bytes_per_array)
+ for _ in xrange(self.n_arrays)
+ ]
+ for ptr in allocated:
+ mjlib.mju_free(ptr)
+
+ def test_allocate_outside_free_inside(self):
+ """Allocating outside and freeing inside shouldn't raise any exceptions."""
+ allocated = [
+ mjlib.mju_malloc(self.bytes_per_array)
+ for _ in xrange(self.n_arrays)
+ ]
+ with memory_checker.assert_mujoco_memory_freed():
+ for ptr in allocated:
+ mjlib.mju_free(ptr)
+
+ def test_allocate_inside_free_outside(self):
+ """Allocating inside and freeing outside should raise an AssertionError."""
+ with self.assertRaisesRegexp(
+ AssertionError,
+ "MuJoCo failed to free {} arrays with a total size of {} B"
+ .format(self.n_arrays, self.n_arrays * self.bytes_per_array_padded)):
+ with memory_checker.assert_mujoco_memory_freed():
+ allocated = [
+ mjlib.mju_malloc(self.bytes_per_array)
+ for _ in xrange(self.n_arrays)
+ ]
+ for ptr in allocated:
+ mjlib.mju_free(ptr)
+
+
+if __name__ == "__main__":
+ absltest.main()
diff --git a/dm_control/mujoco/thread_safety_test.py b/dm_control/mujoco/thread_safety_test.py
new file mode 100644
index 00000000..81e83480
--- /dev/null
+++ b/dm_control/mujoco/thread_safety_test.py
@@ -0,0 +1,96 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests to check whether methods of `mujoco.Physics` are threadsafe."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+
+from absl.testing import absltest
+
+from dm_control.mujoco import engine
+from dm_control.mujoco.testing import assets
+from dm_control.mujoco.testing import decorators
+
+MODEL = assets.get_contents('cartpole.xml')
+NUM_STEPS = 10
+
+
+class ThreadSafetyTest(absltest.TestCase):
+
+ @decorators.run_threaded()
+ def test_load_physics_from_string(self):
+ engine.Physics.from_xml_string(MODEL)
+
+ @decorators.run_threaded()
+ def test_load_and_reload_physics_from_string(self):
+ physics = engine.Physics.from_xml_string(MODEL)
+ physics.reload_from_xml_string(MODEL)
+
+ @decorators.run_threaded()
+ def test_load_and_step_physics(self):
+ physics = engine.Physics.from_xml_string(MODEL)
+ for _ in xrange(NUM_STEPS):
+ physics.step()
+
+ @decorators.run_threaded()
+ def test_load_and_step_multiple_physics_parallel(self):
+ physics1 = engine.Physics.from_xml_string(MODEL)
+ physics2 = engine.Physics.from_xml_string(MODEL)
+ for _ in xrange(NUM_STEPS):
+ physics1.step()
+ physics2.step()
+
+ @decorators.run_threaded()
+ def test_load_and_step_multiple_physics_sequential(self):
+ physics1 = engine.Physics.from_xml_string(MODEL)
+ for _ in xrange(NUM_STEPS):
+ physics1.step()
+ del physics1
+ physics2 = engine.Physics.from_xml_string(MODEL)
+ for _ in xrange(NUM_STEPS):
+ physics2.step()
+
+ @decorators.run_threaded(calls_per_thread=5)
+ def test_load_physics_and_render(self):
+ physics = engine.Physics.from_xml_string(MODEL)
+
+ # Check that frames aren't repeated - make the cartpole move.
+ physics.set_control([1.0])
+
+ unique_frames = set()
+ for _ in xrange(NUM_STEPS):
+ physics.step()
+ frame = physics.render(width=320, height=240, camera_id=0)
+ unique_frames.add(frame.tostring())
+
+ self.assertEqual(NUM_STEPS, len(unique_frames))
+
+ @decorators.run_threaded(calls_per_thread=5)
+ def test_render_multiple_physics_instances_per_thread_parallel(self):
+ physics1 = engine.Physics.from_xml_string(MODEL)
+ physics2 = engine.Physics.from_xml_string(MODEL)
+ for _ in xrange(NUM_STEPS):
+ physics1.step()
+ physics1.render(width=320, height=240, camera_id=0)
+ physics2.step()
+ physics2.render(width=320, height=240, camera_id=0)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/mujoco/wrapper/README.md b/dm_control/mujoco/wrapper/README.md
new file mode 100644
index 00000000..c5ec64cd
--- /dev/null
+++ b/dm_control/mujoco/wrapper/README.md
@@ -0,0 +1,6 @@
+# MuJoCo Wrapper
+This package contains Python bindings for the [MuJoCo physics engine][1] using
+[`ctypes`][2]. The bindings and some higher-level wrapper code can be
+automatically generated by parsing MuJoCo's header files.
+
+See [package documentation](/third_party/py/dm_control/mujoco/wrapper).
diff --git a/dm_control/mujoco/wrapper/__init__.py b/dm_control/mujoco/wrapper/__init__.py
new file mode 100644
index 00000000..7ec384f7
--- /dev/null
+++ b/dm_control/mujoco/wrapper/__init__.py
@@ -0,0 +1,41 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Python bindings and wrapper classes for MuJoCo."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from dm_control.mujoco.wrapper import mjbindings
+
+from dm_control.mujoco.wrapper.core import callback_context
+
+from dm_control.mujoco.wrapper.core import Error
+
+from dm_control.mujoco.wrapper.core import get_schema
+
+from dm_control.mujoco.wrapper.core import MjData
+from dm_control.mujoco.wrapper.core import MjModel
+from dm_control.mujoco.wrapper.core import MjrContext
+from dm_control.mujoco.wrapper.core import MjvCamera
+from dm_control.mujoco.wrapper.core import MjvFigure
+from dm_control.mujoco.wrapper.core import MjvOption
+from dm_control.mujoco.wrapper.core import MjvPerturb
+from dm_control.mujoco.wrapper.core import MjvScene
+
+from dm_control.mujoco.wrapper.core import save_last_parsed_model_to_xml
+from dm_control.mujoco.wrapper.core import set_callback
diff --git a/dm_control/mujoco/wrapper/core.py b/dm_control/mujoco/wrapper/core.py
new file mode 100644
index 00000000..64c8a7f3
--- /dev/null
+++ b/dm_control/mujoco/wrapper/core.py
@@ -0,0 +1,745 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Main user-facing classes and utility functions for loading MuJoCo models."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+import ctypes
+import os
+import weakref
+
+# Internal dependencies.
+
+from absl import logging
+
+from dm_control.mujoco.wrapper import util
+from dm_control.mujoco.wrapper.mjbindings import constants
+from dm_control.mujoco.wrapper.mjbindings import enums
+from dm_control.mujoco.wrapper.mjbindings import functions
+from dm_control.mujoco.wrapper.mjbindings import mjlib
+from dm_control.mujoco.wrapper.mjbindings import types
+from dm_control.mujoco.wrapper.mjbindings import wrappers
+
+import six
+
+_NULL = b"\00"
+_FAKE_XML_FILENAME = b"model.xml"
+_FAKE_BINARY_FILENAME = b"model.mjb"
+
+# Global cache used to store finalizers for freeing ctypes pointers.
+# Contains {pointer_address: weakref_object} pairs.
+_FINALIZERS = {}
+
+# Cache of ctypes-wrapped Python callback functions that are called from C. We
+# need to retain references to all wrapped Python callbacks that are currently
+# in use, otherwise they might be garbage collected before they are called.
+_ACTIVE_PYTHON_CALLBACKS = {}
+
+
+class Error(Exception):
+ """Base class for MuJoCo exceptions."""
+ pass
+
+
+if constants.mjVERSION_HEADER != mjlib.mj_version():
+ raise Error("MuJoCo library version ({0}) does not match header version "
+ "({1})".format(constants.mjVERSION_HEADER, mjlib.mj_version()))
+
+_REGISTERED = False
+_ERROR_BUFSIZE = 1000
+
+# This is used to keep track of the `MJMODEL` pointer that was most recently
+# loaded by `_get_model_ptr_from_xml`. Only this model can be saved to XML.
+_LAST_PARSED_MODEL_PTR = None
+
+_NOT_LAST_PARSED_ERROR = (
+ "Only the model that was most recently loaded from an XML file or string "
+ "can be saved to an XML file.")
+
+
+# NB: Python functions that are called from C are defined at module-level to
+# ensure they won't be garbage-collected before they are called.
+@ctypes.CFUNCTYPE(None, ctypes.c_char_p)
+def _warning_callback(message):
+ logging.warn(util.to_native_string(message))
+
+
+@ctypes.CFUNCTYPE(None, ctypes.c_char_p)
+def _error_callback(message):
+ logging.fatal(util.to_native_string(message))
+
+
+# Override MuJoCo's callbacks for handling warnings and errors.
+mjlib.mju_user_warning = ctypes.c_void_p.in_dll(mjlib, "mju_user_warning")
+mjlib.mju_user_error = ctypes.c_void_p.in_dll(mjlib, "mju_user_error")
+mjlib.mju_user_warning.value = ctypes.cast(
+ _warning_callback, ctypes.c_void_p).value
+mjlib.mju_user_error.value = ctypes.cast(
+ _error_callback, ctypes.c_void_p).value
+
+# Decorator that wraps a Python callback with the signature
+# func(const_mjmodel_ptr, mjdata_ptr) -> None
+# and returns a `ctypes.CFunctionType`.
+_WRAP_PYFUNC = ctypes.CFUNCTYPE(None, ctypes.POINTER(types.MJMODEL),
+ ctypes.POINTER(types.MJDATA))
+
+
+def _maybe_register_license(path=None):
+ """Registers the MuJoCo license if not already registered.
+
+ Args:
+ path: Optional custom path to license key file.
+
+ Raises:
+ Error: If the license could not be registered.
+ """
+ global _REGISTERED
+ if not _REGISTERED:
+ if path is None:
+ path = util.get_mjkey_path()
+ result = mjlib.mj_activate(util.to_binary_string(path))
+ if result == 1:
+ _REGISTERED = True
+ elif result == 0:
+ raise Error("Could not register license.")
+ else:
+ raise Error("Unknown registration error (code: {})".format(result))
+
+
+def _str2type(type_str):
+ type_id = mjlib.mju_str2Type(util.to_binary_string(type_str))
+ if not type_id:
+ raise Error("{!r} is not a valid object type name.".format(type_str))
+ return type_id
+
+
+def _type2str(type_id):
+ type_str_ptr = mjlib.mju_type2Str(type_id)
+ if not type_str_ptr:
+ raise Error("{!r} is not a valid object type ID.".format(type_id))
+ return ctypes.string_at(type_str_ptr)
+
+
+def set_callback(name, new_callback=None):
+ """Sets a user-defined callback function to modify MuJoCo's behavior.
+
+ Callback functions should have the following signature:
+ func(const_mjmodel_ptr, mjdata_ptr) -> None
+
+ Args:
+ name: Name of the callback to set. Must be a field in
+ `functions.function_pointers`.
+ new_callback: The new callback. This can be one of the following:
+ * A Python callable
+ * A C function exposed by a `ctypes.CDLL` object
+ * An integer specifying the address of a callback function
+ * None, in which case any existing callback of that name is removed
+
+ Returns:
+ Either an integer specifying the address of the previous function used for
+ this callback, or None if the callback has not already been overridden.
+
+ Raises:
+ ValueError: If `name` is not in `functions.function_pointers`.
+ """
+ if name not in functions.function_pointers._fields:
+ raise ValueError("Invalid callback name: {!r}. Must be one of {!r}.".format(
+ name, functions.function_pointers._fields))
+ callback_ptr = getattr(functions.function_pointers, name)
+ try:
+ new_callback_ptr = ctypes.cast(new_callback, ctypes.c_void_p)
+ except ctypes.ArgumentError:
+ # Python callables must be wrapped before casting to `ctypes.c_void_p`.
+ wrapped_callback = _WRAP_PYFUNC(new_callback)
+ new_callback_ptr = ctypes.cast(wrapped_callback, ctypes.c_void_p)
+ # We must retain a reference to the wrapped callback function, otherwise it
+ # might be garbage collected before it is called.
+ _ACTIVE_PYTHON_CALLBACKS[new_callback_ptr.value] = wrapped_callback
+
+ old_callback_address = callback_ptr.value
+
+ # If the old callback was a wrapped Python function then we remove it from the
+ # cache of active callbacks so that it can be garbage collected.
+ if old_callback_address in _ACTIVE_PYTHON_CALLBACKS:
+ del _ACTIVE_PYTHON_CALLBACKS[old_callback_address]
+
+ callback_ptr.value = new_callback_ptr.value
+ return old_callback_address
+
+
+@contextlib.contextmanager
+def callback_context(name, new_callback=None):
+ """Context manager that temporarily overrides a MuJoCo callback function.
+
+ On exit, the callback will be restored to its original value (None if the
+ callback was not already overridden when the context was entered).
+
+ Args:
+ name: Name of the callback to set. Must be a field in
+ `mjbindings.function_pointers`.
+ new_callback: The new callback. This can be one of the following:
+ * A Python callable
+ * A C function exposed by a `ctypes.CDLL` object
+ * An integer specifying the address of a callback function
+ * None, in which case any existing callback of that name is removed
+
+ Yields:
+ None
+ """
+ old_callback = set_callback(name, new_callback)
+ try:
+ yield
+ finally:
+ # Ensure that the callback is reset on exit, even if an exception is raised.
+ set_callback(name, old_callback)
+
+
+def get_schema():
+ """Returns a string containing the schema used by the MuJoCo XML parser."""
+ buf = ctypes.create_string_buffer(100000)
+ mjlib.mj_printSchema(None, buf, len(buf), 0, 0)
+ return buf.value
+
+
+@contextlib.contextmanager
+def _temporary_vfs(filenames_and_contents):
+ """Creates a temporary VFS containing one or more files.
+
+ Args:
+ filenames_and_contents: A dict containing `{filename: contents}` pairs.
+
+ Yields:
+ A `types.MJVFS` instance.
+
+ Raises:
+ Error: If a file cannot be added to the VFS, or if an error occurs when
+ looking up the filename.
+ """
+ vfs = types.MJVFS()
+ mjlib.mj_defaultVFS(vfs)
+ for filename, contents in six.iteritems(filenames_and_contents):
+ filename = util.to_binary_string(filename)
+ contents = util.to_binary_string(contents)
+ _, extension = os.path.splitext(filename)
+ # For XML files we need to append a NULL byte, otherwise MuJoCo's parser
+ # can sometimes read past the end of the string. However, we should *not*
+ # do this for other file types (in particular for STL meshes, where this
+ # causes MuJoCo's compiler to complain that the file size is incorrect).
+ append_null = extension.lower() == b".xml"
+ num_bytes = len(contents) + append_null
+ retcode = mjlib.mj_makeEmptyFileVFS(vfs, filename, num_bytes)
+ if retcode == 1:
+ raise Error("Failed to create {!r}: VFS is full.".format(filename))
+ elif retcode == 2:
+ raise Error("Failed to create {!r}: duplicate filename.".format(filename))
+ file_index = mjlib.mj_findFileVFS(vfs, filename)
+ if file_index == -1:
+ raise Error("Could not find {!r} in the VFS".format(filename))
+ vf = vfs.filedata[file_index]
+ vf_as_char_arr = ctypes.cast(vf, ctypes.POINTER(ctypes.c_char * num_bytes))
+ vf_as_char_arr.contents[:len(contents)] = contents
+ if append_null:
+ vf_as_char_arr.contents[-1] = _NULL
+ try:
+ yield vfs
+ finally:
+ mjlib.mj_deleteVFS(vfs) # Ensure that we free the VFS afterwards.
+
+
+def _create_finalizer(ptr, free_func):
+ """Creates a finalizer for a ctypes pointer.
+
+ Args:
+ ptr: A `ctypes.POINTER` to be freed.
+ free_func: A callable that frees the pointer. It will be called with `ptr`
+ as its only argument when `ptr` is garbage collected.
+ """
+ ptr_type = type(ptr)
+ address = ctypes.addressof(ptr)
+
+ if address not in _FINALIZERS: # Only one finalizer needed per address.
+
+ def callback(dead_ptr_ref):
+ del dead_ptr_ref # Unused weakref to the dead ctypes pointer object.
+ # Temporarily resurrect the dead pointer so that we can free it.
+ temp_ptr = ptr_type.from_address(address)
+ logging.debug("Freeing %s", temp_ptr)
+ free_func(temp_ptr)
+ del _FINALIZERS[address] # Remove the weakref from the global cache.
+
+ # Store weakrefs in a global cache so that they don't get garbage collected
+ # before their referents.
+ _FINALIZERS[address] = weakref.ref(ptr, callback)
+
+
+def _load_xml(filename, vfs_or_none):
+ """Invokes `mj_loadXML` with logging/error handling."""
+ error_buf = ctypes.create_string_buffer(_ERROR_BUFSIZE)
+ model_ptr = mjlib.mj_loadXML(
+ util.to_binary_string(filename),
+ vfs_or_none,
+ error_buf,
+ _ERROR_BUFSIZE)
+ if not model_ptr:
+ raise Error(util.to_native_string(error_buf.value))
+ elif error_buf.value:
+ logging.warn(util.to_native_string(error_buf.value))
+
+ # Free resources when the ctypes pointer is garbage collected.
+ _create_finalizer(model_ptr, mjlib.mj_deleteModel)
+
+ return model_ptr
+
+
+def _get_model_ptr_from_xml(xml_path=None, xml_string=None, assets=None):
+ """Parses a model XML file, compiles it, and returns a pointer to an mjModel.
+
+ Args:
+ xml_path: Path to a model XML file in MJCF or URDF format.
+ xml_string: XML string containing an MJCF or URDF model description.
+ assets: Optional dict containing external assets referenced by the model
+ (such as additional XML files, textures, meshes etc.), in the form of
+ `{filename: contents_string}` pairs. The keys should correspond to the
+ filenames specified in the model XML. Ignored if `xml_string` is None.
+
+ One of `xml_path` or `xml_string` must be specified.
+
+ Returns:
+ A `ctypes.POINTER` to a new `mjbindings.types.MJMODEL` instance.
+
+ Raises:
+ TypeError: If both or neither of `xml_path` and `xml_string` are specified.
+ Error: If the model is not created successfully.
+ """
+ if xml_path is None and xml_string is None:
+ raise TypeError(
+ "At least one of `xml_path` or `xml_string` must be specified.")
+ elif xml_path is not None and xml_string is not None:
+ raise TypeError(
+ "Only one of `xml_path` or `xml_string` may be specified.")
+
+ _maybe_register_license()
+
+ if xml_string is not None:
+ assets = {} if assets is None else assets.copy()
+ # Ensure that the fake XML filename doesn't overwrite an existing asset.
+ xml_path = _FAKE_XML_FILENAME
+ while xml_path in assets:
+ xml_path = "_" + xml_path
+ assets[xml_path] = xml_string
+ with _temporary_vfs(assets) as vfs:
+ ptr = _load_xml(xml_path, vfs)
+ else:
+ ptr = _load_xml(xml_path, None)
+
+ global _LAST_PARSED_MODEL_PTR
+ _LAST_PARSED_MODEL_PTR = ptr
+
+ return ptr
+
+
+def save_last_parsed_model_to_xml(xml_path, check_model=None):
+ """Writes a description of the most recently loaded model to an MJCF XML file.
+
+ Args:
+ xml_path: Path to the output XML file.
+ check_model: Optional `MjModel` instance. If specified, this model will be
+ checked to see if it is the most recently parsed one, and a ValueError
+ will be raised otherwise.
+ Raises:
+ Error: If MuJoCo encounters an error while writing the XML file.
+ ValueError: If `check_model` was passed, and this model is not the most
+ recently parsed one.
+ """
+ if check_model and check_model.ptr is not _LAST_PARSED_MODEL_PTR:
+ raise ValueError(_NOT_LAST_PARSED_ERROR)
+ error_buf = ctypes.create_string_buffer(_ERROR_BUFSIZE)
+ mjlib.mj_saveLastXML(util.to_binary_string(xml_path),
+ _LAST_PARSED_MODEL_PTR,
+ error_buf,
+ _ERROR_BUFSIZE)
+ if error_buf.value:
+ raise Error(error_buf.value)
+
+
+def _get_model_ptr_from_binary(binary_path=None, byte_string=None):
+ """Returns a pointer to an mjModel from the contents of a MuJoCo model binary.
+
+ Args:
+ binary_path: Path to an MJB file (as produced by MjModel.save_binary).
+ byte_string: String of bytes (as returned by MjModel.to_bytes).
+
+ One of `binary_path` or `byte_string` must be specified.
+
+ Returns:
+ A `ctypes.POINTER` to a new `mjbindings.types.MJMODEL` instance.
+
+ Raises:
+ TypeError: If both or neither of `byte_string` and `binary_path`
+ are specified.
+ """
+ if binary_path is None and byte_string is None:
+ raise TypeError(
+ "At least one of `byte_string` or `binary_path` must be specified.")
+ elif binary_path is not None and byte_string is not None:
+ raise TypeError(
+ "Only one of `byte_string` or `binary_path` may be specified.")
+
+ _maybe_register_license()
+
+ if byte_string is not None:
+ with _temporary_vfs({_FAKE_BINARY_FILENAME: byte_string}) as vfs:
+ ptr = mjlib.mj_loadModel(_FAKE_BINARY_FILENAME, vfs)
+ else:
+ ptr = mjlib.mj_loadModel(util.to_binary_string(binary_path), None)
+
+ # Free resources when the ctypes pointer is garbage collected.
+ _create_finalizer(ptr, mjlib.mj_deleteModel)
+
+ return ptr
+
+
+# Subclasses implementing constructors/destructors for low-level wrappers.
+# ------------------------------------------------------------------------------
+
+
+class MjModel(wrappers.MjModelWrapper):
+ """Wrapper class for a MuJoCo 'mjModel' instance.
+
+ MjModel encapsulates features of the model that are expected to remain
+ constant. It also contains simulation and visualization options which may be
+ changed occasionally, although this is done explicitly by the user.
+ """
+
+ def __init__(self, model_ptr):
+ """Creates a new MjModel instance from a ctypes pointer.
+
+ Args:
+ model_ptr: A `ctypes.POINTER` to an `mjbindings.types.MJMODEL` instance.
+ """
+ super(MjModel, self).__init__(model_ptr)
+
+ def __getstate__(self):
+ # All of MjModel's state is assumed to reside within the MuJoCo C struct.
+ # However there is no mechanism to prevent users from adding arbitrary
+ # Python attributes to an MjModel instance - these would not be serialized.
+ return self.to_bytes()
+
+ def __setstate__(self, byte_string):
+ model_ptr = _get_model_ptr_from_binary(byte_string=byte_string)
+ self.__init__(model_ptr)
+
+ def __copy__(self):
+ new_model_ptr = mjlib.mj_copyModel(None, self.ptr)
+ return self.__class__(new_model_ptr)
+
+ @classmethod
+ def from_xml_string(cls, xml_string, assets=None):
+ """Creates an `MjModel` instance from a model description XML string.
+
+ Args:
+ xml_string: String containing an MJCF or URDF model description.
+ assets: Optional dict containing external assets referenced by the model
+ (such as additional XML files, textures, meshes etc.), in the form of
+ `{filename: contents_string}` pairs. The keys should correspond to the
+ filenames specified in the model XML.
+
+ Returns:
+ An `MjModel` instance.
+ """
+ model_ptr = _get_model_ptr_from_xml(xml_string=xml_string, assets=assets)
+ return cls(model_ptr)
+
+ @classmethod
+ def from_byte_string(cls, byte_string):
+ """Creates an MjModel instance from a model binary as a string of bytes."""
+ model_ptr = _get_model_ptr_from_binary(byte_string=byte_string)
+ return cls(model_ptr)
+
+ @classmethod
+ def from_xml_path(cls, xml_path):
+ """Creates an MjModel instance from a path to a model XML file."""
+ model_ptr = _get_model_ptr_from_xml(xml_path=xml_path)
+ return cls(model_ptr)
+
+ @classmethod
+ def from_binary_path(cls, binary_path):
+ """Creates an MjModel instance from a path to a compiled model binary."""
+ model_ptr = _get_model_ptr_from_binary(binary_path=binary_path)
+ return cls(model_ptr)
+
+ def save_binary(self, binary_path):
+ """Saves the MjModel instance to a binary file."""
+ mjlib.mj_saveModel(self.ptr, util.to_binary_string(binary_path), None, 0)
+
+ def to_bytes(self):
+ """Serialize the model to a string of bytes."""
+ bufsize = mjlib.mj_sizeModel(self.ptr)
+ buf = ctypes.create_string_buffer(bufsize)
+ mjlib.mj_saveModel(self.ptr, None, buf, bufsize)
+ return buf.raw
+
+ def copy(self):
+ """Returns a copy of this MjModel instance."""
+ return self.__copy__()
+
+ def name2id(self, name, object_type):
+ """Returns the integer ID of a specified MuJoCo object.
+
+ Args:
+ name: String specifying the name of the object to query.
+ object_type: The type of the object. Can be either a lowercase string
+ (e.g. 'body', 'geom') or an `mjtObj` enum value.
+
+ Returns:
+ An integer object ID.
+
+ Raises:
+ Error: If `object_type` is not a valid MuJoCo object type, or if no object
+ with the corresponding name and type was found.
+ """
+ if not isinstance(object_type, int):
+ object_type = _str2type(object_type)
+ obj_id = mjlib.mj_name2id(
+ self.ptr, object_type, util.to_binary_string(name))
+ if obj_id == -1:
+ raise Error("Object of type {!r} with name {!r} does not exist.".format(
+ _type2str(object_type), name))
+ return obj_id
+
+ def id2name(self, object_id, object_type):
+ """Returns the name associated with a MuJoCo object ID, if there is one.
+
+ Args:
+ object_id: Integer ID.
+ object_type: The type of the object. Can be either a lowercase string
+ (e.g. 'body', 'geom') or an `mjtObj` enum value.
+
+ Returns:
+ A string containing the object name, or an empty string if the object ID
+ either doesn't exist or has no name.
+
+ Raises:
+ Error: If `object_type` is not a valid MuJoCo object type.
+ """
+ if not isinstance(object_type, int):
+ object_type = _str2type(object_type)
+ name_ptr = mjlib.mj_id2name(self.ptr, object_type, object_id)
+ if not name_ptr:
+ return ""
+ return util.to_native_string(ctypes.string_at(name_ptr))
+
+ @contextlib.contextmanager
+ def disable(self, *flags):
+ """Context manager for temporarily disabling MuJoCo flags.
+
+ Args:
+ *flags: Positional arguments specifying flags to disable. Can be either
+ lowercase strings (e.g. 'gravity', 'contact') or `mjtDisableBit` enum
+ values.
+
+ Yields:
+ None
+
+ Raises:
+ ValueError: If any item in `flags` is neither a valid name nor a value
+ from `enums.mjtDisableBit`.
+ """
+ old_bitmask = self.opt.disableflags
+ new_bitmask = old_bitmask
+ for flag in flags:
+ if isinstance(flag, six.string_types):
+ try:
+ field_name = "mjDSBL_" + flag.upper()
+ bitmask = getattr(enums.mjtDisableBit, field_name)
+ except AttributeError:
+ valid_names = [field_name.split("_")[1].lower()
+ for field_name in enums.mjtDisableBit._fields[:-1]]
+ raise ValueError("'{}' is not a valid flag name. Valid names: {}"
+ .format(flag, ", ".join(valid_names)))
+ else:
+ if flag not in enums.mjtDisableBit[:-1]:
+ raise ValueError("'{}' is not a value in `enums.mjtDisableBit`. "
+ "Valid values: {}"
+ .format(flag, tuple(enums.mjtDisableBit[:-1])))
+ bitmask = flag
+ new_bitmask |= bitmask
+ self.opt.disableflags = new_bitmask
+ try:
+ yield
+ finally:
+ self.opt.disableflags = old_bitmask
+
+ @property
+ def name(self):
+ """Returns the name of the model."""
+ # The model name is the first null-terminated string in the `names` buffer.
+ return util.to_native_string(
+ ctypes.string_at(ctypes.addressof(self.names.contents)))
+
+
+class MjData(wrappers.MjDataWrapper):
+ """Wrapper class for a MuJoCo 'mjData' instance.
+
+ MjData contains all of the dynamic variables and intermediate results produced
+ by the simulation. These are expected to change on each simulation timestep.
+ """
+
+ def __init__(self, model):
+ """Construct a new MjData instance.
+
+ Args:
+ model: An MjModel instance.
+ """
+ self._model = model
+
+ # Allocate resources for mjData.
+ data_ptr = mjlib.mj_makeData(model.ptr)
+
+ # Free resources when the ctypes pointer is garbage collected.
+ _create_finalizer(data_ptr, mjlib.mj_deleteData)
+
+ super(MjData, self).__init__(data_ptr, model)
+
+ def __getstate__(self):
+ # Note: we can replace this once a `saveData` MJAPI function exists.
+ # To reconstruct an MjData instance we need three things:
+ # 1. Its parent MjModel instance
+ # 2. A subset of its fixed-size fields whose values aren't determined by
+ # the model
+ # 3. The contents of its internal buffer (all of its pointer fields point
+ # into this)
+ struct_fields = {}
+ for name in ["solver", "timer", "warning"]:
+ new_structs = []
+ for struct in getattr(self, name):
+ new_struct = type(struct)()
+ ctypes.memmove(ctypes.byref(new_struct), ctypes.byref(struct),
+ ctypes.sizeof(struct))
+ new_structs.append(new_struct)
+ struct_fields[name] = new_structs
+ scalar_field_names = ["ncon", "time", "energy"]
+ scalar_fields = {name: getattr(self, name) for name in scalar_field_names}
+ static_fields = {"struct_fields": struct_fields,
+ "scalar_fields": scalar_fields}
+ buffer_contents = ctypes.string_at(self.buffer_, self.nbuffer)
+ return (self._model, static_fields, buffer_contents)
+
+ def __setstate__(self, state_tuple):
+ # Replace this once a `loadData` MJAPI function exists.
+ self._model, static_fields, buffer_contents = state_tuple
+ self.__init__(self.model)
+ for name, contents in six.iteritems(static_fields["struct_fields"]):
+ target_carray = getattr(self, name)
+ for i, struct in enumerate(contents):
+ ctypes.memmove(ctypes.byref(target_carray[i]), ctypes.byref(struct),
+ ctypes.sizeof(struct))
+
+ for name, value in six.iteritems(static_fields["scalar_fields"]):
+ # Array and scalar values must be handled separately.
+ try:
+ getattr(self, name)[:] = value
+ except TypeError:
+ setattr(self, name, value)
+ buf_ptr = (ctypes.c_char * self.nbuffer).from_address(self.buffer_)
+ buf_ptr[:] = buffer_contents
+
+ def __copy__(self):
+ # This makes a shallow copy that shares the same parent MjModel instance.
+ new_obj = self.__class__(self.model)
+ mjlib.mj_copyData(new_obj.ptr, self.model.ptr, self.ptr)
+ return new_obj
+
+ def copy(self):
+ """Returns a copy of this MjData instance with the same parent MjModel."""
+ return self.__copy__()
+
+ @property
+ def model(self):
+ """The parent MjModel for this MjData instance."""
+ return self._model
+
+ @property
+ def contact(self):
+ """Iterator over detected contacts."""
+ return (wrappers.MjContactWrapper(ctypes.pointer(c))
+ for c in super(MjData, self).contact[:self.ncon])
+
+
+# Docstrings for these subclasses are inherited from their Wrapper parent class.
+
+
+class MjvCamera(wrappers.MjvCameraWrapper):
+
+ def __init__(self):
+ ptr = ctypes.pointer(types.MJVCAMERA())
+ mjlib.mjv_defaultCamera(ptr)
+ super(MjvCamera, self).__init__(ptr)
+
+
+class MjvOption(wrappers.MjvOptionWrapper):
+
+ def __init__(self):
+ ptr = ctypes.pointer(types.MJVOPTION())
+ mjlib.mjv_defaultOption(ptr)
+ super(MjvOption, self).__init__(ptr)
+
+
+class MjrContext(wrappers.MjrContextWrapper):
+
+ def __init__(self):
+ ptr = ctypes.pointer(types.MJRCONTEXT())
+ mjlib.mjr_defaultContext(ptr)
+ super(MjrContext, self).__init__(ptr)
+
+
+class MjvScene(wrappers.MjvSceneWrapper): # pylint: disable=missing-docstring
+
+ def __init__(self, max_geom=1000):
+ """Initializes a new `MjvScene` instance.
+
+ Args:
+ max_geom: (optional) An integer specifying the maximum number of geoms
+ that can be represented in the scene.
+ """
+ scene_ptr = ctypes.pointer(types.MJVSCENE())
+
+ # Allocate and initialize resources for the abstract scene.
+ mjlib.mjv_makeScene(scene_ptr, max_geom)
+
+ # Free resources when the ctypes pointer is garbage collected.
+ _create_finalizer(scene_ptr, mjlib.mjv_freeScene)
+
+ super(MjvScene, self).__init__(scene_ptr)
+
+
+class MjvPerturb(wrappers.MjvPerturbWrapper):
+
+ def __init__(self):
+ ptr = ctypes.pointer(types.MJVPERTURB())
+ mjlib.mjv_defaultPerturb(ptr)
+ super(MjvPerturb, self).__init__(ptr)
+
+
+class MjvFigure(wrappers.MjvFigureWrapper):
+
+ def __init__(self):
+ ptr = ctypes.pointer(types.MJVFIGURE())
+ mjlib.mjv_defaultFigure(ptr)
+ super(MjvFigure, self).__init__(ptr)
diff --git a/dm_control/mujoco/wrapper/core_test.py b/dm_control/mujoco/wrapper/core_test.py
new file mode 100644
index 00000000..5cd457ce
--- /dev/null
+++ b/dm_control/mujoco/wrapper/core_test.py
@@ -0,0 +1,459 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for core.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+# Internal dependencies.
+
+from absl.testing import absltest
+from absl.testing import parameterized
+
+from dm_control.mujoco.testing import assets
+from dm_control.mujoco.wrapper import core
+from dm_control.mujoco.wrapper.mjbindings import enums
+from dm_control.mujoco.wrapper.mjbindings import mjlib
+
+import mock
+import numpy as np
+from six.moves import cPickle
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+
+HUMANOID_XML_PATH = assets.get_path("humanoid.xml")
+MODEL_WITH_ASSETS = assets.get_contents("model_with_assets.xml")
+ASSETS = {
+ "texture.png": assets.get_contents("deepmind.png"),
+ "mesh.stl": assets.get_contents("cube.stl"),
+ "included.xml": assets.get_contents("sphere.xml")
+}
+
+SCALAR_TYPES = (int, float)
+ARRAY_TYPES = (np.ndarray,)
+
+OUT_DIR = absltest.get_default_test_tmpdir()
+if not os.path.exists(OUT_DIR):
+ os.makedirs(OUT_DIR) # Ensure that the output directory exists.
+
+
+class CoreTest(parameterized.TestCase):
+
+ def setUp(self):
+ self.model = core.MjModel.from_xml_path(HUMANOID_XML_PATH)
+ self.data = core.MjData(self.model)
+
+ def _assert_attributes_equal(self, actual_obj, expected_obj, attr_to_compare):
+ for name in attr_to_compare:
+ actual_value = getattr(actual_obj, name)
+ expected_value = getattr(expected_obj, name)
+ try:
+ if isinstance(expected_value, np.ndarray):
+ np.testing.assert_array_equal(actual_value, expected_value)
+ else:
+ self.assertEqual(actual_value, expected_value)
+ except AssertionError as e:
+ self.fail("Attribute '{}' differs from expected value: {}"
+ .format(name, str(e)))
+
+ def _assert_structs_equal(self, expected, actual):
+ for name in set(dir(actual) + dir(expected)):
+ if not name.startswith("_"):
+ expected_value = getattr(expected, name)
+ actual_value = getattr(actual, name)
+ self.assertEqual(
+ expected_value,
+ actual_value,
+ msg="struct field '{}' has value {}, expected {}".format(
+ name, actual_value, expected_value))
+
+ def testLoadXML(self):
+ with open(HUMANOID_XML_PATH, "r") as f:
+ xml_string = f.read()
+ model = core.MjModel.from_xml_string(xml_string)
+ core.MjData(model)
+ with self.assertRaises(TypeError):
+ core.MjModel()
+ with self.assertRaises(core.Error):
+ core.MjModel.from_xml_path("/path/to/nonexistent/model/file.xml")
+
+ xml_with_warning = """
+
+
+
+
+
+
+
+
+
+
+
+ """
+ with mock.patch.object(core, "logging") as mock_logging:
+ core.MjModel.from_xml_string(xml_with_warning)
+ mock_logging.warn.assert_called_once_with(
+ "Error: Pre-allocated constraint buffer is full. "
+ "Increase njmax above 2. Time = 0.0000.")
+
+ def testLoadXMLWithAssetsFromString(self):
+ core.MjModel.from_xml_string(MODEL_WITH_ASSETS, assets=ASSETS)
+ with self.assertRaises(core.Error):
+ # Should fail to load without the assets
+ core.MjModel.from_xml_string(MODEL_WITH_ASSETS)
+
+ def testSaveLastParsedModelToXML(self):
+ save_xml_path = os.path.join(OUT_DIR, "tmp_humanoid.xml")
+
+ not_last_parsed = core.MjModel.from_xml_path(HUMANOID_XML_PATH)
+ last_parsed = core.MjModel.from_xml_path(HUMANOID_XML_PATH)
+
+ # Modify the model before saving it in order to confirm that the changes are
+ # written to the XML.
+ last_parsed.geom_pos.flat[:] = np.arange(last_parsed.geom_pos.size)
+
+ core.save_last_parsed_model_to_xml(save_xml_path, check_model=last_parsed)
+
+ loaded = core.MjModel.from_xml_path(save_xml_path)
+ self._assert_attributes_equal(last_parsed, loaded, ["geom_pos"])
+ core.MjData(loaded)
+
+ # Test that `check_model` results in a ValueError if it is not the most
+ # recently parsed model.
+ with self.assertRaisesWithLiteralMatch(
+ ValueError, core._NOT_LAST_PARSED_ERROR):
+ core.save_last_parsed_model_to_xml(save_xml_path,
+ check_model=not_last_parsed)
+
+ def testBinaryIO(self):
+ bin_path = os.path.join(OUT_DIR, "tmp_humanoid.mjb")
+ self.model.save_binary(bin_path)
+ core.MjModel.from_binary_path(bin_path)
+ byte_string = self.model.to_bytes()
+ core.MjModel.from_byte_string(byte_string)
+
+ def testDimensions(self):
+ self.assertEqual(self.data.qpos.shape[0], self.model.nq)
+ self.assertEqual(self.data.qvel.shape[0], self.model.nv)
+ self.assertEqual(self.model.body_pos.shape, (self.model.nbody, 3))
+
+ def testStep(self):
+ t0 = self.data.time
+ mjlib.mj_step(self.model.ptr, self.data.ptr)
+ self.assertEqual(self.data.time, t0 + self.model.opt.timestep)
+ self.assert_(np.all(np.isfinite(self.data.qpos[:])))
+ self.assert_(np.all(np.isfinite(self.data.qvel[:])))
+
+ def testMultipleData(self):
+ data2 = core.MjData(self.model)
+ self.assertNotEqual(self.data.ptr, data2.ptr)
+ t0 = self.data.time
+ mjlib.mj_step(self.model.ptr, self.data.ptr)
+ self.assertEqual(self.data.time, t0 + self.model.opt.timestep)
+ self.assertEqual(data2.time, 0)
+
+ def testMultipleModel(self):
+ model2 = core.MjModel.from_xml_path(HUMANOID_XML_PATH)
+ self.assertNotEqual(self.model.ptr, model2.ptr)
+ self.model.opt.timestep += 0.001
+ self.assertEqual(self.model.opt.timestep, model2.opt.timestep + 0.001)
+
+ def testModelName(self):
+ self.assertEqual(self.model.name, "humanoid")
+
+ @parameterized.named_parameters(
+ ("_copy", lambda x: x.copy()),
+ ("_pickle_unpickle", lambda x: cPickle.loads(cPickle.dumps(x))),)
+ def testCopyOrPickleModel(self, func):
+ timestep = 0.12345
+ self.model.opt.timestep = timestep
+ body_pos = self.model.body_pos + 1
+ self.model.body_pos[:] = body_pos
+ model2 = func(self.model)
+ self.assertNotEqual(model2.ptr, self.model.ptr)
+ self.assertEqual(model2.opt.timestep, timestep)
+ np.testing.assert_array_equal(model2.body_pos, body_pos)
+
+ @parameterized.named_parameters(
+ ("_copy", lambda x: x.copy()),
+ ("_pickle_unpickle", lambda x: cPickle.loads(cPickle.dumps(x))),)
+ def testCopyOrPickleData(self, func):
+ for _ in xrange(10):
+ mjlib.mj_step(self.model.ptr, self.data.ptr)
+ data2 = func(self.data)
+ attr_to_compare = ("time", "energy", "qpos", "xpos")
+ self.assertNotEqual(data2.ptr, self.data.ptr)
+ self._assert_attributes_equal(data2, self.data, attr_to_compare)
+ for _ in xrange(10):
+ mjlib.mj_step(self.model.ptr, self.data.ptr)
+ mjlib.mj_step(data2.model.ptr, data2.ptr)
+ self._assert_attributes_equal(data2, self.data, attr_to_compare)
+
+ @parameterized.named_parameters(
+ ("_copy", lambda x: x.copy()),
+ ("_pickle_unpickle", lambda x: cPickle.loads(cPickle.dumps(x))),)
+ def testCopyOrPickleStructs(self, func):
+ for _ in xrange(10):
+ mjlib.mj_step(self.model.ptr, self.data.ptr)
+ data2 = func(self.data)
+ self.assertNotEqual(data2.ptr, self.data.ptr)
+ for name in ["warning", "timer", "solver"]:
+ self._assert_structs_equal(getattr(self.data, name), getattr(data2, name))
+ for _ in xrange(10):
+ mjlib.mj_step(self.model.ptr, self.data.ptr)
+ mjlib.mj_step(data2.model.ptr, data2.ptr)
+ for expected, actual in zip(self.data.timer, data2.timer):
+ self._assert_structs_equal(expected, actual)
+
+ @parameterized.parameters(
+ ("right_foot", "body", 6),
+ ("right_foot", enums.mjtObj.mjOBJ_BODY, 6),
+ ("left_knee", "joint", 11),
+ ("left_knee", enums.mjtObj.mjOBJ_JOINT, 11))
+ def testNamesIds(self, name, object_type, object_id):
+ output_id = self.model.name2id(name, object_type)
+ self.assertEqual(object_id, output_id)
+ output_name = self.model.id2name(object_id, object_type)
+ self.assertEqual(name, output_name)
+
+ def testNamesIdsExceptions(self):
+ with self.assertRaisesRegexp(core.Error, "does not exist"):
+ self.model.name2id("nonexistent_body_name", "body")
+ with self.assertRaisesRegexp(core.Error, "is not a valid object type"):
+ self.model.name2id("right_foot", "nonexistent_type_name")
+
+ def testNamelessObject(self):
+ # The model in humanoid.xml contains a single nameless camera.
+ name = self.model.id2name(0, "camera")
+ self.assertEqual("", name)
+
+ def testWarningCallback(self):
+ self.data.qpos[0] = np.inf
+ with mock.patch.object(core, "logging") as mock_logging:
+ mjlib.mj_step(self.model.ptr, self.data.ptr)
+ mock_logging.warn.assert_called_once_with(
+ "Nan, Inf or huge value in QPOS at DOF 0. The simulation is unstable. "
+ "Time = 0.0000.")
+
+ def testErrorCallback(self):
+ with mock.patch.object(core, "logging") as mock_logging:
+ mjlib.mj_activate(b"nonexistent_activation_key")
+ mock_logging.fatal.assert_called_once_with(
+ "Could not open activation key file nonexistent_activation_key")
+
+ def testSingleCallbackContext(self):
+
+ callback_was_called = [False]
+
+ def callback(unused_model, unused_data):
+ callback_was_called[0] = True
+
+ mjlib.mj_step(self.model.ptr, self.data.ptr)
+ self.assertFalse(callback_was_called[0])
+
+ class DummyError(RuntimeError):
+ pass
+
+ try:
+ with core.callback_context("mjcb_passive", callback):
+
+ # Stepping invokes the `mjcb_passive` callback.
+ mjlib.mj_step(self.model.ptr, self.data.ptr)
+ self.assertTrue(callback_was_called[0])
+
+ # Exceptions should not prevent `mjcb_passive` from being reset.
+ raise DummyError("Simulated exception.")
+
+ except DummyError:
+ pass
+
+ # `mjcb_passive` should have been reset to None.
+ callback_was_called[0] = False
+ mjlib.mj_step(self.model.ptr, self.data.ptr)
+ self.assertFalse(callback_was_called[0])
+
+ def testNestedCallbackContexts(self):
+
+ last_called = [None]
+ outer_called = "outer called"
+ inner_called = "inner called"
+
+ def outer(unused_model, unused_data):
+ last_called[0] = outer_called
+
+ def inner(unused_model, unused_data):
+ last_called[0] = inner_called
+
+ with core.callback_context("mjcb_passive", outer):
+
+ # This should execute `outer` a few times.
+ mjlib.mj_step(self.model.ptr, self.data.ptr)
+ self.assertEqual(last_called[0], outer_called)
+
+ with core.callback_context("mjcb_passive", inner):
+
+ # This should execute `inner` a few times.
+ mjlib.mj_step(self.model.ptr, self.data.ptr)
+ self.assertEqual(last_called[0], inner_called)
+
+ # When we exit the inner context, the `mjcb_passive` callback should be
+ # reset to `outer`.
+ mjlib.mj_step(self.model.ptr, self.data.ptr)
+ self.assertEqual(last_called[0], outer_called)
+
+ # When we exit the outer context, the `mjcb_passive` callback should be
+ # reset to None, and stepping should not affect `last_called`.
+ last_called[0] = None
+ mjlib.mj_step(self.model.ptr, self.data.ptr)
+ self.assertIsNone(last_called[0])
+
+ def testDisableFlags(self):
+ xml_string = """
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ """
+ model = core.MjModel.from_xml_string(xml_string)
+ data = core.MjData(model)
+ for _ in xrange(100): # Let the simulation settle for a while.
+ mjlib.mj_step(model.ptr, data.ptr)
+
+ # With gravity and contact enabled, the cube should be stationary and the
+ # touch sensor should give a reading of ~9.81 N.
+ self.assertAlmostEqual(data.qvel[0], 0, places=4)
+ self.assertAlmostEqual(data.sensordata[0], 9.81, places=2)
+
+ # If we disable both contacts and gravity then the cube should remain
+ # stationary and the touch sensor should read zero.
+ with model.disable("contact", "gravity"):
+ mjlib.mj_step(model.ptr, data.ptr)
+ self.assertAlmostEqual(data.qvel[0], 0, places=4)
+ self.assertEqual(data.sensordata[0], 0)
+
+ # If we disable contacts but not gravity then the cube should fall through
+ # the floor.
+ with model.disable(enums.mjtDisableBit.mjDSBL_CONTACT):
+ for _ in xrange(10):
+ mjlib.mj_step(model.ptr, data.ptr)
+ self.assertLess(data.qvel[0], -0.1)
+
+ def testDisableFlagsExceptions(self):
+ with self.assertRaisesRegexp(ValueError, "not a valid flag name"):
+ with self.model.disable("invalid_flag_name"):
+ pass
+ with self.assertRaisesRegexp(ValueError,
+ "not a value in `enums.mjtDisableBit`"):
+ with self.model.disable(-99):
+ pass
+
+
+def _get_attributes_test_params():
+ model = core.MjModel.from_xml_path(HUMANOID_XML_PATH)
+ data = core.MjData(model)
+ # Get the names of the non-private attributes of model and data through
+ # introspection. These are passed as parameters to each of the test methods
+ # in AttributesTest.
+ array_args = []
+ scalar_args = []
+ skipped_args = []
+ for parent_name, parent_obj in zip(("model", "data"), (model, data)):
+ for attr_name in dir(parent_obj):
+ if not attr_name.startswith("_"): # Skip 'private' attributes
+ args = (parent_name, attr_name)
+ attr = getattr(parent_obj, attr_name)
+ if isinstance(attr, ARRAY_TYPES):
+ array_args.append(args)
+ elif isinstance(attr, SCALAR_TYPES):
+ scalar_args.append(args)
+ elif callable(attr):
+ # Methods etc. should be covered specifically in CoreTest.
+ continue
+ else:
+ skipped_args.append(args)
+ return array_args, scalar_args, skipped_args
+
+
+_array_args, _scalar_args, _skipped_args = _get_attributes_test_params()
+
+
+class AttributesTest(parameterized.TestCase):
+ """Generic tests covering attributes of MjModel and MjData."""
+
+ # Iterates over ('parent_name', 'attr_name') tuples
+ @parameterized.parameters(*_array_args)
+ def testReadWriteArray(self, parent_name, attr_name):
+ attr = getattr(getattr(self, parent_name), attr_name)
+ if not isinstance(attr, ARRAY_TYPES):
+ raise TypeError("{}.{} has incorrect type {!r} - must be one of {!r}."
+ .format(parent_name, attr_name, type(attr), ARRAY_TYPES))
+ # Check that we can read the contents of the array
+ old_contents = attr[:]
+ # Don't write to integer arrays since these might contain pointers.
+ if not np.issubdtype(old_contents.dtype, int):
+ # Write unique values to the array, check that we can read them back.
+ new_contents = np.arange(old_contents.size, dtype=old_contents.dtype)
+ new_contents.shape = old_contents.shape
+ attr[:] = new_contents
+ np.testing.assert_array_equal(new_contents, attr[:])
+ self._take_steps() # Take a few steps, check that we don't get segfaults.
+
+ @parameterized.parameters(*_scalar_args)
+ def testReadWriteScalar(self, parent_name, attr_name):
+ parent_obj = getattr(self, parent_name)
+ # Check that we can read the value.
+ attr = getattr(parent_obj, attr_name)
+ if not isinstance(attr, SCALAR_TYPES):
+ raise TypeError("{}.{} has incorrect type {!r} - must be one of {!r}."
+ .format(parent_name, attr_name, type(attr), SCALAR_TYPES))
+ # Don't write to integers since these might be pointers.
+ if not isinstance(attr, int):
+ # Set the value of this attribute, check that we can read it back.
+ new_value = type(attr)(99)
+ setattr(parent_obj, attr_name, new_value)
+ self.assertEqual(new_value, getattr(parent_obj, attr_name))
+ self._take_steps() # Take a few steps, check that we don't get segfaults.
+
+ @parameterized.parameters(*_skipped_args)
+ @absltest.unittest.skip("No tests defined for attributes of this type.")
+ def testSkipped(self, *unused_args):
+ # This is a do-nothing test that indicates where we currently lack coverage.
+ pass
+
+ def setUp(self):
+ self.model = core.MjModel.from_xml_path(HUMANOID_XML_PATH)
+ self.data = core.MjData(self.model)
+
+ def _take_steps(self, n=5):
+ for _ in xrange(n):
+ mjlib.mj_step(self.model.ptr, self.data.ptr)
+
+
+if __name__ == "__main__":
+ absltest.main()
diff --git a/dm_control/mujoco/wrapper/mjbindings/__init__.py b/dm_control/mujoco/wrapper/mjbindings/__init__.py
new file mode 100644
index 00000000..8e248e17
--- /dev/null
+++ b/dm_control/mujoco/wrapper/mjbindings/__init__.py
@@ -0,0 +1,36 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Import core names of MuJoCo ctypes bindings."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl import logging
+
+from dm_control.mujoco.wrapper.mjbindings import constants
+from dm_control.mujoco.wrapper.mjbindings import enums
+from dm_control.mujoco.wrapper.mjbindings import sizes
+from dm_control.mujoco.wrapper.mjbindings import types
+from dm_control.mujoco.wrapper.mjbindings import wrappers
+
+# pylint: disable=g-import-not-at-top
+try:
+ from dm_control.mujoco.wrapper.mjbindings import functions
+ from dm_control.mujoco.wrapper.mjbindings.functions import mjlib
+except (IOError, OSError):
+ logging.warn('mjbindings failed to import mjlib and other functions. '
+ 'libmujoco.so may not be accessible.')
diff --git a/dm_control/mujoco/wrapper/mjbindings_test.py b/dm_control/mujoco/wrapper/mjbindings_test.py
new file mode 100644
index 00000000..cd959a50
--- /dev/null
+++ b/dm_control/mujoco/wrapper/mjbindings_test.py
@@ -0,0 +1,48 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for mjbindings."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+
+from absl.testing import absltest
+from absl.testing import parameterized
+
+from dm_control.mujoco.wrapper.mjbindings import constants
+from dm_control.mujoco.wrapper.mjbindings import sizes
+
+
+class MjbindingsTest(parameterized.TestCase):
+
+ @parameterized.parameters(
+ ('mjdata', 'xpos', ('nbody', 3)),
+ ('mjmodel', 'geom_type', ('ngeom',)),
+ # Fields with identifiers in mjxmacro that are resolved at compile-time.
+ ('mjmodel', 'actuator_dynprm', ('nu', constants.mjNDYN)),
+ ('mjdata', 'efc_solref', ('njmax', constants.mjNREF)),
+ # Fields with multiple named indices.
+ ('mjmodel', 'key_qpos', ('nkey', 'nq')),
+ )
+ def testIndexDict(self, struct_name, field_name, expected_metadata):
+ self.assertEqual(expected_metadata,
+ sizes.array_sizes[struct_name][field_name])
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/mujoco/wrapper/util.py b/dm_control/mujoco/wrapper/util.py
new file mode 100644
index 00000000..75cec637
--- /dev/null
+++ b/dm_control/mujoco/wrapper/util.py
@@ -0,0 +1,220 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Various helper functions and classes."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import ctypes
+import ctypes.util
+import functools
+import os
+import sys
+import threading
+# Internal dependencies.
+import numpy as np
+import six
+
+from dm_control.utils import resources
+
+# Environment variables that can be used to override the default paths to the
+# MuJoCo shared library and key file.
+ENV_MJLIB_PATH = "MJLIB_PATH"
+ENV_MJKEY_PATH = "MJKEY_PATH"
+
+
+def _find_shared_library_extension():
+ try:
+ libc_path = ctypes.util.find_library("c")
+ libc_filename = os.path.split(libc_path)[1]
+ return "." + libc_filename.split(".")[1]
+ except (AttributeError, IndexError):
+ return ".so"
+
+
+SHARED_LIB_EXT = _find_shared_library_extension()
+DEFAULT_MJLIB_PATH = "~/.mujoco/mjpro150/bin/libmujoco150" + SHARED_LIB_EXT
+DEFAULT_MJKEY_PATH = "~/.mujoco/mjkey.txt"
+
+
+DEFAULT_ENCODING = sys.getdefaultencoding()
+
+
+def to_binary_string(s):
+ """Convert text string to binary."""
+ if isinstance(s, six.binary_type):
+ return s
+ return s.encode(DEFAULT_ENCODING)
+
+
+def to_native_string(s):
+ """Convert a text or binary string to the native string format."""
+ if six.PY3 and isinstance(s, six.binary_type):
+ return s.decode(DEFAULT_ENCODING)
+ elif six.PY2 and isinstance(s, six.text_type):
+ return s.encode(DEFAULT_ENCODING)
+ else:
+ return s
+
+
+def _get_full_path(path):
+ expanded_path = os.path.expanduser(os.path.expandvars(path))
+ return resources.GetResourceFilename(expanded_path)
+
+
+def get_mjlib():
+ """Loads `libmujoco.so` and returns it as a `ctypes.CDLL` object."""
+ try:
+ # Use the MJLIB_PATH environment variable if it has been set.
+ raw_path = os.environ[ENV_MJLIB_PATH]
+ except KeyError:
+ paths_to_try = [
+ # If libmujoco is in LD_LIBRARY_PATH then ctypes only needs its name.
+ os.path.basename(DEFAULT_MJLIB_PATH),
+ _get_full_path(DEFAULT_MJLIB_PATH),
+ ]
+ for library_path in paths_to_try:
+ try:
+ return ctypes.cdll.LoadLibrary(library_path)
+ except OSError:
+ pass
+ raw_path = DEFAULT_MJLIB_PATH
+ return ctypes.cdll.LoadLibrary(_get_full_path(raw_path))
+
+
+def get_mjkey_path():
+ """Returns a path to the MuJoCo key file."""
+ raw_path = os.environ.get(ENV_MJKEY_PATH, DEFAULT_MJKEY_PATH)
+ return _get_full_path(raw_path)
+
+
+class WrapperBase(object):
+ """Base class for wrappers that provide getters/setters for ctypes structs."""
+
+ # This is needed so that the __del__ methods of MjModel and MjData can still
+ # succeed in cases where an exception occurs during __init__() before the _ptr
+ # attribute has been assigned.
+ _ptr = None
+
+ def __init__(self, ptr, model=None):
+ """Constructs a wrapper instance from a `ctypes.Structure`.
+
+ Args:
+ ptr: `ctypes.POINTER` to the struct to be wrapped.
+ model: `MjModel` instance; needed by `MjDataWrapper` in order to get the
+ dimensions of dynamically-sized arrays at runtime.
+ """
+ self._ptr = ptr
+ self._model = model
+
+ @property
+ def ptr(self):
+ """Pointer to the underlying `ctypes.Structure` instance."""
+ return self._ptr
+
+
+class CachedProperty(property):
+ """A property that is evaluated only once per object instance."""
+
+ def __init__(self, func, doc=None):
+ super(CachedProperty, self).__init__(fget=func, doc=doc)
+ self.lock = threading.RLock()
+
+ def __get__(self, obj, cls):
+ if obj is None:
+ return self
+ name = self.fget.__name__
+ obj_dict = obj.__dict__
+ with self.lock:
+ try:
+ # Return cached result if it was computed before the lock was acquired
+ return obj_dict[name]
+ except KeyError:
+ # Otherwise call the function, cache the result, and return it
+ return obj_dict.setdefault(name, self.fget(obj))
+
+
+# It's easy to create numpy arrays from a pointer then have these persist after
+# the model has been destroyed and its underlying memory freed. To mitigate the
+# risk of writing to a pointer after it has been freed, all array attributes are
+# read-only by default. In order to write to them you need to explicitly set
+# their ".writeable" flag to True (the SetFlags context manager above provides
+# a convenient way to do this).
+
+# The proper solution would be to prevent the model from being garbage-collected
+# whilst any of the views onto its buffers are still alive.
+
+
+def _as_array(src, shape):
+ """Converts a native `src` array to a managed numpy buffer.
+
+ Args:
+ src: A ctypes pointer or array.
+ shape: A tuple specifying the dimensions of the output array.
+
+ Returns:
+ A numpy array.
+ """
+
+ # To work around a memory leak in numpy, we have to go through this
+ # frombuffer method instead of calling ctypeslib.as_array. See
+ # https://github.com/numpy/numpy/issues/6511
+ # return np.ctypeslib.as_array(src, shape)
+
+ # This is part of the public API. See
+ # http://git.net/ml/python.ctypes/2008-02/msg00014.html
+ ctype = src._type_ # pylint: disable=protected-access
+
+ size = np.product(shape)
+ ptr = ctypes.cast(src, ctypes.POINTER(ctype * size))
+ buf = np.frombuffer(ptr.contents, dtype=ctype)
+ buf.shape = shape
+ return buf
+
+
+def buf_to_npy(src, shape, np_dtype=None):
+ """Returns a numpy array view of the contents of a ctypes pointer or array.
+
+ Args:
+ src: A ctypes pointer or array.
+ shape: A tuple specifying the dimensions of the output array.
+ np_dtype: A string or `np.dtype` object specifying the dtype of the output
+ array. If None, the dtype is inferred from the type of `src`.
+
+ Returns:
+ A numpy array.
+ """
+ # This causes a harmless RuntimeWarning about mismatching buffer format
+ # strings due to a bug in ctypes: http://stackoverflow.com/q/4964101/1461210
+ arr = _as_array(src, shape)
+ if np_dtype is not None:
+ arr.dtype = np_dtype
+ return arr
+
+
+@functools.wraps(np.ctypeslib.ndpointer)
+def ndptr(*args, **kwargs):
+ """Wraps `np.ctypeslib.ndpointer` to allow passing None for NULL pointers."""
+ base = np.ctypeslib.ndpointer(*args, **kwargs)
+
+ def from_param(_, obj):
+ if obj is None:
+ return obj
+ else:
+ return base.from_param(obj)
+
+ return type(base.__name__, (base,), {"from_param": classmethod(from_param)})
diff --git a/dm_control/mujoco/wrapper/util_test.py b/dm_control/mujoco/wrapper/util_test.py
new file mode 100644
index 00000000..1d6cf982
--- /dev/null
+++ b/dm_control/mujoco/wrapper/util_test.py
@@ -0,0 +1,58 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for util."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import resource
+
+# Internal dependencies.
+
+from absl.testing import absltest
+
+from dm_control.mujoco.wrapper import core
+from dm_control.mujoco.wrapper import util
+
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+_NUM_CALLS = 10000
+_RSS_GROWTH_TOLERANCE = 150 # Bytes
+
+
+class UtilTest(absltest.TestCase):
+
+ def test_buf_to_npy_no_memory_leak(self):
+ """Ensures we can call buf_to_npy without leaking memory."""
+ model = core.MjModel.from_xml_string("")
+ src = model._ptr.contents.name_geomadr
+ shape = (model.ngeom,)
+
+ # This uses high water marks to find memory leaks in native code.
+ old_max = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
+ for _ in xrange(_NUM_CALLS):
+ buf = util.buf_to_npy(src, shape)
+ del buf
+ new_max = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
+ growth = new_max - old_max
+
+ if growth > _RSS_GROWTH_TOLERANCE:
+ self.fail("RSS grew by {} bytes, exceeding tolerance of {} bytes."
+ .format(growth, _RSS_GROWTH_TOLERANCE))
+
+if __name__ == "__main__":
+ absltest.main()
diff --git a/dm_control/render/__init__.py b/dm_control/render/__init__.py
new file mode 100644
index 00000000..bf1b0dd3
--- /dev/null
+++ b/dm_control/render/__init__.py
@@ -0,0 +1,55 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""OpenGL context management for rendering MuJoCo scenes.
+
+The `Renderer` class will use one of the following rendering APIs, in order of
+descending priority: EGL > GLFW > OSMesa.
+"""
+
+# pylint: disable=g-import-not-at-top
+try:
+ from dm_control.render.glfw_renderer import GLFWRenderer as _GLFWRenderer
+except (ImportError, IOError):
+ _GLFWRenderer = None
+try:
+ from dm_control.render.egl_renderer import EGLRenderer as _EGLRenderer
+except ImportError:
+ _EGLRenderer = None
+try:
+ from dm_control.render.osmesa_renderer import OSMesaRenderer as _OSMesaRenderer
+except ImportError:
+ _OSMesaRenderer = None
+# pylint: enable=g-import-not-at-top
+
+# pylint: disable=invalid-name
+if _EGLRenderer:
+ Renderer = _EGLRenderer
+elif _GLFWRenderer:
+ Renderer = _GLFWRenderer
+elif _OSMesaRenderer:
+ Renderer = _OSMesaRenderer
+else:
+ # This is a workaround that allows imports from `dm_control.render` to succeed
+ # even when there is no rendering API available. We need this in order to run
+ # integration tests on headless servers.
+
+ def Renderer(*args, **kwargs):
+ del args, kwargs # Unused.
+ raise ImportError('No OpenGL rendering backend could be imported.')
+
+# pylint: enable=invalid-name
+
+
diff --git a/dm_control/render/base.py b/dm_control/render/base.py
new file mode 100644
index 00000000..6f1fb96b
--- /dev/null
+++ b/dm_control/render/base.py
@@ -0,0 +1,110 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Base class for OpenGL context handlers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import contextlib
+
+# Internal dependencies.
+import six
+
+
+@six.add_metaclass(abc.ABCMeta)
+class Renderer(object):
+ """Base `Renderer` class for managing OpenGL contexts."""
+
+ def __init__(self, max_width, max_height):
+ """Initializes this `Renderer`.
+
+ Arguments to this method are passed to `_create`.
+
+ Args:
+ max_width: Integer specifying the maximum framebuffer width in pixels.
+ max_height: Integer specifying the maximum framebuffer height in pixels.
+ """
+ self._max_width = max_width
+ self._max_height = max_height
+ self._create(max_width, max_height)
+
+ @abc.abstractmethod
+ def _create(self, max_width, max_height):
+ """Called internally by `__init__` to create the OpenGL context.
+
+ Args:
+ max_width: Integer specifying the maximum framebuffer width in pixels.
+ max_height: Integer specifying the maximum framebuffer height in pixels.
+ """
+
+ @contextlib.contextmanager
+ def make_current(self, width, height):
+ """Context manager that makes this Renderer's OpenGL context current.
+
+ Args:
+ width: Integer specifying the new framebuffer width in pixels.
+ height: Integer specifying the new framebuffer height in pixels.
+
+ Yields:
+ None
+
+ Raises:
+ ValueError: If width > max_width, or height > max_height.
+ """
+ if width > self._max_width:
+ raise ValueError('Maximal framebuffer width is {}. {} given.'
+ .format(self._max_width, width))
+ if height > self._max_height:
+ raise ValueError('Maximal framebuffer height is {}. {} given.'
+ .format(self._max_height, height))
+
+ previous_context = self._before_make_current(width, height)
+ try:
+ yield
+ finally:
+ self._after_make_current(previous_context)
+
+ @abc.abstractmethod
+ def _before_make_current(self, width, height):
+ """Called when entering the `make_current` context manager.
+
+ Args:
+ width: Integer specifying the new framebuffer width in pixels.
+ height: Integer specifying the new framebuffer height in pixels.
+
+ Returns:
+ Either a pointer to the previous OpenGL context to be passed to
+ `_after_make_current`, or else None.
+ """
+
+ @abc.abstractmethod
+ def _after_make_current(self, previous_context):
+ """Called when exiting the `make_current` context manager.
+
+ Args:
+ previous_context: The return value of `_before_make_current`. This should
+ either be a pointer to a previous OpenGL context to be made current, or
+ else None.
+ """
+
+ @abc.abstractmethod
+ def free_context(self):
+ """Frees resources associated with this context."""
+
+ def __del__(self):
+ self.free_context()
diff --git a/dm_control/render/glfw_renderer.py b/dm_control/render/glfw_renderer.py
new file mode 100644
index 00000000..94462bef
--- /dev/null
+++ b/dm_control/render/glfw_renderer.py
@@ -0,0 +1,66 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""An OpenGL renderer backed by GLFW."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+
+from dm_control.render import base
+import glfw
+
+_done_init_glfw = False
+
+
+def _maybe_init_glfw():
+ global _done_init_glfw
+ if not _done_init_glfw:
+ if not glfw.init():
+ raise OSError('Failed to initialize GLFW.')
+ _done_init_glfw = True
+
+
+class GLFWRenderer(base.Renderer):
+ """An OpenGL renderer backed by GLFW."""
+
+ def _create(self, max_width, max_height):
+ _maybe_init_glfw()
+ glfw.window_hint(glfw.VISIBLE, 0)
+ glfw.window_hint(glfw.DOUBLEBUFFER, 0)
+ self._context = glfw.create_window(width=max_width, height=max_height,
+ title='Invisible window', monitor=None,
+ share=None)
+ # This reference prevents `glfw` from being garbage-collected before the
+ # last window is destroyed, otherwise we may get `AttributeError`s when the
+ # `__del__` method is later called.
+ self._glfw = glfw
+
+ def _before_make_current(self, width, height):
+ previous_context = glfw.get_current_context()
+ glfw.make_context_current(self._context)
+ if (width, height) != glfw.get_window_size(self._context):
+ glfw.set_window_size(self._context, width, height)
+ return previous_context
+
+ def _after_make_current(self, previous_context):
+ glfw.make_context_current(previous_context)
+
+ def free_context(self):
+ if self._context is not None:
+ self._glfw.destroy_window(self._context)
+ self._context = None
diff --git a/dm_control/render/glfw_renderer_test.py b/dm_control/render/glfw_renderer_test.py
new file mode 100644
index 00000000..0822e9de
--- /dev/null
+++ b/dm_control/render/glfw_renderer_test.py
@@ -0,0 +1,55 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for GLFWRenderer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+
+from absl.testing import absltest
+
+from dm_control import render
+
+import glfw
+import mock
+
+MAX_WIDTH = 1024
+MAX_HEIGHT = 1024
+
+
+class GLFWRendererTest(absltest.TestCase):
+
+ @mock.patch(render.__name__ + ".glfw_renderer.glfw", spec=glfw)
+ def test_context_activation_and_deactivation(self, mock_glfw):
+ context = mock.MagicMock()
+
+ mock_glfw.create_window = mock.MagicMock(return_value=context)
+ mock_glfw.get_current_context = mock.MagicMock(return_value=None)
+
+ renderer = render.Renderer(MAX_WIDTH, MAX_HEIGHT)
+ renderer.make_context_current = mock.MagicMock()
+
+ with renderer.make_current(2, 2):
+ mock_glfw.make_context_current.assert_called_once_with(context)
+ mock_glfw.make_context_current.reset_mock()
+
+ mock_glfw.make_context_current.assert_called_once_with(None)
+
+
+if __name__ == "__main__":
+ absltest.main()
diff --git a/dm_control/rl/__init__.py b/dm_control/rl/__init__.py
new file mode 100644
index 00000000..be364790
--- /dev/null
+++ b/dm_control/rl/__init__.py
@@ -0,0 +1,16 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""RL interface code."""
diff --git a/dm_control/rl/control.py b/dm_control/rl/control.py
new file mode 100644
index 00000000..7290b053
--- /dev/null
+++ b/dm_control/rl/control.py
@@ -0,0 +1,370 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""An environment.Base subclass for control-specific environments."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import collections
+import contextlib
+
+# Internal dependencies.
+
+import numpy as np
+import six
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+from dm_control.rl import environment
+from dm_control.rl import specs
+
+FLAT_OBSERVATION_KEY = 'observations'
+
+
+class Environment(environment.Base):
+ """Class for physics-based reinforcement learning environments."""
+
+ def __init__(self,
+ physics,
+ task,
+ time_limit=float('inf'),
+ control_timestep=None,
+ n_sub_steps=None,
+ flat_observation=False):
+ """Initializes a new `Environment`.
+
+ Args:
+ physics: Instance of `Physics`.
+ task: Instance of `Task`.
+ time_limit: Optional `int`, maximum time for each episode in seconds. By
+ default this is set to infinite.
+ control_timestep: Optional control time-step, in seconds.
+ n_sub_steps: Optional number of physical time-steps in one control
+ time-step, aka "action repeats". Can only be supplied if
+ `control_timestep` is not specified.
+ flat_observation: If True, observations will be flattened and concatenated
+ into a single numpy array.
+
+ Raises:
+ ValueError: If both `n_sub_steps` and `control_timestep` are supplied.
+ """
+ self._task = task
+ self._physics = physics
+ self._time_limit = time_limit
+ self._flat_observation = flat_observation
+
+ if n_sub_steps is not None and control_timestep is not None:
+ raise ValueError('Both n_sub_steps and control_timestep were supplied.')
+ elif n_sub_steps is not None:
+ self._n_sub_steps = n_sub_steps
+ elif control_timestep is not None:
+ self._n_sub_steps = compute_n_steps(control_timestep,
+ self._physics.timestep())
+ else:
+ self._n_sub_steps = 1
+
+ self._reset_next_step = True
+
+ def reset(self):
+ """Starts a new episode and returns the first `TimeStep`."""
+ self._reset_next_step = False
+ with self._physics.reset_context():
+ self._task.initialize_episode(self._physics)
+
+ observation = self._task.get_observation(self._physics)
+ if self._flat_observation:
+ observation = flatten_observation(observation)
+
+ return environment.TimeStep(
+ step_type=environment.StepType.FIRST,
+ reward=None,
+ discount=None,
+ observation=observation)
+
+ def step(self, action):
+ """Updates the environment using the action and returns a `TimeStep`."""
+
+ if self._reset_next_step:
+ return self.reset()
+
+ self._task.before_step(action, self._physics)
+ for _ in xrange(self._n_sub_steps):
+ self._physics.step()
+ self._task.after_step(self._physics)
+
+ reward = self._task.get_reward(self._physics)
+ observation = self._task.get_observation(self._physics)
+ if self._flat_observation:
+ observation = flatten_observation(observation)
+
+ if self.physics.time() >= self._time_limit:
+ discount = 1.0
+ else:
+ discount = self._task.get_termination(self._physics)
+
+ if discount is None:
+ return environment.TimeStep(
+ environment.StepType.MID, reward, 1.0, observation)
+ else:
+ self._reset_next_step = True
+ return environment.TimeStep(
+ environment.StepType.LAST, reward, discount, observation)
+
+ def action_spec(self):
+ """Returns the action specification for this environment."""
+ return self._task.action_spec(self._physics)
+
+ def observation_spec(self):
+ """Returns the observation specification for this environment.
+
+ Infers the spec from the observation, unless the Task implements the
+ `observation_spec` method.
+
+ Returns:
+ An dict mapping observation name to `ArraySpec` containing observation
+ shape and dtype.
+ """
+ try:
+ return self._task.observation_spec(self._physics)
+ except NotImplementedError:
+ observation = self._task.get_observation(self._physics)
+ if self._flat_observation:
+ observation = flatten_observation(observation)
+ return _spec_from_observation(observation)
+
+ @property
+ def physics(self):
+ return self._physics
+
+ @property
+ def task(self):
+ return self._task
+
+ def control_timestep(self):
+ """Returns the interval between agent actions in seconds."""
+ return self.physics.timestep() * self._n_sub_steps
+
+
+def compute_n_steps(control_timestep, physics_timestep, tolerance=1e-8):
+ """Returns the number of physics timesteps in a single control timestep.
+
+ Args:
+ control_timestep: Control time-step, should be an integer multiple of the
+ physics timestep.
+ physics_timestep: The time-step of the physics simulation.
+ tolerance: Optional tolerance value for checking if `physics_timestep`
+ divides `control_timestep`.
+
+ Returns:
+ The number of physics timesteps in a single control timestep.
+
+ Raises:
+ ValueError: If `control_timestep` is smaller than `physics_timestep` or if
+ `control_timestep` is not an integer multiple of `physics_timestep`.
+ """
+ if control_timestep < physics_timestep:
+ raise ValueError(
+ 'Control timestep ({}) cannot be smaller than physics timestep ({}).'.
+ format(control_timestep, physics_timestep))
+ if abs((control_timestep / physics_timestep - round(
+ control_timestep / physics_timestep))) > tolerance:
+ raise ValueError(
+ 'Control timestep ({}) must be an integer multiple of physics timestep '
+ '({})'.format(control_timestep, physics_timestep))
+ return int(round(control_timestep / physics_timestep))
+
+
+def _spec_from_observation(observation):
+ result = collections.OrderedDict()
+ for key, value in six.iteritems(observation):
+ result[key] = specs.ArraySpec(value.shape, value.dtype)
+ return result
+
+# Base class definitions for objects supplied to Environment.
+
+
+@six.add_metaclass(abc.ABCMeta)
+class Physics(object):
+ """Simulates a physical environment."""
+
+ @abc.abstractmethod
+ def step(self, n_sub_steps=1):
+ """Updates the simulation state.
+
+ Args:
+ n_sub_steps: Optional number of times to repeatedly update the simulation
+ state. Defaults to 1.
+ """
+
+ @abc.abstractmethod
+ def time(self):
+ """Returns the elapsed simulation time in seconds."""
+
+ @abc.abstractmethod
+ def timestep(self):
+ """Returns the simulation timestep."""
+
+ def set_control(self, control):
+ """Sets the control signal for the actuators."""
+ raise NotImplementedError('set_control is not supported.')
+
+ @contextlib.contextmanager
+ def reset_context(self):
+ """Context manager for resetting the simulation state.
+
+ Sets the internal simulation to a default state when entering the block.
+
+ ```python
+ with physics.reset_context():
+ # Set joint and object positions.
+
+ physics.step()
+ ```
+
+ Yields:
+ The `Physics` instance.
+ """
+ self.reset()
+ yield self
+ self.after_reset()
+
+ @abc.abstractmethod
+ def reset(self):
+ """Resets internal variables of the physics simulation."""
+
+ @abc.abstractmethod
+ def after_reset(self):
+ """Runs after resetting internal variables of the physics simulation."""
+
+ def check_divergence(self):
+ """Raises a `PhysicsError` if the simulation state is divergent.
+
+ The default implementation is a no-op.
+ """
+
+
+class PhysicsError(RuntimeError):
+ """Raised if the state of the physics simulation becomes divergent."""
+
+
+@six.add_metaclass(abc.ABCMeta)
+class Task(object):
+ """Defines a task in a `control.Environment`."""
+
+ @abc.abstractmethod
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode.
+
+ Called by `control.Environment` at the start of each episode *within*
+ `physics.reset_context()` (see the documentation for `base.Physics`).
+
+ Args:
+ physics: Instance of `Physics`.
+ """
+
+ @abc.abstractmethod
+ def before_step(self, action, physics):
+ """Updates the task from the provided action.
+
+ Called by `control.Environment` before stepping the physics engine.
+
+ Args:
+ action: Actions proto.
+ physics: Instance of `Physics`.
+ """
+
+ def after_step(self, physics):
+ """Optional method to update the task after the physics engine has stepped.
+
+ Called by `control.Environment` after stepping the physics engine and before
+ `control.Environment` calls `get_observation, `get_reward` and
+ `get_termination`.
+
+ The default implementation is a no-op.
+
+ Args:
+ physics: Instance of `Physics`.
+ """
+
+ @abc.abstractmethod
+ def action_spec(self, physics):
+ """Returns a nested structure of `ArraySpec`s describing the actions.
+
+ Args:
+ physics: Instance of `Physics`.
+ """
+
+ @abc.abstractmethod
+ def get_observation(self, physics):
+ """Returns an observation from the environment.
+
+ Args:
+ physics: Instance of `Physics`.
+ """
+
+ @abc.abstractmethod
+ def get_reward(self, physics):
+ """Returns a reward from the environment.
+
+ Args:
+ physics: Instance of `Physics`.
+ """
+
+ def get_termination(self, physics):
+ """If the episode should end, returns a final discount, otherwise None."""
+
+ def observation_spec(self, physics):
+ """Optional method that returns the observation spec.
+
+ If not implemented, the Environment infers the spec from the observation.
+
+ Args:
+ physics: Instance of `Physics`.
+
+ Returns:
+ A dict mapping observation name to `ArraySpec` containing observation
+ shape and dtype.
+ """
+ raise NotImplementedError()
+
+
+def flatten_observation(observation, output_key=FLAT_OBSERVATION_KEY):
+ """Flattens multiple observation arrays into a single numpy array.
+
+ Args:
+ observation: A mutable mapping from observation names to numpy arrays.
+ output_key: The key for the flattened observation array in the output.
+
+ Returns:
+ A mutable mapping of the same type as `observation`. This will contain a
+ single key-value pair consisting of `output_key` and the flattened
+ and concatenated observation array.
+
+ Raises:
+ ValueError: If `observation` is not a `collections.MutableMapping`.
+ """
+ if not isinstance(observation, collections.MutableMapping):
+ raise ValueError('Can only flatten dict-like observations.')
+
+ if isinstance(observation, collections.OrderedDict):
+ keys = six.iterkeys(observation)
+ else:
+ # Keep a consistent ordering for other mappings.
+ keys = sorted(six.iterkeys(observation))
+
+ observation_arrays = [observation[key].ravel() for key in keys]
+ return type(observation)([(output_key, np.concatenate(observation_arrays))])
diff --git a/dm_control/rl/control_test.py b/dm_control/rl/control_test.py
new file mode 100644
index 00000000..f7def19f
--- /dev/null
+++ b/dm_control/rl/control_test.py
@@ -0,0 +1,129 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Control Environment tests."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+
+from absl.testing import absltest
+from absl.testing import parameterized
+
+from dm_control.rl import control
+
+import mock
+import numpy as np
+
+from dm_control.rl import specs
+
+_CONSTANT_REWARD_VALUE = 1.0
+_CONSTANT_OBSERVATION = {'observations': np.asarray(_CONSTANT_REWARD_VALUE)}
+
+_ACTION_SPEC = specs.BoundedArraySpec(
+ shape=(1,), dtype=np.float, minimum=0.0, maximum=1.0)
+_OBSERVATION_SPEC = {'observations': specs.ArraySpec(shape=(), dtype=np.float)}
+
+
+class EnvironmentTest(parameterized.TestCase):
+
+ def setUp(self):
+ self._task = mock.Mock(spec=control.Task)
+ self._task.initialize_episode = mock.Mock()
+ self._task.get_observation = mock.Mock(return_value=_CONSTANT_OBSERVATION)
+ self._task.get_reward = mock.Mock(return_value=_CONSTANT_REWARD_VALUE)
+ self._task.get_termination = mock.Mock(return_value=None)
+ self._task.action_spec = mock.Mock(return_value=_ACTION_SPEC)
+ self._task.observation_spec.side_effect = NotImplementedError()
+
+ self._physics = mock.Mock(spec=control.Physics)
+ self._physics.time = mock.Mock(return_value=0.0)
+
+ self._physics.reset_context = mock.MagicMock()
+
+ self._env = control.Environment(physics=self._physics, task=self._task)
+
+ def test_environment_calls(self):
+ self._env.action_spec()
+ self._task.action_spec.assert_called_with(self._physics)
+
+ self._env.reset()
+ self._task.initialize_episode.assert_called_with(self._physics)
+ self._task.get_observation.assert_called_with(self._physics)
+
+ action = [1]
+ time_step = self._env.step(action)
+
+ self._task.before_step.assert_called()
+ self._task.after_step.assert_called_with(self._physics)
+ self._task.get_termination.assert_called_with(self._physics)
+
+ self.assertEquals(_CONSTANT_REWARD_VALUE, time_step.reward)
+
+ def test_timeout(self):
+ self._physics.time = mock.Mock(return_value=2.)
+ env = control.Environment(
+ physics=self._physics, task=self._task, time_limit=1.)
+ env.reset()
+ time_step = env.step([1])
+ self.assertTrue(time_step.last())
+
+ time_step = env.step([1])
+ self.assertTrue(time_step.first())
+
+ def test_observation_spec(self):
+ observation_spec = self._env.observation_spec()
+ self.assertEqual(_OBSERVATION_SPEC, observation_spec)
+
+ def test_redundant_args_error(self):
+ with self.assertRaises(ValueError):
+ control.Environment(physics=self._physics, task=self._task,
+ n_sub_steps=2, control_timestep=0.1)
+
+ def test_control_timestep(self):
+ self._physics.timestep.return_value = .002
+ env = control.Environment(
+ physics=self._physics, task=self._task, n_sub_steps=5)
+ self.assertEqual(.01, env.control_timestep())
+
+ def test_flatten_observations(self):
+ multimodal_obs = dict(_CONSTANT_OBSERVATION)
+ multimodal_obs['sensor'] = np.zeros(7, dtype=np.bool)
+ self._task.get_observation = mock.Mock(return_value=multimodal_obs)
+ env = control.Environment(
+ physics=self._physics, task=self._task, flat_observation=True)
+ timestep = env.reset()
+ self.assertEqual(len(timestep.observation), 1)
+ self.assertEqual(timestep.observation[control.FLAT_OBSERVATION_KEY].size,
+ 1 + 7)
+
+
+class ComputeNStepsTest(parameterized.TestCase):
+
+ @parameterized.parameters((0.2, 0.1, 2), (.111, .111, 1), (100, 5, 20),
+ (0.03, 0.005, 6))
+ def testComputeNSteps(self, control_timestep, physics_timestep, expected):
+ steps = control.compute_n_steps(control_timestep, physics_timestep)
+ self.assertEquals(expected, steps)
+
+ @parameterized.parameters((3, 2), (.003, .00101))
+ def testComputeNStepsFailures(self, control_timestep, physics_timestep):
+ with self.assertRaises(ValueError):
+ control.compute_n_steps(control_timestep, physics_timestep)
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/rl/environment.py b/dm_control/rl/environment.py
new file mode 100644
index 00000000..0a9b981d
--- /dev/null
+++ b/dm_control/rl/environment.py
@@ -0,0 +1,203 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Python RL Environment API."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import collections
+
+# Internal dependencies.
+
+import enum
+import six
+
+
+class TimeStep(collections.namedtuple(
+ 'TimeStep', ['step_type', 'reward', 'discount', 'observation'])):
+ """Returned with every call to `step` and `reset` on an environment.
+
+ A `TimeStep` contains the data emitted by an environment at each step of
+ interaction. A `TimeStep` holds a `step_type`, an `observation` (typically a
+ NumPy array or a dict or list of arrays), and an associated `reward` and
+ `discount`.
+
+ The first `TimeStep` in a sequence will have `StepType.FIRST`. The final
+ `TimeStep` will have `StepType.LAST`. All other `TimeStep`s in a sequence will
+ have `StepType.MID.
+
+ Attributes:
+ step_type: A `StepType` enum value.
+ reward: A scalar, or `None` if `step_type` is `StepType.FIRST`, i.e. at the
+ start of a sequence.
+ discount: A discount value in the range `[0, 1]`, or `None` if `step_type`
+ is `StepType.FIRST`, i.e. at the start of a sequence.
+ observation: A NumPy array, or a nested dict, list or tuple of arrays.
+ """
+ __slots__ = ()
+
+ def first(self):
+ return self.step_type is StepType.FIRST
+
+ def mid(self):
+ return self.step_type is StepType.MID
+
+ def last(self):
+ return self.step_type is StepType.LAST
+
+
+class StepType(enum.IntEnum):
+ """Defines the status of a `TimeStep` within a sequence."""
+ # Denotes the first `TimeStep` in a sequence.
+ FIRST = 0
+ # Denotes any `TimeStep` in a sequence that is not FIRST or LAST.
+ MID = 1
+ # Denotes the last `TimeStep` in a sequence.
+ LAST = 2
+
+ def first(self):
+ return self is StepType.FIRST
+
+ def mid(self):
+ return self is StepType.MID
+
+ def last(self):
+ return self is StepType.LAST
+
+
+@six.add_metaclass(abc.ABCMeta)
+class Base(object):
+ """Abstract base class for Python RL environments.
+
+ Observations and valid actions are described with `ArraySpec`s, defined in
+ the `specs` module.
+ """
+
+ @abc.abstractmethod
+ def reset(self):
+ """Starts a new sequence and returns the first `TimeStep` of this sequence.
+
+ Returns:
+ A `TimeStep` namedtuple containing:
+ step_type: A `StepType` of `FIRST`.
+ reward: `None`, indicating the reward is undefined.
+ discount: `None`, indicating the discount is undefined.
+ observation: A NumPy array, or a nested dict, list or tuple of arrays
+ corresponding to `observation_spec()`.
+ """
+
+ @abc.abstractmethod
+ def step(self, action):
+ """Updates the environment according to the action and returns a `TimeStep`.
+
+ If the environment returned a `TimeStep` with `StepType.LAST` at the
+ previous step, this call to `step` will start a new sequence and `action`
+ will be ignored.
+
+ This method will also start a new sequence if called after the environment
+ has been constructed and `reset` has not been called. Again, in this case
+ `action` will be ignored.
+
+ Args:
+ action: A NumPy array, or a nested dict, list or tuple of arrays
+ corresponding to `action_spec()`.
+
+ Returns:
+ A `TimeStep` namedtuple containing:
+ step_type: A `StepType` value.
+ reward: Reward at this timestep, or None if step_type is
+ `StepType.FIRST`.
+ discount: A discount in the range [0, 1], or None if step_type is
+ `StepType.FIRST`.
+ observation: A NumPy array, or a nested dict, list or tuple of arrays
+ corresponding to `observation_spec()`.
+ """
+
+ @abc.abstractmethod
+ def observation_spec(self):
+ """Defines the observations provided by the environment.
+
+ May use a subclass of `ArraySpec` that specifies additional properties such
+ as min and max bounds on the values.
+
+ Returns:
+ An `ArraySpec`, or a nested dict, list or tuple of `ArraySpec`s.
+ """
+
+ @abc.abstractmethod
+ def action_spec(self):
+ """Defines the actions that should be provided to `step`.
+
+ May use a subclass of `ArraySpec` that specifies additional properties such
+ as min and max bounds on the values.
+
+ Returns:
+ An `ArraySpec`, or a nested dict, list or tuple of `ArraySpec`s.
+ """
+
+ def close(self):
+ """Frees any resources used by the environment.
+
+ Implement this method for an environment backed by an external process.
+
+ This method be used directly
+
+ ```python
+ env = Env(...)
+ # Use env.
+ env.close()
+ ```
+
+ or via a context manager
+
+ ```python
+ with Env(...) as env:
+ # Use env.
+ ```
+ """
+ pass
+
+ def __enter__(self):
+ """Allows the environment to be used in a with-statement context."""
+ return self
+
+ def __exit__(self, unused_exception_type, unused_exc_value, unused_traceback):
+ """Allows the environment to be used in a with-statement context."""
+ self.close()
+
+# Helper functions for creating TimeStep namedtuples with default settings.
+
+
+def restart(observation):
+ """Returns a `TimeStep` with `step_type` set to `StepType.FIRST`."""
+ return TimeStep(StepType.FIRST, None, None, observation)
+
+
+def transition(reward, observation, discount=1.0):
+ """Returns a `TimeStep` with `step_type` set to `StepType.MID`."""
+ return TimeStep(StepType.MID, reward, discount, observation)
+
+
+def termination(reward, observation):
+ """Returns a `TimeStep` with `step_type` set to `StepType.LAST`."""
+ return TimeStep(StepType.LAST, reward, 0.0, observation)
+
+
+def truncation(reward, observation, discount=1.0):
+ """Returns a `TimeStep` with `step_type` set to `StepType.LAST`."""
+ return TimeStep(StepType.LAST, reward, discount, observation)
diff --git a/dm_control/rl/specs.py b/dm_control/rl/specs.py
new file mode 100644
index 00000000..4e52dc39
--- /dev/null
+++ b/dm_control/rl/specs.py
@@ -0,0 +1,210 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Classes that describe the shape and dtype of numpy arrays."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+
+import numpy as np
+
+
+class ArraySpec(object):
+ """Describes a numpy array or scalar shape and dtype.
+
+ An `ArraySpec` allows an API to describe the arrays that it accepts or
+ returns, before that array exists.
+ The equivalent version describing a `tf.Tensor` is `TensorSpec`.
+ """
+ __slots__ = ('_shape', '_dtype', '_name')
+
+ def __init__(self, shape, dtype, name=None):
+ """Initializes a new `ArraySpec`.
+
+ Args:
+ shape: An iterable specifying the array shape.
+ dtype: numpy dtype or string specifying the array dtype.
+ name: Optional string containing a semantic name for the corresponding
+ array. Defaults to `None`.
+
+ Raises:
+ TypeError: If the shape is not an iterable or if the `dtype` is an invalid
+ numpy dtype.
+ """
+ self._shape = tuple(shape)
+ self._dtype = np.dtype(dtype)
+ self._name = name
+
+ @property
+ def shape(self):
+ """Returns a `tuple` specifying the array shape."""
+ return self._shape
+
+ @property
+ def dtype(self):
+ """Returns a numpy dtype specifying the array dtype."""
+ return self._dtype
+
+ @property
+ def name(self):
+ """Returns the name of the ArraySpec."""
+ return self._name
+
+ def __repr__(self):
+ return 'ArraySpec(shape={}, dtype={}, name={})'.format(self.shape,
+ repr(self.dtype),
+ repr(self.name))
+
+ def __eq__(self, other):
+ """Checks if the shape and dtype of two specs are equal."""
+ if not isinstance(other, ArraySpec):
+ return False
+ return self.shape == other.shape and self.dtype == other.dtype
+
+ def __ne__(self, other):
+ return not self == other
+
+ def _fail_validation(self, message, *args):
+ message %= args
+ if self.name:
+ message += ' for spec %s' % self.name
+ raise ValueError(message)
+
+ def validate(self, value):
+ """Checks if value conforms to this spec.
+
+ Args:
+ value: a numpy array or value convertible to one via `np.asarray`.
+
+ Returns:
+ value, converted if necessary to a numpy array.
+
+ Raises:
+ ValueError: if value doesn't conform to this spec.
+ """
+ value = np.asarray(value)
+ if value.shape != self.shape:
+ self._fail_validation(
+ 'Expected shape %r but found %r', self.shape, value.shape)
+ if value.dtype != self.dtype:
+ self._fail_validation(
+ 'Expected dtype %s but found %s', self.dtype, value.dtype)
+
+ def generate_value(self):
+ """Generate a test value which conforms to this spec."""
+ return np.zeros(shape=self.shape, dtype=self.dtype)
+
+
+class BoundedArraySpec(ArraySpec):
+ """An `ArraySpec` that specifies minimum and maximum values.
+
+ Example usage:
+ ```python
+ # Specifying the same minimum and maximum for every element.
+ spec = BoundedArraySpec((3, 4), np.float64, minimum=0.0, maximum=1.0)
+
+ # Specifying a different minimum and maximum for each element.
+ spec = BoundedArraySpec(
+ (2,), np.float64, minimum=[0.1, 0.2], maximum=[0.9, 0.9])
+
+ # Specifying the same minimum and a different maximum for each element.
+ spec = BoundedArraySpec(
+ (3,), np.float64, minimum=-10.0, maximum=[4.0, 5.0, 3.0])
+ ```
+
+ Bounds are meant to be inclusive. This is especially important for
+ integer types. The following spec will be satisfied by arrays
+ with values in the set {0, 1, 2}:
+ ```python
+ spec = BoundedArraySpec((3, 4), np.int, minimum=0, maximum=2)
+ ```
+ """
+
+ __slots__ = ('_minimum', '_maximum')
+
+ def __init__(self, shape, dtype, minimum, maximum, name=None):
+ """Initializes a new `BoundedArraySpec`.
+
+ Args:
+ shape: An iterable specifying the array shape.
+ dtype: numpy dtype or string specifying the array dtype.
+ minimum: Number or sequence specifying the maximum element bounds
+ (inclusive). Must be broadcastable to `shape`.
+ maximum: Number or sequence specifying the maximum element bounds
+ (inclusive). Must be broadcastable to `shape`.
+ name: Optional string containing a semantic name for the corresponding
+ array. Defaults to `None`.
+
+ Raises:
+ ValueError: If `minimum` or `maximum` are not broadcastable to `shape`.
+ TypeError: If the shape is not an iterable or if the `dtype` is an invalid
+ numpy dtype.
+ """
+ super(BoundedArraySpec, self).__init__(shape, dtype, name)
+
+ try:
+ np.broadcast_to(minimum, shape=shape)
+ except ValueError as numpy_exception:
+ raise ValueError('minimum is not compatible with shape. '
+ 'Message: {!r}.'.format(numpy_exception))
+
+ try:
+ np.broadcast_to(maximum, shape=shape)
+ except ValueError as numpy_exception:
+ raise ValueError('maximum is not compatible with shape. '
+ 'Message: {!r}.'.format(numpy_exception))
+
+ self._minimum = np.array(minimum)
+ self._minimum.setflags(write=False)
+
+ self._maximum = np.array(maximum)
+ self._maximum.setflags(write=False)
+
+ @property
+ def minimum(self):
+ """Returns a NumPy array specifying the minimum bounds (inclusive)."""
+ return self._minimum
+
+ @property
+ def maximum(self):
+ """Returns a NumPy array specifying the maximum bounds (inclusive)."""
+ return self._maximum
+
+ def __repr__(self):
+ template = ('BoundedArraySpec(shape={}, dtype={}, name={}, '
+ 'minimum={}, maximum={})')
+ return template.format(self.shape, repr(self.dtype), repr(self.name),
+ self._minimum, self._maximum)
+
+ def __eq__(self, other):
+ if not isinstance(other, BoundedArraySpec):
+ return False
+ return (super(BoundedArraySpec, self).__eq__(other) and
+ (self.minimum == other.minimum).all() and
+ (self.maximum == other.maximum).all())
+
+ def validate(self, value):
+ value = np.asarray(value)
+ super(BoundedArraySpec, self).validate(value)
+ if (value < self.minimum).any() or (value > self.maximum).any():
+ self._fail_validation(
+ 'Values were not all within bounds %s <= value <= %s',
+ self.minimum, self.maximum)
+
+ def generate_value(self):
+ return (np.ones(shape=self.shape, dtype=self.dtype) *
+ self.dtype.type(self.minimum))
diff --git a/dm_control/rl/specs_test.py b/dm_control/rl/specs_test.py
new file mode 100644
index 00000000..1b3feaaf
--- /dev/null
+++ b/dm_control/rl/specs_test.py
@@ -0,0 +1,188 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests for specs."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+
+from absl.testing import absltest
+from dm_control.rl import specs as array_spec
+import numpy as np
+
+
+class ArraySpecTest(absltest.TestCase):
+
+ def testShapeTypeError(self):
+ with self.assertRaises(TypeError):
+ array_spec.ArraySpec(32, np.int32)
+
+ def testDtypeTypeError(self):
+ with self.assertRaises(TypeError):
+ array_spec.ArraySpec((1, 2, 3), "32")
+
+ def testStringDtype(self):
+ array_spec.ArraySpec((1, 2, 3), "int32")
+
+ def testNumpyDtype(self):
+ array_spec.ArraySpec((1, 2, 3), np.int32)
+
+ def testDtype(self):
+ spec = array_spec.ArraySpec((1, 2, 3), np.int32)
+ self.assertEqual(np.int32, spec.dtype)
+
+ def testShape(self):
+ spec = array_spec.ArraySpec([1, 2, 3], np.int32)
+ self.assertEqual((1, 2, 3), spec.shape)
+
+ def testEqual(self):
+ spec_1 = array_spec.ArraySpec((1, 2, 3), np.int32)
+ spec_2 = array_spec.ArraySpec((1, 2, 3), np.int32)
+ self.assertEqual(spec_1, spec_2)
+
+ def testNotEqualDifferentShape(self):
+ spec_1 = array_spec.ArraySpec((1, 2, 3), np.int32)
+ spec_2 = array_spec.ArraySpec((1, 3, 3), np.int32)
+ self.assertNotEqual(spec_1, spec_2)
+
+ def testNotEqualDifferentDtype(self):
+ spec_1 = array_spec.ArraySpec((1, 2, 3), np.int64)
+ spec_2 = array_spec.ArraySpec((1, 2, 3), np.int32)
+ self.assertNotEqual(spec_1, spec_2)
+
+ def testNotEqualOtherClass(self):
+ spec_1 = array_spec.ArraySpec((1, 2, 3), np.int32)
+ spec_2 = None
+ self.assertNotEqual(spec_1, spec_2)
+ self.assertNotEqual(spec_2, spec_1)
+
+ spec_2 = ()
+ self.assertNotEqual(spec_1, spec_2)
+ self.assertNotEqual(spec_2, spec_1)
+
+ def testValidateDtype(self):
+ spec = array_spec.ArraySpec((1, 2), np.int32)
+ spec.validate(np.zeros((1, 2), dtype=np.int32))
+ with self.assertRaises(ValueError):
+ spec.validate(np.zeros((1, 2), dtype=np.float32))
+
+ def testValidateShape(self):
+ spec = array_spec.ArraySpec((1, 2), np.int32)
+ spec.validate(np.zeros((1, 2), dtype=np.int32))
+ with self.assertRaises(ValueError):
+ spec.validate(np.zeros((1, 2, 3), dtype=np.int32))
+
+ def testGenerateValue(self):
+ spec = array_spec.ArraySpec((1, 2), np.int32)
+ test_value = spec.generate_value()
+ spec.validate(test_value)
+
+
+class BoundedArraySpecTest(absltest.TestCase):
+
+ def testInvalidMinimum(self):
+ with self.assertRaisesRegexp(ValueError, "not compatible"):
+ array_spec.BoundedArraySpec((3, 5), np.uint8, (0, 0, 0), (1, 1))
+
+ def testInvalidMaximum(self):
+ with self.assertRaisesRegexp(ValueError, "not compatible"):
+ array_spec.BoundedArraySpec((3, 5), np.uint8, 0, (1, 1, 1))
+
+ def testMinMaxAttributes(self):
+ spec = array_spec.BoundedArraySpec((1, 2, 3), np.float32, 0, (5, 5, 5))
+ self.assertEqual(type(spec.minimum), np.ndarray)
+ self.assertEqual(type(spec.maximum), np.ndarray)
+
+ def testNotWriteable(self):
+ spec = array_spec.BoundedArraySpec((1, 2, 3), np.float32, 0, (5, 5, 5))
+ with self.assertRaisesRegexp(ValueError, "read-only"):
+ spec.minimum[0] = -1
+ with self.assertRaisesRegexp(ValueError, "read-only"):
+ spec.maximum[0] = 100
+
+ def testEqualBroadcastingBounds(self):
+ spec_1 = array_spec.BoundedArraySpec(
+ (1, 2), np.int32, minimum=0.0, maximum=1.0)
+ spec_2 = array_spec.BoundedArraySpec(
+ (1, 2), np.int32, minimum=[0.0, 0.0], maximum=[1.0, 1.0])
+ self.assertEqual(spec_1, spec_2)
+
+ def testNotEqualDifferentMinimum(self):
+ spec_1 = array_spec.BoundedArraySpec(
+ (1, 2), np.int32, minimum=[0.0, -0.6], maximum=[1.0, 1.0])
+ spec_2 = array_spec.BoundedArraySpec(
+ (1, 2), np.int32, minimum=[0.0, 0.0], maximum=[1.0, 1.0])
+ self.assertNotEqual(spec_1, spec_2)
+
+ def testNotEqualOtherClass(self):
+ spec_1 = array_spec.BoundedArraySpec(
+ (1, 2), np.int32, minimum=[0.0, -0.6], maximum=[1.0, 1.0])
+ spec_2 = array_spec.ArraySpec((1, 2), np.int32)
+ self.assertNotEqual(spec_1, spec_2)
+ self.assertNotEqual(spec_2, spec_1)
+
+ spec_2 = None
+ self.assertNotEqual(spec_1, spec_2)
+ self.assertNotEqual(spec_2, spec_1)
+
+ spec_2 = ()
+ self.assertNotEqual(spec_1, spec_2)
+ self.assertNotEqual(spec_2, spec_1)
+
+ def testNotEqualDifferentMaximum(self):
+ spec_1 = array_spec.BoundedArraySpec(
+ (1, 2), np.int32, minimum=0.0, maximum=2.0)
+ spec_2 = array_spec.BoundedArraySpec(
+ (1, 2), np.int32, minimum=[0.0, 0.0], maximum=[1.0, 1.0])
+ self.assertNotEqual(spec_1, spec_2)
+
+ def testRepr(self):
+ as_string = repr(array_spec.BoundedArraySpec(
+ (1, 2), np.int32, minimum=101.0, maximum=73.0))
+ self.assertIn("101", as_string)
+ self.assertIn("73", as_string)
+
+ def testValidateBounds(self):
+ spec = array_spec.BoundedArraySpec((2, 2), np.int32, minimum=5, maximum=10)
+ spec.validate(np.array([[5, 6], [8, 10]], dtype=np.int32))
+ with self.assertRaises(ValueError):
+ spec.validate(np.array([[5, 6], [8, 11]], dtype=np.int32))
+ with self.assertRaises(ValueError):
+ spec.validate(np.array([[4, 6], [8, 10]], dtype=np.int32))
+
+ def testGenerateValue(self):
+ spec = array_spec.BoundedArraySpec((2, 2), np.int32, minimum=5, maximum=10)
+ test_value = spec.generate_value()
+ spec.validate(test_value)
+
+ def testScalarBounds(self):
+ spec = array_spec.BoundedArraySpec((), np.float, minimum=0.0, maximum=1.0)
+
+ self.assertIsInstance(spec.minimum, np.ndarray)
+ self.assertIsInstance(spec.maximum, np.ndarray)
+
+ # Sanity check that numpy compares correctly to a scalar for an empty shape.
+ self.assertEqual(0.0, spec.minimum)
+ self.assertEqual(1.0, spec.maximum)
+
+ # Check that the spec doesn't fail its own input validation.
+ _ = array_spec.BoundedArraySpec(
+ spec.shape, spec.dtype, spec.minimum, spec.maximum)
+
+
+if __name__ == "__main__":
+ absltest.main()
diff --git a/dm_control/suite/README.md b/dm_control/suite/README.md
new file mode 100644
index 00000000..d5fabc7d
--- /dev/null
+++ b/dm_control/suite/README.md
@@ -0,0 +1,4 @@
+# DeepMind Control Suite.
+
+This directory contains the domains and tasks described in the
+*DeepMind Control Suite* paper.
diff --git a/dm_control/suite/__init__.py b/dm_control/suite/__init__.py
new file mode 100644
index 00000000..021a4be9
--- /dev/null
+++ b/dm_control/suite/__init__.py
@@ -0,0 +1,142 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""A collection of MuJoCo-based Reinforcement Learning environments."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import inspect
+import itertools
+
+from dm_control.rl import control
+
+from dm_control.suite import acrobot
+from dm_control.suite import ball_in_cup
+from dm_control.suite import cartpole
+from dm_control.suite import cheetah
+from dm_control.suite import finger
+from dm_control.suite import fish
+from dm_control.suite import hopper
+from dm_control.suite import humanoid
+from dm_control.suite import humanoid_CMU
+from dm_control.suite import lqr
+from dm_control.suite import manipulator
+from dm_control.suite import pendulum
+from dm_control.suite import point_mass
+from dm_control.suite import reacher
+from dm_control.suite import stacker
+from dm_control.suite import swimmer
+from dm_control.suite import walker
+
+# Find all domains imported.
+_DOMAINS = {name: module for name, module in locals().items()
+ if inspect.ismodule(module) and hasattr(module, 'SUITE')}
+
+
+def _get_tasks(tag):
+ """Returns a sequence of (domain name, task name) pairs for the given tag."""
+ result = []
+
+ for domain_name in sorted(_DOMAINS.keys()):
+ domain = _DOMAINS[domain_name]
+
+ if tag is None:
+ tasks_in_domain = domain.SUITE
+ else:
+ tasks_in_domain = domain.SUITE.tagged(tag)
+
+ for task_name in tasks_in_domain.keys():
+ result.append((domain_name, task_name))
+
+ return tuple(result)
+
+
+def _get_tasks_by_domain(tasks):
+ """Returns a dict mapping from task name to a tuple of domain names."""
+ result = collections.defaultdict(list)
+
+ for domain_name, task_name in tasks:
+ result[domain_name].append(task_name)
+
+ return {k: tuple(v) for k, v in result.items()}
+
+
+# A sequence containing all (domain name, task name) pairs.
+ALL_TASKS = _get_tasks(tag=None)
+
+# Subsets of ALL_TASKS, generated via the tag mechanism.
+BENCHMARKING = _get_tasks('benchmarking')
+EASY = _get_tasks('easy')
+HARD = _get_tasks('hard')
+EXTRA = tuple(sorted(set(ALL_TASKS) - set(BENCHMARKING)))
+
+# A mapping from each domain name to a sequence of its task names.
+TASKS_BY_DOMAIN = _get_tasks_by_domain(ALL_TASKS)
+
+
+def load(domain_name, task_name, task_kwargs=None, visualize_reward=False):
+ """Returns an environment from a domain name, task name and optional settings.
+
+ ```python
+ env = suite.load('cartpole', 'balance')
+ ```
+
+ Args:
+ domain_name: A string containing the name of a domain.
+ task_name: A string containing the name of a task.
+ task_kwargs: Optional `dict` of keyword arguments for the task.
+ visualize_reward: Optional `bool`. If `True`, object colours in rendered
+ frames are set to indicate the reward at each step. Default `False`.
+
+ Returns:
+ The requested environment.
+ """
+ return build_environment(domain_name, task_name, task_kwargs,
+ visualize_reward)
+
+
+def build_environment(domain_name, task_name, task_kwargs=None,
+ visualize_reward=False):
+ """Returns an environment from the suite given a domain name and a task name.
+
+ Args:
+ domain_name: A string containing the name of a domain.
+ task_name: A string containing the name of a task.
+ task_kwargs: Optional `dict` specifying keyword arguments for the task.
+ visualize_reward: Optional `bool`. If `True`, object colours in rendered
+ frames are set to indicate the reward at each step. Default `False`.
+
+ Raises:
+ ValueError: If the domain or task doesn't exist.
+
+ Returns:
+ An instance of the requested environment.
+ """
+ if domain_name not in _DOMAINS:
+ raise ValueError('Domain {!r} does not exist.'.format(domain_name))
+
+ domain = _DOMAINS[domain_name]
+
+ if task_name not in domain.SUITE:
+ raise ValueError('Level {!r} does not exist in domain {!r}.'.format(
+ task_name, domain_name))
+
+ task_kwargs = task_kwargs or {}
+ env = domain.SUITE[task_name](**task_kwargs)
+ env.task.visualize_reward = visualize_reward
+ return env
diff --git a/dm_control/suite/acrobot.py b/dm_control/suite/acrobot.py
new file mode 100644
index 00000000..72fe30cb
--- /dev/null
+++ b/dm_control/suite/acrobot.py
@@ -0,0 +1,123 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Acrobot domain."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+# Internal dependencies.
+
+from dm_control import mujoco
+from dm_control.rl import control
+from dm_control.suite import base
+from dm_control.suite import common
+from dm_control.utils import containers
+from dm_control.utils import rewards
+
+import numpy as np
+
+_DEFAULT_TIME_LIMIT = 10
+SUITE = containers.TaggedTasks()
+
+
+def get_model_and_assets():
+ """Returns a tuple containing the model XML string and a dict of assets."""
+ return common.read_model('acrobot.xml'), common.ASSETS
+
+
+@SUITE.add('benchmarking')
+def swingup(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+ """Returns Acrobot balance task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Balance(sparse=False, random=random)
+ return control.Environment(physics, task, time_limit=time_limit)
+
+
+@SUITE.add('benchmarking')
+def swingup_sparse(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+ """Returns Acrobot sparse balance."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Balance(sparse=True, random=random)
+ return control.Environment(physics, task, time_limit=time_limit)
+
+
+class Physics(mujoco.Physics):
+ """Physics simulation with additional features for the Acrobot domain."""
+
+ def horizontal(self):
+ """Returns horizontal (x) component of body frame z-axes."""
+ return self.named.data.xmat[['upper_arm', 'lower_arm'], 'xz']
+
+ def vertical(self):
+ """Returns vertical (z) component of body frame z-axes."""
+ return self.named.data.xmat[['upper_arm', 'lower_arm'], 'zz']
+
+ def to_target(self):
+ """Returns the distance from the tip to the target."""
+ tip_to_target = (self.named.data.site_xpos['target'] -
+ self.named.data.site_xpos['tip'])
+ return np.linalg.norm(tip_to_target)
+
+ def orientations(self):
+ """Returns the sines and cosines of the pole angles."""
+ return np.concatenate((self.horizontal(), self.vertical()))
+
+
+class Balance(base.Task):
+ """An Acrobot `Task` to swing up and balance the pole."""
+
+ def __init__(self, sparse, random=None):
+ """Initializes an instance of `Balance`.
+
+ Args:
+ sparse: A `bool` specifying whether to use a sparse (indicator) reward.
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ self._sparse = sparse
+ super(Balance, self).__init__(random=random)
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode.
+
+ Shoulder and elbow are set to a random position between [-pi, pi).
+
+ Args:
+ physics: An instance of `Physics`.
+ """
+ physics.named.data.qpos[
+ ['shoulder', 'elbow']] = self.random.uniform(-np.pi, np.pi, 2)
+
+ def get_observation(self, physics):
+ """Returns an observation of pole orientation and angular velocities."""
+ obs = collections.OrderedDict()
+ obs['orientations'] = physics.orientations()
+ obs['velocity'] = physics.velocity()
+ return obs
+
+ def _get_reward(self, physics, sparse):
+ target_radius = physics.named.model.site_size['target', 0]
+ return rewards.tolerance(physics.to_target(),
+ bounds=(0, target_radius),
+ margin=0 if sparse else 1)
+
+ def get_reward(self, physics):
+ """Returns a sparse or a smooth reward, as specified in the constructor."""
+ return self._get_reward(physics, sparse=self._sparse)
diff --git a/dm_control/suite/acrobot.xml b/dm_control/suite/acrobot.xml
new file mode 100644
index 00000000..6d05fe35
--- /dev/null
+++ b/dm_control/suite/acrobot.xml
@@ -0,0 +1,43 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/suite/ball_in_cup.py b/dm_control/suite/ball_in_cup.py
new file mode 100644
index 00000000..2eab2471
--- /dev/null
+++ b/dm_control/suite/ball_in_cup.py
@@ -0,0 +1,99 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Ball-in-Cup Domain."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+# Internal dependencies.
+
+from dm_control import mujoco
+from dm_control.rl import control
+from dm_control.suite import base
+from dm_control.suite import common
+from dm_control.utils import containers
+
+_DEFAULT_TIME_LIMIT = 20 # (seconds)
+_CONTROL_TIMESTEP = .02 # (seconds)
+
+
+SUITE = containers.TaggedTasks()
+
+
+def get_model_and_assets():
+ """Returns a tuple containing the model XML string and a dict of assets."""
+ return common.read_model('ball_in_cup.xml'), common.ASSETS
+
+
+@SUITE.add('benchmarking', 'easy')
+def catch(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+ """Returns the Ball-in-Cup task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = BallInCup(random=random)
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP)
+
+
+class Physics(mujoco.Physics):
+ """Physics with additional features for the Ball-in-Cup domain."""
+
+ def ball_to_target(self):
+ """Returns the vector from the ball to the target."""
+ target = self.named.data.site_xpos['target', ['x', 'z']]
+ ball = self.named.data.xpos['ball', ['x', 'z']]
+ return target - ball
+
+ def in_target(self):
+ """Returns 1 if the ball is in the target, 0 otherwise."""
+ ball_to_target = abs(self.ball_to_target())
+ target_size = self.named.model.site_size['target', [0, 2]]
+ ball_size = self.named.model.geom_size['ball', 0]
+ return float(all(ball_to_target < target_size - ball_size))
+
+
+class BallInCup(base.Task):
+ """The Ball-in-Cup task. Put the ball in the cup."""
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode.
+
+ Args:
+ physics: An instance of `Physics`.
+
+ """
+ # Find a collision-free random initial position of the ball.
+ penetrating = True
+ while penetrating:
+ # Assign a random ball position.
+ physics.named.data.qpos['ball_x'] = self.random.uniform(-.2, .2)
+ physics.named.data.qpos['ball_z'] = self.random.uniform(.2, .5)
+ # Check for collisions.
+ physics.after_reset()
+ penetrating = physics.data.ncon > 0
+
+ def get_observation(self, physics):
+ """Returns an observation of the state."""
+ obs = collections.OrderedDict()
+ obs['position'] = physics.position()
+ obs['velocity'] = physics.velocity()
+ return obs
+
+ def get_reward(self, physics):
+ """Returns a sparse reward."""
+ return physics.in_target()
diff --git a/dm_control/suite/ball_in_cup.xml b/dm_control/suite/ball_in_cup.xml
new file mode 100644
index 00000000..792073f0
--- /dev/null
+++ b/dm_control/suite/ball_in_cup.xml
@@ -0,0 +1,54 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/suite/base.py b/dm_control/suite/base.py
new file mode 100644
index 00000000..3835b899
--- /dev/null
+++ b/dm_control/suite/base.py
@@ -0,0 +1,106 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Base class for tasks in the Control Suite."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+
+from dm_control import mujoco
+from dm_control.rl import control
+
+import numpy as np
+
+
+class Task(control.Task):
+ """Base class for tasks in the Control Suite.
+
+ Maps actions directly to the states of MuJoCo actuators.
+
+ Attributes:
+ random: A `numpy.random.RandomState` instance. This should be used to
+ generate all random variables associated with the task, such as random
+ starting states, observation noise* etc.
+
+ *If sensor noise is enabled in the MuJoCo model then this will be generated
+ using MuJoCo's internal RNG, which has its own independent state.
+ """
+
+ def __init__(self, random=None):
+ """Initializes a new continuous control task.
+
+ Args:
+ random: Optional, either a `numpy.random.RandomState` instance, an integer
+ seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ if not isinstance(random, np.random.RandomState):
+ random = np.random.RandomState(random)
+ self._random = random
+ self._visualize_reward = False
+
+ @property
+ def random(self):
+ """Task-specific `numpy.random.RandomState` instance."""
+ return self._random
+
+ def action_spec(self, physics):
+ """Returns actions corresponding to the Mujoco actuators."""
+ return mujoco.action_spec(physics)
+
+ def before_step(self, actions, physics):
+ """Sets actuation from the continuous actions."""
+ # Support legacy internal code.
+ try:
+ physics.set_control(actions.continuous_actions)
+ except AttributeError:
+ physics.set_control(actions)
+
+ # Reset any reward visualisation at the start of a new episode.
+ if self._visualize_reward and physics.time() == 0.0:
+ _set_reward_colors(physics, reward=0.0)
+
+ def after_step(self, physics):
+ """Modifies colors according to the reward."""
+ if self._visualize_reward:
+ reward = np.clip(self.get_reward(physics), 0.0, 1.0)
+ _set_reward_colors(physics, reward)
+
+ @property
+ def visualize_reward(self):
+ return self._visualize_reward
+
+ @visualize_reward.setter
+ def visualize_reward(self, value):
+ if not isinstance(value, bool):
+ raise ValueError("Expected a boolean, got {}.".format(type(value)))
+ self._visualize_reward = value
+
+
+def _set_reward_colors(physics, reward):
+ """Sets the highlight, effector and target colors according to the reward."""
+ assert 0.0 <= reward <= 1.0
+
+ colors = physics.named.model.mat_rgba
+
+ def blend(color1, color2):
+ return reward * colors[color1] + (1.0 - reward) * colors[color2]
+
+ colors["self"] = blend("self_highlight", "self_default")
+ colors["effector"] = blend("effector_highlight", "effector_default")
+ colors["target"] = blend("target_highlight", "target_default")
diff --git a/dm_control/suite/cartpole.py b/dm_control/suite/cartpole.py
new file mode 100644
index 00000000..18775f69
--- /dev/null
+++ b/dm_control/suite/cartpole.py
@@ -0,0 +1,214 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Cartpole domain."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+# Internal dependencies.
+
+from dm_control import mujoco
+from dm_control.rl import control
+from dm_control.suite import base
+from dm_control.suite import common
+from dm_control.utils import containers
+from dm_control.utils import rewards
+
+from lxml import etree
+import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+
+_DEFAULT_TIME_LIMIT = 10
+SUITE = containers.TaggedTasks()
+
+
+def get_model_and_assets(num_poles=1):
+ """Returns a tuple containing the model XML string and a dict of assets."""
+ return _make_model(num_poles), common.ASSETS
+
+
+@SUITE.add('benchmarking')
+def balance(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+ """Returns the Cartpole Balance task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Balance(swing_up=False, sparse=False, random=random)
+ return control.Environment(physics, task, time_limit=time_limit)
+
+
+@SUITE.add('benchmarking')
+def balance_sparse(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+ """Returns the sparse reward variant of the Cartpole Balance task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Balance(swing_up=False, sparse=True, random=random)
+ return control.Environment(physics, task, time_limit=time_limit)
+
+
+@SUITE.add('benchmarking')
+def swingup(time_limit=_DEFAULT_TIME_LIMIT, random=None, **kwargs):
+ """Returns the Cartpole Swing-Up task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Balance(swing_up=True, sparse=False, random=random)
+ return control.Environment(physics, task, time_limit=time_limit, **kwargs)
+
+
+@SUITE.add('benchmarking')
+def swingup_sparse(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+ """Returns the sparse reward variant of teh Cartpole Swing-Up task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Balance(swing_up=True, sparse=True, random=random)
+ return control.Environment(physics, task, time_limit=time_limit)
+
+
+@SUITE.add()
+def two_poles(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+ """Returns the Cartpole Balance task."""
+ physics = Physics.from_xml_string(*get_model_and_assets(num_poles=2))
+ task = Balance(swing_up=True, sparse=False, random=random)
+ return control.Environment(physics, task, time_limit=time_limit)
+
+
+@SUITE.add()
+def three_poles(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+ """Returns the Cartpole Balance task."""
+ physics = Physics.from_xml_string(*get_model_and_assets(num_poles=3))
+ task = Balance(swing_up=True, sparse=False, random=random)
+ return control.Environment(physics, task, time_limit=time_limit)
+
+
+def _make_model(n_poles):
+ """Generates an xml string defining a cart with `n_poles` bodies."""
+ xml_string = common.read_model('cartpole.xml')
+ if n_poles == 1:
+ return xml_string
+ mjcf = etree.fromstring(xml_string)
+ parent = mjcf.find('./worldbody/body/body') # Find first pole.
+ # Make chain of poles.
+ for pole_index in xrange(2, n_poles+1):
+ child = etree.Element('body', name='pole_{}'.format(pole_index),
+ pos='0 0 1', childclass='pole')
+ etree.SubElement(child, 'joint', name='hinge_{}'.format(pole_index))
+ etree.SubElement(child, 'geom', name='pole_{}'.format(pole_index))
+ parent.append(child)
+ parent = child
+ # Move plane down.
+ floor = mjcf.find('./worldbody/geom')
+ floor.set('pos', '0 0 {}'.format(1 - n_poles - .05))
+ # Move cameras back.
+ cameras = mjcf.findall('./worldbody/camera')
+ cameras[0].set('pos', '0 {} 1'.format(-1 - 2*n_poles))
+ cameras[1].set('pos', '0 {} 2'.format(-2*n_poles))
+ return etree.tostring(mjcf, pretty_print=True)
+
+
+class Physics(mujoco.Physics):
+ """Physics simulation with additional features for the Cartpole domain."""
+
+ def cart_position(self):
+ """Returns the position of the cart."""
+ return self.named.data.qpos['slider'][0]
+
+ def angular_vel(self):
+ """Returns the angular velocity of the pole."""
+ return self.data.qvel[1:]
+
+ def pole_angle_cosine(self):
+ """Returns the cosine of the pole angle."""
+ return self.named.data.xmat[2:, 'zz']
+
+ def bounded_position(self):
+ """Returns the state, with pole angle split into sin/cos."""
+ return np.hstack((self.cart_position(),
+ self.named.data.xmat[2:, ['zz', 'xz']].ravel()))
+
+
+class Balance(base.Task):
+ """A Cartpole `Task` to balance the pole.
+
+ State is initialized either close to the target configuration or at a random
+ configuration.
+ """
+ _CART_RANGE = (-.25, .25)
+ _ANGLE_COSINE_RANGE = (.995, 1)
+
+ def __init__(self, swing_up, sparse, random=None):
+ """Initializes an instance of `Balance`.
+
+ Args:
+ swing_up: A `bool`, which if `True` sets the cart to the middle of the
+ slider and the pole pointing towards the ground. Otherwise, sets the
+ cart to a random position on the slider and the pole to a random
+ near-vertical position.
+ sparse: A `bool`, whether to return a sparse or a smooth reward.
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ self._sparse = sparse
+ self._swing_up = swing_up
+ super(Balance, self).__init__(random=random)
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode.
+
+ Initializes the cart and pole according to `swing_up`, and in both cases
+ adds a small random initial velocity to break symmetry.
+
+ Args:
+ physics: An instance of `Physics`.
+ """
+ nv = physics.model.nv
+ if self._swing_up:
+ physics.named.data.qpos['slider'] = .01*self.random.randn()
+ physics.named.data.qpos['hinge_1'] = np.pi + .01*self.random.randn()
+ physics.named.data.qpos[2:] = .1*self.random.randn(nv - 2)
+ else:
+ physics.named.data.qpos['slider'] = self.random.uniform(-.1, .1)
+ physics.named.data.qpos[1:] = self.random.uniform(-.034, .034, nv - 1)
+ physics.named.data.qvel[:] = 0.01 * self.random.randn(physics.model.nv)
+
+ def get_observation(self, physics):
+ """Returns an observation of the (bounded) physics state."""
+ obs = collections.OrderedDict()
+ obs['position'] = physics.bounded_position()
+ obs['velocity'] = physics.velocity()
+ return obs
+
+ def _get_reward(self, physics, sparse):
+ if sparse:
+ cart_in_bounds = rewards.tolerance(physics.cart_position(),
+ self._CART_RANGE)
+ angle_in_bounds = rewards.tolerance(physics.pole_angle_cosine(),
+ self._ANGLE_COSINE_RANGE).prod()
+ return cart_in_bounds * angle_in_bounds
+ else:
+ upright = (physics.pole_angle_cosine() + 1) / 2
+ centered = rewards.tolerance(physics.cart_position(), margin=2)
+ centered = (1 + centered) / 2
+ small_control = rewards.tolerance(physics.control(), margin=1,
+ value_at_margin=0,
+ sigmoid='quadratic')[0]
+ small_control = (4 + small_control) / 5
+ small_velocity = rewards.tolerance(physics.angular_vel(), margin=5).min()
+ small_velocity = (1 + small_velocity) / 2
+ return upright.mean() * small_control * small_velocity * centered
+
+ def get_reward(self, physics):
+ """Returns a sparse or a smooth reward, as specified in the constructor."""
+ return self._get_reward(physics, sparse=self._sparse)
diff --git a/dm_control/suite/cartpole.xml b/dm_control/suite/cartpole.xml
new file mode 100644
index 00000000..af638e5f
--- /dev/null
+++ b/dm_control/suite/cartpole.xml
@@ -0,0 +1,37 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/suite/cheetah.py b/dm_control/suite/cheetah.py
new file mode 100644
index 00000000..9b26bf7b
--- /dev/null
+++ b/dm_control/suite/cheetah.py
@@ -0,0 +1,101 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Cheetah Domain."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+# Internal dependencies.
+
+from dm_control import mujoco
+from dm_control.rl import control
+from dm_control.suite import base
+from dm_control.suite import common
+from dm_control.utils import containers
+from dm_control.utils import rewards
+
+
+# How long the simulation will run, in seconds.
+_DEFAULT_TIME_LIMIT = 10
+
+# Running speed above which reward is 1.
+_RUN_SPEED = 10
+
+SUITE = containers.TaggedTasks()
+
+
+def get_model_and_assets():
+ """Returns a tuple containing the model XML string and a dict of assets."""
+ return common.read_model('cheetah.xml'), common.ASSETS
+
+
+@SUITE.add('benchmarking')
+def run(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+ """Returns the run task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Cheetah(random)
+ return control.Environment(physics, task, time_limit=time_limit)
+
+
+class Physics(mujoco.Physics):
+ """Physics simulation with additional features for the Cheetah domain."""
+
+ def speed(self):
+ """Returns the horizontal speed of the Cheetah."""
+ return self.named.data.subtree_linvel['torso', 'x']
+
+
+class Cheetah(base.Task):
+ """A `Task` to train a running Cheetah."""
+
+ def __init__(self, random=None):
+ """Initializes an instance of `Cheetah`.
+
+ Args:
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ super(Cheetah, self).__init__(random)
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode."""
+
+ # Stabilize the model before the actual simulation.
+ for _ in range(200):
+ physics.step()
+
+ physics.data.time = 0
+ self._timeout_progress = 0
+
+ def get_observation(self, physics):
+ """Returns an observation of the state, ignoring horizontal position."""
+ obs = collections.OrderedDict()
+ # Ignores horizontal position to maintain translational invariance.
+ obs['position'] = physics.data.qpos[1:]
+ obs['velocity'] = physics.velocity()
+ return obs
+
+ def get_reward(self, physics):
+ """Returns a reward to the agent."""
+ return rewards.tolerance(physics.speed(),
+ bounds=(_RUN_SPEED, float('inf')),
+ margin=_RUN_SPEED,
+ value_at_margin=0,
+ sigmoid='linear')
diff --git a/dm_control/suite/cheetah.xml b/dm_control/suite/cheetah.xml
new file mode 100644
index 00000000..ef396644
--- /dev/null
+++ b/dm_control/suite/cheetah.xml
@@ -0,0 +1,70 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/suite/common/__init__.py b/dm_control/suite/common/__init__.py
new file mode 100644
index 00000000..0c636473
--- /dev/null
+++ b/dm_control/suite/common/__init__.py
@@ -0,0 +1,38 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Functions to manage the common assets for domains."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+from dm_control.utils import resources
+
+_SUITE_DIR = os.path.dirname(os.path.dirname(__file__))
+_FILENAMES = [
+ "common/materials.xml",
+ "common/skybox.xml",
+ "common/visual.xml",
+]
+
+ASSETS = {filename: resources.GetResource(os.path.join(_SUITE_DIR, filename))
+ for filename in _FILENAMES}
+
+
+def read_model(model_filename):
+ """Reads a model XML file and returns its contents as a string."""
+ return resources.GetResource(os.path.join(_SUITE_DIR, model_filename))
diff --git a/dm_control/suite/common/materials.xml b/dm_control/suite/common/materials.xml
new file mode 100644
index 00000000..cae6635e
--- /dev/null
+++ b/dm_control/suite/common/materials.xml
@@ -0,0 +1,22 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/suite/common/skybox.xml b/dm_control/suite/common/skybox.xml
new file mode 100644
index 00000000..9d6f7a79
--- /dev/null
+++ b/dm_control/suite/common/skybox.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
diff --git a/dm_control/suite/common/visual.xml b/dm_control/suite/common/visual.xml
new file mode 100644
index 00000000..ede15ad4
--- /dev/null
+++ b/dm_control/suite/common/visual.xml
@@ -0,0 +1,7 @@
+
+
+
+
+
+
+
diff --git a/dm_control/suite/demos/mocap_demo.py b/dm_control/suite/demos/mocap_demo.py
new file mode 100644
index 00000000..731b2c0a
--- /dev/null
+++ b/dm_control/suite/demos/mocap_demo.py
@@ -0,0 +1,84 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Demonstration of amc parsing for CMU mocap database.
+
+To run the demo, supply a path to a `.amc` file:
+
+ python mocap_demo --filename='path/to/mocap.amc'
+
+CMU motion capture clips are available at mocap.cs.cmu.edu
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+# Internal dependencies.
+
+from absl import app
+from absl import flags
+
+from dm_control.suite import humanoid_CMU
+from dm_control.suite.utils import parse_amc
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+FLAGS = flags.FLAGS
+flags.DEFINE_string('filename', None, 'amc file to be converted.')
+flags.DEFINE_integer('max_num_frames', 90,
+ 'Maximum number of frames for plotting/playback')
+
+
+def main(unused_argv):
+ env = humanoid_CMU.stand()
+
+ # Parse and convert specified clip.
+ converted = parse_amc.convert(FLAGS.filename,
+ env.physics, env.control_timestep())
+
+ max_frame = min(FLAGS.max_num_frames, converted.qpos.shape[1] - 1)
+
+ width = 480
+ height = 480
+ video = np.zeros((max_frame, height, 2 * width, 3), dtype=np.uint8)
+
+ for i in range(max_frame):
+ p_i = converted.qpos[:, i]
+ with env.physics.reset_context():
+ env.physics.data.qpos[:] = p_i
+ video[i] = np.hstack([env.physics.render(height, width, camera_id=0),
+ env.physics.render(height, width, camera_id=1)])
+
+ tic = time.time()
+ for i in range(max_frame):
+ if i == 0:
+ img = plt.imshow(video[i])
+ else:
+ img.set_data(video[i])
+ toc = time.time()
+ clock_dt = toc - tic
+ tic = time.time()
+ # Real-time playback not always possible as clock_dt > .03
+ plt.pause(np.max(0.01, .03 - clock_dt)) # Need min display time > 0.0.
+ plt.draw()
+ plt.waitforbuttonpress()
+
+
+if __name__ == '__main__':
+ flags.mark_flag_as_required('filename')
+ app.run(main)
diff --git a/dm_control/suite/demos/zeros.amc b/dm_control/suite/demos/zeros.amc
new file mode 100644
index 00000000..b4590a42
--- /dev/null
+++ b/dm_control/suite/demos/zeros.amc
@@ -0,0 +1,213 @@
+#DUMMY AMC for testing
+:FULLY-SPECIFIED
+:DEGREES
+1
+root 0 0 0 0 0 0
+lowerback 0 0 0
+upperback 0 0 0
+thorax 0 0 0
+lowerneck 0 0 0
+upperneck 0 0 0
+head 0 0 0
+rclavicle 0 0
+rhumerus 0 0 0
+rradius 0
+rwrist 0
+rhand 0 0
+rfingers 0
+rthumb 0 0
+lclavicle 0 0
+lhumerus 0 0 0
+lradius 0
+lwrist 0
+lhand 0 0
+lfingers 0
+lthumb 0 0
+rfemur 0 0 0
+rtibia 0
+rfoot 0 0
+rtoes 0
+lfemur 0 0 0
+ltibia 0
+lfoot 0 0
+ltoes 0
+2
+root 0 0 0 0 0 0
+lowerback 0 0 0
+upperback 0 0 0
+thorax 0 0 0
+lowerneck 0 0 0
+upperneck 0 0 0
+head 0 0 0
+rclavicle 0 0
+rhumerus 0 0 0
+rradius 0
+rwrist 0
+rhand 0 0
+rfingers 0
+rthumb 0 0
+lclavicle 0 0
+lhumerus 0 0 0
+lradius 0
+lwrist 0
+lhand 0 0
+lfingers 0
+lthumb 0 0
+rfemur 0 0 0
+rtibia 0
+rfoot 0 0
+rtoes 0
+lfemur 0 0 0
+ltibia 0
+lfoot 0 0
+ltoes 0
+3
+root 0 0 0 0 0 0
+lowerback 0 0 0
+upperback 0 0 0
+thorax 0 0 0
+lowerneck 0 0 0
+upperneck 0 0 0
+head 0 0 0
+rclavicle 0 0
+rhumerus 0 0 0
+rradius 0
+rwrist 0
+rhand 0 0
+rfingers 0
+rthumb 0 0
+lclavicle 0 0
+lhumerus 0 0 0
+lradius 0
+lwrist 0
+lhand 0 0
+lfingers 0
+lthumb 0 0
+rfemur 0 0 0
+rtibia 0
+rfoot 0 0
+rtoes 0
+lfemur 0 0 0
+ltibia 0
+lfoot 0 0
+ltoes 0
+4
+root 0 0 0 0 0 0
+lowerback 0 0 0
+upperback 0 0 0
+thorax 0 0 0
+lowerneck 0 0 0
+upperneck 0 0 0
+head 0 0 0
+rclavicle 0 0
+rhumerus 0 0 0
+rradius 0
+rwrist 0
+rhand 0 0
+rfingers 0
+rthumb 0 0
+lclavicle 0 0
+lhumerus 0 0 0
+lradius 0
+lwrist 0
+lhand 0 0
+lfingers 0
+lthumb 0 0
+rfemur 0 0 0
+rtibia 0
+rfoot 0 0
+rtoes 0
+lfemur 0 0 0
+ltibia 0
+lfoot 0 0
+ltoes 0
+5
+root 0 0 0 0 0 0
+lowerback 0 0 0
+upperback 0 0 0
+thorax 0 0 0
+lowerneck 0 0 0
+upperneck 0 0 0
+head 0 0 0
+rclavicle 0 0
+rhumerus 0 0 0
+rradius 0
+rwrist 0
+rhand 0 0
+rfingers 0
+rthumb 0 0
+lclavicle 0 0
+lhumerus 0 0 0
+lradius 0
+lwrist 0
+lhand 0 0
+lfingers 0
+lthumb 0 0
+rfemur 0 0 0
+rtibia 0
+rfoot 0 0
+rtoes 0
+lfemur 0 0 0
+ltibia 0
+lfoot 0 0
+ltoes 0
+6
+root 0 0 0 0 0 0
+lowerback 0 0 0
+upperback 0 0 0
+thorax 0 0 0
+lowerneck 0 0 0
+upperneck 0 0 0
+head 0 0 0
+rclavicle 0 0
+rhumerus 0 0 0
+rradius 0
+rwrist 0
+rhand 0 0
+rfingers 0
+rthumb 0 0
+lclavicle 0 0
+lhumerus 0 0 0
+lradius 0
+lwrist 0
+lhand 0 0
+lfingers 0
+lthumb 0 0
+rfemur 0 0 0
+rtibia 0
+rfoot 0 0
+rtoes 0
+lfemur 0 0 0
+ltibia 0
+lfoot 0 0
+ltoes 0
+7
+root 0 0 0 0 0 0
+lowerback 0 0 0
+upperback 0 0 0
+thorax 0 0 0
+lowerneck 0 0 0
+upperneck 0 0 0
+head 0 0 0
+rclavicle 0 0
+rhumerus 0 0 0
+rradius 0
+rwrist 0
+rhand 0 0
+rfingers 0
+rthumb 0 0
+lclavicle 0 0
+lhumerus 0 0 0
+lradius 0
+lwrist 0
+lhand 0 0
+lfingers 0
+lthumb 0 0
+rfemur 0 0 0
+rtibia 0
+rfoot 0 0
+rtoes 0
+lfemur 0 0 0
+ltibia 0
+lfoot 0 0
+ltoes 0
diff --git a/dm_control/suite/finger.py b/dm_control/suite/finger.py
new file mode 100644
index 00000000..fc5b9632
--- /dev/null
+++ b/dm_control/suite/finger.py
@@ -0,0 +1,209 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Finger Domain."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+# Internal dependencies.
+
+from dm_control import mujoco
+from dm_control.rl import control
+from dm_control.suite import base
+from dm_control.suite import common
+from dm_control.suite.utils import randomizers
+from dm_control.utils import containers
+
+import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+_DEFAULT_TIME_LIMIT = 20 # (seconds)
+_CONTROL_TIMESTEP = .02 # (seconds)
+# For TURN tasks, the 'tip' geom needs to enter a spherical target of sizes:
+_EASY_TARGET_SIZE = 0.07
+_HARD_TARGET_SIZE = 0.03
+# Initial spin velocity for the Stop task.
+_INITIAL_SPIN_VELOCITY = 100
+# Spinning slower than this value (radian/second) is considered stopped.
+_STOP_VELOCITY = 1e-6
+# Spinning faster than this value (radian/second) is considered spinning.
+_SPIN_VELOCITY = 15.0
+
+
+SUITE = containers.TaggedTasks()
+
+
+def get_model_and_assets():
+ """Returns a tuple containing the model XML string and a dict of assets."""
+ return common.read_model('finger.xml'), common.ASSETS
+
+
+@SUITE.add('benchmarking')
+def spin(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+ """Returns the Spin task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Spin(random=random)
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP)
+
+
+@SUITE.add('benchmarking')
+def turn_easy(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+ """Returns the easy Turn task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Turn(target_radius=_EASY_TARGET_SIZE, random=random)
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP)
+
+
+@SUITE.add('benchmarking')
+def turn_hard(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+ """Returns the hard Turn task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Turn(target_radius=_HARD_TARGET_SIZE, random=random)
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP)
+
+
+class Physics(mujoco.Physics):
+ """Physics simulation with additional features for the Finger domain."""
+
+ def touch(self):
+ """Returns logarithmically scaled signals from the two touch sensors."""
+ return np.log1p(self.named.data.sensordata[['touchtop', 'touchbottom']])
+
+ def hinge_velocity(self):
+ """Returns the velocity of the hinge joint."""
+ return self.named.data.sensordata['hinge_velocity']
+
+ def tip_position(self):
+ """Returns the (x,z) position of the tip relative to the hinge."""
+ return (self.named.data.sensordata['tip'][[0, 2]] -
+ self.named.data.sensordata['spinner'][[0, 2]])
+
+ def bounded_position(self):
+ """Returns the positions, with the hinge angle replaced by tip position."""
+ return np.hstack((self.named.data.sensordata[['proximal', 'distal']],
+ self.tip_position()))
+
+ def velocity(self):
+ """Returns the velocities (extracted from sensordata)."""
+ return self.named.data.sensordata[['proximal_velocity',
+ 'distal_velocity',
+ 'hinge_velocity']]
+
+ def target_position(self):
+ """Returns the (x,z) position of the target relative to the hinge."""
+ return (self.named.data.sensordata['target'][[0, 2]] -
+ self.named.data.sensordata['spinner'][[0, 2]])
+
+ def to_target(self):
+ """Returns the vector from the tip to the target."""
+ return self.target_position() - self.tip_position()
+
+ def dist_to_target(self):
+ """Returns the signed distance to the target surface, negative is inside."""
+ return (np.linalg.norm(self.to_target()) -
+ self.named.model.site_size['target', 0])
+
+
+class Spin(base.Task):
+ """A Finger `Task` to spin the stopped body."""
+
+ def __init__(self, random=None):
+ """Initializes a new `Spin` instance.
+
+ Args:
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ super(Spin, self).__init__(random=random)
+
+ def initialize_episode(self, physics):
+ physics.named.model.site_rgba['target', 3] = 0
+ physics.named.model.site_rgba['tip', 3] = 0
+ physics.named.model.dof_damping['hinge'] = .03
+ _set_random_joint_angles(physics, self.random)
+
+ def get_observation(self, physics):
+ """Returns state and touch sensors, and target info."""
+ obs = collections.OrderedDict()
+ obs['position'] = physics.bounded_position()
+ obs['velocity'] = physics.velocity()
+ obs['touch'] = physics.touch()
+ return obs
+
+ def get_reward(self, physics):
+ """Returns a sparse reward."""
+ return float(physics.hinge_velocity() <= -_SPIN_VELOCITY)
+
+
+class Turn(base.Task):
+ """A Finger `Task` to turn the body to a target angle."""
+
+ def __init__(self, target_radius, random=None):
+ """Initializes a new `Turn` instance.
+
+ Args:
+ target_radius: Radius of the target site, which specifies the goal angle.
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ self._target_radius = target_radius
+ super(Turn, self).__init__(random=random)
+
+ def initialize_episode(self, physics):
+ target_angle = self.random.uniform(-np.pi, np.pi)
+ hinge_x, hinge_z = physics.named.data.xanchor['hinge', ['x', 'z']]
+ radius = physics.named.model.geom_size['cap1'].sum()
+ target_x = hinge_x + radius * np.sin(target_angle)
+ target_z = hinge_z + radius * np.cos(target_angle)
+ physics.named.model.site_pos['target', ['x', 'z']] = target_x, target_z
+ physics.named.model.site_size['target', 0] = self._target_radius
+
+ _set_random_joint_angles(physics, self.random)
+
+ def get_observation(self, physics):
+ """Returns state, touch sensors, and target info."""
+ obs = collections.OrderedDict()
+ obs['position'] = physics.bounded_position()
+ obs['velocity'] = physics.velocity()
+ obs['touch'] = physics.touch()
+ obs['target_position'] = physics.target_position()
+ obs['dist_to_target'] = physics.dist_to_target()
+ return obs
+
+ def get_reward(self, physics):
+ return float(physics.dist_to_target() <= 0)
+
+
+def _set_random_joint_angles(physics, random, max_attempts=1000):
+ """Sets the joints to a random collision-free state."""
+
+ for _ in xrange(max_attempts):
+ randomizers.randomize_limited_and_rotational_joints(physics, random)
+ # Check for collisions.
+ physics.after_reset()
+ if physics.data.ncon == 0:
+ break
+ else:
+ raise RuntimeError('Could not find a collision-free state '
+ 'after {} attempts'.format(max_attempts))
diff --git a/dm_control/suite/finger.xml b/dm_control/suite/finger.xml
new file mode 100644
index 00000000..4692bdff
--- /dev/null
+++ b/dm_control/suite/finger.xml
@@ -0,0 +1,72 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/suite/fish.py b/dm_control/suite/fish.py
new file mode 100644
index 00000000..651c6aa2
--- /dev/null
+++ b/dm_control/suite/fish.py
@@ -0,0 +1,172 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Fish Domain."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+# Internal dependencies.
+
+from dm_control import mujoco
+from dm_control.rl import control
+from dm_control.suite import base
+from dm_control.suite import common
+from dm_control.utils import containers
+from dm_control.utils import rewards
+
+import numpy as np
+
+
+_DEFAULT_TIME_LIMIT = 40
+_CONTROL_TIMESTEP = .04
+_JOINTS = ['tail1',
+ 'tail_twist',
+ 'tail2',
+ 'finright_roll',
+ 'finright_pitch',
+ 'finleft_roll',
+ 'finleft_pitch']
+SUITE = containers.TaggedTasks()
+
+
+def get_model_and_assets():
+ """Returns a tuple containing the model XML string and a dict of assets."""
+ return common.read_model('fish.xml'), common.ASSETS
+
+
+@SUITE.add('benchmarking')
+def upright(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+ """Returns the Fish Upright task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Upright(random=random)
+ return control.Environment(
+ physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit)
+
+
+@SUITE.add('benchmarking')
+def swim(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+ """Returns the Fish Swim task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Swim(random=random)
+ return control.Environment(
+ physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit)
+
+
+class Physics(mujoco.Physics):
+ """Physics simulation with additional features for the Fish domain."""
+
+ def upright(self):
+ """Returns projection from z-axes of torso to the z-axes of worldbody."""
+ return self.named.data.xmat['torso', 'zz']
+
+ def torso_velocity(self):
+ """Returns velocities and angular velocities of the torso."""
+ return self.data.sensordata
+
+ def joint_velocities(self):
+ """Returns the joint velocities."""
+ return self.named.data.qvel[_JOINTS]
+
+ def joint_angles(self):
+ """Returns the joint positions."""
+ return self.named.data.qpos[_JOINTS]
+
+ def mouth_to_target(self):
+ """Returns a vector, from mouth to target in local coordinate of mouth."""
+ data = self.named.data
+ mouth_to_target_global = data.geom_xpos['target'] - data.geom_xpos['mouth']
+ return mouth_to_target_global.dot(data.geom_xmat['mouth'].reshape(3, 3))
+
+
+class Upright(base.Task):
+ """A Fish `Task` for getting the torso upright with smooth reward."""
+
+ def __init__(self, random=None):
+ """Initializes an instance of `Upright`.
+
+ Args:
+ random: Either an existing `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically.
+ """
+ super(Upright, self).__init__(random=random)
+
+ def initialize_episode(self, physics):
+ """Randomizes the tail and fin angles and the orientation of the Fish."""
+ quat = self.random.randn(4)
+ physics.named.data.qpos['root'][3:7] = quat / np.linalg.norm(quat)
+ for joint in _JOINTS:
+ physics.named.data.qpos[joint] = self.random.uniform(-.2, .2)
+ # Hide the target. It's irrelevant for this task.
+ physics.named.model.geom_rgba['target', 3] = 0
+
+ def get_observation(self, physics):
+ """Returns an observation of joint angles, velocities and uprightness."""
+ obs = collections.OrderedDict()
+ obs['joint_angles'] = physics.joint_angles()
+ obs['upright'] = physics.upright()
+ obs['velocity'] = physics.velocity()
+ return obs
+
+ def get_reward(self, physics):
+ """Returns a smooth reward."""
+ return rewards.tolerance(physics.upright(), bounds=(1, 1), margin=1)
+
+
+class Swim(base.Task):
+ """A Fish `Task` for swimming with smooth reward."""
+
+ def __init__(self, random=None):
+ """Initializes an instance of `Swim`.
+
+ Args:
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ super(Swim, self).__init__(random=random)
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode."""
+
+ quat = self.random.randn(4)
+ physics.named.data.qpos['root'][3:7] = quat / np.linalg.norm(quat)
+ for joint in _JOINTS:
+ physics.named.data.qpos[joint] = self.random.uniform(-.2, .2)
+ # Randomize target position.
+ physics.named.model.geom_pos['target', 'x'] = self.random.uniform(-.4, .4)
+ physics.named.model.geom_pos['target', 'y'] = self.random.uniform(-.4, .4)
+ physics.named.model.geom_pos['target', 'z'] = self.random.uniform(.1, .3)
+
+ def get_observation(self, physics):
+ """Returns an observation of joints, target direction and velocities."""
+ obs = collections.OrderedDict()
+ obs['joint_angles'] = physics.joint_angles()
+ obs['upright'] = physics.upright()
+ obs['target'] = physics.mouth_to_target()
+ obs['velocity'] = physics.velocity()
+ return obs
+
+ def get_reward(self, physics):
+ """Returns a smooth reward."""
+ radii = physics.named.model.geom_size[['mouth', 'target'], 0].sum()
+ in_target = rewards.tolerance(np.linalg.norm(physics.mouth_to_target()),
+ bounds=(0, radii), margin=2*radii)
+ is_upright = 0.5 * (physics.upright() + 1)
+ return (7*in_target + is_upright) / 8
diff --git a/dm_control/suite/fish.xml b/dm_control/suite/fish.xml
new file mode 100644
index 00000000..43de56d5
--- /dev/null
+++ b/dm_control/suite/fish.xml
@@ -0,0 +1,85 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/suite/hopper.py b/dm_control/suite/hopper.py
new file mode 100644
index 00000000..e9f08376
--- /dev/null
+++ b/dm_control/suite/hopper.py
@@ -0,0 +1,136 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Hopper domain."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+# Internal dependencies.
+
+from dm_control import mujoco
+from dm_control.rl import control
+from dm_control.suite import base
+from dm_control.suite import common
+from dm_control.suite.utils import randomizers
+from dm_control.utils import containers
+from dm_control.utils import rewards
+
+import numpy as np
+
+
+SUITE = containers.TaggedTasks()
+
+_CONTROL_TIMESTEP = .02 # (Seconds)
+
+# Default duration of an episode, in seconds.
+_DEFAULT_TIME_LIMIT = 20
+
+# Minimal height of torso over foot above which stand reward is 1.
+_STAND_HEIGHT = 0.6
+
+# Hopping speed above which hop reward is 1.
+_HOP_SPEED = 2
+
+
+def get_model_and_assets():
+ """Returns a tuple containing the model XML string and a dict of assets."""
+ return common.read_model('hopper.xml'), common.ASSETS
+
+
+@SUITE.add('benchmarking')
+def stand(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+ """Returns a Hopper that strives to stand upright, balancing its pose."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Hopper(hopping=False, random=random)
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP)
+
+
+@SUITE.add('benchmarking')
+def hop(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+ """Returns a Hopper that strives to hop forward."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Hopper(hopping=True, random=random)
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP)
+
+
+class Physics(mujoco.Physics):
+ """Physics simulation with additional features for the Hopper domain."""
+
+ def height(self):
+ """Returns height of torso with respect to foot."""
+ return (self.named.data.xipos['torso', 'z'] -
+ self.named.data.xipos['foot', 'z'])
+
+ def speed(self):
+ """Returns horizontal speed of the Hopper."""
+ return self.named.data.subtree_linvel['torso', 'x']
+
+ def touch(self):
+ """Returns the signals from two foot touch sensors."""
+ return np.log1p(self.named.data.sensordata[['touch_toe', 'touch_heel']])
+
+
+class Hopper(base.Task):
+ """A Hopper's `Task` to train a standing and a jumping Hopper."""
+
+ def __init__(self, hopping, random=None):
+ """Initialize an instance of `Hopper`.
+
+ Args:
+ hopping: Boolean, if True the task is to hop forwards, otherwise it is to
+ balance upright.
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ self._hopping = hopping
+ super(Hopper, self).__init__(random=random)
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode."""
+ randomizers.randomize_limited_and_rotational_joints(physics, self.random)
+ self._timeout_progress = 0
+
+ def get_observation(self, physics):
+ """Returns an observation of positions, velocities and touch sensors."""
+ obs = collections.OrderedDict()
+ # Ignores horizontal position to maintain translational invariance:
+ obs['position'] = physics.data.qpos[1:]
+ obs['velocity'] = physics.velocity()
+ obs['touch'] = physics.touch()
+ return obs
+
+ def get_reward(self, physics):
+ """Returns a reward applicable to the performed task."""
+ standing = rewards.tolerance(physics.height(), (_STAND_HEIGHT, 2))
+ if self._hopping:
+ hopping = rewards.tolerance(physics.speed(),
+ bounds=(_HOP_SPEED, float('inf')),
+ margin=_HOP_SPEED/2,
+ value_at_margin=0.5,
+ sigmoid='linear')
+ return standing * hopping
+ else:
+ small_control = rewards.tolerance(physics.control(),
+ margin=1, value_at_margin=0,
+ sigmoid='quadratic').mean()
+ small_control = (small_control + 4) / 5
+ return standing * small_control
diff --git a/dm_control/suite/hopper.xml b/dm_control/suite/hopper.xml
new file mode 100644
index 00000000..c97c4bc6
--- /dev/null
+++ b/dm_control/suite/hopper.xml
@@ -0,0 +1,65 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/suite/humanoid.py b/dm_control/suite/humanoid.py
new file mode 100644
index 00000000..814f29bc
--- /dev/null
+++ b/dm_control/suite/humanoid.py
@@ -0,0 +1,207 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Humanoid Domain."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+# Internal dependencies.
+
+from dm_control import mujoco
+from dm_control.rl import control
+from dm_control.suite import base
+from dm_control.suite import common
+from dm_control.suite.utils import randomizers
+from dm_control.utils import containers
+from dm_control.utils import rewards
+
+import numpy as np
+
+_DEFAULT_TIME_LIMIT = 25
+_CONTROL_TIMESTEP = .025
+
+# Height of head above which stand reward is 1.
+_STAND_HEIGHT = 1.4
+
+# Horizontal speeds above which move reward is 1.
+_WALK_SPEED = 1
+_RUN_SPEED = 10
+
+
+SUITE = containers.TaggedTasks()
+
+
+def get_model_and_assets():
+ """Returns a tuple containing the model XML string and a dict of assets."""
+ return common.read_model('humanoid.xml'), common.ASSETS
+
+
+@SUITE.add('benchmarking')
+def stand(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+ """Returns the Stand task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Humanoid(move_speed=0, pure_state=False, random=random)
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP)
+
+
+@SUITE.add('benchmarking')
+def walk(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+ """Returns the Walk task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Humanoid(move_speed=_WALK_SPEED, pure_state=False, random=random)
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP)
+
+
+@SUITE.add('benchmarking')
+def run(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+ """Returns the Run task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Humanoid(move_speed=_RUN_SPEED, pure_state=False, random=random)
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP)
+
+
+@SUITE.add()
+def run_pure_state(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+ """Returns the Run task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Humanoid(move_speed=_RUN_SPEED, pure_state=True, random=random)
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP)
+
+
+class Physics(mujoco.Physics):
+ """Physics simulation with additional features for the Walker domain."""
+
+ def torso_upright(self):
+ """Returns projection from z-axes of torso to the z-axes of world."""
+ return self.named.data.xmat['torso', 'zz']
+
+ def head_height(self):
+ """Returns the height of the torso."""
+ return self.named.data.xpos['head', 'z']
+
+ def center_of_mass_position(self):
+ """Returns position of the center-of-mass."""
+ return self.named.data.subtree_com['torso']
+
+ def center_of_mass_velocity(self):
+ """Returns the velocity of the center-of-mass."""
+ return self.named.data.subtree_linvel['torso']
+
+ def torso_vertical_orientation(self):
+ """Returns the z-projection of the torso orientation matrix."""
+ return self.named.data.xmat['torso', ['zx', 'zy', 'zz']]
+
+ def joint_angles(self):
+ """Returns the state without global orientation or position."""
+ return self.data.qpos[7:] # Skip the 7 DoFs of the free root joint.
+
+ def extremities(self):
+ """Returns end effector positions in egocentric frame."""
+ torso_frame = self.named.data.xmat['torso'].reshape(3, 3)
+ torso_pos = self.named.data.xpos['torso']
+ positions = []
+ for side in ('left_', 'right_'):
+ for limb in ('hand', 'foot'):
+ torso_to_limb = self.named.data.xpos[side + limb] - torso_pos
+ positions.append(torso_to_limb.dot(torso_frame))
+ return np.hstack(positions)
+
+
+class Humanoid(base.Task):
+ """A humanoid task."""
+
+ def __init__(self, move_speed, pure_state, random=None):
+ """Initializes an instance of `Humanoid`.
+
+ Args:
+ move_speed: A float. If this value is zero, reward is given simply for
+ standing up. Otherwise this specifies a target horizontal velocity for
+ the walking task.
+ pure_state: A bool. Whether the observations consist of the pure MuJoCo
+ state or includes some useful features thereof.
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ self._move_speed = move_speed
+ self._pure_state = pure_state
+ super(Humanoid, self).__init__(random=random)
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode.
+
+ In 'standing' mode, use initial orientation and small velocities.
+ In 'random' mode, randomize joint angles and let fall to the floor.
+
+ Args:
+ physics: An instance of `Physics`.
+
+ """
+ # Find a collision-free random initial configuration.
+ penetrating = True
+ while penetrating:
+ randomizers.randomize_limited_and_rotational_joints(physics, self.random)
+ # Check for collisions.
+ physics.after_reset()
+ penetrating = physics.data.ncon > 0
+
+ def get_observation(self, physics):
+ """Returns either the pure state or a set of egocentric features."""
+ obs = collections.OrderedDict()
+ if self._pure_state:
+ obs['position'] = physics.position()
+ obs['velocity'] = physics.velocity()
+ else:
+ obs['joint_angles'] = physics.joint_angles()
+ obs['head_height'] = physics.head_height()
+ obs['extremities'] = physics.extremities()
+ obs['torso_vertical'] = physics.torso_vertical_orientation()
+ obs['com_velocity'] = physics.center_of_mass_velocity()
+ obs['velocity'] = physics.velocity()
+ return obs
+
+ def get_reward(self, physics):
+ """Returns a reward to the agent."""
+ standing = rewards.tolerance(physics.head_height(),
+ bounds=(_STAND_HEIGHT, float('inf')),
+ margin=_STAND_HEIGHT/4)
+ upright = rewards.tolerance(physics.torso_upright(),
+ bounds=(0.9, float('inf')), sigmoid='linear',
+ margin=1.9, value_at_margin=0)
+ stand_reward = standing * upright
+ small_control = rewards.tolerance(physics.control(), margin=1,
+ value_at_margin=0,
+ sigmoid='quadratic').mean()
+ small_control = (4 + small_control) / 5
+ if self._move_speed == 0:
+ horizontal_velocity = physics.center_of_mass_velocity()[[0, 1]]
+ dont_move = rewards.tolerance(horizontal_velocity, margin=2).mean()
+ return small_control * stand_reward * dont_move
+ else:
+ com_velocity = np.linalg.norm(physics.center_of_mass_velocity()[[0, 1]])
+ move = rewards.tolerance(com_velocity,
+ bounds=(self._move_speed, float('inf')),
+ margin=self._move_speed, value_at_margin=0,
+ sigmoid='linear')
+ move = (5*move + 1) / 6
+ return small_control * stand_reward * move
diff --git a/dm_control/suite/humanoid.xml b/dm_control/suite/humanoid.xml
new file mode 100644
index 00000000..de4b3958
--- /dev/null
+++ b/dm_control/suite/humanoid.xml
@@ -0,0 +1,159 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/suite/humanoid_CMU.py b/dm_control/suite/humanoid_CMU.py
new file mode 100644
index 00000000..1d6eec20
--- /dev/null
+++ b/dm_control/suite/humanoid_CMU.py
@@ -0,0 +1,177 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Humanoid_CMU Domain."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+# Internal dependencies.
+
+from dm_control import mujoco
+from dm_control.rl import control
+from dm_control.suite import base
+from dm_control.suite import common
+from dm_control.suite.utils import randomizers
+from dm_control.utils import containers
+from dm_control.utils import rewards
+
+import numpy as np
+
+_DEFAULT_TIME_LIMIT = 20
+_CONTROL_TIMESTEP = 0.02
+
+# Height of head above which stand reward is 1.
+_STAND_HEIGHT = 1.4
+
+# Horizontal speeds above which move reward is 1.
+_WALK_SPEED = 1
+_RUN_SPEED = 10
+
+SUITE = containers.TaggedTasks()
+
+
+def get_model_and_assets():
+ """Returns a tuple containing the model XML string and a dict of assets."""
+ return common.read_model('humanoid_CMU.xml'), common.ASSETS
+
+
+@SUITE.add()
+def stand(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+ """Returns the Stand task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = HumanoidCMU(move_speed=0, random=random)
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP)
+
+
+@SUITE.add()
+def run(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+ """Returns the Run task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = HumanoidCMU(move_speed=_RUN_SPEED, random=random)
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP)
+
+
+class Physics(mujoco.Physics):
+ """Physics simulation with additional features for the humanoid_CMU domain."""
+
+ def thorax_upright(self):
+ """Returns projection from y-axes of thorax to the z-axes of world."""
+ return self.named.data.xmat['thorax', 'zy']
+
+ def head_height(self):
+ """Returns the height of the head."""
+ return self.named.data.xpos['head', 'z']
+
+ def center_of_mass_position(self):
+ """Returns position of the center-of-mass."""
+ return self.named.data.subtree_com['thorax']
+
+ def center_of_mass_velocity(self):
+ """Returns the velocity of the center-of-mass."""
+ return self.named.data.subtree_linvel['thorax']
+
+ def torso_vertical_orientation(self):
+ """Returns the z-projection of the thorax orientation matrix."""
+ return self.named.data.xmat['thorax', ['zx', 'zy', 'zz']]
+
+ def joint_angles(self):
+ """Returns the state without global orientation or position."""
+ return self.data.qpos[7:] # Skip the 7 DoFs of the free root joint.
+
+ def extremities(self):
+ """Returns end effector positions in egocentric frame."""
+ torso_frame = self.named.data.xmat['thorax'].reshape(3, 3)
+ torso_pos = self.named.data.xpos['thorax']
+ positions = []
+ for side in ('l', 'r'):
+ for limb in ('hand', 'foot'):
+ torso_to_limb = self.named.data.xpos[side + limb] - torso_pos
+ positions.append(torso_to_limb.dot(torso_frame))
+ return np.hstack(positions)
+
+
+class HumanoidCMU(base.Task):
+ """A task for the CMU Humanoid."""
+
+ def __init__(self, move_speed, random=None):
+ """Initializes an instance of `Humanoid_CMU`.
+
+ Args:
+ move_speed: A float. If this value is zero, reward is given simply for
+ standing up. Otherwise this specifies a target horizontal velocity for
+ the walking task.
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ self._move_speed = move_speed
+ super(HumanoidCMU, self).__init__(random=random)
+
+ def initialize_episode(self, physics):
+ """Sets a random collision-free configuration at the start of each episode.
+
+ Args:
+ physics: An instance of `Physics`.
+ """
+ penetrating = True
+ while penetrating:
+ randomizers.randomize_limited_and_rotational_joints(
+ physics, self.random)
+ # Check for collisions.
+ physics.after_reset()
+ penetrating = physics.data.ncon > 0
+
+ def get_observation(self, physics):
+ """Returns a set of egocentric features."""
+ obs = collections.OrderedDict()
+ obs['joint_angles'] = physics.joint_angles()
+ obs['head_height'] = physics.head_height()
+ obs['extremities'] = physics.extremities()
+ obs['torso_vertical'] = physics.torso_vertical_orientation()
+ obs['com_velocity'] = physics.center_of_mass_velocity()
+ obs['velocity'] = physics.velocity()
+ return obs
+
+ def get_reward(self, physics):
+ """Returns a reward to the agent."""
+ standing = rewards.tolerance(physics.head_height(),
+ bounds=(_STAND_HEIGHT, float('inf')),
+ margin=_STAND_HEIGHT/4)
+ upright = rewards.tolerance(physics.thorax_upright(),
+ bounds=(0.9, float('inf')), sigmoid='linear',
+ margin=1.9, value_at_margin=0)
+ stand_reward = standing * upright
+ small_control = rewards.tolerance(physics.control(), margin=1,
+ value_at_margin=0,
+ sigmoid='quadratic').mean()
+ small_control = (4 + small_control) / 5
+ if self._move_speed == 0:
+ horizontal_velocity = physics.center_of_mass_velocity()[[0, 1]]
+ dont_move = rewards.tolerance(horizontal_velocity, margin=2).mean()
+ return small_control * stand_reward * dont_move
+ else:
+ com_velocity = np.linalg.norm(physics.center_of_mass_velocity()[[0, 1]])
+ move = rewards.tolerance(com_velocity,
+ bounds=(self._move_speed, float('inf')),
+ margin=self._move_speed, value_at_margin=0,
+ sigmoid='linear')
+ move = (5*move + 1) / 6
+ return small_control * stand_reward * move
diff --git a/dm_control/suite/humanoid_CMU.xml b/dm_control/suite/humanoid_CMU.xml
new file mode 100644
index 00000000..238d6110
--- /dev/null
+++ b/dm_control/suite/humanoid_CMU.xml
@@ -0,0 +1,287 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/suite/lqr.py b/dm_control/suite/lqr.py
new file mode 100644
index 00000000..6a021354
--- /dev/null
+++ b/dm_control/suite/lqr.py
@@ -0,0 +1,266 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Procedurally generated LQR domain."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+import os
+
+# Internal dependencies.
+
+from dm_control import mujoco
+from dm_control.rl import control
+from dm_control.suite import base
+from dm_control.suite import common
+from dm_control.utils import containers
+from dm_control.utils import xml_tools
+
+from lxml import etree
+import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+from dm_control.utils import resources
+
+_DEFAULT_TIME_LIMIT = float('inf')
+_CONTROL_COST_COEF = 0.1
+SUITE = containers.TaggedTasks()
+
+
+def get_model_and_assets(n_bodies, n_actuators, random):
+ """Returns the model description as an XML string and a dict of assets.
+
+ Args:
+ n_bodies: An int, number of bodies of the LQR.
+ n_actuators: An int, number of actuated bodies of the LQR. `n_actuators`
+ should be less or equal than `n_bodies`.
+ random: A `numpy.random.RandomState` instance.
+
+ Returns:
+ A tuple `(model_xml_string, assets)`, where `assets` is a dict consisting of
+ `{filename: contents_string}` pairs.
+ """
+ return _make_model(n_bodies, n_actuators, random), common.ASSETS
+
+
+@SUITE.add()
+def lqr_2_1(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+ """Returns an LQR environment with 2 bodies of which the first is actuated."""
+ return _make_lqr(n_bodies=2,
+ n_actuators=1,
+ control_cost_coef=_CONTROL_COST_COEF,
+ time_limit=time_limit,
+ random=random)
+
+
+@SUITE.add()
+def lqr_6_2(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+ """Returns an LQR environment with 6 bodies of which first 2 are actuated."""
+ return _make_lqr(n_bodies=6,
+ n_actuators=2,
+ control_cost_coef=_CONTROL_COST_COEF,
+ time_limit=time_limit,
+ random=random)
+
+
+def _make_lqr(n_bodies, n_actuators, control_cost_coef, time_limit, random):
+ """Returns a LQR environment.
+
+ Args:
+ n_bodies: An int, number of bodies of the LQR.
+ n_actuators: An int, number of actuated bodies of the LQR. `n_actuators`
+ should be less or equal than `n_bodies`.
+ control_cost_coef: A number, the coefficient of the control cost.
+ time_limit: An int, maximum time for each episode in seconds.
+ random: Either an existing `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically.
+
+ Returns:
+ A LQR environment with `n_bodies` bodies of which first `n_actuators` are
+ actuated.
+ """
+
+ if not isinstance(random, np.random.RandomState):
+ random = np.random.RandomState(random)
+
+ model_string, assets = get_model_and_assets(n_bodies, n_actuators,
+ random=random)
+ physics = Physics.from_xml_string(model_string, assets=assets)
+ task = LQRLevel(control_cost_coef, random=random)
+ return control.Environment(physics, task, time_limit=time_limit)
+
+
+def _make_body(body_id, stiffness_range, damping_range, random):
+ """Returns an `etree.Element` defining a body.
+
+ Args:
+ body_id: Id of the created body.
+ stiffness_range: A tuple of (stiffness_lower_bound, stiffness_uppder_bound).
+ The stiffness of the joint is drawn uniformly from this range.
+ damping_range: A tuple of (damping_lower_bound, damping_upper_bound). The
+ damping of the joint is drawn uniformly from this range.
+ random: A `numpy.random.RandomState` instance.
+
+ Returns:
+ A new instance of `etree.Element`. A body element with two children: joint
+ and geom.
+ """
+ body_name = 'body_{}'.format(body_id)
+ joint_name = 'joint_{}'.format(body_id)
+ geom_name = 'geom_{}'.format(body_id)
+
+ body = etree.Element('body', name=body_name)
+ body.set('pos', '.25 0 0')
+ joint = etree.SubElement(body, 'joint', name=joint_name)
+ body.append(etree.Element('geom', name=geom_name))
+ joint.set('stiffness',
+ str(random.uniform(stiffness_range[0], stiffness_range[1])))
+ joint.set('damping',
+ str(random.uniform(damping_range[0], damping_range[1])))
+ return body
+
+
+def _make_model(n_bodies,
+ n_actuators,
+ random,
+ stiffness_range=(15, 25),
+ damping_range=(0, 0)):
+ """Returns an MJCF XML string defining a model of springs and dampers.
+
+ Args:
+ n_bodies: An integer, the number of bodies (DoFs) in the system.
+ n_actuators: An integer, the number of actuated bodies.
+ random: A `numpy.random.RandomState` instance.
+ stiffness_range: A tuple containing minimum and maximum stiffness. Each
+ joint's stiffness is sampled uniformly from this interval.
+ damping_range: A tuple containing minimum and maximum damping. Each joint's
+ damping is sampled uniformly from this interval.
+
+ Returns:
+ An MJCF string describing the linear system.
+
+ Raises:
+ ValueError: If the number of bodies or actuators is erronous.
+ """
+ if n_bodies < 1 or n_actuators < 1:
+ raise ValueError('At least 1 body and 1 actuator required.')
+ if n_actuators > n_bodies:
+ raise ValueError('At most 1 actuator per body.')
+
+ file_path = os.path.join(os.path.dirname(__file__), 'lqr.xml')
+ xml_file = resources.GetResourceAsFile(file_path)
+ mjcf = xml_tools.parse(xml_file)
+ parent = mjcf.find('./worldbody')
+ actuator = etree.SubElement(mjcf.getroot(), 'actuator')
+ tendon = etree.SubElement(mjcf.getroot(), 'tendon')
+
+ for body in xrange(n_bodies):
+ # Inserting body.
+ child = _make_body(body, stiffness_range, damping_range, random)
+ site_name = 'site_{}'.format(body)
+ child.append(etree.Element('site', name=site_name))
+
+ if body == 0:
+ child.set('pos', '.25 0 .1')
+ # Add actuators to the first n_actuators bodies.
+ if body < n_actuators:
+ # Adding actuator.
+ joint_name = 'joint_{}'.format(body)
+ motor_name = 'motor_{}'.format(body)
+ child.find('joint').set('name', joint_name)
+ actuator.append(etree.Element('motor', name=motor_name, joint=joint_name))
+
+ # Add a tendon between consecutive bodies (for visualisation purposes only).
+ if body < n_bodies - 1:
+ child_site_name = 'site_{}'.format(body + 1)
+ tendon_name = 'tendon_{}'.format(body)
+ spatial = etree.SubElement(tendon, 'spatial', name=tendon_name)
+ spatial.append(etree.Element('site', site=site_name))
+ spatial.append(etree.Element('site', site=child_site_name))
+ parent.append(child)
+ parent = child
+
+ return etree.tostring(mjcf, pretty_print=True)
+
+
+class Physics(mujoco.Physics):
+ """Physics simulation with additional features for the LQR domain."""
+
+ def state_norm(self):
+ """Returns the norm of the physics state."""
+ return np.linalg.norm(self.state())
+
+
+class LQRLevel(base.Task):
+ """A Linear Quadratic Regulator `Task`."""
+
+ _TERMINAL_TOL = 1e-6
+
+ def __init__(self, control_cost_coef, random=None):
+ """Initializes an LQR level with cost = sum(states^2) + c*sum(controls^2).
+
+ Args:
+ control_cost_coef: The coefficient of the control cost.
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+
+ Raises:
+ ValueError: If the control cost coefficient is not positive.
+ """
+ if control_cost_coef <= 0:
+ raise ValueError('control_cost_coef must be positive.')
+
+ self._control_cost_coef = control_cost_coef
+ super(LQRLevel, self).__init__(random=random)
+
+ @property
+ def control_cost_coef(self):
+ return self._control_cost_coef
+
+ def initialize_episode(self, physics):
+ """Random state sampled from a unit sphere."""
+ ndof = physics.model.nq
+ unit = self.random.randn(ndof)
+ physics.data.qpos[:] = np.sqrt(2) * unit / np.linalg.norm(unit)
+
+ def get_observation(self, physics):
+ """Returns an observation of the state."""
+ obs = collections.OrderedDict()
+ obs['position'] = physics.position()
+ obs['velocity'] = physics.velocity()
+ return obs
+
+ def get_reward(self, physics):
+ """Returns a quadratic state and control reward."""
+ position = physics.position()
+ state_cost = 0.5 * np.dot(position, position)
+ control_signal = physics.control()
+ control_l2_norm = 0.5 * np.dot(control_signal, control_signal)
+ return 1 - (state_cost + control_l2_norm * self._control_cost_coef)
+
+ def get_evaluation(self, physics):
+ """Returns a sparse evaluation reward that is not used for learning."""
+ return float(physics.state_norm() <= 0.01)
+
+ def get_termination(self, physics):
+ """Terminates when the state norm is smaller than epsilon."""
+ if physics.state_norm() < self._TERMINAL_TOL:
+ return 0.0
diff --git a/dm_control/suite/lqr.xml b/dm_control/suite/lqr.xml
new file mode 100644
index 00000000..d403532a
--- /dev/null
+++ b/dm_control/suite/lqr.xml
@@ -0,0 +1,26 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/suite/lqr_solver.py b/dm_control/suite/lqr_solver.py
new file mode 100644
index 00000000..4cc6ec63
--- /dev/null
+++ b/dm_control/suite/lqr_solver.py
@@ -0,0 +1,145 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+r"""Optimal policy for LQR levels.
+
+LQR control problem is described in
+https://en.wikipedia.org/wiki/Linear-quadratic_regulator#Infinite-horizon.2C_discrete-time_LQR
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+from absl import logging
+
+from dm_control.mujoco import wrapper
+
+import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+try:
+ import scipy.linalg as sp # pylint: disable=g-import-not-at-top
+except ImportError:
+ sp = None
+
+
+def _solve_dare(a, b, q, r):
+ """Solves the Discrete-time Algebraic Riccati Equation (DARE) by iteration.
+
+ Algebraic Riccati Equation:
+ ```none
+ P_{t-1} = Q + A' * P_{t} * A -
+ A' * P_{t} * B * (R + B' * P_{t} * B)^{-1} * B' * P_{t} * A
+ ```
+
+ Args:
+ a: A 2 dimensional numpy array, transition matrix A.
+ b: A 2 dimensional numpy array, control matrix B.
+ q: A 2 dimensional numpy array, symmetric positive definite cost matrix.
+ r: A 2 dimensional numpy array, symmetric positive definite cost matrix
+
+ Returns:
+ A numpy array, a real symmetric matrix P which is the solution to DARE.
+
+ Raises:
+ RuntimeError: If the computed P matrix is not symmetric and
+ positive-definite.
+ """
+ p = np.eye(len(a))
+ for _ in xrange(1000000):
+ a_p = a.T.dot(p) # A' * P_t
+ a_p_b = np.dot(a_p, b) # A' * P_t * B
+ # Algebraic Riccati Equation.
+ p_next = q + np.dot(a_p, a) - a_p_b.dot(
+ np.linalg.solve(b.T.dot(p.dot(b)) + r, a_p_b.T))
+ p_next += p_next.T
+ p_next *= .5
+ if np.abs(p - p_next).max() < 1e-12:
+ break
+ p = p_next
+ else:
+ logging.warn('DARE solver did not converge')
+ try:
+ # Check that the result is symmetric and positive-definite.
+ np.linalg.cholesky(p_next)
+ except np.linalg.LinAlgError:
+ raise RuntimeError('ARE solver failed: P matrix is not symmetric and '
+ 'positive-definite.')
+ return p_next
+
+
+def solve(env):
+ """Returns the optimal value and policy for LQR problem.
+
+ Args:
+ env: An instance of `control.EnvironmentV2` with LQR level.
+
+ Returns:
+ p: A numpy array, the Hessian of the optimal total cost-to-go (value
+ function at state x) is V(x) = .5 * x' * p * x.
+ k: A numpy array which gives the optimal linear policy u = k * x.
+ beta: The maximum eigenvalue of (a + b * k). Under optimal policy, at
+ timestep n the state tends to 0 like beta^n.
+
+ Raises:
+ RuntimeError: If the controlled system is unstable.
+ """
+ n = env.physics.model.nq # number of DoFs
+ m = env.physics.model.nu # number of controls
+
+ # Compute the mass matrix.
+ mass = np.zeros((n, n))
+ wrapper.mjbindings.mjlib.mj_fullM(env.physics.model.ptr, mass,
+ env.physics.data.qM)
+
+ # Compute input matrices a, b, q and r to the DARE solvers.
+ # State transition matrix a.
+ stiffness = np.diag(env.physics.model.jnt_stiffness.ravel())
+ damping = np.diag(env.physics.model.dof_damping.ravel())
+ dt = env.physics.model.opt.timestep
+
+ j = np.linalg.solve(-mass, np.hstack((stiffness, damping)))
+ a = np.eye(2 * n) + dt * np.vstack(
+ (dt * j + np.hstack((np.zeros((n, n)), np.eye(n))), j))
+
+ # Control transition matrix b.
+ b = env.physics.data.actuator_moment.T
+ bc = np.linalg.solve(mass, b)
+ b = dt * np.vstack((dt * bc, bc))
+
+ # State cost Hessian q.
+ q = np.diag(np.hstack([np.ones(n), np.zeros(n)]))
+
+ # Control cost Hessian r.
+ r = env.task.control_cost_coef * np.eye(m)
+
+ if sp:
+ # Use scipy's faster DARE solver if available.
+ solve_dare = sp.solve_discrete_are
+ else:
+ # Otherwise fall back on a slower internal implementation.
+ solve_dare = _solve_dare
+
+ # Solve the discrete algebraic Riccati equation.
+ p = solve_dare(a, b, q, r)
+ k = -np.linalg.solve(b.T.dot(p.dot(b)) + r, b.T.dot(p.dot(a)))
+
+ # Under optimal policy, state tends to 0 like beta^n_timesteps
+ beta = np.abs(np.linalg.eigvals(a + b.dot(k))).max()
+ if beta >= 1.0:
+ raise RuntimeError('Controlled system is unstable.')
+ return p, k, beta
diff --git a/dm_control/suite/manipulator.py b/dm_control/suite/manipulator.py
new file mode 100644
index 00000000..3e253b16
--- /dev/null
+++ b/dm_control/suite/manipulator.py
@@ -0,0 +1,284 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Planar Manipulator domain."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+# Internal dependencies.
+
+from dm_control import mujoco
+from dm_control.mujoco.wrapper.mjbindings import enums
+from dm_control.rl import control
+from dm_control.suite import base
+from dm_control.suite import common
+from dm_control.utils import containers
+from dm_control.utils import rewards
+from dm_control.utils import xml_tools
+
+from lxml import etree
+import numpy as np
+
+_CLOSE = .01 # (Meters) Distance below which a thing is considered close.
+_CONTROL_TIMESTEP = .01 # (Seconds)
+_TIME_LIMIT = 10 # (Seconds)
+_P_IN_HAND = .1 # Probabillity of object-in-hand initial state
+_P_IN_TARGET = .1 # Probabillity of object-in-target initial state
+_ARM_JOINTS = ['arm_root', 'arm_shoulder', 'arm_elbow', 'arm_wrist',
+ 'finger', 'fingertip', 'thumb', 'thumbtip']
+_ALL_PROPS = frozenset(['ball', 'target_ball', 'cup',
+ 'peg', 'target_peg', 'slot'])
+
+SUITE = containers.TaggedTasks()
+
+
+def make_model(use_peg, insert):
+ """Returns a tuple containing the model XML string and a dict of assets."""
+ xml_string = common.read_model('manipulator.xml')
+ parser = etree.XMLParser(remove_blank_text=True)
+ mjcf = etree.XML(xml_string, parser)
+
+ # Select the desired prop.
+ if use_peg:
+ required_props = ['peg', 'target_peg']
+ if insert:
+ required_props += ['slot']
+ else:
+ required_props = ['ball', 'target_ball']
+ if insert:
+ required_props += ['cup']
+
+ # Remove unused props
+ for unused_prop in _ALL_PROPS.difference(required_props):
+ prop = xml_tools.find_element(mjcf, 'body', unused_prop)
+ prop.getparent().remove(prop)
+
+ return etree.tostring(mjcf, pretty_print=True), common.ASSETS
+
+
+@SUITE.add('benchmarking', 'hard')
+def bring_ball(observe_target=True, time_limit=_TIME_LIMIT, random=None):
+ """Returns manipulator bring task with the ball prop."""
+ use_peg = False
+ insert = False
+ physics = Physics.from_xml_string(*make_model(use_peg, insert))
+ task = Bring(use_peg, insert, observe_target, random=random)
+ return control.Environment(
+ physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit)
+
+
+@SUITE.add('hard')
+def bring_peg(observe_target=True, time_limit=_TIME_LIMIT, random=None):
+ """Returns manipulator bring task with the peg prop."""
+ use_peg = True
+ insert = False
+ physics = Physics.from_xml_string(*make_model(use_peg, insert))
+ task = Bring(use_peg, insert, observe_target, random=random)
+ return control.Environment(
+ physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit)
+
+
+@SUITE.add('hard')
+def insert_ball(observe_target=True, time_limit=_TIME_LIMIT, random=None):
+ """Returns manipulator insert task with the ball prop."""
+ use_peg = False
+ insert = True
+ physics = Physics.from_xml_string(*make_model(use_peg, insert))
+ task = Bring(use_peg, insert, observe_target, random=random)
+ return control.Environment(
+ physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit)
+
+
+@SUITE.add('hard')
+def insert_peg(observe_target=True, time_limit=_TIME_LIMIT, random=None):
+ """Returns manipulator insert task with the peg prop."""
+ use_peg = True
+ insert = True
+ physics = Physics.from_xml_string(*make_model(use_peg, insert))
+ task = Bring(use_peg, insert, observe_target, random=random)
+ return control.Environment(
+ physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit)
+
+
+class Physics(mujoco.Physics):
+ """Physics with additional features for the Planar Manipulator domain."""
+
+ def bounded_position(self):
+ """Returns the position, with unbounded angles as sine/cosine."""
+ state = []
+ hinge_joint = enums.mjtJoint.mjJNT_HINGE
+ for joint_id in range(self.model.njnt):
+ joint_value = self.named.data.qpos[joint_id]
+ if (not self.model.jnt_limited[joint_id] and
+ self.model.jnt_type[joint_id] == hinge_joint): # Unbounded hinge.
+ state += [np.sin(joint_value), np.cos(joint_value)]
+ else:
+ state.append(joint_value)
+ return np.asarray(state)
+
+ def body_location(self, body):
+ """Returns the x,z position and y orientation of a body."""
+ body_position = self.named.model.body_pos[body, ['x', 'z']]
+ body_orientation = self.named.model.body_quat[body, ['qw', 'qy']]
+ return np.hstack((body_position, body_orientation))
+
+ def proprioception(self):
+ """Returns the arm state, with unbounded angles as sine/cosine."""
+ arm = []
+ for joint in _ARM_JOINTS:
+ joint_value = self.named.data.qpos[joint]
+ if not self.named.model.jnt_limited[joint]:
+ arm += [np.sin(joint_value), np.cos(joint_value)]
+ else:
+ arm.append(joint_value)
+ return np.hstack(arm + [self.named.data.qvel[_ARM_JOINTS]])
+
+ def touch(self):
+ return np.log1p(self.data.sensordata)
+
+ def site_distance(self, site1, site2):
+ site1_to_site2 = np.diff(self.named.data.site_xpos[[site2, site1]], axis=0)
+ return np.linalg.norm(site1_to_site2)
+
+
+class Bring(base.Task):
+ """A Bring `Task`: bring the prop to the target."""
+
+ def __init__(self, use_peg, insert, observe_target, random=None):
+ """Initialize an instance of the `Bring` task.
+
+ Args:
+ use_peg: A `bool`, whether to replace the ball prop with the peg prop.
+ insert: A `bool`, whether to insert the prop in a receptacle.
+ observe_target: A `bool`, whether the observation contains target info.
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ self._use_peg = use_peg
+ self._target = 'target_peg' if use_peg else 'target_ball'
+ self._object = 'peg' if self._use_peg else 'ball'
+ self._receptacle = 'slot' if self._use_peg else 'cup'
+ self._insert = insert
+ self._observe_target = observe_target
+ super(Bring, self).__init__(random=random)
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode."""
+ # local shortcuts
+ uniform = self.random.uniform
+ model = physics.named.model
+ data = physics.named.data
+
+ # Find a collision-free random initial configuration.
+ penetrating = True
+ while penetrating:
+
+ # Randomise angles of arm joints.
+ is_limited = model.jnt_limited[_ARM_JOINTS].astype(np.bool)
+ joint_range = model.jnt_range[_ARM_JOINTS]
+ lower_limits = np.where(is_limited, joint_range[:, 0], -np.pi)
+ upper_limits = np.where(is_limited, joint_range[:, 1], np.pi)
+ angles = uniform(lower_limits, upper_limits)
+ data.qpos[_ARM_JOINTS] = angles
+
+ # Symmetrize hand.
+ data.qpos['finger'] = data.qpos['thumb']
+
+ # Randomise target location.
+ target_x = uniform(-.4, .4)
+ target_z = uniform(.1, .4)
+ if self._insert:
+ target_angle = uniform(-np.pi/3, np.pi/3)
+ model.body_pos[self._receptacle, ['x', 'z']] = target_x, target_z
+ model.body_quat[self._receptacle, ['qw', 'qy']] = [
+ np.cos(target_angle/2), np.sin(target_angle/2)]
+ else:
+ target_angle = uniform(-np.pi, np.pi)
+
+ model.body_pos[self._target, ['x', 'z']] = target_x, target_z
+ model.body_quat[self._target, ['qw', 'qy']] = [
+ np.cos(target_angle/2), np.sin(target_angle/2)]
+
+ # Randomise object location.
+ object_init_probs = [_P_IN_HAND, _P_IN_TARGET, 1-_P_IN_HAND-_P_IN_TARGET]
+ init_type = np.random.choice(['in_hand', 'in_target', 'uniform'], 1,
+ p=object_init_probs)[0]
+ if init_type == 'in_target':
+ object_x = target_x
+ object_z = target_z
+ object_angle = target_angle
+ elif init_type == 'in_hand':
+ physics.after_reset()
+ object_x = data.site_xpos['grasp', 'x']
+ object_z = data.site_xpos['grasp', 'z']
+ grasp_direction = data.site_xmat['grasp', ['xx', 'zx']]
+ object_angle = np.pi-np.arctan2(grasp_direction[1], grasp_direction[0])
+ else:
+ object_x = uniform(-.5, .5)
+ object_z = uniform(0, .7)
+ object_angle = uniform(0, 2*np.pi)
+ data.qvel[self._object + '_x'] = uniform(-5, 5)
+
+ data.qpos[self._object + '_x'] = object_x
+ data.qpos[self._object + '_z'] = object_z
+ data.qpos[self._object + '_y'] = object_angle
+
+ # Check for collisions.
+ physics.after_reset()
+ penetrating = physics.data.ncon > 0
+
+ def get_observation(self, physics):
+ """Returns either features or only sensors (to be used with pixels)."""
+ obs = collections.OrderedDict()
+ if self._observe_target:
+ obs['position'] = physics.bounded_position()
+ obs['hand'] = physics.body_location('hand')
+ obs['target'] = physics.body_location(self._target)
+ obs['velocity'] = physics.velocity()
+ obs['touch'] = physics.touch()
+ else:
+ obs['proprioception'] = physics.proprioception()
+ obs['touch'] = physics.touch()
+ return obs
+
+ def _is_close(self, distance):
+ return rewards.tolerance(distance, (0, _CLOSE), _CLOSE*2)
+
+ def _peg_reward(self, physics):
+ """Returns a reward for bringing the peg prop to the target."""
+ grasp = self._is_close(physics.site_distance('peg_grasp', 'grasp'))
+ pinch = self._is_close(physics.site_distance('peg_pinch', 'pinch'))
+ grasping = (grasp + pinch) / 2
+ bring = self._is_close(physics.site_distance('peg', 'target_peg'))
+ bring_tip = self._is_close(physics.site_distance('target_peg_tip',
+ 'peg_tip'))
+ bringing = (bring + bring_tip) / 2
+ return max(bringing, grasping/3)
+
+ def _ball_reward(self, physics):
+ """Returns a reward for bringing the ball prop to the target."""
+ return self._is_close(physics.site_distance('ball', 'target_ball'))
+
+ def get_reward(self, physics):
+ """Returns a reward to the agent."""
+ if self._use_peg:
+ return self._peg_reward(physics)
+ else:
+ return self._ball_reward(physics)
diff --git a/dm_control/suite/manipulator.xml b/dm_control/suite/manipulator.xml
new file mode 100644
index 00000000..6e9b2014
--- /dev/null
+++ b/dm_control/suite/manipulator.xml
@@ -0,0 +1,211 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+ >
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/suite/pendulum.py b/dm_control/suite/pendulum.py
new file mode 100644
index 00000000..f6331f66
--- /dev/null
+++ b/dm_control/suite/pendulum.py
@@ -0,0 +1,113 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Pendulum domain."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+# Internal dependencies.
+
+from dm_control import mujoco
+from dm_control.rl import control
+from dm_control.suite import base
+from dm_control.suite import common
+from dm_control.utils import containers
+from dm_control.utils import rewards
+
+import numpy as np
+
+
+_DEFAULT_TIME_LIMIT = 20
+_ANGLE_BOUND = 8
+_COSINE_BOUND = np.cos(np.deg2rad(_ANGLE_BOUND))
+SUITE = containers.TaggedTasks()
+
+
+def get_model_and_assets():
+ """Returns a tuple containing the model XML string and a dict of assets."""
+ return common.read_model('pendulum.xml'), common.ASSETS
+
+
+@SUITE.add('benchmarking')
+def swingup(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+ """Returns pendulum swingup task ."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = SwingUp(random=random)
+ return control.Environment(physics, task, time_limit=time_limit)
+
+
+class Physics(mujoco.Physics):
+ """Physics simulation with additional features for the Pendulum domain."""
+
+ def pole_vertical(self):
+ """Returns vertical (z) component of pole frame."""
+ return self.named.data.xmat['pole', 'zz']
+
+ def angular_velocity(self):
+ """Returns the angular velocity of the pole."""
+ return self.named.data.qvel['hinge']
+
+ def pole_orientation(self):
+ """Returns both horizontal and vertical components of pole frame."""
+ return self.named.data.xmat['pole', ['zz', 'xz']]
+
+
+class SwingUp(base.Task):
+ """A Pendulum `Task` to swing up and balance the pole."""
+
+ def __init__(self, random=None):
+ """Initialize an instance of `Pendulum`.
+
+ Args:
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ super(SwingUp, self).__init__(random=random)
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode.
+
+ Pole is set to a random angle between [-pi, pi).
+
+ Args:
+ physics: An instance of `Physics`.
+
+ """
+ physics.named.data.qpos['hinge'] = self.random.uniform(-np.pi, np.pi)
+
+ def get_observation(self, physics):
+ """Returns an observation.
+
+ Observations are states concatenating pole orientation and angular velocity
+ and pixels from fixed camera.
+
+ Args:
+ physics: An instance of `physics`, Pendulum physics.
+
+ Returns:
+ A `dict` of observation.
+ """
+ obs = collections.OrderedDict()
+ obs['orientation'] = physics.pole_orientation()
+ obs['velocity'] = physics.angular_velocity()
+ return obs
+
+ def get_reward(self, physics):
+ return rewards.tolerance(physics.pole_vertical(), (_COSINE_BOUND, 1))
diff --git a/dm_control/suite/pendulum.xml b/dm_control/suite/pendulum.xml
new file mode 100644
index 00000000..14377ae6
--- /dev/null
+++ b/dm_control/suite/pendulum.xml
@@ -0,0 +1,26 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/suite/point_mass.py b/dm_control/suite/point_mass.py
new file mode 100644
index 00000000..c499e00e
--- /dev/null
+++ b/dm_control/suite/point_mass.py
@@ -0,0 +1,128 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Point-mass domain."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+# Internal dependencies.
+
+from dm_control import mujoco
+from dm_control.rl import control
+from dm_control.suite import base
+from dm_control.suite import common
+from dm_control.suite.utils import randomizers
+from dm_control.utils import containers
+from dm_control.utils import rewards
+
+import numpy as np
+
+_DEFAULT_TIME_LIMIT = 20
+SUITE = containers.TaggedTasks()
+
+
+def get_model_and_assets():
+ """Returns a tuple containing the model XML string and a dict of assets."""
+ return common.read_model('point_mass.xml'), common.ASSETS
+
+
+@SUITE.add('benchmarking', 'easy')
+def easy(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+ """Returns the easy point_mass task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = PointMass(randomize_gains=False, random=random)
+ return control.Environment(physics, task, time_limit=time_limit)
+
+
+@SUITE.add()
+def hard(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+ """Returns the hard point_mass task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = PointMass(randomize_gains=True, random=random)
+ return control.Environment(physics, task, time_limit=time_limit)
+
+
+class Physics(mujoco.Physics):
+ """physics for the point_mass domain."""
+
+ def mass_to_target(self):
+ """Returns the vector from mass to target in global coordinate."""
+ return (self.named.data.geom_xpos['target'] -
+ self.named.data.geom_xpos['pointmass'])
+
+ def mass_to_target_dist(self):
+ """Returns the distance from mass to the target."""
+ return np.linalg.norm(self.mass_to_target())
+
+
+class PointMass(base.Task):
+ """A point_mass `Task` to reach target with smooth reward."""
+
+ def __init__(self, randomize_gains, random=None):
+ """Initialize an instance of `PointMass`.
+
+ Args:
+ randomize_gains: A `bool`, whether to randomize the actuator gains.
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ self._randomize_gains = randomize_gains
+ super(PointMass, self).__init__(random=random)
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode.
+
+ If _randomize_gains is True, the relationship between the controls and
+ the joints is randomized, so that each control actuates a random linear
+ combination of joints.
+
+ Args:
+ physics: An instance of `mujoco.Physics`.
+ """
+ randomizers.randomize_limited_and_rotational_joints(physics, self.random)
+ if self._randomize_gains:
+ dir1 = self.random.randn(2)
+ dir1 /= np.linalg.norm(dir1)
+ # Find another actuation direction that is not 'too parallel' to dir1.
+ parallel = True
+ while parallel:
+ dir2 = self.random.randn(2)
+ dir2 /= np.linalg.norm(dir2)
+ parallel = abs(np.dot(dir1, dir2)) > 0.9
+ physics.model.wrap_prm[[0, 1]] = dir1
+ physics.model.wrap_prm[[2, 3]] = dir2
+
+ def get_observation(self, physics):
+ """Returns an observation of the state."""
+ obs = collections.OrderedDict()
+ obs['position'] = physics.position()
+ obs['velocity'] = physics.velocity()
+ return obs
+
+ def get_reward(self, physics):
+ """Returns a reward to the agent."""
+ target_size = physics.named.model.geom_size['target', 0]
+ near_target = rewards.tolerance(physics.mass_to_target_dist(),
+ bounds=(0, target_size), margin=target_size)
+ control_reward = rewards.tolerance(physics.control(), margin=1,
+ value_at_margin=0,
+ sigmoid='quadratic').mean()
+ small_control = (control_reward + 4) / 5
+ return near_target * small_control
diff --git a/dm_control/suite/point_mass.xml b/dm_control/suite/point_mass.xml
new file mode 100644
index 00000000..c447cf61
--- /dev/null
+++ b/dm_control/suite/point_mass.xml
@@ -0,0 +1,49 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/suite/reacher.py b/dm_control/suite/reacher.py
new file mode 100644
index 00000000..4b701e26
--- /dev/null
+++ b/dm_control/suite/reacher.py
@@ -0,0 +1,113 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Reacher domain."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+# Internal dependencies.
+
+from dm_control import mujoco
+from dm_control.rl import control
+from dm_control.suite import base
+from dm_control.suite import common
+from dm_control.suite.utils import randomizers
+from dm_control.utils import containers
+from dm_control.utils import rewards
+
+import numpy as np
+
+SUITE = containers.TaggedTasks()
+_DEFAULT_TIME_LIMIT = 20
+_BIG_TARGET = .05
+_SMALL_TARGET = .015
+
+
+def get_model_and_assets():
+ """Returns a tuple containing the model XML string and a dict of assets."""
+ return common.read_model('reacher.xml'), common.ASSETS
+
+
+@SUITE.add('benchmarking', 'easy')
+def easy(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+ """Returns reacher with sparse reward with 5e-2 tol and randomized target."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Reacher(target_size=_BIG_TARGET, random=random)
+ return control.Environment(physics, task, time_limit=time_limit)
+
+
+@SUITE.add('benchmarking')
+def hard(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+ """Returns reacher with sparse reward with 1e-2 tol and randomized target."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = Reacher(target_size=_SMALL_TARGET, random=random)
+ return control.Environment(physics, task, time_limit=time_limit)
+
+
+class Physics(mujoco.Physics):
+ """Physics simulation with additional features for the Reacher domain."""
+
+ def finger_to_target(self):
+ """Returns the vector from target to finger in global coordinate."""
+ return (self.named.data.geom_xpos['target'] -
+ self.named.data.geom_xpos['finger'])
+
+ def finger_to_target_dist(self):
+ """Returns the signed distance between the finger and target surface."""
+ return np.linalg.norm(self.finger_to_target())
+
+
+class Reacher(base.Task):
+ """A reacher `Task` to reach the target."""
+
+ def __init__(self, target_size, random=None):
+ """Initialize an instance of `Reacher`.
+
+ Args:
+ target_size: A `float`, tolerance to determine whether finger reached the
+ target.
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ self._target_size = target_size
+ super(Reacher, self).__init__(random=random)
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode."""
+ physics.named.model.geom_size['target', 0] = self._target_size
+ randomizers.randomize_limited_and_rotational_joints(physics, self.random)
+
+ # randomize target position
+ angle = self.random.uniform(0, 2 * np.pi)
+ radius = self.random.uniform(.05, .20)
+ physics.named.model.geom_pos['target', 'x'] = radius * np.sin(angle)
+ physics.named.model.geom_pos['target', 'y'] = radius * np.cos(angle)
+
+ def get_observation(self, physics):
+ """Returns an observation of the state and the target position."""
+ obs = collections.OrderedDict()
+ obs['position'] = physics.position()
+ obs['to_target'] = physics.finger_to_target()
+ obs['velocity'] = physics.velocity()
+ return obs
+
+ def get_reward(self, physics):
+ radii = physics.named.model.geom_size[['target', 'finger'], 0].sum()
+ return rewards.tolerance(physics.finger_to_target_dist(), (0, radii))
diff --git a/dm_control/suite/reacher.xml b/dm_control/suite/reacher.xml
new file mode 100644
index 00000000..343f799c
--- /dev/null
+++ b/dm_control/suite/reacher.xml
@@ -0,0 +1,47 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/suite/stacker.py b/dm_control/suite/stacker.py
new file mode 100644
index 00000000..8de91d27
--- /dev/null
+++ b/dm_control/suite/stacker.py
@@ -0,0 +1,204 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Planar Stacker domain."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+# Internal dependencies.
+
+from dm_control import mujoco
+from dm_control.mujoco.wrapper.mjbindings import enums
+from dm_control.rl import control
+from dm_control.suite import base
+from dm_control.suite import common
+from dm_control.utils import containers
+from dm_control.utils import rewards
+from dm_control.utils import xml_tools
+
+from lxml import etree
+import numpy as np
+
+
+_CLOSE = .01 # (Meters) Distance below which a thing is considered close.
+_CONTROL_TIMESTEP = .01 # (Seconds)
+_TIME_LIMIT = 10 # (Seconds)
+_ARM_JOINTS = ['arm_root', 'arm_shoulder', 'arm_elbow', 'arm_wrist',
+ 'finger', 'fingertip', 'thumb', 'thumbtip']
+
+SUITE = containers.TaggedTasks()
+
+
+def make_model(n_boxes):
+ """Returns a tuple containing the model XML string and a dict of assets."""
+ xml_string = common.read_model('stacker.xml')
+ parser = etree.XMLParser(remove_blank_text=True)
+ mjcf = etree.XML(xml_string, parser)
+
+ # Remove unused boxes
+ for b in range(n_boxes, 4):
+ box = xml_tools.find_element(mjcf, 'body', 'box' + str(b))
+ box.getparent().remove(box)
+
+ return etree.tostring(mjcf, pretty_print=True), common.ASSETS
+
+
+@SUITE.add('hard')
+def stack_2(observable=True, time_limit=_TIME_LIMIT, random=None):
+ """Returns stacker task with 2 boxes."""
+ n_boxes = 2
+ physics = Physics.from_xml_string(*make_model(n_boxes=n_boxes))
+ task = Stack(n_boxes, observable, random=random)
+ return control.Environment(
+ physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit)
+
+
+@SUITE.add('hard')
+def stack_4(observable=True, time_limit=_TIME_LIMIT, random=None):
+ """Returns stacker task with 4 boxes."""
+ n_boxes = 4
+ physics = Physics.from_xml_string(*make_model(n_boxes=n_boxes))
+ task = Stack(n_boxes, observable, random=random)
+ return control.Environment(
+ physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit)
+
+
+class Physics(mujoco.Physics):
+ """Physics with additional features for the Planar Manipulator domain."""
+
+ def bounded_position(self):
+ """Returns the state, with unbounded angles as sine/cosine."""
+ state = []
+ hinge_joint = enums.mjtJoint.mjJNT_HINGE
+ for joint_id in range(self.model.njnt):
+ joint_value = self.named.data.qpos[joint_id]
+ if (not self.model.jnt_limited[joint_id] and
+ self.model.jnt_type[joint_id] == hinge_joint): # Unbounded hinge.
+ state += [np.sin(joint_value), np.cos(joint_value)]
+ else:
+ state.append(joint_value)
+ return np.asarray(state)
+
+ def body_location(self, body):
+ """Returns the x,z position and y orientation of a body."""
+ body_position = self.named.model.body_pos[body, ['x', 'z']]
+ body_orientation = self.named.model.body_quat[body, ['qw', 'qy']]
+ return np.hstack((body_position, body_orientation))
+
+ def proprioception(self):
+ """Returns the arm state, with unbounded angles as sine/cosine."""
+ arm = []
+ for joint in _ARM_JOINTS:
+ joint_value = self.named.data.qpos[joint]
+ if not self.named.model.jnt_limited[joint]:
+ arm += [np.sin(joint_value), np.cos(joint_value)]
+ else:
+ arm.append(joint_value)
+ return np.hstack(arm + [self.named.data.qvel[_ARM_JOINTS]])
+
+ def touch(self):
+ return np.log1p(self.data.sensordata)
+
+ def site_distance(self, site1, site2):
+ site1_to_site2 = np.diff(self.named.data.site_xpos[[site2, site1]], axis=0)
+ return np.linalg.norm(site1_to_site2)
+
+
+class Stack(base.Task):
+ """A Stack `Task`: stack the boxes."""
+
+ def __init__(self, n_boxes, observable, random=None):
+ """Initialize an instance of the `Stack` task.
+
+ Args:
+ n_boxes: An `int`, number of boxes to stack.
+ observable: A `bool`, whether the observation contains target info.
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ self._n_boxes = n_boxes
+ self._observable = observable
+ super(Stack, self).__init__(random=random)
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode."""
+ # local shortcuts
+ uniform = self.random.uniform
+ model = physics.named.model
+ data = physics.named.data
+
+ # Find a collision-free random initial configuration.
+ penetrating = True
+ while penetrating:
+
+ # Randomise angles of arm joints.
+ is_limited = model.jnt_limited[_ARM_JOINTS].astype(np.bool)
+ joint_range = model.jnt_range[_ARM_JOINTS]
+ lower_limits = np.where(is_limited, joint_range[:, 0], -np.pi)
+ upper_limits = np.where(is_limited, joint_range[:, 1], np.pi)
+ angles = uniform(lower_limits, upper_limits)
+ data.qpos[_ARM_JOINTS] = angles
+
+ # Symmetrize hand.
+ data.qpos['finger'] = data.qpos['thumb']
+
+ # Randomise target location.
+ target_height = 2*np.random.randint(self._n_boxes) + 1
+ box_size = model.geom_size['target', 0]
+ model.body_pos['target', 'z'] = box_size * target_height
+ model.body_pos['target', 'x'] = uniform(-.37, .37)
+
+ # Randomise box locations.
+ for b in range(self._n_boxes):
+ box = 'box' + str(b)
+ data.qpos[box + '_x'] = uniform(.1, .3)
+ data.qpos[box + '_z'] = uniform(0, .7)
+ data.qpos[box + '_y'] = uniform(0, 2*np.pi)
+
+ # Check for collisions.
+ physics.after_reset()
+ penetrating = physics.data.ncon > 0
+
+ def get_observation(self, physics):
+ """Returns either features or only sensors (to be used with pixels)."""
+ obs = collections.OrderedDict()
+ if self._observable:
+ box_locations = [physics.body_location('box' + str(b))
+ for b in range(self._n_boxes)]
+ obs['position'] = physics.bounded_position()
+ obs['hand'] = physics.body_location('hand')
+ obs['boxes'] = np.hstack(box_locations)
+ obs['velocity'] = physics.velocity()
+ obs['touch'] = physics.touch()
+ else:
+ obs['proprioception'] = physics.proprioception()
+ obs['touch'] = physics.touch()
+ return obs
+
+ def get_reward(self, physics):
+ """Returns a reward to the agent."""
+ box_size = physics.named.model.geom_size['target', 0]
+ def target_to_box(b):
+ return rewards.tolerance(physics.site_distance('box' + str(b), 'target'),
+ margin=2*box_size)
+ box_is_close = max(target_to_box(b) for b in range(self._n_boxes))
+ hand_to_target = physics.site_distance('grasp', 'target')
+ hand_is_far = rewards.tolerance(hand_to_target, (.1, float('inf')), _CLOSE)
+ return box_is_close * hand_is_far
diff --git a/dm_control/suite/stacker.xml b/dm_control/suite/stacker.xml
new file mode 100644
index 00000000..06e4846c
--- /dev/null
+++ b/dm_control/suite/stacker.xml
@@ -0,0 +1,193 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+ >
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/suite/swimmer.py b/dm_control/suite/swimmer.py
new file mode 100644
index 00000000..057004c5
--- /dev/null
+++ b/dm_control/suite/swimmer.py
@@ -0,0 +1,208 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Procedurally generated Swimmer domain."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+# Internal dependencies.
+
+from dm_control import mujoco
+from dm_control.rl import control
+from dm_control.suite import base
+from dm_control.suite import common
+from dm_control.suite.utils import randomizers
+from dm_control.utils import containers
+from dm_control.utils import rewards
+
+from lxml import etree
+import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+_DEFAULT_TIME_LIMIT = 30
+_CONTROL_TIMESTEP = .03 # (Seconds)
+
+SUITE = containers.TaggedTasks()
+
+
+def get_model_and_assets(n_joints):
+ """Returns a tuple containing the model XML string and a dict of assets.
+
+ Args:
+ n_joints: An integer specifying the number of joints in the swimmer.
+
+ Returns:
+ A tuple `(model_xml_string, assets)`, where `assets` is a dict consisting of
+ `{filename: contents_string}` pairs.
+ """
+ return _make_model(n_joints), common.ASSETS
+
+
+@SUITE.add('benchmarking')
+def swimmer6(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+ """Returns a 6-link swimmer."""
+ return _make_swimmer(6, time_limit, random=random)
+
+
+@SUITE.add('benchmarking')
+def swimmer15(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+ """Returns a 15-link swimmer."""
+ return _make_swimmer(15, time_limit, random=random)
+
+
+def swimmer(n_links=3, time_limit=_DEFAULT_TIME_LIMIT,
+ random=None):
+ """Returns a swimmer with n links."""
+ return _make_swimmer(n_links, time_limit, random=random)
+
+
+def _make_swimmer(n_joints, time_limit=_DEFAULT_TIME_LIMIT, random=None):
+ """Returns a swimmer control environment."""
+ model_string, assets = get_model_and_assets(n_joints)
+ physics = Physics.from_xml_string(model_string, assets=assets)
+ task = Swimmer(random=random)
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP)
+
+
+def _make_model(n_bodies):
+ """Generates an xml string defining a swimmer with `n_bodies` bodies."""
+ if n_bodies < 3:
+ raise ValueError('At least 3 bodies required. Received {}'.format(n_bodies))
+ mjcf = etree.fromstring(common.read_model('swimmer.xml'))
+ head_body = mjcf.find('./worldbody/body')
+ actuator = etree.SubElement(mjcf, 'actuator')
+ sensor = etree.SubElement(mjcf, 'sensor')
+
+ parent = head_body
+ for body_index in xrange(n_bodies - 1):
+ site_name = 'site_{}'.format(body_index)
+ child = _make_body(body_index=body_index)
+ child.append(etree.Element('site', name=site_name))
+ joint_name = 'joint_{}'.format(body_index)
+ joint_limit = 360.0/n_bodies
+ joint_range = '{} {}'.format(-joint_limit, joint_limit)
+ child.append(etree.Element('joint', {'name': joint_name,
+ 'range': joint_range}))
+ motor_name = 'motor_{}'.format(body_index)
+ actuator.append(etree.Element('motor', name=motor_name, joint=joint_name))
+ velocimeter_name = 'velocimeter_{}'.format(body_index)
+ sensor.append(etree.Element('velocimeter', name=velocimeter_name,
+ site=site_name))
+ gyro_name = 'gyro_{}'.format(body_index)
+ sensor.append(etree.Element('gyro', name=gyro_name, site=site_name))
+ parent.append(child)
+ parent = child
+
+ # Move tracking cameras further away from the swimmer according to its length.
+ cameras = mjcf.findall('./worldbody/body/camera')
+ scale = n_bodies / 6.0
+ for cam in cameras:
+ if cam.get('mode') == 'trackcom':
+ old_pos = cam.get('pos').split(' ')
+ new_pos = ' '.join([str(float(dim) * scale) for dim in old_pos])
+ cam.set('pos', new_pos)
+
+ return etree.tostring(mjcf, pretty_print=True)
+
+
+def _make_body(body_index):
+ """Generates an xml string defining a single physical body."""
+ body_name = 'segment_{}'.format(body_index)
+ visual_name = 'visual_{}'.format(body_index)
+ inertial_name = 'inertial_{}'.format(body_index)
+ body = etree.Element('body', name=body_name)
+ body.set('pos', '0 .1 0')
+ etree.SubElement(body, 'geom', {'class': 'visual', 'name': visual_name})
+ etree.SubElement(body, 'geom', {'class': 'inertial', 'name': inertial_name})
+ return body
+
+
+class Physics(mujoco.Physics):
+ """Physics simulation with additional features for the swimmer domain."""
+
+ def nose_to_target(self):
+ """Returns a vector from nose to target in local coordinate of the head."""
+ nose_to_target = (self.named.data.geom_xpos['target'] -
+ self.named.data.geom_xpos['nose'])
+ head_orientation = self.named.data.xmat['head'].reshape(3, 3)
+ return nose_to_target.dot(head_orientation)[:2]
+
+ def nose_to_target_dist(self):
+ """Returns the distance from the nose to the target."""
+ return np.linalg.norm(self.nose_to_target())
+
+ def body_velocities(self):
+ """Returns local body velocities: x,y linear, z rotational."""
+ xvel_local = self.data.sensordata[12:].reshape((-1, 6))
+ vx_vy_wz = [0, 1, 5] # Indices for linear x,y vels and rotational z vel.
+ return xvel_local[:, vx_vy_wz].ravel()
+
+ def joints(self):
+ """Returns all internal joint angles (excluding root joints)."""
+ return self.data.qpos[3:]
+
+
+class Swimmer(base.Task):
+ """A swimmer `Task` to reach the target or just swim."""
+
+ def __init__(self, random=None):
+ """Initializes an instance of `Swimmer`.
+
+ Args:
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ super(Swimmer, self).__init__(random=random)
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode.
+
+ Initializes the swimmer orientation to [-pi, pi) and the relative joint
+ angle of each joint uniformly within its range.
+
+ Args:
+ physics: An instance of `Physics`.
+ """
+ # Random joint angles:
+ randomizers.randomize_limited_and_rotational_joints(physics, self.random)
+ # Random target position.
+ close_target = self.random.rand() < .2 # Probability of a close target.
+ target_box = .3 if close_target else 2
+ xpos, ypos = self.random.uniform(-target_box, target_box, size=2)
+ physics.named.model.geom_pos['target', 'x'] = xpos
+ physics.named.model.geom_pos['target', 'y'] = ypos
+ physics.named.model.light_pos['target_light', 'x'] = xpos
+ physics.named.model.light_pos['target_light', 'y'] = ypos
+
+ def get_observation(self, physics):
+ """Returns an observation of joint angles, body velocities and target."""
+ obs = collections.OrderedDict()
+ obs['joints'] = physics.joints()
+ obs['to_target'] = physics.nose_to_target()
+ obs['body_velocities'] = physics.body_velocities()
+ return obs
+
+ def get_reward(self, physics):
+ """Returns a smooth reward."""
+ target_size = physics.named.model.geom_size['target', 0]
+ return rewards.tolerance(physics.nose_to_target_dist(),
+ bounds=(0, target_size),
+ margin=5*target_size,
+ sigmoid='long_tail')
diff --git a/dm_control/suite/swimmer.xml b/dm_control/suite/swimmer.xml
new file mode 100644
index 00000000..29c7bc81
--- /dev/null
+++ b/dm_control/suite/swimmer.xml
@@ -0,0 +1,57 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/suite/tests/domains_test.py b/dm_control/suite/tests/domains_test.py
new file mode 100644
index 00000000..933ec40b
--- /dev/null
+++ b/dm_control/suite/tests/domains_test.py
@@ -0,0 +1,157 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for dm_control.suite domains."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+
+from absl.testing import absltest
+from absl.testing import parameterized
+
+from dm_control import suite
+
+import numpy as np
+import six
+
+_NUM_EPISODES = 5
+_NUM_STEPS_PER_EPISODE = 10
+
+
+class DomainTest(parameterized.TestCase):
+ """Tests run on all the tasks registered."""
+
+ def test_constants(self):
+ num_tasks = sum(len(tasks) for tasks in
+ six.itervalues(suite.TASKS_BY_DOMAIN))
+
+ self.assertEqual(len(suite.ALL_TASKS), num_tasks)
+
+ def _validate_observation(self, observation_dict, observation_spec):
+ obs = observation_dict.copy()
+ for name, spec in six.iteritems(observation_spec):
+ arr = obs.pop(name)
+ self.assertEqual(arr.shape, spec.shape)
+ self.assertEqual(arr.dtype, spec.dtype)
+ self.assertTrue(
+ np.all(np.isfinite(arr)),
+ msg='{!r} has non-finite value(s): {!r}'.format(name, arr))
+ self.assertEmpty(
+ obs,
+ msg='Observation contains arrays(s) that are not in the spec: {!r}'
+ .format(obs))
+
+ def _validate_reward_range(self, time_step):
+ if time_step.first():
+ self.assertIsNone(time_step.reward)
+ else:
+ self.assertIsInstance(time_step.reward, float)
+ self.assertBetween(time_step.reward, 0, 1)
+
+ def _validate_discount(self, time_step):
+ if time_step.first():
+ self.assertIsNone(time_step.discount)
+ else:
+ self.assertIsInstance(time_step.discount, float)
+ self.assertBetween(time_step.discount, 0, 1)
+
+ def _validate_control_range(self, lower_bounds, upper_bounds):
+ for b in lower_bounds:
+ self.assertEqual(b, -1.0)
+ for b in upper_bounds:
+ self.assertEqual(b, 1.0)
+
+ @parameterized.parameters(*suite.ALL_TASKS)
+ def test_components_have_names(self, domain, task):
+ env = suite.load(domain, task)
+ model = env.physics.model
+
+ object_types_and_size_fields = {
+ 'body': 'nbody',
+ 'joint': 'njnt',
+ 'geom': 'ngeom',
+ 'site': 'nsite',
+ 'camera': 'ncam',
+ 'light': 'nlight',
+ 'mesh': 'nmesh',
+ 'hfield': 'nhfield',
+ 'texture': 'ntex',
+ 'material': 'nmat',
+ 'equality': 'neq',
+ 'tendon': 'ntendon',
+ 'actuator': 'nu',
+ 'sensor': 'nsensor',
+ 'numeric': 'nnumeric',
+ 'text': 'ntext',
+ 'tuple': 'ntuple',
+ }
+
+ for object_type, size_field in six.iteritems(object_types_and_size_fields):
+ for idx in range(getattr(model, size_field)):
+ object_name = model.id2name(idx, object_type)
+ self.assertNotEqual(object_name, '',
+ msg='Model {!r} contains unnamed {!r} with ID {}.'
+ .format(model.name, object_type, idx))
+
+ @parameterized.parameters(*suite.ALL_TASKS)
+ def test_task_runs(self, domain, task):
+ """Tests task runs correctly and observation is coherent with spec."""
+ is_benchmark = (domain, task) in suite.BENCHMARKING
+ env = suite.load(domain, task)
+
+ observation_spec = env.observation_spec()
+ action_spec = env.action_spec()
+ model = env.physics.model
+
+ # Check cameras.
+ self.assertGreaterEqual(model.ncam, 2, 'Model {!r} should have at least 2 '
+ 'cameras, has {!r}.'.format(model.name, model.ncam))
+
+ # Check action bounds.
+ lower_bounds = action_spec.minimum
+ upper_bounds = action_spec.maximum
+
+ if is_benchmark:
+ self._validate_control_range(lower_bounds, upper_bounds)
+
+ lower_bounds = np.where(np.isinf(lower_bounds), -1.0, lower_bounds)
+ upper_bounds = np.where(np.isinf(upper_bounds), 1.0, upper_bounds)
+
+ # Run a partial episode, check observations, rewards, discount.
+ for _ in range(_NUM_EPISODES):
+ time_step = env.reset()
+ for _ in range(_NUM_STEPS_PER_EPISODE):
+ self._validate_observation(time_step.observation, observation_spec)
+ if is_benchmark:
+ self._validate_reward_range(time_step)
+ self._validate_discount(time_step)
+ action = np.random.uniform(lower_bounds, upper_bounds)
+ time_step = env.step(action)
+
+ @parameterized.parameters(*suite.ALL_TASKS)
+ def test_visualize_reward(self, domain, task):
+ env = suite.load(domain, task)
+ env.task.visualise_reward = True
+ env.reset()
+ action = np.zeros(env.action_spec().shape)
+ for _ in range(2):
+ env.step(action)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/suite/tests/loader_test.py b/dm_control/suite/tests/loader_test.py
new file mode 100644
index 00000000..cbce4f50
--- /dev/null
+++ b/dm_control/suite/tests/loader_test.py
@@ -0,0 +1,52 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for the dm_control.suite loader."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+
+from absl.testing import absltest
+
+from dm_control import suite
+from dm_control.rl import control
+
+
+class LoaderTest(absltest.TestCase):
+
+ def test_load_without_kwargs(self):
+ env = suite.load('cartpole', 'swingup')
+ self.assertIsInstance(env, control.Environment)
+
+ def test_load_with_kwargs(self):
+ env = suite.load('cartpole', 'swingup',
+ task_kwargs={'time_limit': 40, 'random': 99})
+ self.assertIsInstance(env, control.Environment)
+
+
+class LoaderConstantsTest(absltest.TestCase):
+
+ def testSuiteConstants(self):
+ self.assertNotEmpty(suite.BENCHMARKING)
+ self.assertNotEmpty(suite.EASY)
+ self.assertNotEmpty(suite.HARD)
+ self.assertNotEmpty(suite.EXTRA)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/suite/tests/lqr_test.py b/dm_control/suite/tests/lqr_test.py
new file mode 100644
index 00000000..214fe822
--- /dev/null
+++ b/dm_control/suite/tests/lqr_test.py
@@ -0,0 +1,88 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests specific to the LQR domain."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import unittest
+
+# Internal dependencies.
+from absl import logging
+
+from absl.testing import absltest
+from absl.testing import parameterized
+
+from dm_control.suite import lqr
+from dm_control.suite import lqr_solver
+
+import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+
+class LqrTest(parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ ('lqr_2_1', lqr.lqr_2_1),
+ ('lqr_6_2', lqr.lqr_6_2))
+ def test_lqr_optimal_policy(self, make_env):
+ env = make_env()
+ p, k, beta = lqr_solver.solve(env)
+ self.assertPolicyisOptimal(env, p, k, beta)
+
+ @parameterized.named_parameters(
+ ('lqr_2_1', lqr.lqr_2_1),
+ ('lqr_6_2', lqr.lqr_6_2))
+ @unittest.skipUnless(
+ condition=lqr_solver.sp,
+ reason='scipy is not available, so non-scipy DARE solver is the default.')
+ def test_lqr_optimal_policy_no_scipy(self, make_env):
+ env = make_env()
+ old_sp = lqr_solver.sp
+ try:
+ lqr_solver.sp = None # Force the solver to use the non-scipy code path.
+ p, k, beta = lqr_solver.solve(env)
+ finally:
+ lqr_solver.sp = old_sp
+ self.assertPolicyisOptimal(env, p, k, beta)
+
+ def assertPolicyisOptimal(self, env, p, k, beta):
+ tolerance = 1e-3
+ n_steps = int(math.ceil(math.log10(tolerance) / math.log10(beta)))
+ logging.info('%d timesteps for %g convergence.', n_steps, tolerance)
+ total_loss = 0.0
+
+ timestep = env.reset()
+ initial_state = np.hstack((timestep.observation['position'],
+ timestep.observation['velocity']))
+ logging.info('Measuring total cost over %d steps.', n_steps)
+ for _ in xrange(n_steps):
+ x = np.hstack((timestep.observation['position'],
+ timestep.observation['velocity']))
+ # u = k*x is the optimal policy
+ u = k.dot(x)
+ total_loss += 1 - (timestep.reward or 0.0)
+ timestep = env.step(u)
+
+ logging.info('Analytical expected total cost is .5*x^T*p*x.')
+ expected_loss = .5 * initial_state.T.dot(p).dot(initial_state)
+ logging.info('Comparing measured and predicted costs.')
+ np.testing.assert_allclose(expected_loss, total_loss, rtol=tolerance)
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/suite/utils/__init__.py b/dm_control/suite/utils/__init__.py
new file mode 100644
index 00000000..bde50111
--- /dev/null
+++ b/dm_control/suite/utils/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Utility functions used in the control suite."""
+
+from dm_control.suite.utils import randomizers
diff --git a/dm_control/suite/utils/parse_amc.py b/dm_control/suite/utils/parse_amc.py
new file mode 100644
index 00000000..95c34932
--- /dev/null
+++ b/dm_control/suite/utils/parse_amc.py
@@ -0,0 +1,254 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Parse and convert amc motion capture data."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+# Internal dependencies.
+
+from dm_control.mujoco.wrapper import mjbindings
+
+import numpy as np
+
+from scipy import interpolate
+
+mjlib = mjbindings.mjlib
+
+MOCAP_DT = 1.0/120.0
+CONVERSION_LENGTH = 0.056444
+
+_CMU_MOCAP_JOINT_ORDER = (
+ 'root0', 'root1', 'root2', 'root3', 'root4', 'root5', 'lowerbackrx',
+ 'lowerbackry', 'lowerbackrz', 'upperbackrx', 'upperbackry', 'upperbackrz',
+ 'thoraxrx', 'thoraxry', 'thoraxrz', 'lowerneckrx', 'lowerneckry',
+ 'lowerneckrz', 'upperneckrx', 'upperneckry', 'upperneckrz', 'headrx',
+ 'headry', 'headrz', 'rclaviclery', 'rclaviclerz', 'rhumerusrx',
+ 'rhumerusry', 'rhumerusrz', 'rradiusrx', 'rwristry', 'rhandrx', 'rhandrz',
+ 'rfingersrx', 'rthumbrx', 'rthumbrz', 'lclaviclery', 'lclaviclerz',
+ 'lhumerusrx', 'lhumerusry', 'lhumerusrz', 'lradiusrx', 'lwristry',
+ 'lhandrx', 'lhandrz', 'lfingersrx', 'lthumbrx', 'lthumbrz', 'rfemurrx',
+ 'rfemurry', 'rfemurrz', 'rtibiarx', 'rfootrx', 'rfootrz', 'rtoesrx',
+ 'lfemurrx', 'lfemurry', 'lfemurrz', 'ltibiarx', 'lfootrx', 'lfootrz',
+ 'ltoesrx'
+)
+
+Converted = collections.namedtuple('Converted',
+ ['qpos', 'qvel', 'time'])
+
+
+def convert(file_name, physics, timestep):
+ """Converts the parsed .amc values into qpos and qvel values and resamples.
+
+ Args:
+ file_name: The .amc file to be parsed and converted.
+ physics: The corresponding physics instance.
+ timestep: Desired output interval between resampled frames.
+
+ Returns:
+ A namedtuple with fields:
+ `qpos`, a numpy array containing converted positional variables.
+ `qvel`, a numpy array containing converted velocity variables.
+ `time`, a numpy array containing the corresponding times.
+ """
+ frame_values = parse(file_name)
+ joint2index = {}
+ for name in physics.named.data.qpos.axes.row.names:
+ joint2index[name] = physics.named.data.qpos.axes.row.convert_key_item(name)
+ index2joint = {}
+ for joint, index in joint2index.items():
+ if isinstance(index, slice):
+ indices = range(index.start, index.stop)
+ else:
+ indices = [index]
+ for ii in indices:
+ index2joint[ii] = joint
+
+ # Convert frame_values to qpos
+ amcvals2qpos_transformer = Amcvals2qpos(index2joint, _CMU_MOCAP_JOINT_ORDER)
+ qpos_values = []
+ for frame_value in frame_values:
+ qpos_values.append(amcvals2qpos_transformer(frame_value))
+ qpos_values = np.stack(qpos_values) # Time by nq
+
+ # Interpolate/resample.
+ # Note: interpolate quaternions rather than euler angles (slerp).
+ # see https://en.wikipedia.org/wiki/Slerp
+ qpos_values_resampled = []
+ time_vals = np.arange(0, len(frame_values)*MOCAP_DT - 1e-8, MOCAP_DT)
+ time_vals_new = np.arange(0, len(frame_values)*MOCAP_DT, timestep)
+ while time_vals_new[-1] > time_vals[-1]:
+ time_vals_new = time_vals_new[:-1]
+
+ for i in xrange(qpos_values.shape[1]):
+ f = interpolate.splrep(time_vals, qpos_values[:, i])
+ qpos_values_resampled.append(interpolate.splev(time_vals_new, f))
+
+ qpos_values_resampled = np.stack(qpos_values_resampled) # nq by ntime
+
+ qvel_list = []
+ for t in range(qpos_values_resampled.shape[1]-1):
+ p_tp1 = qpos_values_resampled[:, t + 1]
+ p_t = qpos_values_resampled[:, t]
+ qvel = [(p_tp1[:3]-p_t[:3])/ timestep,
+ mj_quat2vel(mj_quatdiff(p_t[3:7], p_tp1[3:7]), timestep),
+ (p_tp1[7:]-p_t[7:])/ timestep]
+ qvel_list.append(np.concatenate(qvel))
+
+ qvel_values_resampled = np.vstack(qvel_list).T
+
+ return Converted(qpos_values_resampled, qvel_values_resampled, time_vals_new)
+
+
+def parse(file_name):
+ """Parses the amc file format."""
+ values = []
+ fid = open(file_name, 'r')
+ line = fid.readline().strip()
+ frame_ind = 1
+ first_frame = True
+ while True:
+ # Parse first frame.
+ if first_frame and line[0] == str(frame_ind):
+ first_frame = False
+ frame_ind += 1
+ frame_vals = []
+ while True:
+ line = fid.readline().strip()
+ if not line or line == str(frame_ind):
+ values.append(np.array(frame_vals, dtype=np.float))
+ break
+ tokens = line.split()
+ frame_vals.extend(tokens[1:])
+ # Parse other frames.
+ elif line == str(frame_ind):
+ frame_ind += 1
+ frame_vals = []
+ while True:
+ line = fid.readline().strip()
+ if not line or line == str(frame_ind):
+ values.append(np.array(frame_vals, dtype=np.float))
+ break
+ tokens = line.split()
+ frame_vals.extend(tokens[1:])
+ else:
+ line = fid.readline().strip()
+ if not line:
+ break
+ return values
+
+
+class Amcvals2qpos(object):
+ """Callable that converts .amc values for a frame and to MuJoCo qpos format.
+ """
+
+ def __init__(self, index2joint, joint_order):
+ """Initializes a new Amcvals2qpos instance.
+
+ Args:
+ index2joint: List of joint angles in .amc file.
+ joint_order: List of joint names in MuJoco MJCF.
+ """
+ # Root is x,y,z, then quat.
+ # need to get indices of qpos that order for amc default order
+ self.qpos_root_xyz_ind = [0, 1, 2]
+ self.root_xyz_ransform = np.array(
+ [[1, 0, 0], [0, 0, -1], [0, 1, 0]]) * CONVERSION_LENGTH
+ self.qpos_root_quat_ind = [3, 4, 5, 6]
+ amc2qpos_transform = np.zeros((len(index2joint), len(joint_order)))
+ for i in xrange(len(index2joint)):
+ for j in xrange(len(joint_order)):
+ if index2joint[i] == joint_order[j]:
+ if 'rx' in index2joint[i]:
+ amc2qpos_transform[i][j] = 1
+ elif 'ry' in index2joint[i]:
+ amc2qpos_transform[i][j] = 1
+ elif 'rz' in index2joint[i]:
+ amc2qpos_transform[i][j] = 1
+ self.amc2qpos_transform = amc2qpos_transform
+
+ def __call__(self, amc_val):
+ """Converts a `.amc` frame to MuJoCo qpos format."""
+ amc_val_rad = np.deg2rad(amc_val)
+ qpos = np.dot(self.amc2qpos_transform, amc_val_rad)
+
+ # Root.
+ qpos[:3] = np.dot(self.root_xyz_ransform, amc_val[:3])
+ qpos_quat = euler2quat(amc_val[3], amc_val[4], amc_val[5])
+ qpos_quat = mj_quatprod(euler2quat(90, 0, 0), qpos_quat)
+
+ for i, ind in enumerate(self.qpos_root_quat_ind):
+ qpos[ind] = qpos_quat[i]
+
+ return qpos
+
+
+def euler2quat(ax, ay, az):
+ """Converts euler angles to a quaternion.
+
+ Note: rotation order is zyx
+
+ Args:
+ ax: Roll angle (deg)
+ ay: Pitch angle (deg).
+ az: Yaw angle (deg).
+
+ Returns:
+ A numpy array representing the rotation as a quaternion.
+ """
+ r1 = az
+ r2 = ay
+ r3 = ax
+
+ c1 = np.cos(np.deg2rad(r1 / 2))
+ s1 = np.sin(np.deg2rad(r1 / 2))
+ c2 = np.cos(np.deg2rad(r2 / 2))
+ s2 = np.sin(np.deg2rad(r2 / 2))
+ c3 = np.cos(np.deg2rad(r3 / 2))
+ s3 = np.sin(np.deg2rad(r3 / 2))
+
+ q0 = c1 * c2 * c3 + s1 * s2 * s3
+ q1 = c1 * c2 * s3 - s1 * s2 * c3
+ q2 = c1 * s2 * c3 + s1 * c2 * s3
+ q3 = s1 * c2 * c3 - c1 * s2 * s3
+
+ return np.array([q0, q1, q2, q3])
+
+
+def mj_quatprod(q, r):
+ quaternion = np.zeros(4)
+ mjlib.mju_mulQuat(quaternion, np.ascontiguousarray(q),
+ np.ascontiguousarray(r))
+ return quaternion
+
+
+def mj_quat2vel(q, dt):
+ vel = np.zeros(3)
+ mjlib.mju_quat2Vel(vel, np.ascontiguousarray(q), dt)
+ return vel
+
+
+def mj_quatneg(q):
+ quaternion = np.zeros(4)
+ mjlib.mju_negQuat(quaternion, np.ascontiguousarray(q))
+ return quaternion
+
+
+def mj_quatdiff(source, target):
+ return mj_quatprod(mj_quatneg(source), np.ascontiguousarray(target))
diff --git a/dm_control/suite/utils/parse_amc_test.py b/dm_control/suite/utils/parse_amc_test.py
new file mode 100644
index 00000000..1d2f40e2
--- /dev/null
+++ b/dm_control/suite/utils/parse_amc_test.py
@@ -0,0 +1,68 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for parse_amc utility."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+# Internal dependencies.
+
+from absl.testing import absltest
+from dm_control.suite import humanoid_CMU
+from dm_control.suite.utils import parse_amc
+
+from dm_control.utils import resources
+
+_TEST_AMC_PATH = resources.GetResourceFilename(
+ os.path.join(os.path.dirname(__file__), '../demos/zeros.amc'))
+
+
+class ParseAMCTest(absltest.TestCase):
+
+ def test_sizes_of_parsed_data(self):
+
+ # Instantiate the humanoid environment.
+ env = humanoid_CMU.stand()
+
+ # Parse and convert specified clip.
+ converted = parse_amc.convert(
+ _TEST_AMC_PATH, env.physics, env.control_timestep())
+
+ self.assertEqual(converted.qpos.shape[0], 63)
+ self.assertEqual(converted.qvel.shape[0], 62)
+ self.assertEqual(converted.time.shape[0], converted.qpos.shape[1])
+ self.assertEqual(converted.qpos.shape[1],
+ converted.qvel.shape[1] + 1)
+
+ # Parse and convert specified clip -- WITH SMALLER TIMESTEP
+ converted2 = parse_amc.convert(
+ _TEST_AMC_PATH, env.physics, 0.5 * env.control_timestep())
+
+ self.assertEqual(converted2.qpos.shape[0], 63)
+ self.assertEqual(converted2.qvel.shape[0], 62)
+ self.assertEqual(converted2.time.shape[0], converted2.qpos.shape[1])
+ self.assertEqual(converted.qpos.shape[1],
+ converted.qvel.shape[1] + 1)
+
+ # Compare sizes of parsed objects for different timesteps
+ self.assertEqual(converted.qpos.shape[1] * 2, converted2.qpos.shape[1])
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/suite/utils/randomizers.py b/dm_control/suite/utils/randomizers.py
new file mode 100644
index 00000000..df557ae1
--- /dev/null
+++ b/dm_control/suite/utils/randomizers.py
@@ -0,0 +1,94 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Randomization functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+
+from dm_control.mujoco.wrapper import mjbindings
+
+import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+
+def random_limited_quaternion(random, limit):
+ """Generates a random quaternion limited to the specified rotations."""
+ axis = random.randn(3)
+ axis /= np.linalg.norm(axis)
+ angle = random.rand() * limit
+
+ quaternion = np.zeros(4)
+ mjbindings.mjlib.mju_axisAngle2Quat(quaternion, axis, angle)
+
+ return quaternion
+
+
+def randomize_limited_and_rotational_joints(physics, random=None):
+ """Randomizes the positions of joints defined in the physics body.
+
+ The following randomization rules apply:
+ - Bounded joints (hinges or sliders) are sampled uniformly in the bounds.
+ - Unbounded hinges are samples uniformly in [-pi, pi]
+ - Quaternions for unlimited free joints and ball joints are sampled
+ uniformly on the unit 3-sphere.
+ - Quaternions for limited ball joints are sampled uniformly on a sector
+ of the unit 3-sphere.
+ - The linear degrees of freedom of free joints are not randomized.
+
+ Args:
+ physics: Instance of 'Physics' class that holds a loaded model.
+ random: Optional instance of 'np.random.RandomState'. Defaults to the global
+ NumPy random state.
+ """
+ random = random or np.random
+
+ hinge = mjbindings.enums.mjtJoint.mjJNT_HINGE
+ slide = mjbindings.enums.mjtJoint.mjJNT_SLIDE
+ ball = mjbindings.enums.mjtJoint.mjJNT_BALL
+ free = mjbindings.enums.mjtJoint.mjJNT_FREE
+
+ qpos = physics.named.data.qpos
+
+ for joint_id in xrange(physics.model.njnt):
+ joint_name = physics.model.id2name(joint_id, 'joint')
+ joint_type = physics.model.jnt_type[joint_id]
+ is_limited = physics.model.jnt_limited[joint_id]
+ range_min, range_max = physics.model.jnt_range[joint_id]
+
+ if is_limited:
+ if joint_type == hinge or joint_type == slide:
+ qpos[joint_name] = random.uniform(range_min, range_max)
+
+ elif joint_type == ball:
+ qpos[joint_name] = random_limited_quaternion(random, range_max)
+
+ else:
+ if joint_type == hinge:
+ qpos[joint_name] = random.uniform(-np.pi, np.pi)
+
+ elif joint_type == ball:
+ quat = random.randn(4)
+ quat /= np.linalg.norm(quat)
+ qpos[joint_name] = quat
+
+ elif joint_type == free:
+ quat = random.rand(4)
+ quat /= np.linalg.norm(quat)
+ qpos[joint_name][3:] = quat
+
diff --git a/dm_control/suite/utils/randomizers_test.py b/dm_control/suite/utils/randomizers_test.py
new file mode 100644
index 00000000..f6dd6a1d
--- /dev/null
+++ b/dm_control/suite/utils/randomizers_test.py
@@ -0,0 +1,165 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for randomizers.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+
+from absl.testing import absltest
+from absl.testing import parameterized
+
+from dm_control.mujoco import engine
+from dm_control.mujoco.wrapper.mjbindings import mjlib
+from dm_control.suite.utils import randomizers
+
+import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+
+class RandomizeUnlimitedJointsTest(parameterized.TestCase):
+
+ def setUp(self):
+ self.rand = np.random.RandomState(100)
+
+ def test_single_joint_of_each_type(self):
+ physics = engine.Physics.from_xml_string("""
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ """)
+
+ randomizers.randomize_limited_and_rotational_joints(physics, self.rand)
+ self.assertNotEqual(0., physics.named.data.qpos['hinge'])
+ self.assertNotEqual(0., physics.named.data.qpos['limited_hinge'])
+ self.assertNotEqual(0., physics.named.data.qpos['limited_slide'])
+
+ self.assertNotEqual(0., np.sum(physics.named.data.qpos['ball']))
+ self.assertNotEqual(0., np.sum(physics.named.data.qpos['limited_ball']))
+
+ self.assertNotEqual(0., np.sum(physics.named.data.qpos['free'][3:]))
+
+ # Unlimited slide and the positional part of the free joint remains
+ # uninitialized.
+ self.assertEqual(0., physics.named.data.qpos['slide'])
+ self.assertEqual(0., np.sum(physics.named.data.qpos['free'][:3]))
+
+ def test_multiple_joints_of_same_type(self):
+ physics = engine.Physics.from_xml_string("""
+
+
+
+
+
+
+
+
+ """)
+
+ randomizers.randomize_limited_and_rotational_joints(physics, self.rand)
+ self.assertNotEqual(0., physics.named.data.qpos['hinge_1'])
+ self.assertNotEqual(0., physics.named.data.qpos['hinge_2'])
+ self.assertNotEqual(0., physics.named.data.qpos['hinge_3'])
+
+ self.assertNotEqual(physics.named.data.qpos['hinge_1'],
+ physics.named.data.qpos['hinge_2'])
+
+ self.assertNotEqual(physics.named.data.qpos['hinge_2'],
+ physics.named.data.qpos['hinge_3'])
+
+ self.assertNotEqual(physics.named.data.qpos['hinge_1'],
+ physics.named.data.qpos['hinge_3'])
+
+ def test_unlimited_hinge_randomization_range(self):
+ physics = engine.Physics.from_xml_string("""
+
+
+
+
+
+
+ """)
+
+ for _ in xrange(10):
+ randomizers.randomize_limited_and_rotational_joints(physics, self.rand)
+ self.assertBetween(physics.named.data.qpos['hinge'], -np.pi, np.pi)
+
+ def test_limited_1d_joint_limits_are_respected(self):
+ physics = engine.Physics.from_xml_string("""
+
+
+
+
+
+
+
+
+
+
+ """)
+
+ for _ in xrange(10):
+ randomizers.randomize_limited_and_rotational_joints(physics, self.rand)
+ self.assertBetween(physics.named.data.qpos['hinge'],
+ np.deg2rad(0), np.deg2rad(10))
+ self.assertBetween(physics.named.data.qpos['slide'], 30, 50)
+
+ def test_limited_ball_joint_are_respected(self):
+ physics = engine.Physics.from_xml_string("""
+
+
+
+
+
+
+ """)
+
+ body_axis = np.array([1., 0., 0.])
+ joint_axis = np.zeros(3)
+ for _ in xrange(10):
+ randomizers.randomize_limited_and_rotational_joints(physics, self.rand)
+
+ quat = physics.named.data.qpos['ball']
+ mjlib.mju_rotVecQuat(joint_axis, body_axis, quat)
+ angle_cos = np.dot(body_axis, joint_axis)
+ self.assertGreater(angle_cos, 0.5) # cos(60) = 0.5
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/suite/walker.py b/dm_control/suite/walker.py
new file mode 100644
index 00000000..9aedf9b7
--- /dev/null
+++ b/dm_control/suite/walker.py
@@ -0,0 +1,153 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Planar Walker Domain."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+# Internal dependencies.
+
+from dm_control import mujoco
+from dm_control.rl import control
+from dm_control.suite import base
+from dm_control.suite import common
+from dm_control.suite.utils import randomizers
+from dm_control.utils import containers
+from dm_control.utils import rewards
+
+
+_DEFAULT_TIME_LIMIT = 25
+_CONTROL_TIMESTEP = .025
+
+# Minimal height of torso over foot above which stand reward is 1.
+_STAND_HEIGHT = 1.2
+
+# Horizontal speeds (meters/second) above which move reward is 1.
+_WALK_SPEED = 1
+_RUN_SPEED = 8
+
+
+SUITE = containers.TaggedTasks()
+
+
+def get_model_and_assets():
+ """Returns a tuple containing the model XML string and a dict of assets."""
+ return common.read_model('walker.xml'), common.ASSETS
+
+
+@SUITE.add('benchmarking')
+def stand(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+ """Returns the Stand task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = PlanarWalker(move_speed=0, random=random)
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP)
+
+
+@SUITE.add('benchmarking')
+def walk(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+ """Returns the Walk task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = PlanarWalker(move_speed=_WALK_SPEED, random=random)
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP)
+
+
+@SUITE.add('benchmarking')
+def run(time_limit=_DEFAULT_TIME_LIMIT, random=None):
+ """Returns the Run task."""
+ physics = Physics.from_xml_string(*get_model_and_assets())
+ task = PlanarWalker(move_speed=_RUN_SPEED, random=random)
+ return control.Environment(
+ physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP)
+
+
+class Physics(mujoco.Physics):
+ """Physics simulation with additional features for the Walker domain."""
+
+ def torso_upright(self):
+ """Returns projection from z-axes of torso to the z-axes of world."""
+ return self.named.data.xmat['torso', 'zz']
+
+ def torso_height(self):
+ """Returns the height of the torso."""
+ return self.named.data.xpos['torso', 'z']
+
+ def horizontal_velocity(self):
+ """Returns the horizontal velocity of the center-of-mass."""
+ return self.named.data.subtree_linvel['torso', 'x']
+
+ def orientations(self):
+ """Returns planar orientations of all bodies."""
+ return self.named.data.xmat[1:, ['xx', 'xz']].ravel()
+
+
+class PlanarWalker(base.Task):
+ """A planar walker task."""
+
+ def __init__(self, move_speed, random=None):
+ """Initializes an instance of `PlanarWalker`.
+
+ Args:
+ move_speed: A float. If this value is zero, reward is given simply for
+ standing up. Otherwise this specifies a target horizontal velocity for
+ the walking task.
+ random: Optional, either a `numpy.random.RandomState` instance, an
+ integer seed for creating a new `RandomState`, or None to select a seed
+ automatically (default).
+ """
+ self._move_speed = move_speed
+ super(PlanarWalker, self).__init__(random=random)
+
+ def initialize_episode(self, physics):
+ """Sets the state of the environment at the start of each episode.
+
+ In 'standing' mode, use initial orientation and small velocities.
+ In 'random' mode, randomize joint angles and let fall to the floor.
+
+ Args:
+ physics: An instance of `Physics`.
+
+ """
+ randomizers.randomize_limited_and_rotational_joints(physics, self.random)
+
+ def get_observation(self, physics):
+ """Returns an observation of body orientations, height and velocites."""
+ obs = collections.OrderedDict()
+ obs['orientations'] = physics.orientations()
+ obs['height'] = physics.torso_height()
+ obs['velocity'] = physics.velocity()
+ return obs
+
+ def get_reward(self, physics):
+ """Returns a reward to the agent."""
+ standing = rewards.tolerance(physics.torso_height(),
+ bounds=(_STAND_HEIGHT, float('inf')),
+ margin=_STAND_HEIGHT/2)
+ upright = (1 + physics.torso_upright()) / 2
+ stand_reward = (3*standing + upright) / 4
+ if self._move_speed == 0:
+ return stand_reward
+ else:
+ move_reward = rewards.tolerance(physics.horizontal_velocity(),
+ bounds=(self._move_speed, float('inf')),
+ margin=self._move_speed/2,
+ value_at_margin=0.5,
+ sigmoid='linear')
+ return stand_reward * (5*move_reward + 1) / 6
diff --git a/dm_control/suite/walker.xml b/dm_control/suite/walker.xml
new file mode 100644
index 00000000..b9072c23
--- /dev/null
+++ b/dm_control/suite/walker.xml
@@ -0,0 +1,66 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dm_control/suite/wrappers/pixels.py b/dm_control/suite/wrappers/pixels.py
new file mode 100644
index 00000000..da66865e
--- /dev/null
+++ b/dm_control/suite/wrappers/pixels.py
@@ -0,0 +1,122 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Wrapper that adds pixel observations to a control environment."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+# Internal dependencies.
+
+from dm_control.rl import environment
+from dm_control.rl import specs
+
+STATE_KEY = 'state'
+
+
+class Wrapper(environment.Base):
+ """Wraps a control environment and adds a rendered pixel observation."""
+
+ def __init__(self, env, pixels_only=True, render_kwargs=None,
+ observation_key='pixels'):
+ """Initializes a new pixel Wrapper.
+
+ Args:
+ env: The environment to wrap.
+ pixels_only: If True (default), the original set of 'state' observations
+ returned by the wrapped environment will be discarded, and the
+ `OrderedDict` of observations will only contain pixels. If False, the
+ `OrderedDict` will contain the original observations as well as the
+ pixel observations.
+ render_kwargs: Optional `dict` containing keyword arguments passed to the
+ `mujoco.Physics.render` method.
+ observation_key: Optional custom string specifying the pixel observation's
+ key in the `OrderedDict` of observations. Defaults to 'pixels'.
+
+ Raises:
+ ValueError: If `env`'s observation spec is not compatible with the
+ wrapper. Supported formats are a single array, or a dict of arrays.
+ ValueError: If `env`'s observation already contains the specified
+ `observation_key`.
+ """
+ if render_kwargs is None:
+ render_kwargs = {}
+
+ wrapped_observation_spec = env.observation_spec()
+
+ if isinstance(wrapped_observation_spec, specs.ArraySpec):
+ self._observation_is_dict = False
+ invalid_keys = set([STATE_KEY])
+ elif isinstance(wrapped_observation_spec, collections.MutableMapping):
+ self._observation_is_dict = True
+ invalid_keys = set(wrapped_observation_spec.keys())
+ else:
+ raise ValueError('Unsupported observation spec structure.')
+
+ if not pixels_only and observation_key in invalid_keys:
+ raise ValueError('Duplicate or reserved observation key {!r}.'
+ .format(observation_key))
+
+ if pixels_only:
+ self._observation_spec = collections.OrderedDict()
+ elif self._observation_is_dict:
+ self._observation_spec = wrapped_observation_spec.copy()
+ else:
+ self._observation_spec = collections.OrderedDict()
+ self._observation_spec[STATE_KEY] = wrapped_observation_spec
+
+ # Extend observation spec.
+ pixels = env.physics.render(**render_kwargs)
+ pixels_spec = specs.ArraySpec(
+ shape=pixels.shape, dtype=pixels.dtype, name=observation_key)
+ self._observation_spec[observation_key] = pixels_spec
+
+ self._env = env
+ self._pixels_only = pixels_only
+ self._render_kwargs = render_kwargs
+ self._observation_key = observation_key
+
+ def reset(self):
+ time_step = self._env.reset()
+ return self._add_pixel_observation(time_step)
+
+ def step(self, action):
+ time_step = self._env.step(action)
+ return self._add_pixel_observation(time_step)
+
+ def observation_spec(self):
+ return self._observation_spec
+
+ def action_spec(self):
+ return self._env.action_spec()
+
+ def _add_pixel_observation(self, time_step):
+ if self._pixels_only:
+ observation = collections.OrderedDict()
+ elif self._observation_is_dict:
+ observation = type(time_step.observation)(time_step.observation)
+ else:
+ observation = collections.OrderedDict()
+ observation[STATE_KEY] = time_step.observation
+
+ pixels = self._env.physics.render(**self._render_kwargs)
+ observation[self._observation_key] = pixels
+ return time_step._replace(observation=observation)
+
+ def __getattr__(self, name):
+ return getattr(self._env, name)
diff --git a/dm_control/suite/wrappers/pixels_test.py b/dm_control/suite/wrappers/pixels_test.py
new file mode 100644
index 00000000..e9ea48d0
--- /dev/null
+++ b/dm_control/suite/wrappers/pixels_test.py
@@ -0,0 +1,135 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for the pixel wrapper."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+# Internal dependencies.
+
+from absl.testing import absltest
+from absl.testing import parameterized
+
+
+from dm_control.rl import environment
+from dm_control.rl import specs
+from dm_control.suite import cartpole
+from dm_control.suite.wrappers import pixels
+
+import numpy as np
+
+
+class FakePhysics(object):
+
+ def render(self, *args, **kwargs):
+ del args
+ del kwargs
+ return np.zeros((4, 5, 3), dtype=np.uint8)
+
+
+class FakeArrayObservationEnvironment(environment.Base):
+
+ def __init__(self):
+ self.physics = FakePhysics()
+
+ def reset(self):
+ return environment.restart(np.zeros((2,)))
+
+ def step(self, action):
+ del action
+ return environment.transition(0.0, np.zeros((2,)))
+
+ def action_spec(self):
+ pass
+
+ def observation_spec(self):
+ return specs.ArraySpec(shape=(2,), dtype=np.float)
+
+
+class PixelsTest(parameterized.TestCase):
+
+ @parameterized.parameters(True, False)
+ def test_dict_observation(self, pixels_only):
+ pixel_key = 'rgb'
+
+ env = cartpole.swingup()
+
+ # Make sure we are testing the right environment for the test.
+ observation_spec = env.observation_spec()
+ self.assertIsInstance(observation_spec, collections.OrderedDict)
+
+ width = 320
+ height = 240
+
+ # The wrapper should only add one observation.
+ wrapped = pixels.Wrapper(env,
+ observation_key=pixel_key,
+ pixels_only=pixels_only,
+ render_kwargs={'width': width, 'height': height})
+
+ wrapped_observation_spec = wrapped.observation_spec()
+ self.assertIsInstance(wrapped_observation_spec, collections.OrderedDict)
+
+ if pixels_only:
+ self.assertEqual(1, len(wrapped_observation_spec))
+ self.assertEqual([pixel_key], list(wrapped_observation_spec.keys()))
+ else:
+ self.assertEqual(len(observation_spec) + 1, len(wrapped_observation_spec))
+ expected_keys = list(observation_spec.keys()) + [pixel_key]
+ self.assertEqual(expected_keys, wrapped_observation_spec.keys())
+
+ # Check that the added spec item is consistent with the added observation.
+ time_step = wrapped.reset()
+ rgb_observation = time_step.observation[pixel_key]
+ wrapped_observation_spec[pixel_key].validate(rgb_observation)
+
+ self.assertEqual(rgb_observation.shape, (height, width, 3))
+ self.assertEqual(rgb_observation.dtype, np.uint8)
+
+ @parameterized.parameters(True, False)
+ def test_single_array_observation(self, pixels_only):
+ pixel_key = 'depth'
+
+ env = FakeArrayObservationEnvironment()
+ observation_spec = env.observation_spec()
+ self.assertIsInstance(observation_spec, specs.ArraySpec)
+
+ wrapped = pixels.Wrapper(env, observation_key=pixel_key,
+ pixels_only=pixels_only)
+ wrapped_observation_spec = wrapped.observation_spec()
+ self.assertIsInstance(wrapped_observation_spec, collections.OrderedDict)
+
+ if pixels_only:
+ self.assertEqual(1, len(wrapped_observation_spec))
+ self.assertEqual([pixel_key], list(wrapped_observation_spec.keys()))
+ else:
+ self.assertEqual(2, len(wrapped_observation_spec))
+ self.assertEqual([pixels.STATE_KEY, pixel_key],
+ list(wrapped_observation_spec.keys()))
+
+ time_step = wrapped.reset()
+
+ depth_observation = time_step.observation[pixel_key]
+ wrapped_observation_spec[pixel_key].validate(depth_observation)
+
+ self.assertEqual(depth_observation.shape, (4, 5, 3))
+ self.assertEqual(depth_observation.dtype, np.uint8)
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/utils/README.md b/dm_control/utils/README.md
new file mode 100644
index 00000000..d605e4c8
--- /dev/null
+++ b/dm_control/utils/README.md
@@ -0,0 +1,6 @@
+# Tolerance
+
+`tolerance()` is a soft indicator function evaluating whether a number is within
+bounds.
+
+See [package documentation](/third_party/py/dm_control/utils).
diff --git a/dm_control/utils/__init__.py b/dm_control/utils/__init__.py
new file mode 100644
index 00000000..1ebb270f
--- /dev/null
+++ b/dm_control/utils/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
diff --git a/dm_control/utils/containers.py b/dm_control/utils/containers.py
new file mode 100644
index 00000000..362362cf
--- /dev/null
+++ b/dm_control/utils/containers.py
@@ -0,0 +1,160 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Container classes used in control domains."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+
+class Tasks(collections.Mapping):
+ """Maps task names to their corresponding factory functions.
+
+ To store a function in a `Tasks` container, we can use its `.add` decorator:
+
+ ```python
+ tasks = Tasks()
+
+ @tasks.add
+ def example_task():
+ ...
+ return environment
+
+ environment_factory = tasks['example_task']
+ ```
+
+ To add tasks that are procedurally generated, we can pass the optional `name`
+ argument to the `.add` method:
+
+ ```python
+ for difficulty in ('easy', 'normal', 'hard'):
+ func = my_task_generator(difficulty)
+ tasks.add(func, name='my_task_{}'.format(difficulty))
+ ```
+
+ """
+
+ def __init__(self):
+ self._tasks = collections.OrderedDict()
+
+ def add(self, factory_func, name=None):
+ """Decorator that adds a factory function to the container.
+
+ Args:
+ factory_func: A function that returns a `ControlEnvironment` instance.
+ name: Optional task name. If unspecified, `factory_func.name` is used.
+
+ Returns:
+ The same function.
+
+ Raises:
+ ValueError: if a function with the same name already exists within the
+ container.
+ """
+ if name is None:
+ name = factory_func.__name__
+ if name in self:
+ raise ValueError("Function named {!r} already exists in the container."
+ "".format(name))
+ self._tasks[name] = factory_func
+ return factory_func
+
+ def __getitem__(self, k):
+ return self._tasks[k]
+
+ def __iter__(self):
+ return iter(self._tasks)
+
+ def __len__(self):
+ return len(self._tasks)
+
+ def __repr__(self):
+ return "{}({})".format(self.__class__.__name__, str(self._tasks))
+
+
+class TaggedTasks(collections.Mapping):
+ """Maps task names to their corresponding factory functions with tags.
+
+ To store a function in a `TaggedTasks` container, we can use its `.add`
+ decorator:
+
+ ```python
+ tasks = TaggedTasks()
+
+ @tasks.add('easy', 'stable')
+ def example_task():
+ ...
+ return environment
+
+ environment_factory = tasks['example_task']
+
+ # Or to restrict to a given tag:
+ environment_factory = tasks.tagged('easy')['example_task']
+ ```
+ """
+
+ def __init__(self):
+ self._tasks = collections.OrderedDict()
+ self._tags = collections.defaultdict(dict)
+
+ def add(self, *tags):
+ """Decorator that adds a factory function to the container with tags.
+
+ Args:
+ *tags: Strings specifying the tags for this function.
+
+ Returns:
+ The same function.
+
+ Raises:
+ ValueError: if a function with the same name already exists within the
+ container.
+ """
+ def wrap(factory_func):
+ name = factory_func.__name__
+ if name in self:
+ raise ValueError("Function named {!r} already exists in the container."
+ "".format(name))
+ self._tasks[name] = factory_func
+ for tag in tags:
+ self._tags[tag][name] = factory_func
+ return factory_func
+ return wrap
+
+ def tagged(self, tag):
+ """Returns a (possibly empty) dict of all items that match the given tag."""
+ if tag not in self._tags:
+ return {}
+ else:
+ return self._tags[tag]
+
+ def tags(self):
+ """Returns a list of all the tags in this container."""
+ return list(self._tags.keys())
+
+ def __getitem__(self, k):
+ return self._tasks[k]
+
+ def __iter__(self):
+ return iter(self._tasks)
+
+ def __len__(self):
+ return len(self._tasks)
+
+ def __repr__(self):
+ return "{}({})".format(self.__class__.__name__, str(self._tasks))
diff --git a/dm_control/utils/containers_test.py b/dm_control/utils/containers_test.py
new file mode 100644
index 00000000..151ada40
--- /dev/null
+++ b/dm_control/utils/containers_test.py
@@ -0,0 +1,130 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for control.utils.containers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+
+from absl.testing import absltest
+
+from dm_control.utils import containers
+
+
+class TaskTest(absltest.TestCase):
+
+ def test_factory_registered(self):
+ tasks = containers.Tasks()
+
+ @tasks.add
+ def test_factory1(): # pylint: disable=unused-variable
+ return 'executed 1'
+
+ @tasks.add
+ def test_factory2(): # pylint: disable=unused-variable
+ return 'executed 2'
+
+ with self.assertRaises(ValueError):
+ @tasks.add
+ def test_factory1(): # pylint: disable=function-redefined
+ return
+
+ self.assertEqual(2, len(tasks))
+ self.assertEqual(set(['test_factory1', 'test_factory2']),
+ set(tasks.keys()))
+ self.assertEqual('executed 1', tasks['test_factory1']())
+ self.assertEqual('executed 2', tasks['test_factory2']())
+
+ def test_procedural_names(self):
+ tasks = containers.Tasks()
+ names = set(('easy', 'normal', 'hard'))
+ for name in names:
+ tasks.add(lambda: None, name=name)
+ self.assertEqual(len(names), len(tasks))
+ self.assertSetEqual(names, set(tasks.keys()))
+
+ def test_iteration_order(self):
+ tasks = containers.Tasks()
+ expected_order = ['first', 'second', 'third', 'fourth']
+ for name in expected_order:
+ tasks.add(lambda: None, name=name)
+ actual_order = list(tasks)
+ self.assertEqual(expected_order, actual_order)
+
+
+class TaggedTaskTest(absltest.TestCase):
+
+ def test_registration(self):
+ tasks = containers.TaggedTasks()
+
+ @tasks.add()
+ def test_factory1(): # pylint: disable=unused-variable
+ return 'executed 1'
+
+ @tasks.add('basic', 'stable')
+ def test_factory2(): # pylint: disable=unused-variable
+ return 'executed 2'
+
+ @tasks.add('expert', 'stable')
+ def test_factory3(): # pylint: disable=unused-variable
+ return 'executed 3'
+
+ @tasks.add('expert', 'unstable')
+ def test_factory4(): # pylint: disable=unused-variable
+ return 'executed 4'
+
+ self.assertEqual(4, len(tasks))
+ self.assertEqual(set(['basic', 'expert', 'stable', 'unstable']),
+ set(tasks.tags()))
+
+ self.assertEqual(1, len(tasks.tagged('basic')))
+ self.assertEqual(2, len(tasks.tagged('expert')))
+ self.assertEqual(2, len(tasks.tagged('stable')))
+ self.assertEqual(1, len(tasks.tagged('unstable')))
+
+ self.assertEqual('executed 2', tasks['test_factory2']())
+
+ self.assertEqual('executed 3', tasks.tagged('expert')['test_factory3']())
+
+ self.assertNotIn('test_factory4', tasks.tagged('stable'))
+
+ def test_iteration_order(self):
+ tasks = containers.TaggedTasks()
+
+ @tasks.add()
+ def first(): # pylint: disable=unused-variable
+ pass
+
+ @tasks.add()
+ def second(): # pylint: disable=unused-variable
+ pass
+
+ @tasks.add()
+ def third(): # pylint: disable=unused-variable
+ pass
+
+ @tasks.add()
+ def fourth(): # pylint: disable=unused-variable
+ pass
+
+ expected_order = ['first', 'second', 'third', 'fourth']
+ actual_order = list(tasks)
+ self.assertEqual(expected_order, actual_order)
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/utils/corruptors.py b/dm_control/utils/corruptors.py
new file mode 100644
index 00000000..c365e08e
--- /dev/null
+++ b/dm_control/utils/corruptors.py
@@ -0,0 +1,115 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Corruptors."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import collections
+import copy
+import functools
+
+# Internal dependencies.
+
+import numpy as np
+import six
+
+
+@six.add_metaclass(abc.ABCMeta)
+class CorruptorBase(object):
+
+ @abc.abstractmethod
+ def __call__(self, x):
+ """Returns a corrupted version of the input x."""
+
+ @abc.abstractmethod
+ def reset(self):
+ """Resets the internal state of the corruptor."""
+
+
+class Delay(CorruptorBase):
+ """Applies a delay to the input."""
+
+ def __init__(self, steps, padding=None):
+ """Initialize an instance of `Delay`.
+
+ Args:
+ steps: An int, number of steps for the delay.
+ padding: An optional numpy array or a function. The output in the
+ first `steps`.
+
+ Raises:
+ ValueError: When `steps` <= 0.
+ """
+ if steps <= 0:
+ raise ValueError('Delay steps should be greater than 0, %d found', steps)
+ self._buffer = collections.deque(maxlen=steps + 1)
+ self._padding = padding or (lambda x: np.zeros(x.shape))
+
+ def __call__(self, x):
+ """Returns the input to this function from `steps` calls ago."""
+ self._buffer.append(copy.deepcopy(x))
+ if len(self._buffer) == self._buffer.maxlen:
+ return self._buffer.popleft()
+ else:
+ return self._padding(x) if callable(self._padding) else self._padding
+
+ def reset(self):
+ """Resets the buffer."""
+ self._buffer.clear()
+
+
+class StatelessNoise(CorruptorBase):
+ """Applies noise to an input without relying on any internal state."""
+
+ def __init__(self, noise_function, **noise_parameters):
+ """Initialize an instance of `StatelessNoise`.
+
+ Args:
+ noise_function: A function, adding noise to its input.
+ **noise_parameters: Additional keyword arguments taken by the
+ `noise_function`.
+ """
+ self._noise_function = functools.partial(noise_function, **noise_parameters)
+
+ def __call__(self, x):
+ """Returns the input to this function with noise added."""
+ return self._noise_function(x)
+
+ def reset(self):
+ pass
+
+
+def gaussian_noise(x, std):
+ """Adds gaussian noise to each dimension of x.
+
+ Example of gaussian noise corruptor:
+ ```python
+ corruptor = StatelessNoise(noise_function=gaussian_noise,
+ noise_parameter={'std': .1})
+ ```
+
+ Args:
+ x: A numpy array, the input.
+ std: A number, standard deviation of the gaussian noise.
+
+ Returns:
+ A numpy array with the same dimension as x, which adds a noise draw from a
+ normal distribution to each dimension of x.
+ """
+ return x + np.random.standard_normal(x.shape) * std
diff --git a/dm_control/utils/corruptors_test.py b/dm_control/utils/corruptors_test.py
new file mode 100644
index 00000000..a651f933
--- /dev/null
+++ b/dm_control/utils/corruptors_test.py
@@ -0,0 +1,91 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for corruptors."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+
+from absl.testing import absltest
+from absl.testing import parameterized
+
+from dm_control.utils import corruptors
+
+import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+
+class DelayTest(absltest.TestCase):
+
+ def setUp(self):
+ self.n = 10
+ self.delay = corruptors.Delay(steps=self.n)
+ super(DelayTest, self).setUp()
+
+ def testProcess(self):
+ obs = np.array(range(2 * self.n))
+ actual_obs = []
+ for i in obs:
+ actual_obs.append(self.delay(i))
+ expected = np.hstack(([.0] * self.n, obs[:self.n]))
+ np.testing.assert_array_equal(expected, actual_obs)
+
+ actual_obs = []
+ for i in obs:
+ actual_obs.append(self.delay(i))
+ expected = np.hstack((obs[self.n:], obs[:self.n]))
+ np.testing.assert_array_equal(expected, actual_obs)
+
+ def testReset(self):
+ obs = np.array(range(2 * self.n))
+ for _ in xrange(2):
+ actual_obs = []
+ for i in obs:
+ actual_obs.append(self.delay(i))
+ self.delay.reset()
+
+ expected = np.hstack(([.0] * self.n, obs[:self.n]))
+ np.testing.assert_array_equal(expected, actual_obs)
+
+
+class StatelessNoiseTest(absltest.TestCase):
+
+ def testProcess(self):
+ c = corruptors.StatelessNoise(noise_function=corruptors.gaussian_noise,
+ std=1e-3)
+ x = np.array([.0] * 3)
+ y = np.array([.0] * 3)
+ n = 1e3
+ for _ in xrange(int(n)):
+ y += c(x)
+ y /= n
+ np.testing.assert_allclose(x, y, atol=1e-4)
+
+
+class NoiseFunctionTest(parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ ('1D', np.array([3., 4.]), .1),
+ ('2D', np.array([[.0, .1], [1., 2.]]), .4)
+ )
+ def testGaussianNoise_Shape(self, x, std):
+ noisy_x = corruptors.gaussian_noise(x, std)
+ self.assertEqual(x.shape, noisy_x.shape)
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/dm_control/utils/resources.py b/dm_control/utils/resources.py
new file mode 100644
index 00000000..b5976fd1
--- /dev/null
+++ b/dm_control/utils/resources.py
@@ -0,0 +1,30 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS-IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""IO functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+def GetResource(name, mode='rb'):
+ with open(name, mode=mode) as f:
+ return f.read()
+
+
+def GetResourceFilename(name, mode='rb'):
+ del mode # Unused.
+ return name
+
+GetResourceAsFile = open # pylint: disable=invalid-name
diff --git a/dm_control/utils/rewards.py b/dm_control/utils/rewards.py
new file mode 100644
index 00000000..2e2fba55
--- /dev/null
+++ b/dm_control/utils/rewards.py
@@ -0,0 +1,132 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Soft indicator function evaluating whether a number is within bounds."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+import numpy as np
+
+# The value returned by tolerance() at `margin` distance from `bounds` interval.
+_DEFAULT_VALUE_AT_MARGIN = 0.1
+
+
+def _sigmoids(x, value_at_1, sigmoid):
+ """Returns 1 when `x` == 0, between 0 and 1 otherwise.
+
+ Args:
+ x: A scalar or numpy array.
+ value_at_1: A float between 0 and 1 specifying the output when `x` == 1.
+ sigmoid: String, choice of sigmoid type.
+
+ Returns:
+ A numpy array with values between 0.0 and 1.0.
+
+ Raises:
+ ValueError: If not 0 < `value_at_1` < 1, except for `linear`, `cosine` and
+ `quadratic` sigmoids which allow `value_at_1` == 0.
+ ValueError: If `sigmoid` is of an unknown type.
+ """
+ if sigmoid in ('cosine', 'linear', 'quadratic'):
+ if not 0 <= value_at_1 < 1:
+ raise ValueError('`value_at_1` must be nonnegative and smaller than 1, '
+ 'got {}.'.format(value_at_1))
+ else:
+ if not 0 < value_at_1 < 1:
+ raise ValueError('`value_at_1` must be strictly between 0 and 1, '
+ 'got {}.'.format(value_at_1))
+
+ if sigmoid == 'gaussian':
+ scale = np.sqrt(-2 * np.log(value_at_1))
+ return np.exp(-0.5 * (x*scale)**2)
+
+ elif sigmoid == 'hyperbolic':
+ scale = np.arccosh(1/value_at_1)
+ return 1 / np.cosh(x*scale)
+
+ elif sigmoid == 'long_tail':
+ scale = np.sqrt(1/value_at_1 - 1)
+ return 1 / ((x*scale)**2 + 1)
+
+ elif sigmoid == 'cosine':
+ scale = np.arccos(2*value_at_1 - 1) / np.pi
+ scaled_x = x*scale
+ return np.where(abs(scaled_x) < 1, (1 + np.cos(np.pi*scaled_x))/2, 0.0)
+
+ elif sigmoid == 'linear':
+ scale = 1-value_at_1
+ scaled_x = x*scale
+ return np.where(abs(scaled_x) < 1, 1 - scaled_x, 0.0)
+
+ elif sigmoid == 'quadratic':
+ scale = np.sqrt(1-value_at_1)
+ scaled_x = x*scale
+ return np.where(abs(scaled_x) < 1, 1 - scaled_x**2, 0.0)
+
+ elif sigmoid == 'tanh_squared':
+ scale = np.arctanh(np.sqrt(1-value_at_1))
+ return 1 - np.tanh(x*scale)**2
+
+ else:
+ raise ValueError('Unknown sigmoid type {!r}.'.format(sigmoid))
+
+
+def tolerance(x, bounds=(0.0, 0.0), margin=0.0, sigmoid='gaussian',
+ value_at_margin=_DEFAULT_VALUE_AT_MARGIN):
+ """Returns 1 when `x` falls inside the bounds, between 0 and 1 otherwise.
+
+ Args:
+ x: A scalar or numpy array.
+ bounds: A tuple of floats specifying inclusive `(lower, upper)` bounds for
+ the target interval. These can be infinite if the interval is unbounded
+ at one or both ends, or they can be equal to one another if the target
+ value is exact.
+ margin: Float. Parameter that controls how steeply the output decreases as
+ `x` moves out-of-bounds.
+ * If `margin == 0` then the output will be 0 for all values of `x`
+ outside of `bounds`.
+ * If `margin > 0` then the output will decrease sigmoidally with
+ increasing distance from the nearest bound.
+ sigmoid: String, choice of sigmoid type. Valid values are: 'gaussian',
+ 'linear', 'hyperbolic', 'long_tail', 'cosine', 'tanh_squared'.
+ value_at_margin: A float between 0 and 1 specifying the output value when
+ the distance from `x` to the nearest bound is equal to `margin`. Ignored
+ if `margin == 0`.
+
+ Returns:
+ A float or numpy array with values between 0.0 and 1.0.
+
+ Raises:
+ ValueError: If `bounds[0] > bounds[1]`.
+ ValueError: If `margin` is negative.
+ """
+ lower, upper = bounds
+ if lower > upper:
+ raise ValueError('Lower bound must be <= upper bound.')
+ if margin < 0:
+ raise ValueError('`margin` must be non-negative.')
+
+ in_bounds = np.logical_and(lower <= x, x <= upper)
+ if margin == 0:
+ value = np.where(in_bounds, 1.0, 0.0)
+ else:
+ d = np.where(x < lower, lower - x, x - upper) / margin
+ value = np.where(in_bounds, 1.0, _sigmoids(d, value_at_margin, sigmoid))
+
+ return float(value) if np.isscalar(x) else value
+
diff --git a/dm_control/utils/rewards_test.py b/dm_control/utils/rewards_test.py
new file mode 100644
index 00000000..a3428ca2
--- /dev/null
+++ b/dm_control/utils/rewards_test.py
@@ -0,0 +1,127 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for dm_control.utils.rewards."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+
+from absl.testing import absltest
+from absl.testing import parameterized
+
+from dm_control.utils import rewards
+
+import numpy as np
+
+
+_INPUT_VECTOR_SIZE = 10
+EPS = np.finfo(np.double).eps
+INF = float("inf")
+
+
+class ToleranceTest(parameterized.TestCase):
+
+ @parameterized.parameters((0.5, 0.95), (1e12, 1-EPS), (1e12, EPS),
+ (EPS, 1-EPS), (EPS, EPS))
+ def test_tolerance_sigmoid_parameterisation(self, margin, value_at_margin):
+ actual = rewards.tolerance(x=margin, bounds=(0, 0), margin=margin,
+ value_at_margin=value_at_margin)
+ self.assertAlmostEqual(actual, value_at_margin)
+
+ @parameterized.parameters(("gaussian",), ("hyperbolic",), ("long_tail",),
+ ("cosine",), ("tanh_squared",), ("linear",),
+ ("quadratic"))
+ def test_tolerance_sigmoids(self, sigmoid):
+ margins = [0.01, 1.0, 100, 10000]
+ values_at_margin = [0.1, 0.5, 0.9]
+ bounds_list = [(0, 0), (-1, 1), (-np.pi, np.pi), (-100, 100)]
+ for bounds in bounds_list:
+ for margin in margins:
+ for value_at_margin in values_at_margin:
+ upper_margin = bounds[1]+margin
+ value = rewards.tolerance(x=upper_margin, bounds=bounds,
+ margin=margin,
+ value_at_margin=value_at_margin,
+ sigmoid=sigmoid)
+ self.assertAlmostEqual(value, value_at_margin, delta=np.sqrt(EPS))
+ lower_margin = bounds[0]-margin
+ value = rewards.tolerance(x=lower_margin, bounds=bounds,
+ margin=margin,
+ value_at_margin=value_at_margin,
+ sigmoid=sigmoid)
+ self.assertAlmostEqual(value, value_at_margin, delta=np.sqrt(EPS))
+
+ @parameterized.parameters((-1, 0), (-0.5, 0.1), (0, 1), (0.5, 0.1), (1, 0))
+ def test_tolerance_margin_loss_shape(self, x, expected):
+ actual = rewards.tolerance(x=x, bounds=(0, 0), margin=0.5,
+ value_at_margin=0.1)
+ self.assertAlmostEqual(actual, expected, delta=1e-3)
+
+ def test_tolerance_vectorization(self):
+ bounds = (-.1, .1)
+ margin = 0.2
+ x_array = np.random.randn(2, 3, 4)
+ value_array = rewards.tolerance(x=x_array, bounds=bounds, margin=margin)
+ self.assertEqual(x_array.shape, value_array.shape)
+ for i, x in enumerate(x_array.ravel()):
+ value = rewards.tolerance(x=x, bounds=bounds, margin=margin)
+ self.assertEqual(value, value_array.ravel()[i])
+
+ # pylint: disable=bad-whitespace
+ @parameterized.parameters(
+ # Exact target.
+ (0, (0, 0), 1),
+ (EPS, (0, 0), 0),
+ (-EPS, (0, 0), 0),
+ # Interval with one open end.
+ (0, (0, INF), 1),
+ (EPS, (0, INF), 1),
+ (-EPS, (0, INF), 0),
+ # Closed interval.
+ (0, (0, 1), 1),
+ (EPS, (0, 1), 1),
+ (-EPS, (0, 1), 0),
+ (1, (0, 1), 1),
+ (1+EPS, (0, 1), 0))
+ def test_tolerance_bounds(self, x, bounds, expected):
+ actual = rewards.tolerance(x, bounds=bounds, margin=0)
+ self.assertEqual(actual, expected) # Should be exact, since margin == 0.
+
+ def test_tolerance_incorrect_bounds_order(self):
+ with self.assertRaisesWithLiteralMatch(
+ ValueError, "Lower bound must be <= upper bound."):
+ rewards.tolerance(0, bounds=(1, 0), margin=0.05)
+
+ def test_tolerance_negative_margin(self):
+ with self.assertRaisesWithLiteralMatch(
+ ValueError, "`margin` must be non-negative."):
+ rewards.tolerance(0, bounds=(0, 1), margin=-0.05)
+
+ def test_tolerance_bad_value_at_margin(self):
+ with self.assertRaisesWithLiteralMatch(
+ ValueError, "`value_at_1` must be strictly between 0 and 1, got 0."):
+ rewards.tolerance(0, bounds=(0, 1), margin=1, value_at_margin=0)
+
+ def test_tolerance_unknown_sigmoid(self):
+ with self.assertRaisesWithLiteralMatch(
+ ValueError, "Unknown sigmoid type 'unsupported_sigmoid'."):
+ rewards.tolerance(0, bounds=(0, 1), margin=.1,
+ sigmoid="unsupported_sigmoid")
+
+if __name__ == "__main__":
+ absltest.main()
diff --git a/dm_control/utils/xml_tools.py b/dm_control/utils/xml_tools.py
new file mode 100644
index 00000000..d0e04355
--- /dev/null
+++ b/dm_control/utils/xml_tools.py
@@ -0,0 +1,93 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Helper functions for model xml creation and modification."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+
+# Internal dependencies.
+
+from lxml import etree
+
+
+def find_element(root, tag, name):
+ """Finds and returns the first element of specified tag and name.
+
+ Args:
+ root: `etree.Element` to be searched recursively.
+ tag: The `tag` property of the sought element.
+ name: The `name` attribute of the sought element.
+
+ Returns:
+ An `etree.Element` with the specified properties.
+
+ Raises:
+ ValueError: If no matching element is found.
+ """
+ result = root.find('.//{}[@name={!r}]'.format(tag, name))
+ if result is None:
+ raise ValueError(
+ 'Element with tag {!r} and name {!r} not found'.format(tag, name))
+ return result
+
+
+def nested_element(element, depth):
+ """Makes a nested `tree.Element` given a single element.
+
+ If `depth=2`, the new tree will look like
+
+ ```xml
+
+
+
+
+
+
+ ```
+
+ Args:
+ element: The `etree.Element` used to create a nested structure.
+ depth: An `int` denoting the nesting depth. The resulting will contain
+ `element` nested `depth` times.
+
+
+ Returns:
+ A nested `etree.Element`.
+ """
+ if depth > 0:
+ child = nested_element(copy.deepcopy(element), depth=(depth - 1))
+ element.append(child)
+ return element
+
+
+def parse(file_obj):
+ """Reads xml from a file and returns an `etree.Element`.
+
+ Compared to the `etree.fromstring()`, this function removes the whitespace in
+ the xml file. This means later on, a user can pretty print the `etree.Element`
+ with `etree.tostring(element, pretty_print=True)`.
+
+ Args:
+ file_obj: A file or file-like object.
+
+ Returns:
+ `etree.Element` of the xml file.
+ """
+ parser = etree.XMLParser(remove_blank_text=True)
+ return etree.parse(file_obj, parser)
diff --git a/dm_control/utils/xml_tools_test.py b/dm_control/utils/xml_tools_test.py
new file mode 100644
index 00000000..28a03712
--- /dev/null
+++ b/dm_control/utils/xml_tools_test.py
@@ -0,0 +1,81 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for utils.xml_tools."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+
+from absl.testing import absltest
+
+from dm_control.utils import xml_tools
+
+from lxml import etree
+
+import six
+
+
+class XmlHelperTest(absltest.TestCase):
+
+ def test_nested(self):
+ element = etree.Element('inserted')
+ xml_tools.nested_element(element, depth=2)
+ level_1 = element.find('inserted')
+ self.assertIsNotNone(level_1)
+ level_2 = level_1.find('inserted')
+ self.assertIsNotNone(level_2)
+
+ def test_tostring(self):
+ xml_str = """
+
+
+
+
+ """
+ tree = xml_tools.parse(six.StringIO(xml_str))
+ self.assertEqual(b'\n \n \n \n\n',
+ etree.tostring(tree, pretty_print=True))
+
+ def test_find_element(self):
+ xml_str = """
+
+
+
+
+
+ """
+ tree = xml_tools.parse(six.StringIO(xml_str))
+ world = xml_tools.find_element(root=tree, tag='world', name='world_name')
+ self.assertEqual(world.tag, 'world')
+ self.assertEqual(world.attrib['name'], 'world_name')
+
+ geom = xml_tools.find_element(root=tree, tag='geom', name='geom_name')
+ self.assertEqual(geom.tag, 'geom')
+ self.assertEqual(geom.attrib['name'], 'geom_name')
+
+ with self.assertRaisesRegexp(ValueError, 'Element with tag'):
+ xml_tools.find_element(root=tree, tag='does_not_exist', name='name')
+
+ with self.assertRaisesRegexp(ValueError, 'Element with tag'):
+ xml_tools.find_element(root=tree, tag='world', name='does_not_exist')
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 00000000..bf8b7c7f
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,10 @@
+absl-py==0.1.5
+enum34==1.1.6
+future==0.16.0
+glfw==1.4.0
+lxml==4.1.1
+mock==2.0.0
+nose==1.3.7
+numpy==1.13.3
+pyparsing==2.2.0
+six==1.11.0
diff --git a/setup.py b/setup.py
new file mode 100644
index 00000000..bafe2114
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,156 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Install script for setuptools."""
+
+import os
+import subprocess
+import sys
+
+from distutils import cmd
+from distutils import log
+from setuptools import find_packages
+from setuptools import setup
+from setuptools.command import install
+from setuptools.command import test
+
+DEFAULT_HEADERS_DIR = '~/.mujoco/mjpro150/include'
+
+# Relative paths to the binding generator script and the output directory.
+AUTOWRAP_PATH = 'dm_control/autowrap/autowrap.py'
+MJBINDINGS_DIR = 'dm_control/mujoco/wrapper/mjbindings'
+
+# We specify the header filenames explicitly rather than listing the contents
+# of the `HEADERS_DIR` at runtime, since it will probably contain other stuff
+# (e.g. `glfw.h`).
+HEADER_FILENAMES = [
+ 'mjdata.h',
+ 'mjmodel.h',
+ 'mjrender.h',
+ 'mjvisualize.h',
+ 'mjxmacro.h',
+ 'mujoco.h',
+]
+
+
+class BuildMJBindingsCommand(cmd.Command):
+ """Runs `autowrap.py` to generate the low-level ctypes bindings for MuJoCo."""
+ description = __doc__
+ user_options = [
+ # The format is (long option, short option, description).
+ ('headers-dir=', None,
+ 'Path to directory containing MuJoCo headers.'),
+ ('inplace=', None,
+ 'Place generated files in source directory rather than `build-lib`.'),
+ ]
+ boolean_options = ['inplace']
+
+ def initialize_options(self):
+ """Set default values for options."""
+ # A default value must be assigned to each user option here.
+ self.inplace = 0
+ self.headers_dir = os.path.expanduser(DEFAULT_HEADERS_DIR)
+
+ def finalize_options(self):
+ """Post-process options."""
+ header_paths = []
+ for filename in HEADER_FILENAMES:
+ full_path = os.path.join(self.headers_dir, filename)
+ if not os.path.exists(full_path):
+ raise IOError('Header file {!r} does not exist.'.format(full_path))
+ header_paths.append(full_path)
+ self._header_paths = ' '.join(header_paths)
+
+ def run(self):
+ cwd = os.path.realpath(os.curdir)
+ if self.inplace:
+ dist_root = cwd
+ else:
+ build_cmd = self.get_finalized_command('build')
+ dist_root = os.path.realpath(build_cmd.build_lib)
+ output_dir = os.path.join(dist_root, MJBINDINGS_DIR)
+ command = [
+ sys.executable or 'python',
+ AUTOWRAP_PATH,
+ '--header_paths={}'.format(self._header_paths),
+ '--output_dir={}'.format(output_dir)
+ ]
+ self.announce('Running command: {}'.format(command), level=log.DEBUG)
+ try:
+ # Prepend the current directory to $PYTHONPATH so that internal imports
+ # in `autowrap` can succeed before we've installed anything.
+ old_environ = os.environ.copy()
+ new_pythonpath = [cwd]
+ if 'PYTHONPATH' in old_environ:
+ new_pythonpath.append(old_environ['PYTHONPATH'])
+ os.environ['PYTHONPATH'] = ':'.join(new_pythonpath)
+ subprocess.check_call(command)
+ finally:
+ os.environ = old_environ
+
+
+class InstallCommand(install.install):
+ """Runs 'build_mjbindings' before installation."""
+
+ def run(self):
+ self.run_command('build_mjbindings')
+ install.install.run(self)
+
+
+class TestCommand(test.test):
+ """Prepends path to generated sources before running unit tests."""
+
+ def run(self):
+ # Generate ctypes bindings in-place so that they can be imported in tests.
+ self.reinitialize_command('build_mjbindings', inplace=1)
+ self.run_command('build_mjbindings')
+ test.test.run(self)
+
+setup(
+ name='dm_control',
+ description='Continuous control environments and MuJoCo Python bindings.',
+ author='DeepMind',
+ license='Apache License, Version 2.0',
+ keywords='machine learning control physics MuJoCo AI',
+ install_requires=[
+ 'absl-py',
+ 'enum34',
+ 'future',
+ 'glfw',
+ 'lxml',
+ 'numpy',
+ 'pyparsing',
+ 'setuptools',
+ 'six',
+ ],
+ tests_require=[
+ 'mock',
+ 'nose',
+ ],
+ test_suite='nose.collector',
+ packages=find_packages(),
+ package_data={
+ 'dm_control.mujoco.testing':
+ ['assets/*.png', 'assets/*.stl', 'assets/*.xml'],
+ 'dm_control.suite':
+ ['*.xml', 'common/*.xml'],
+ },
+ cmdclass={
+ 'build_mjbindings': BuildMJBindingsCommand,
+ 'install': InstallCommand,
+ 'test': TestCommand,
+ },
+ entry_points={},
+)
diff --git a/tech_report.pdf b/tech_report.pdf
new file mode 100644
index 00000000..904aa0eb
Binary files /dev/null and b/tech_report.pdf differ