Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions android_env/components/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import json
import re
import threading
from typing import Any
from typing import Any, Optional

from absl import logging
from android_env.components import adb_call_parser as adb_call_parser_lib
Expand All @@ -40,6 +40,11 @@
class TaskManager:
"""Handles all events and information related to the task."""

_setup_step_interpreter: setup_step_interpreter.SetupStepInterpreter
_dumpsys_thread: dumpsys_thread.DumpsysThread
_task_start_time: datetime.datetime
_logcat_thread: logcat_thread.LogcatThread

def __init__(
self,
task: task_pb2.Task,
Expand All @@ -55,9 +60,6 @@ def __init__(
self._task = task
self._config = config or config_classes.TaskManagerConfig()
self._lock = threading.Lock()
self._logcat_thread = None
self._dumpsys_thread = None
self._setup_step_interpreter = None

# Initialize stats.
self._stats = {
Expand All @@ -71,7 +73,6 @@ def __init__(
}

# Initialize internal state
self._task_start_time = None
self._bad_state_counter = 0
self._is_bad_episode = False

Expand All @@ -84,6 +85,11 @@ def __init__(

logging.info('Task config: %s', self._task)

@property
def _logcate_thread_ok(self) -> logcat_thread.LogcatThread:
assert self._logcat_thread is not None
return self._logcat_thread

def stats(self) -> dict[str, Any]:
"""Returns a dictionary of stats.

Expand All @@ -109,16 +115,16 @@ def start(
"""Starts task processing."""

self._start_logcat_thread(log_stream=log_stream)
self._logcat_thread.resume()
self._logcate_thread_ok.resume()
self._start_dumpsys_thread(adb_call_parser_factory())
self._start_setup_step_interpreter(adb_call_parser_factory())

def reset_task(self) -> None:
"""Resets a task for a new run."""

self._logcat_thread.pause()
self._logcate_thread_ok.pause()
self._setup_step_interpreter.interpret(self._task.reset_steps)
self._logcat_thread.resume()
self._logcate_thread_ok.resume()

# Reset some other variables.
if not self._is_bad_episode:
Expand All @@ -139,7 +145,7 @@ def rl_reset(self, observation: dict[str, Any]) -> dm_env.TimeStep:

self._stats['episode_steps'] = 0

self._logcat_thread.line_ready().wait()
self._logcate_thread_ok.line_ready().wait()
with self._lock:
extras = self._get_current_extras()

Expand All @@ -156,7 +162,7 @@ def rl_step(self, observation: dict[str, Any]) -> dm_env.TimeStep:

self._stats['episode_steps'] += 1

self._logcat_thread.line_ready().wait()
self._logcate_thread_ok.line_ready().wait()
with self._lock:
reward = self._get_current_reward()
extras = self._get_current_extras()
Expand Down
Loading