Skip to content

Commit

Permalink
mjcf parser: Add support for <include> tag
Browse files Browse the repository at this point in the history
  • Loading branch information
JafarAbdi committed Jan 24, 2025
1 parent 1c5c797 commit 3c6ee1c
Show file tree
Hide file tree
Showing 2 changed files with 339 additions and 0 deletions.
92 changes: 92 additions & 0 deletions src/parsers/mjcf/mjcf-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,91 @@
#include "pinocchio/multibody/model.hpp"
#include "pinocchio/algorithm/contact-info.hpp"

namespace
{
void updateMeshPaths(boost::property_tree::ptree & tree, const boost::filesystem::path & basePath)
{
for (auto & [key, child] : tree)
{
// Check if the tag is <mesh> and has a 'file' attribute
if (key == "mesh")
{

// Check if the 'file' attribute exists
if (auto fileAttr = child.get_optional<std::string>("<xmlattr>.file"))
{
// Update the 'file' attribute
child.put("<xmlattr>.file", (basePath / *fileAttr).string());
}
}
else if (!child.empty())
{
// Recursively process child nodes
updateMeshPaths(child, basePath);
}
}
}

// Merge the content of an included XML tree into the main XML tree
void mergeXmlTrees(
boost::property_tree::ptree & mainTree,
const boost::property_tree::ptree & includedTree,
const std::string & tagName)
{
if (auto includedChild = includedTree.get_child_optional(tagName))
{
for (const auto & [key, child] : *includedChild)
{
mainTree.add_child(key, child);
}
}
}

// Recursively process <include> tags in the XML tree
void processIncludes(boost::property_tree::ptree & tree, const boost::filesystem::path & basePath)
{
auto it = tree.begin();
while (it != tree.end())
{
if (it->first == "include")
{
// Check if the 'file' attribute exists
if (auto fileAttr = it->second.get_optional<std::string>("<xmlattr>.file"))
{
std::string filePath = *fileAttr;
boost::filesystem::path fullPath = basePath / filePath;

// Read the included XML file
boost::property_tree::ptree includedTree;
boost::property_tree::read_xml(fullPath.string(), includedTree);

// Update mesh paths in the included tree
updateMeshPaths(includedTree, fullPath.parent_path());

// Merge the included content into the main tree
mergeXmlTrees(tree, includedTree, "mujoco");

// Remove the <include> tag after merging
it = tree.erase(it);
}
else
{
PINOCCHIO_THROW_PRETTY(std::runtime_error, "Missing 'file' attribute in <include> tag");
}
}
else
{
// Recursively process child nodes
if (!it->second.empty())
{
processIncludes(it->second, basePath);
}
++it;
}
}
}
} // namespace

namespace pinocchio
{
namespace mjcf
Expand Down Expand Up @@ -899,6 +984,13 @@ namespace pinocchio
void MjcfGraph::parseGraphFromXML(const std::string & xmlStr)
{
boost::property_tree::read_xml(xmlStr, pt);
// Recursively process includes in the entire XML tree
if (pt.get_child_optional("mujoco"))
{
auto basePath = boost::filesystem::path(xmlStr).parent_path();
auto & mujocoTree = pt.get_child("mujoco");
processIncludes(mujocoTree, basePath);
}
parseGraph();
}

Expand Down
247 changes: 247 additions & 0 deletions unittest/mjcf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <iostream>
#include <cstdio> // for std::tmpnam
#include <algorithm>

#include "pinocchio/multibody/model.hpp"

Expand Down Expand Up @@ -1431,4 +1432,250 @@ BOOST_AUTO_TEST_CASE(test_get_unknown_size_vector_from_stream)
BOOST_CHECK(v3 == expected3);
}

