Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add device as parameter to TP and rotary_embedding functions #11888

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,8 +881,11 @@

_TP: Optional[GroupCoordinator] = None


Check failure on line 884 in vllm/distributed/parallel_state.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Incompatible return value type (got "None", expected "GroupCoordinator") [return-value]

Check failure on line 884 in vllm/distributed/parallel_state.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Incompatible return value type (got "None", expected "GroupCoordinator") [return-value]

Check failure on line 884 in vllm/distributed/parallel_state.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Incompatible return value type (got "None", expected "GroupCoordinator") [return-value]

Check failure on line 884 in vllm/distributed/parallel_state.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Incompatible return value type (got "None", expected "GroupCoordinator") [return-value]
def get_tp_group() -> GroupCoordinator:
def get_tp_group(device: Optional[str] = None) -> GroupCoordinator:
if device == "cpu":
return None

assert _TP is not None, ("tensor model parallel group is not initialized")
return _TP

Expand Down Expand Up @@ -1140,14 +1143,14 @@
_TP = old_tp_group


def get_tensor_model_parallel_world_size():
def get_tensor_model_parallel_world_size(device: Optional[str] = None):
"""Return world size for the tensor model parallel group."""
return get_tp_group().world_size
return get_tp_group().world_size if device != "cpu" else 1


def get_tensor_model_parallel_rank():
def get_tensor_model_parallel_rank(device: Optional[str] = None):
"""Return my rank for the tensor model parallel group."""
return get_tp_group().rank_in_group
return get_tp_group().rank_in_group if device != "cpu" else 0


def destroy_model_parallel():
Expand Down
43 changes: 23 additions & 20 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(
base: int,
is_neox_style: bool,
dtype: torch.dtype,
device: Optional[str] = None,
) -> None:
super().__init__()
self.head_size = head_size
Expand All @@ -91,6 +92,7 @@ def __init__(
self.base = base
self.is_neox_style = is_neox_style
self.dtype = dtype
self.device = device

cache = self._compute_cos_sin_cache()
cache = cache.to(dtype)
Expand Down Expand Up @@ -611,23 +613,22 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
Credits to Peng et al. github.com/jquesnelle/yarn
"""

def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
mscale: float = 1,
mscale_all_dim: float = 0,
) -> None:
def __init__(self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
mscale: float = 1,
mscale_all_dim: float = 0,
device: Optional[str] = None) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
Expand All @@ -639,11 +640,11 @@ def __init__(
yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
attn_factor)
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
is_neox_style, dtype, device)

def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
0, self.rotary_dim, 2, dtype=torch.float, device=self.device) /
self.rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
Expand All @@ -662,7 +663,7 @@ def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
device="cuda",
device=self.device,
dtype=torch.float32)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = (freqs.cos() * self.mscale)
Expand Down Expand Up @@ -943,6 +944,7 @@ def get_rope(
rope_scaling: Optional[Dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None,
partial_rotary_factor: float = 1.0,
device: Optional[str] = None,
) -> RotaryEmbedding:
if dtype is None:
dtype = torch.get_default_dtype()
Expand Down Expand Up @@ -1037,6 +1039,7 @@ def get_rope(
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
"beta_slow", "mscale", "mscale_all_dim")
}
extra_kwargs["device"] = device
rotary_emb = DeepseekScalingRotaryEmbedding(
head_size, rotary_dim, original_max_position, base,
is_neox_style, scaling_factor, dtype, **extra_kwargs)
Expand Down
Loading