Skip to content

Commit d835b27

Browse files
committed
refactored tree class
1 parent eb702e2 commit d835b27

File tree

1 file changed

+45
-44
lines changed

1 file changed

+45
-44
lines changed

pybullet_tree_sim/tree.py

Lines changed: 45 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
import numpy as np
1919
import pybullet
2020
import pywavefront
21-
from nptyping import NDArray, Shape, Float
21+
22+
# from nptyping import NDArray, Shape, Float
2223
from numpy.typing import ArrayLike
2324
from pybullet_tree_sim import RGB_LABEL, URDF_PATH, MESHES_PATH, PKL_PATH
2425
from pybullet_tree_sim.utils.pyb_utils import PyBUtils
@@ -94,30 +95,26 @@ def __init__(
9495
self.tree_namespace = namespace
9596
self.tree_id = tree_id
9697
self.tree_type = tree_type
97-
self.id_str = self.create_id_string(tree_id=tree_id, tree_type=tree_type, namespace=namespace, urdf_path=urdf_path)
98-
self.urdf_path = os.path.join(self._tree_generated_urdf_path, self.id_str+'.urdf')
99-
self.mesh_path = os.path.join(self._tree_meshes_unlabeled_path, self.id_str+".obj")
100-
self.labeled_mesh_path = os.path.join(self._tree_meshes_labeled_path, self.id_str+"_labeled.obj")
98+
self.id_str = self.create_id_string(
99+
tree_id=tree_id, tree_type=tree_type, namespace=namespace, urdf_path=urdf_path
100+
)
101+
self.urdf_path = os.path.join(self._tree_generated_urdf_path, self.id_str + ".urdf")
102+
self.mesh_path = os.path.join(self._tree_meshes_unlabeled_path, self.id_str + ".obj")
103+
self.labeled_mesh_path = os.path.join(self._tree_meshes_labeled_path, self.id_str + "_labeled.obj")
101104
self.init_pos = position
102105
self.init_orientation = orientation
103106

104-
105-
log.info(f"__init__ {self.id_str}")
106-
107-
108107
# URDF
109108
self.load_tree_urdf(scale=scale, parent=parent)
110109
# OBJ
111110
tree_obj = self.load_tree_obj()
112-
log.info(f"Tree mesh loaded: {self.mesh_path}")
113-
# Labelled OBJ
111+
# Labeled OBJ
114112
labeled_tree_obj = self.load_labeled_tree_obj()
115113

116114
# Tree specific parameters
117115
self.rgb_label = RGB_LABEL
118-
self.pyb_tree_id = None
119-
120-
116+
# PyBullet parameters
117+
self.pyb_id: int = None
121118

122119
# Set tree pose
123120
if randomize_pose:
@@ -185,24 +182,28 @@ def __init__(
185182

186183
return
187184

188-
def create_id_string(self,
185+
def create_id_string(
186+
self,
189187
tree_id: int | None = None,
190188
tree_type: str | None = None,
191189
namespace: str | None = None,
192-
urdf_path: str | None = None
190+
urdf_path: str | None = None,
193191
) -> str:
194192
if tree_id is None and urdf_path is None:
195193
# log.error("Both urdf_path and tree parameters cannot be None.")
196194
raise TreeException("Both urdf_path and tree parameters cannot be None.")
197195

196+
if tree_id is not None:
197+
tree_id = str(tree_id).zfill(5)
198+
198199
if urdf_path is None:
199-
id_str = f"{namespace}_{tree_type}_tree{tree_id}"
200+
id_str = f"{namespace}_{tree_type}_{tree_id}"
200201
else:
201202
id_str = Path(urdf_path).stem
202-
id_str_components = id_str.split('_')
203+
id_str_components = id_str.split("_")
203204
self.tree_namespace = id_str_components[0]
204205
self.tree_type = id_str_components[1]
205-
self.tree_id = id_str_components[2]
206+
self.tree_id = str(id_str_components[2]).zfill(5)
206207
return id_str
207208

208209
def _load_points_from_pickle(self, pkl_path):
@@ -448,10 +449,8 @@ def load_tree_urdf(
448449
orientation: str = "0.0 0.0 0.0",
449450
save_urdf: bool = True,
450451
regenerate_urdf: bool = False, # TODO: make save/regenerate work well together. Will need to add delete URDF function
451-
) -> None:
452-
"""Load a tree URDF from a given path or generate a tree URDF from a xacro file. Returns the URDF content.
453-
If `tree_urdf_path` is not None, then load that URDF.
454-
Otherwise, process an xacro file with given input parameters.
452+
) -> str:
453+
"""Load a tree URDF from a given path or generate a tree URDF from a xacro file. If content is generated, by default saves the content to /urdf/trees/<tree_type>/generated Returns the URDF content.
455454
456455
Returns
457456
-------
@@ -463,34 +462,37 @@ def load_tree_urdf(
463462
if not os.path.isdir(Tree._tree_generated_urdf_path):
464463
os.mkdir(Tree._tree_generated_urdf_path)
465464

466-
urdf_mappings = {
467-
"namespace": self.tree_namespace,
468-
"tree_id": str(self.tree_id),
469-
"tree_type": self.tree_type,
470-
"parent": parent,
471-
"xyz": position,
472-
"rpy": orientation,
473-
}
474-
# If the tree macro information doesn't describe a generated file, generate it using the generic tree xacro.
475-
urdf_content = xutils.load_urdf_from_xacro(
476-
xacro_path=Tree._tree_xacro_path, mappings=urdf_mappings
477-
).toprettyxml()
478-
if save_urdf:
479-
xutils.save_urdf(urdf_content=urdf_content, urdf_path=self.urdf_path)
480-
log.info(f"Saved URDF to file '{self.urdf_path}'.")
465+
_tree_id = str(self.tree_id).zfill(5)
466+
urdf_mappings = {
467+
"namespace": self.tree_namespace,
468+
"tree_id": _tree_id,
469+
"tree_type": self.tree_type,
470+
"parent": parent,
471+
"xyz": position,
472+
"rpy": orientation,
473+
}
474+
475+
# If the tree macro information doesn't describe a generated file, generate it using the generic tree xacro.
476+
urdf_content = xutils.load_urdf_from_xacro(
477+
xacro_path=Tree._tree_xacro_path, mappings=urdf_mappings
478+
).toprettyxml()
479+
if save_urdf:
480+
xutils.save_urdf(urdf_content=urdf_content, urdf_path=self.urdf_path)
481481
else:
482482
urdf_content = xutils.load_urdf_from_xacro(xacro_path=self.urdf_path).toprettyxml()
483483
log.info(f"Loaded URDF from file '{self.urdf_path}'.")
484484

485-
return
485+
return urdf_content
486486

487487
def load_tree_obj(self):
488+
"""Loads a mesh .obj mesh file with its path defined by the tree_id_str"""
488489
if not os.path.exists(self.mesh_path):
489490
raise TreeException(f"Could not find file '{self.mesh_path}.")
490491
tree_obj = pywavefront.Wavefront(self.mesh_path, create_materials=True, collect_faces=True)
491492
return tree_obj
492493

493494
def load_labeled_tree_obj(self):
495+
"""Loads a labeled mesh .obj mesh file with its path defined by the tree_id_str"""
494496
if not os.path.exists(self.labeled_mesh_path):
495497
raise TreeException(f"Could not find the file {self.labeled_mesh_path}")
496498
labeled_tree_obj = pywavefront.Wavefront(self.labeled_mesh_path, create_materials=True, collect_faces=True)
@@ -500,11 +502,11 @@ def load_labeled_tree_obj(self):
500502
def make_trees_from_ids(
501503
pbutils: PyBUtils,
502504
tree_ids: list[int],
503-
namespace: str = '',
504-
pos: np.ndarray = np.array([0,0,0]),
505-
orientation: np.ndarray = np.array([0,0,0,1]),
505+
namespace: str = "",
506+
pos: np.ndarray = np.array([0, 0, 0]),
507+
orientation: np.ndarray = np.array([0, 0, 0, 1]),
506508
scale: float = 1.0,
507-
randomize_pose: bool = False
509+
randomize_pose: bool = False,
508510
) -> list[Tree]:
509511
trees: list[Tree] = []
510512

@@ -513,7 +515,6 @@ def make_trees_from_ids(
513515
Tree(
514516
pbutils=pbutils,
515517
tree_id=tree_id,
516-
517518
)
518519
)
519520
return trees

0 commit comments

Comments
 (0)