BOOST_AUTO_TEST_CASE(process_include_basic)
{
namespace pt = boost::property_tree;
namespace fs = boost::filesystem;

std::istringstream includedXml(R"(
<mujoco>
<mesh name="mesh1" file="mesh1.obj"/>
<mesh name="mesh2" file="mesh2.obj"/>
</mujoco>
)");
auto includedFile = createTempFile(includedXml);

std::istringstream includedWorldBodyXml(R"(
<mujoco>
<worldbody>
<body name="base">
</body>
</worldbody>
</mujoco>
)");
auto includedWorldBodyFile = createTempFile(includedWorldBodyXml);

std::istringstream mainXml(
R"(
<mujoco>
<asset>
<include file=")"
+ includedFile.path.filename().string() + R"("/>
</asset>
<include file=")"
+ includedWorldBodyFile.path.filename().string() + R"("/>
</mujoco>
)");

auto mainFile = createTempFile(mainXml);

fs::path basePath = includedFile.path.parent_path();

pt::ptree mainTree;
pt::read_xml(mainFile.path.string(), mainTree);

typedef ::pinocchio::mjcf::details::MjcfGraph MjcfGraph;
pinocchio::Model model_m;
MjcfGraph::UrdfVisitor visitor(model_m);

MjcfGraph graph(visitor, "fakeMjcf");
graph.parseGraphFromXML(mainFile.name());
graph.parseRootTree();
BOOST_CHECK(graph.mapOfBodies.find("base") != graph.mapOfBodies.end());
BOOST_CHECK(graph.mapOfMeshes.at("mesh1").filePath == (basePath / "mesh1.obj").string());
BOOST_CHECK(graph.mapOfMeshes.at("mesh2").filePath == (basePath / "mesh2.obj").string());
}

BOOST_AUTO_TEST_CASE(process_include_nested)
{
namespace pt = boost::property_tree;
namespace fs = boost::filesystem;

std::istringstream included2Xml(R"(
<mujoco>
<material name="0,0,0" specular="1.0" shininess="1.0" rgba="0.0 0.0 0.0 1.0" />
<mesh name="mesh3" file="mesh3.obj"/>
<mesh name="mesh4" file="mesh4.obj"/>
</mujoco>
)");
auto included2File = createTempFile(included2Xml);

std::istringstream included1Xml(
R"(
<mujoco>
<include file=")"
+ included2File.path.filename().string() + R"("/>
<mesh name="mesh1" file="mesh1.obj"/>
<mesh name="mesh2" file="mesh2.obj"/>
</mujoco>
)");
auto included1File = createTempFile(included1Xml);

std::istringstream includedGeomXml(R"(
<mujoco>
<geom name="geom3" mesh="mesh3" />
<geom name="geom4" mesh="mesh4" />
</mujoco>
)");
auto includedGeomFile = createTempFile(includedGeomXml);

std::istringstream mainXml(
R"(
<mujoco>
<asset>
<include file=")"
+ included1File.path.filename().string() + R"("/>
</asset>
<worldbody>
<body name="base">
<geom name="geom1" mesh="mesh1" />
<geom name="geom2" mesh="mesh2" />
<include file=")"
+ includedGeomFile.path.filename().string() + R"("/>
</body>
</worldbody>
</mujoco>
)");

auto mainFile = createTempFile(mainXml);

fs::path basePath = included1File.path.parent_path();

pt::ptree mainTree;
pt::read_xml(mainFile.path.string(), mainTree);

typedef ::pinocchio::mjcf::details::MjcfGraph MjcfGraph;
pinocchio::Model model_m;
MjcfGraph::UrdfVisitor visitor(model_m);

MjcfGraph graph(visitor, "fakeMjcf");
graph.parseGraphFromXML(mainFile.name());
graph.parseRootTree();

// Verify that the mesh paths are updated correctly
BOOST_CHECK(graph.mapOfMeshes.at("mesh1").filePath == (basePath / "mesh1.obj").string());
BOOST_CHECK(graph.mapOfMeshes.at("mesh2").filePath == (basePath / "mesh2.obj").string());
BOOST_CHECK(graph.mapOfMeshes.at("mesh3").filePath == (basePath / "mesh3.obj").string());
BOOST_CHECK(graph.mapOfMeshes.at("mesh4").filePath == (basePath / "mesh4.obj").string());
BOOST_CHECK(graph.mapOfMaterials.find("0,0,0") != graph.mapOfMaterials.end());

// Verify that the geoms are correctly added to the base body
const auto baseBody = graph.mapOfBodies.find("base");
BOOST_CHECK(baseBody != graph.mapOfBodies.cend());
BOOST_CHECK(baseBody->second.geomChildren.size() == 4);

std::set<std::string> expectedGeoms = {"geom1", "geom2", "geom3", "geom4"};
std::set<std::string> actualGeoms;
std::transform(
baseBody->second.geomChildren.cbegin(), baseBody->second.geomChildren.cend(),
std::inserter(actualGeoms, actualGeoms.begin()),
[](const auto & geom) { return geom.geomName; });
BOOST_CHECK(std::equal(actualGeoms.cbegin(), actualGeoms.cend(), expectedGeoms.cbegin()));
}

