Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 105 additions & 52 deletions src/transformers/integrations/executorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,18 @@
from typing import Optional

import torch
import torch.utils._pytree as pytree

from ..cache_utils import (
Cache,
ChunkedSlidingLayer,
DynamicCache,
DynamicLayer,
DynamicSlidingWindowLayer,
EncoderDecoderCache,
SlidingWindowLayer,
StaticCache,
StaticLayer,
)
from ..generation.configuration_utils import GenerationConfig
from ..masking_utils import (
Expand All @@ -34,7 +39,6 @@
from ..pytorch_utils import (
is_torch_greater_or_equal,
is_torch_greater_or_equal_than_2_3,
is_torch_greater_or_equal_than_2_6,
)


Expand Down Expand Up @@ -855,7 +859,7 @@ def __init__(self, model, max_static_cache_length, batch_size):
self.static_cache.early_initialization(batch_size, num_heads, head_dim, torch.float32, model_device)
self.cache = EncoderDecoderCache(self.static_cache, DynamicCache(config=self.config))

register_dynamic_cache_export_support()
register_pytree_cache()

# Register cache buffers to make them exportable
for i in range(len(self.static_cache)):
Expand Down Expand Up @@ -1041,7 +1045,7 @@ def export_with_dynamic_cache(
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
model.config._attn_implementation = "sdpa_without_vmap"

register_dynamic_cache_export_support()
register_pytree_cache()

with torch.no_grad():
exported_program = torch.export.export(
Expand All @@ -1058,57 +1062,106 @@ def export_with_dynamic_cache(
return exported_program


def register_dynamic_cache_export_support():
"""
Utilities for `DynamicCache` <> torch.export support
"""
def _register_pytree_cache_layer(cache_layer_cls):
def _flatten_layer(layer):
attributes = {
"keys": layer.keys,
"values": layer.values,
}

try:
torch.utils._pytree.register_pytree_node(
DynamicCache,
lambda dynamic_cache: torch.utils._pytree._dict_flatten(_get_cache_dict(dynamic_cache)),
_unflatten_dynamic_cache,
serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
flatten_with_keys_fn=lambda dynamic_cache: torch.utils._pytree._dict_flatten_with_keys(
_get_cache_dict(dynamic_cache)
),
)
# TODO (tmanlaibaatar) This won't be needed in torch 2.7.
torch.fx._pytree.register_pytree_flatten_spec(
DynamicCache,
lambda cache, spec: torch.fx._pytree._dict_flatten_spec(_get_cache_dict(cache), spec),
if isinstance(layer, StaticLayer):
attributes["max_cache_len"] = layer.max_cache_len
attributes["max_batch_size"] = getattr(layer, "max_batch_size", None)

if isinstance(layer, SlidingWindowLayer):
attributes["cumulative_length"] = layer.cumulative_length

return list(attributes.values()), list(attributes.keys())

def _unflatten_layer(values, context):
attributes = dict(zip(context, values))

if cache_layer_cls == StaticLayer:
static_layer = cache_layer_cls(
max_cache_len=attributes["max_cache_len"],
)
static_layer.keys = attributes["keys"]
static_layer.values = attributes["values"]
static_layer.max_batch_size = attributes["max_batch_size"]
return static_layer

elif cache_layer_cls in (SlidingWindowLayer, ChunkedSlidingLayer):
sliding_window_layer = cache_layer_cls(
max_cache_len=attributes["max_cache_len"],
cumulative_length=attributes["cumulative_length"],
)
sliding_window_layer.keys = attributes["keys"]
sliding_window_layer.values = attributes["values"]
static_layer.max_batch_size = attributes["max_batch_size"]
return sliding_window_layer

elif cache_layer_cls in (DynamicLayer, DynamicSlidingWindowLayer):
dynamic_layer = cache_layer_cls()
dynamic_layer.keys = attributes["keys"]
dynamic_layer.values = attributes["values"]
return dynamic_layer

def _flatten_layer_with_keys(static_layer):
values, context = _flatten_layer(static_layer)
return [(pytree.MappingKey(k), v) for k, v in zip(context, values)], context

pytree.register_pytree_node(
cache_layer_cls,
_flatten_layer,
_unflatten_layer,
serialized_type_name=f"{cache_layer_cls.__module__}.{cache_layer_cls.__name__}",
flatten_with_keys_fn=_flatten_layer_with_keys,
)


def _register_pytree_cache(cache_cls):
def _flatten_cache(cache):
attributes = {
"layers": cache.layers,
"offloading": cache.offloading,
"only_non_sliding": getattr(cache, "only_non_sliding", None),
}
return list(attributes.values()), list(attributes.keys())

def _flatten_cache_with_keys(cache):
values, context = _flatten_cache(cache)
return [(pytree.MappingKey(k), v) for k, v in zip(context, values)], context

def _unflatten_cache(values, context):
attributes = dict(zip(context, values))

cache = Cache(
layers=attributes["layers"],
offloading=attributes["offloading"],
offload_only_non_sliding=attributes["only_non_sliding"],
)
# Catching this in case there are multiple runs for some test runs
except ValueError as e:
if "already registered as pytree node" not in str(e):
raise


def _get_cache_dict(cache: DynamicCache):
"""Convert cache to dictionary format for pytree operations."""
if any(not isinstance(layer, (DynamicLayer, DynamicSlidingWindowLayer)) for layer in cache.layers):
raise RuntimeError("This pytree flattening function should only be applied to DynamicCache")

if not is_torch_greater_or_equal_than_2_6:
logging.warning("DynamicCache + torch.export is tested on torch 2.6.0+ and may not work on earlier versions.")

return {
"key_cache": [layer.keys for layer in cache.layers if layer.keys is not None],
"value_cache": [layer.values for layer in cache.layers if layer.values is not None],
}


def _unflatten_dynamic_cache(values, context: torch.utils._pytree.Context):
dictionary = torch.utils._pytree._dict_unflatten(values, context)
cache = DynamicCache()
# Reconstruct layers from keys and values lists
key_list = dictionary.get("key_cache", [])
value_list = dictionary.get("value_cache", [])
for idx in range(max(len(key_list), len(value_list))):
key = key_list[idx] if idx < len(key_list) else None
value = value_list[idx] if idx < len(value_list) else None
cache.update(key, value, idx)
return cache
cache.__class__ = cache_cls
cache.is_initialized = True
return cache

torch.utils._pytree.register_pytree_node(
cache_cls,
_flatten_cache,
_unflatten_cache,
serialized_type_name=f"{cache_cls.__module__}.{cache_cls.__name__}",
flatten_with_keys_fn=_flatten_cache_with_keys,
)


def register_pytree_cache():
_register_pytree_cache_layer(StaticLayer)
_register_pytree_cache_layer(SlidingWindowLayer)
_register_pytree_cache_layer(ChunkedSlidingLayer)
_register_pytree_cache_layer(DynamicLayer)
_register_pytree_cache_layer(DynamicSlidingWindowLayer)

_register_pytree_cache(StaticCache)
_register_pytree_cache(DynamicCache)


def sdpa_mask_without_vmap(
Expand Down
19 changes: 18 additions & 1 deletion tests/utils/test_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

if is_torch_available():
import torch
import torch.utils._pytree as pytree

from transformers import (
AutoModelForCausalLM,
Expand All @@ -56,7 +57,8 @@
convert_and_export_with_cache,
pipeline,
)
from transformers.integrations.executorch import export_with_dynamic_cache
from transformers.configuration_utils import PretrainedConfig
from transformers.integrations.executorch import export_with_dynamic_cache, register_pytree_cache


TEST_CACHE_IMPLEMENTATIONS = [
Expand Down Expand Up @@ -540,6 +542,21 @@ def test_cache_gptj_model(self, cache_implementation):
class CacheExportIntegrationTest(unittest.TestCase):
"""Cache tests that rely on `torch.export()` and model loading"""

@pytest.mark.torch_export_test
def test_static_cache_pytree(self):
cache = StaticCache(config=PretrainedConfig(num_hidden_layers=2), max_cache_len=1000)
cache.update(torch.ones(1, 1, 1, 1), torch.ones(1, 1, 1, 1), 0)

tree_spec = pytree.tree_flatten(cache)[1]
self.assertEqual(type(tree_spec), pytree.LeafSpec)

register_pytree_cache()

flattened, spec = pytree.tree_flatten(cache)
new_cache = pytree.tree_unflatten(flattened, spec)

torch.allclose(new_cache.layers[0].values, cache.layers[0].values)

@pytest.mark.torch_export_test
def test_dynamic_cache_exportability(self):
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
Expand Down