diff --git a/tests/unit/test_vcr.py b/tests/unit/test_vcr.py index e285eb6d..9e4d4bf0 100644 --- a/tests/unit/test_vcr.py +++ b/tests/unit/test_vcr.py @@ -372,3 +372,19 @@ def test_path_class_as_cassette(): ) with use_cassette(path): pass + + +def test_use_cassette_generator_return(): + ret_val = object() + + vcr = VCR() + + @vcr.use_cassette("test") + def gen(): + return ret_val + yield + + with pytest.raises(StopIteration) as exc_info: + next(gen()) + + assert exc_info.value.value is ret_val diff --git a/vcr/cassette.py b/vcr/cassette.py index fad0d25d..5a189e32 100644 --- a/vcr/cassette.py +++ b/vcr/cassette.py @@ -125,7 +125,7 @@ def _handle_generator(self, fn): duration of the generator. """ with self as cassette: - yield from fn(cassette) + return (yield from fn(cassette)) def _handle_function(self, fn): with self as cassette: