From f75787794737225734f1c3502121b614df081e68 Mon Sep 17 00:00:00 2001 From: Jared Armstrong Date: Sat, 28 Sep 2024 15:44:58 +1200 Subject: [PATCH] Ensure ActiveSupport::BroadcastLogger only executes blocks once. [Fixes #49745] [Related #51883 #49771] Prior to this change, BroadcastLoggers would iterate each broadcast and execute the user provided block for each. This resulted in unintended behaviour since a user-provided block could execute multiple times. The consumer of any Logger would reasonably expect than when calling a method with a block, that block would only execute a single time. That is, the fact that a Logger is a BroadcastLogger should be irrelevant to consumer. The most significant example of this is with ActiveSupport::TaggedLogging. If a BroadcastLogger is used, and there are multiple loggers being broadcast to that respond to the `tagged` method, then calling `tagged` with a block would result in the block being called multiple times. For example: ```ruby broadcasts = ActiveSupport::BroadcastLogger.new( *Array.new(2) { ActiveSupport::TaggedLogging.logger } ) number = 0 broadcasts.tagged("FOO") { broadcasts.log(++number.to_s) } # Outputs: # [FOO] 1 # [FOO] 1 # [FOO] 2 # [FOO] 2 ``` The same issue also applies when calling `info`, `warn`, etc. with a block. This commit modifies the implementation used for dispatching to instead 'wrap' the block calls such that the user-provided block is only executed in the innermost call. (An assumption is made that when passed a block, all loggers will yield it, and have the same return semantics.) For example, the above example would effectively be executed as so: ```ruby broadcasts[0].tagged("FOO") { broadcasts[1].tagged("FOO") { yield } } ``` --- .../lib/active_support/broadcast_logger.rb | 73 +++++++++++-------- .../log_subscriber/test_helper.rb | 4 + activesupport/test/broadcast_logger_test.rb | 32 +++++++- 3 files changed, 77 insertions(+), 32 deletions(-) diff --git a/activesupport/lib/active_support/broadcast_logger.rb b/activesupport/lib/active_support/broadcast_logger.rb index 4cc1e35914056..693a3f8f11c56 100644 --- a/activesupport/lib/active_support/broadcast_logger.rb +++ b/activesupport/lib/active_support/broadcast_logger.rb @@ -110,57 +110,55 @@ def level end def <<(message) - dispatch { |logger| logger.<<(message) } + dispatch(:<<, message) end def add(...) - dispatch { |logger| logger.add(...) } + dispatch(:add, ...) end alias_method :log, :add def debug(...) - dispatch { |logger| logger.debug(...) } + dispatch(:debug, ...) end def info(...) - dispatch { |logger| logger.info(...) } + dispatch(:info, ...) end def warn(...) - dispatch { |logger| logger.warn(...) } + dispatch(:warn, ...) end def error(...) - dispatch { |logger| logger.error(...) } + dispatch(:error, ...) end def fatal(...) - dispatch { |logger| logger.fatal(...) } + dispatch(:fatal, ...) end def unknown(...) - dispatch { |logger| logger.unknown(...) } + dispatch(:unknown, ...) end def formatter=(formatter) - dispatch { |logger| logger.formatter = formatter } + dispatch(:formatter=, formatter) @formatter = formatter end def level=(level) - dispatch { |logger| logger.level = level } + dispatch(:level=, level) end alias_method :sev_threshold=, :level= def local_level=(level) - dispatch do |logger| - logger.local_level = level if logger.respond_to?(:local_level=) - end + dispatch(:local_level=, level) end def close - dispatch { |logger| logger.close } + dispatch(:close) end # +True+ if the log level allows entries with severity Logger::DEBUG to be written @@ -171,7 +169,7 @@ def debug? # Sets the log level to Logger::DEBUG for the whole broadcast. def debug! - dispatch { |logger| logger.debug! } + dispatch(:debug!) end # +True+ if the log level allows entries with severity Logger::INFO to be written @@ -182,7 +180,7 @@ def info? # Sets the log level to Logger::INFO for the whole broadcast. def info! - dispatch { |logger| logger.info! } + dispatch(:info!) end # +True+ if the log level allows entries with severity Logger::WARN to be written @@ -193,7 +191,7 @@ def warn? # Sets the log level to Logger::WARN for the whole broadcast. def warn! - dispatch { |logger| logger.warn! } + dispatch(:warn!) end # +True+ if the log level allows entries with severity Logger::ERROR to be written @@ -204,7 +202,7 @@ def error? # Sets the log level to Logger::ERROR for the whole broadcast. def error! - dispatch { |logger| logger.error! } + dispatch(:error!) end # +True+ if the log level allows entries with severity Logger::FATAL to be written @@ -215,7 +213,7 @@ def fatal? # Sets the log level to Logger::FATAL for the whole broadcast. def fatal! - dispatch { |logger| logger.fatal! } + dispatch(:fatal!) end def initialize_copy(other) @@ -227,20 +225,37 @@ def initialize_copy(other) end private - def dispatch(&block) - @broadcasts.each { |logger| block.call(logger) } - true + def dispatch(name, *args, **kwargs, &block) + loggers = @broadcasts.select { |logger| logger.respond_to?(name) } + if block_given? + # We record the return value of the user provided block so that we can + # ensure each logger gets the same return value. + returns = nil + yielder = proc { |*args, **kwargs| + returns = block.call(*args, **kwargs) + } + + # Wrap the block in a proc nested for each logger. + loggers.inject(yielder) { |yielder, logger| + proc { + logger.send(name, *args, **kwargs) do |*args, **kwargs| + yielder.call(*args, **kwargs) + returns + end + } + }.call + else + loggers.map { |logger| + logger.send(name, *args, **kwargs) + }.first + end end def method_missing(name, ...) - loggers = @broadcasts.select { |logger| logger.respond_to?(name) } - - if loggers.none? - super - elsif loggers.one? - loggers.first.send(name, ...) + if @broadcasts.any? { |logger| logger.respond_to?(name) } + dispatch(name, ...) else - loggers.map { |logger| logger.send(name, ...) } + super end end diff --git a/activesupport/lib/active_support/log_subscriber/test_helper.rb b/activesupport/lib/active_support/log_subscriber/test_helper.rb index b528a7fc10f5f..628d38009c5c1 100644 --- a/activesupport/lib/active_support/log_subscriber/test_helper.rb +++ b/activesupport/lib/active_support/log_subscriber/test_helper.rb @@ -71,6 +71,10 @@ def method_missing(level, message = nil) end end + def respond_to_missing?(level, include_private = false) + %i[debug info warn error fatal].include?(level) || super + end + def logged(level) @logged[level].compact.map { |l| l.to_s.strip } end diff --git a/activesupport/test/broadcast_logger_test.rb b/activesupport/test/broadcast_logger_test.rb index 5f94480cb3db0..9ca18b06bae22 100644 --- a/activesupport/test/broadcast_logger_test.rb +++ b/activesupport/test/broadcast_logger_test.rb @@ -259,13 +259,28 @@ def info(msg, &block) test "calling a method when *multiple* loggers in the broadcast have implemented it" do logger = BroadcastLogger.new(CustomLogger.new, CustomLogger.new) - assert_equal([true, true], logger.foo) + assert(logger.foo) end test "calling a method when a subset of loggers in the broadcast have implemented" do - logger = BroadcastLogger.new(CustomLogger.new, FakeLogger.new) + special_logger = Class.new(CustomLogger) do + def special_method = true + end.new + logger = BroadcastLogger.new(CustomLogger.new, FakeLogger.new, special_logger) - assert(logger.foo) + assert(logger.special_method) + end + + test "calling a method returns the first return value" do + special_logger = Class.new(CustomLogger) do + def special_method = "bar" + end.new + other_logger = Class.new(CustomLogger) do + def special_method = "foo" + end.new + logger = BroadcastLogger.new(CustomLogger.new, other_logger, FakeLogger.new, special_logger) + + assert_equal "foo", logger.special_method end test "calling a method that accepts a block" do @@ -278,6 +293,16 @@ def info(msg, &block) assert(called) end + test "calling a method that accepts a block with multiple loggers" do + logger = BroadcastLogger.new(CustomLogger.new, CustomLogger.new) + + called = 0 + logger.bar do + called += 1 + end + assert_equal 1, called, "block should be called just once" + end + test "calling a method that accepts args" do logger = BroadcastLogger.new(CustomLogger.new) @@ -386,6 +411,7 @@ def <<(x) def add(message_level, message = nil, progname = nil, &block) @adds << [message_level, message, progname] if message_level >= local_level + true end def debug?