@@ -705,6 +705,32 @@ def reset( # pylint: disable=arguments-differ
705705 :raises TypeError: If no benchmark has been set, and the environment
706706 does not have a default benchmark to select from.
707707 """
708+
709+ def _call_with_retry (stub_method , * args , ** kwargs ):
710+ """Call the given stub method. If it fails with an "acceptable"
711+ error, abort this reset and retry.
712+ """
713+ try :
714+ return self .service (stub_method , * args , ** kwargs )
715+ except (ServiceError , ServiceTransportError , TimeoutError ) as e :
716+ # Abort and retry on error.
717+ self .logger .warning ("%s on reset(): %s" , type (e ).__name__ , e )
718+ if self .service :
719+ self .service .close ()
720+ self .service = None
721+
722+ if retry_count >= self ._connection_settings .init_max_attempts :
723+ raise OSError (
724+ f"Failed to reset environment after { retry_count - 1 } attempts.\n "
725+ f"Last error ({ type (e ).__name__ } ): { e } "
726+ ) from e
727+ else :
728+ return self .reset (
729+ benchmark = benchmark ,
730+ action_space = action_space ,
731+ retry_count = retry_count + 1 ,
732+ )
733+
708734 if not self ._next_benchmark :
709735 raise TypeError (
710736 "No benchmark set. Set a benchmark using "
@@ -723,7 +749,7 @@ def reset( # pylint: disable=arguments-differ
723749 # Stop an existing episode.
724750 if self .in_episode :
725751 self .logger .debug ("Ending session %d" , self ._session_id )
726- self . service (
752+ _call_with_retry (
727753 self .service .stub .EndSession ,
728754 EndSessionRequest (session_id = self ._session_id ),
729755 )
@@ -759,33 +785,19 @@ def reset( # pylint: disable=arguments-differ
759785 )
760786
761787 try :
762- reply = self .service (self .service .stub .StartSession , start_session_request )
788+ reply = _call_with_retry (
789+ self .service .stub .StartSession , start_session_request
790+ )
763791 except FileNotFoundError :
764792 # The benchmark was not found, so try adding it and repeating the
765793 # request.
766794 self .service (
767795 self .service .stub .AddBenchmark ,
768796 AddBenchmarkRequest (benchmark = [self ._benchmark_in_use .proto ]),
769797 )
770- reply = self .service (self .service .stub .StartSession , start_session_request )
771- except (ServiceError , ServiceTransportError , TimeoutError ) as e :
772- # Abort and retry on error.
773- self .logger .warning ("%s on reset(): %s" , type (e ).__name__ , e )
774- if self .service :
775- self .service .close ()
776- self .service = None
777-
778- if retry_count >= self ._connection_settings .init_max_attempts :
779- raise OSError (
780- f"Failed to reset environment after { retry_count - 1 } attempts.\n "
781- f"Last error ({ type (e ).__name__ } ): { e } "
782- ) from e
783- else :
784- return self .reset (
785- benchmark = benchmark ,
786- action_space = action_space ,
787- retry_count = retry_count + 1 ,
788- )
798+ reply = _call_with_retry (
799+ self .service .stub .StartSession , start_session_request
800+ )
789801
790802 self ._session_id = reply .session_id
791803 self .observation .session_id = reply .session_id
0 commit comments