Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support keyword arguments in iteration jobs #536

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions lib/job-iteration/iteration.rb
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def deserialize(job_data) # @private
self.total_time = Float(job_data["total_time"] || 0.0)
end

def perform(*params) # @private
interruptible_perform(*params)
def perform(*args, **kwargs) # @private
interruptible_perform(*args, **kwargs)

nil
end
Expand All @@ -128,12 +128,20 @@ def enumerator_builder
JobIteration.enumerator_builder.new(self)
end

def interruptible_perform(*arguments)
def interruptible_perform(*args, **kwargs)
self.start_time = Time.now.utc

enumerator = nil
ActiveSupport::Notifications.instrument("build_enumerator.iteration", instrumentation_tags) do
enumerator = build_enumerator(*arguments, cursor: cursor_position)
enumerator = if has_only_required_kwargs?(method_parameters(:build_enumerator))
if kwargs.empty?
build_enumerator(*args, cursor: cursor_position)
else
build_enumerator(*args, kwargs, cursor: cursor_position)
end
else
build_enumerator(*args, **kwargs, cursor: cursor_position)
end
end

unless enumerator
Expand All @@ -153,7 +161,7 @@ def interruptible_perform(*arguments)
end

completed = catch(:abort) do
iterate_with_enumerator(enumerator, arguments)
iterate_with_enumerator(enumerator, args, kwargs)
end

run_callbacks(:shutdown)
Expand All @@ -170,8 +178,9 @@ def interruptible_perform(*arguments)
end
end

def iterate_with_enumerator(enumerator, arguments)
arguments = arguments.dup.freeze
def iterate_with_enumerator(enumerator, args, kwargs)
args = args.dup.freeze
kwargs = kwargs.dup.freeze
found_record = false
@needs_reenqueue = false

Expand All @@ -183,7 +192,7 @@ def iterate_with_enumerator(enumerator, arguments)
ActiveSupport::Notifications.instrument("each_iteration.iteration", tags) do
found_record = true
run_callbacks(:iterate) do
each_iteration(object_from_enumerator, *arguments)
each_iteration(object_from_enumerator, *args, **kwargs)
end
self.cursor_position = cursor_from_enumerator
end
Expand Down Expand Up @@ -335,6 +344,11 @@ def valid_cursor_parameter?(parameters)
false
end

def has_only_required_kwargs?(parameters)
# puts "parameters: #{parameters.inspect}"
!parameters.any? { |parameter_type, parameter_name| parameter_type != :req && parameter_name != :cursor }
end

SIMPLE_SERIALIZABLE_CLASSES = [String, Integer, Float, NilClass, TrueClass, FalseClass].freeze
private_constant :SIMPLE_SERIALIZABLE_CLASSES
def serializable?(object)
Expand Down
40 changes: 31 additions & 9 deletions test/unit/active_job_iteration_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ class SimpleIterationJob < ActiveJob::Base
end

class MultiArgumentIterationJob < SimpleIterationJob
def build_enumerator(_one_arg, _another_arg, cursor:)
def build_enumerator(_one_arg, _another_arg, keyword_arg:, cursor:)
enumerator_builder.build_times_enumerator(2, cursor: cursor)
end

def each_iteration(item, one_arg, another_arg)
self.class.records_performed << [item, one_arg, another_arg]
def each_iteration(item, one_arg, another_arg, keyword_arg:)
self.class.records_performed << [item, one_arg, another_arg, keyword_arg]
end
end

Expand All @@ -58,6 +58,16 @@ def each_iteration(_record, params)
end
end

class KwargsIterationJob < SimpleIterationJob
def build_enumerator(times:, cursor:)
enumerator_builder.build_times_enumerator(times, cursor: cursor)
end

def each_iteration(_record, times:)
self.class.records_performed << { times: times }
end
end

class ActiveRecordIterationJob < SimpleIterationJob
def build_enumerator(cursor:)
enumerator_builder.active_record_on_records(
Expand Down Expand Up @@ -575,25 +585,25 @@ def perform(*)
end

def test_supports_multiple_job_arguments_and_global_id
MultiArgumentIterationJob.perform_later(Product.first, nil)
MultiArgumentIterationJob.perform_later(Product.first, nil, keyword_arg: "d")

work_one_job

expected = [
[0, Product.first, nil],
[1, Product.first, nil],
[0, Product.first, nil, "d"],
[1, Product.first, nil, "d"],
]
assert_equal(expected, MultiArgumentIterationJob.records_performed)
end

def test_supports_multiple_job_arguments
MultiArgumentIterationJob.perform_later(2, ["a", "b", "c"])
MultiArgumentIterationJob.perform_later(2, ["a", "b", "c"], keyword_arg: "d")

work_one_job

expected = [
[0, 2, ["a", "b", "c"]],
[1, 2, ["a", "b", "c"]],
[0, 2, ["a", "b", "c"], "d"],
[1, 2, ["a", "b", "c"], "d"],
]
assert_equal(expected, MultiArgumentIterationJob.records_performed)
end
Expand All @@ -605,6 +615,18 @@ def test_passes_params_to_each_iteration
assert_equal([params, params], ParamsIterationJob.records_performed)
end

def test_passes_kwargs_to_jobs_without_kwargs_in_build_enumerator
ParamsIterationJob.perform_later(times: 3)
work_one_job
assert_equal([{ times: 3 }, { times: 3 }, { times: 3 }], ParamsIterationJob.records_performed)
end

def test_passes_kwargs_to_each_iteration
KwargsIterationJob.perform_later(times: 3)
work_one_job
assert_equal([{ times: 3 }, { times: 3 }, { times: 3 }], KwargsIterationJob.records_performed)
end

def test_passes_params_to_each_iteration_without_extra_information_on_interruption
iterate_exact_times(1.times)
params = { "walrus" => "yes", "morewalrus" => "si" }
Expand Down
Loading