Skip to content

GH-3614: Fix identify_dynamic_embeddings for composite DataPoints #3659

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

alanakbik
Copy link
Collaborator

Closes #3614

This PR fixes an issue where identify_dynamic_embeddings did not correctly detect dynamic embeddings (those with requires_grad=True) within composite DataPoint types like DataPair or Sentence with token embeddings.

The logic has been refactored by:

  1. Adding internal helper methods (_get_dynamic_embedding_names and _get_all_embedding_names) to the DataPoint base class with default implementations.
  2. Overriding these methods in composite classes (Sentence, Span, DataPair, DataTriple) to recursively check their constituent parts.
  3. Simplifying the identify_dynamic_embeddings function in training_utils.py to use these helpers.

This ensures all relevant dynamic embeddings are correctly identified across different DataPoint structures. Unit tests have been added in tests/test_training_utils.py to cover various scenarios and verify the fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: identify_dynamic_embeddings does not work for DataPair
1 participant