Skip to content

Commit d2b6d2e

Browse files
authored
Merge pull request #41 from vakker/entity-types
Entity types
2 parents 80820a1 + 6d84437 commit d2b6d2e

File tree

5 files changed

+32
-4
lines changed

5 files changed

+32
-4
lines changed

.github/workflows/python-package.yml

-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ jobs:
3737
- name: Install dependencies
3838
run: |
3939
python -m pip install --upgrade pip
40-
sudo apt-get update
4140
pip install flake8 pytest
4241
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
4342
- name: Lint with flake8

simple_playgrounds/playground.py

+12
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ def __init__(self, size):
7070
self._handle_interactions()
7171
self.sensor_collision_index = 2
7272

73+
self.entity_types_map = {}
74+
7375
@staticmethod
7476
def _initialize_space():
7577
""" Method to initialize Pymunk empty space for 2D physics.
@@ -353,6 +355,16 @@ def _add_scene_element(self, new_scene_element, keep_position):
353355
if new_scene_element in self._disappeared_scene_elements:
354356
self._disappeared_scene_elements.remove(new_scene_element)
355357

358+
def create_entity_types_map(self,
359+
additional_types=[]):
360+
entity_types = [type(e) for e in self.scene_elements]
361+
entity_types.extend(additional_types)
362+
363+
self.entity_types_map = {}
364+
for et in entity_types:
365+
if et not in self.entity_types_map:
366+
self.entity_types_map[et] = len(self.entity_types_map)
367+
356368
def _entity_colliding(self, entity):
357369

358370
collides = False

simple_playgrounds/playgrounds/collection/rl/foraging.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
Trajectory)
66

77

8-
@PlaygroundRegister.register('foraging', 'candy_collect')
9-
class CandyCollectEnv(SingleRoom):
8+
@PlaygroundRegister.register('foraging', 'candy_poison')
9+
class CandyPoisonEnv(SingleRoom):
1010
def __init__(self):
1111
super().__init__(size=(200, 200))
1212

@@ -17,6 +17,7 @@ def __init__(self):
1717
width_length=(100, 100))
1818
self.agent_starting_area = area_start
1919

20+
additional_types = set()
2021
for loc in ["down-left", "up-right"]:
2122
area_center, size_area = self.get_area((0, 0), loc)
2223
area = CoordinateSampler(center=area_center,
@@ -27,16 +28,20 @@ def __init__(self):
2728
probability=0.01,
2829
limit=2)
2930
self.add_scene_element(field)
31+
additional_types.add(field.entity_produced)
3032

3133
field = scene_elements.Field(entity_produced=scene_elements.Poison,
3234
production_area=area,
3335
probability=0.01,
3436
limit=2)
3537
self.add_scene_element(field)
38+
additional_types.add(field.entity_produced)
3639

3740
self.time_limit = 2000
3841
self.time_limit_reached_reward = -1
3942

43+
self.create_entity_types_map(additional_types)
44+
4045

4146
@PlaygroundRegister.register('foraging', 'candy_fireballs')
4247
class CandyFireballs(SingleRoom):
@@ -49,6 +54,8 @@ def __init__(self, time_limit=100, probability_production=0.4):
4954
'size_tiles': 4
5055
}
5156

57+
additional_types = set()
58+
5259
# First Fireball
5360
text_1 = {'color_min': [220, 0, 200], 'color_max': [255, 100, 220]}
5461
trajectory = Trajectory('waypoints',
@@ -60,6 +67,7 @@ def __init__(self, time_limit=100, probability_production=0.4):
6067
**text_1
6168
})
6269
self.add_scene_element(fireball, trajectory)
70+
additional_types.add(type(fireball))
6371

6472
# Second Fireball
6573
text_2 = {'color_min': [180, 0, 0], 'color_max': [220, 100, 0]}
@@ -92,5 +100,8 @@ def __init__(self, time_limit=100, probability_production=0.4):
92100
production_area=area_prod,
93101
probability=probability_production)
94102
self.add_scene_element(field)
103+
additional_types.add(field.entity_produced)
95104

96105
self.time_limit = time_limit
106+
107+
self.create_entity_types_map(additional_types)

simple_playgrounds/playgrounds/collection/rl/navigation.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ def __init__(self,
6161
self.time_limit = time_limit
6262
self.time_limit_reached_reward = reward_reached_time_limit
6363

64+
self.create_entity_types_map([Basic])
65+
6466
def _set_goal(self):
6567

6668
index_goal = random.randint(0, 3)
@@ -115,4 +117,4 @@ def __init__(self,
115117

116118
self.time_limit = time_limit
117119

118-
120+
self.create_entity_types_map()

simple_playgrounds/playgrounds/collection/rl/sequential.py

+4
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ def _place_scene_elements(self):
7272
allow_overlapping=False)
7373
self.add_scene_element(self.dispenser, self.area_dispenser)
7474

75+
self.create_entity_types_map([self.dispenser.entity_produced])
76+
7577
def reset(self):
7678
self.remove_scene_element(self.dispenser)
7779

@@ -148,3 +150,5 @@ def __init__(
148150
area_shape='rectangle',
149151
width_length=area_start_shape)
150152
self.initial_agent_coordinates = area_start
153+
154+
self.create_entity_types_map([dispenser.entity_produced])

0 commit comments

Comments
 (0)