@@ -36,6 +36,83 @@ class QEffDynamicCache(DynamicCache):
36
36
37
37
"""
38
38
39
+ def write_only (self , key_states , value_states , layer_idx , cache_kwargs ):
40
+ """
41
+ Write in the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
42
+
43
+ Parameters:
44
+ key_states (`torch.Tensor`):
45
+ The new key states to cache.
46
+ value_states (`torch.Tensor`):
47
+ The new value states to cache.
48
+ layer_idx (`int`):
49
+ The index of the layer to cache the states for.
50
+ cache_kwargs (`Dict[str, Any]`, `optional`):
51
+ Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
52
+ """
53
+ # Update the cache
54
+ if len (self .key_cache ) <= layer_idx :
55
+ self .key_cache .append (key_states )
56
+ self .value_cache .append (value_states )
57
+ else :
58
+ position_ids = cache_kwargs .get ("position_ids" )
59
+ batch_index = cache_kwargs .get ("batch_index" , None )
60
+
61
+ # Scatter
62
+ if batch_index is not None :
63
+ invalid_scatter_index = torch .iinfo (torch .int32 ).max
64
+ scatter_position_ids = torch .where (position_ids < 0 , invalid_scatter_index , position_ids )
65
+
66
+ self .key_cache [layer_idx ] = CtxScatterFuncCB .apply (
67
+ self .key_cache [layer_idx ], batch_index , scatter_position_ids , key_states
68
+ )
69
+ self .value_cache [layer_idx ] = CtxScatterFuncCB .apply (
70
+ self .value_cache [layer_idx ], batch_index , scatter_position_ids , value_states
71
+ )
72
+ else :
73
+ self .key_cache [layer_idx ] = CtxScatterFunc .apply (self .key_cache [layer_idx ], position_ids , key_states )
74
+ self .value_cache [layer_idx ] = CtxScatterFunc .apply (
75
+ self .value_cache [layer_idx ], position_ids , value_states
76
+ )
77
+
78
+ def read_only (self , layer_idx , cache_kwargs ):
79
+ """
80
+ Reads the `key_states` and `value_states` for the layer `layer_idx`.
81
+
82
+ Parameters:
83
+ layer_idx (`int`):
84
+ The index of the layer to cache the states for.
85
+ cache_kwargs (`Dict[str, Any]`, `optional`):
86
+ Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
87
+
88
+ Return:
89
+ A tuple containing the updated key and value states.
90
+ """
91
+ k_out , v_out = self .key_cache [layer_idx ], self .value_cache [layer_idx ]
92
+ position_ids = cache_kwargs .get ("position_ids" )
93
+ batch_index = cache_kwargs .get ("batch_index" , None )
94
+ ctx_len = k_out .shape [2 ]
95
+ ctx_indices = torch .arange (ctx_len )[None , None , ...]
96
+ gather_limit = position_ids .max (1 , keepdim = True ).values .unsqueeze (1 )
97
+ invalid_mask = ctx_indices > gather_limit
98
+
99
+ if torch .onnx .is_in_onnx_export ():
100
+ invalid_idx_value = torch .iinfo (torch .int32 ).max
101
+ else :
102
+ invalid_idx_value = 0
103
+
104
+ ctx_indices = torch .where (invalid_mask , invalid_idx_value , ctx_indices )
105
+
106
+ if batch_index is not None :
107
+ k_out = CtxGatherFuncCB .apply (k_out , batch_index , ctx_indices )
108
+ v_out = CtxGatherFuncCB .apply (v_out , batch_index , ctx_indices )
109
+ else :
110
+ k_out = CtxGatherFunc .apply (k_out , ctx_indices )
111
+ v_out = CtxGatherFunc .apply (v_out , ctx_indices )
112
+
113
+ v_out = torch .where (invalid_mask .unsqueeze (- 1 ), torch .tensor (0.0 , dtype = torch .float32 ), v_out )
114
+ return k_out , v_out
115
+
39
116
def update (
40
117
self ,
41
118
key_states : torch .Tensor ,
0 commit comments