fix derived scan logprob when observed provides more broadcastable information #8016
+95
−3
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
fixed a failure in derived scan logprob construction when the observed/value tensor provides more static broadcastability information than the generative scan graph (e.g. observed has a size-1 axis like
(date, 1)while the scan state was inferred as non broadcastable on that axis).in this,
model.logp()could fail during the measurable scan rewrite with a scanoutputs_infobroadcast pattern mismatch (scan output inferred as matrix like vs.outputs_infoexpecting vector-like).Applynodes (so scan reconstruction remains valid).note
I think the same idea can be generalized by treating static broadcastability metadata as part of the measurable scan rewrite contract:
outputs_infoproxies for the logprob rewritten scan, ensure theirTensorType.shapereflects any size-1/broadcastable axes implied by the outer variables.pt.join/outputs_info) so that init and scan outputs agree on broadcastability, without inserting broadcastApplynodes into the inner graph (placeholders must remain nominal vars).Related Issue
Checklist
Type of change