Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix p2p pushing in rpc_inference, support transformers 4.38.2 #563

Merged
merged 4 commits into from
Mar 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ install_requires =
accelerate>=0.27.2
huggingface-hub>=0.11.1,<1.0.0
tokenizers>=0.13.3
transformers==4.37.1 # if you change this, please also change version assert in petals/__init__.py
transformers==4.38.2 # if you change this, please also change version assert in petals/__init__.py
speedtest-cli==2.1.3
pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind yet
hivemind==1.1.10.post2
Expand Down
4 changes: 2 additions & 2 deletions src/petals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@

if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
assert (
version.parse("4.37.1") <= version.parse(transformers.__version__) < version.parse("4.38.0")
), "Please install a proper transformers version: pip install transformers>=4.37.1,<4.38.0"
version.parse("4.38.2") <= version.parse(transformers.__version__) < version.parse("4.39.0")
), "Please install a proper transformers version: pip install transformers>=4.37.1,<4.39.0"


def _override_bfloat16_mode_default():
Expand Down
17 changes: 13 additions & 4 deletions src/petals/models/llama/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,15 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
assert not output_attentions
assert position_ids is None
if position_ids is None:
past_seen_tokens = past_key_value[0].shape[2] if past_key_value is not None else 0
position_ids = torch.arange(
past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device
).unsqueeze(0)

bsz, q_len, _ = hidden_states.size()

if self.config.pretraining_tp > 1:
Expand Down Expand Up @@ -84,9 +90,8 @@ def forward(
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos = cos[kv_seq_len - q_len :]
sin = sin[kv_seq_len - q_len :]
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
cos, sin = cos.unsqueeze(1), sin.unsqueeze(1)

if q_len == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
query_states, key_states = self._optimized_apply_rotary(query_states, key_states, cos, sin)
Expand Down Expand Up @@ -160,6 +165,8 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
Expand Down Expand Up @@ -190,6 +197,8 @@ def forward(
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)

hidden_states = residual + hidden_states
Expand Down
3 changes: 3 additions & 0 deletions src/petals/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> BaseModelOutputWithPast:
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
Expand All @@ -62,6 +63,8 @@ def forward(
assert (
attention_mask is None or (attention_mask == 1).all()
), f"Custom attention masks are not supported, {attention_mask=}"
if cache_position is not None:
assert position_ids is not None and torch.all(torch.eq(cache_position, position_ids)).item()
assert (
position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all()
), f"Non-consecutive position_ids are not supported, {position_ids=}"
Expand Down
4 changes: 2 additions & 2 deletions src/petals/server/block_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ async def iterate_rpc_inference(
points: int,
quant_type: QuantType,
args_structure: Any = None,
) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool]]:
) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool, Dict]]:
assert len(cache_handles) == len(requested_backends)

prefix_length = 0
Expand Down Expand Up @@ -224,7 +224,7 @@ async def iterate_rpc_inference(
for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
]
can_push = not has_prompts
yield output_tensors, can_push
yield output_tensors, can_push, step_metadata

# prepare for next step
prefix_length += length_increment
4 changes: 2 additions & 2 deletions src/petals/server/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ async def rpc_inference(
requested_backends, batch_size=batch_size, max_length=max_length, timeout=alloc_timeout
) as cache_handles:
background_tasks = set()
async for output_tensors, can_push in iterate_rpc_inference(
async for output_tensors, can_push, step_metadata in iterate_rpc_inference(
requested_uids=requested_uids,
requested_backends=requested_backends,
active_adapter=self._get_active_adapter(metadata),
Expand All @@ -186,7 +186,7 @@ async def rpc_inference(
args_structure=args_structure,
):
if can_push:
task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata))
task = asyncio.create_task(self._push_outputs(request, output_tensors[0], step_metadata))
background_tasks.add(task) # Keep reference until it is done to save it from GC
task.add_done_callback(background_tasks.discard)
yield runtime_pb2.ExpertResponse(tensors=output_tensors)
Expand Down
Loading