on-demand loading of text encoders#1509
Draft
dxqb wants to merge 8 commits into
Draft
Conversation
…model composition in ModelType - Gradient checkpointing and layer offloading are now configured per component (text encoder, transformer, VAE) rather than globally - ModelType centralizes model composition and training method associations Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Introduces OnDemandModule, a persistent delegating proxy for text encoders that must be loaded on demand and freed after use rather than parked on the CPU temp device. Adds load_on_demand per-component config and four text_encoder_N_on_demand() resolvers in TrainConfig. BaseModel.to(device) is removed as an abstract method; release() is now the sole abstract method for parking a model. Each concrete model reads self.train_config.temp_device directly. Call sites in modelSetup, dataLoader, trainer, and SampleWindow are updated to model.release(). Co-Authored-By: dxqb <183307934+dxqb@users.noreply.github.com> Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This comment was marked as resolved.
This comment was marked as resolved.
dxqb
added a commit
that referenced
this pull request
Jun 14, 2026
This comment was marked as resolved.
This comment was marked as resolved.
3 tasks
Several models' release() forwarded self.train_config.temp_device (a str) directly to *_to() methods typed as device: torch.device. This crashes inside LayerOffloadConductor.to() when layer/block-swap offloading is enabled, since it accesses device.type. nn.Module.to() tolerates str so the bug was latent for runs without offloading enabled. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Resolves conflict in the Flux2 LoRA 8GB preset: keeps this branch's per-component offload_fraction scheme and drops the superseded top-level gradient_checkpointing/layer_offload_fraction fields, while picking up master's dynamic_timestep_shifting addition. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…) rename The rename to release() in this PR accidentally dropped the eval() call that used to follow to(temp_device) before caching and before sampling. Without it, the model stays in train() mode during in-training sampling, which breaks models whose forward pass branches on self.training (e.g. HiDream's unpatchify). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
# Conflicts: # modules/modelSetup/BaseErnieSetup.py # modules/modelSetup/BaseWuerstchenSetup.py # modules/util/checkpointing_util.py
dxqb
added a commit
that referenced
this pull request
Jun 19, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Text encoders mostly sit in RAM, and are only moved to VRAM for caching and sampling.
This PR introduces a mechanism to not load the text encoder at all, and load it directly from disk onto the GPU whenever it is needed.
This is needed by the Lens model, because it doesn't seem to be possible to move the quantized GTP-OSS encoder between CPU and GPU: microsoft/Lens#11
It might also be useful for other models (to save RAM), but this PR doesn't implement it for any other models.
includes #1476
Test plan
pre-commit run --all-filespassesAI assistance