BOOST_AUTO_TEST_CASE(process_include_missing_file_attribute)
{
namespace pt = boost::property_tree;

std::istringstream mainXml(R"(
<mujoco>
<asset>
<include/>
</asset>
<worldbody>
<body name="base">
</body>
</worldbody>
</mujoco>
)");

auto mainFile = createTempFile(mainXml);

pt::ptree mainTree;
pt::read_xml(mainFile.path.string(), mainTree);

typedef ::pinocchio::mjcf::details::MjcfGraph MjcfGraph;
pinocchio::Model model_m;
MjcfGraph::UrdfVisitor visitor(model_m);

MjcfGraph graph(visitor, "fakeMjcf");
BOOST_CHECK_THROW(graph.parseGraphFromXML(mainFile.name()), std::runtime_error);
}

BOOST_AUTO_TEST_CASE(process_include_invalid_file_path)
{
namespace pt = boost::property_tree;

std::istringstream mainXml(R"(
<mujoco>
<asset>
<include file="nonexistent.xml"/>
</asset>
<worldbody>
<body name="base">
</body>
</worldbody>
</mujoco>
)");

auto mainFile = createTempFile(mainXml);

pt::ptree mainTree;
pt::read_xml(mainFile.path.string(), mainTree);

typedef ::pinocchio::mjcf::details::MjcfGraph MjcfGraph;
pinocchio::Model model_m;
MjcfGraph::UrdfVisitor visitor(model_m);

MjcfGraph graph(visitor, "fakeMjcf");
BOOST_CHECK_THROW(graph.parseGraphFromXML(mainFile.name()), std::runtime_error);
}

BOOST_AUTO_TEST_CASE(process_mesh_paths_updated)
{
namespace pt = boost::property_tree;
namespace fs = boost::filesystem;

std::istringstream includedXml(R"(
<mujoco>
<mesh name="mesh1" file="meshes/mesh1.obj"/>
<mesh name="mesh2" file="meshes/mesh2.obj"/>
</mujoco>
)");
auto includedFile = createTempFile(includedXml);

std::istringstream mainXml(
R"(
<mujoco>
<asset>
<include file=")"
+ includedFile.path.filename().string() + R"("/>
</asset>
<worldbody>
<body name="base">
</body>
</worldbody>
</mujoco>
)");

auto mainFile = createTempFile(mainXml);

fs::path basePath = includedFile.path.parent_path();

pt::ptree mainTree;
pt::read_xml(mainFile.path.string(), mainTree);

typedef ::pinocchio::mjcf::details::MjcfGraph MjcfGraph;
pinocchio::Model model_m;
MjcfGraph::UrdfVisitor visitor(model_m);

MjcfGraph graph(visitor, "fakeMjcf");
graph.parseGraphFromXML(mainFile.name());
graph.parseRootTree();

// Verify that the mesh paths are updated correctly
BOOST_CHECK(graph.mapOfMeshes.at("mesh1").filePath == (basePath / "meshes/mesh1.obj").string());
BOOST_CHECK(graph.mapOfMeshes.at("mesh2").filePath == (basePath / "meshes/mesh2.obj").string());
}

BOOST_AUTO_TEST_SUITE_END()

0 comments on commit 3c6ee1c

Please sign in to comment.