Skip to content

Commit 0f76d42

Browse files
author
Chris Cummins
authored
Merge pull request #423 from ChrisCummins/reset-rety
[env] Wrap all RPC calls in reset() in retry loop.
2 parents 03f0a3b + 653f65d commit 0f76d42

File tree

1 file changed

+33
-21
lines changed

1 file changed

+33
-21
lines changed

compiler_gym/envs/compiler_env.py

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,32 @@ def reset( # pylint: disable=arguments-differ
705705
:raises TypeError: If no benchmark has been set, and the environment
706706
does not have a default benchmark to select from.
707707
"""
708+
709+
def _call_with_retry(stub_method, *args, **kwargs):
710+
"""Call the given stub method. If it fails with an "acceptable"
711+
error, abort this reset and retry.
712+
"""
713+
try:
714+
return self.service(stub_method, *args, **kwargs)
715+
except (ServiceError, ServiceTransportError, TimeoutError) as e:
716+
# Abort and retry on error.
717+
self.logger.warning("%s on reset(): %s", type(e).__name__, e)
718+
if self.service:
719+
self.service.close()
720+
self.service = None
721+
722+
if retry_count >= self._connection_settings.init_max_attempts:
723+
raise OSError(
724+
f"Failed to reset environment after {retry_count - 1} attempts.\n"
725+
f"Last error ({type(e).__name__}): {e}"
726+
) from e
727+
else:
728+
return self.reset(
729+
benchmark=benchmark,
730+
action_space=action_space,
731+
retry_count=retry_count + 1,
732+
)
733+
708734
if not self._next_benchmark:
709735
raise TypeError(
710736
"No benchmark set. Set a benchmark using "
@@ -723,7 +749,7 @@ def reset( # pylint: disable=arguments-differ
723749
# Stop an existing episode.
724750
if self.in_episode:
725751
self.logger.debug("Ending session %d", self._session_id)
726-
self.service(
752+
_call_with_retry(
727753
self.service.stub.EndSession,
728754
EndSessionRequest(session_id=self._session_id),
729755
)
@@ -759,33 +785,19 @@ def reset( # pylint: disable=arguments-differ
759785
)
760786

761787
try:
762-
reply = self.service(self.service.stub.StartSession, start_session_request)
788+
reply = _call_with_retry(
789+
self.service.stub.StartSession, start_session_request
790+
)
763791
except FileNotFoundError:
764792
# The benchmark was not found, so try adding it and repeating the
765793
# request.
766794
self.service(
767795
self.service.stub.AddBenchmark,
768796
AddBenchmarkRequest(benchmark=[self._benchmark_in_use.proto]),
769797
)
770-
reply = self.service(self.service.stub.StartSession, start_session_request)
771-
except (ServiceError, ServiceTransportError, TimeoutError) as e:
772-
# Abort and retry on error.
773-
self.logger.warning("%s on reset(): %s", type(e).__name__, e)
774-
if self.service:
775-
self.service.close()
776-
self.service = None
777-
778-
if retry_count >= self._connection_settings.init_max_attempts:
779-
raise OSError(
780-
f"Failed to reset environment after {retry_count - 1} attempts.\n"
781-
f"Last error ({type(e).__name__}): {e}"
782-
) from e
783-
else:
784-
return self.reset(
785-
benchmark=benchmark,
786-
action_space=action_space,
787-
retry_count=retry_count + 1,
788-
)
798+
reply = _call_with_retry(
799+
self.service.stub.StartSession, start_session_request
800+
)
789801

790802
self._session_id = reply.session_id
791803
self.observation.session_id = reply.session_id

0 commit comments

Comments
 (0)