Skip to content

Add GLM-10B-Chinese support #868

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ vLLM seamlessly supports many Huggingface models, including the following archit
- Baichuan (`baichuan-inc/Baichuan-7B`, `baichuan-inc/Baichuan-13B-Chat`, etc.)
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
- GLM (`THUDM/glm-10b-chinese`)
- GPT-2 (`gpt2`, `gpt2-xl`, etc.)
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.)
Expand Down
3 changes: 3 additions & 0 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ Alongside each architecture, we include some popular models that use it.
* - :code:`FalconForCausalLM`
- Falcon
- :code:`tiiuae/falcon-7b``, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc.
* - :code:`GLMModel`
- GLM
- :code:`THUDM/glm-10b-chinese`
* - :code:`GPT2LMHeadModel`
- GPT-2
- :code:`gpt2`, :code:`gpt2-xl`, etc.
Expand Down
3 changes: 3 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def __init__(
self.seed = seed

self.hf_config = get_config(model, trust_remote_code)
self.skip_special_tokens = self.hf_config.skip_special_tokens \
if hasattr(
self.hf_config, "skip_special_tokens") else True
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self._verify_tokenizer_mode()

Expand Down
18 changes: 11 additions & 7 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,13 @@ async def engine_step(self, kicking_request_id: Optional[str] = None):
self.request_events[request_id].set()

async def generate(
self,
prompt: Optional[str],
sampling_params: SamplingParams,
request_id: str,
prompt_token_ids: Optional[List[int]] = None) -> RequestOutput:
self,
prompt: Optional[str],
sampling_params: SamplingParams,
request_id: str,
prompt_token_ids: Optional[List[int]] = None,
**kwargs,
) -> RequestOutput:
"""Generate outputs for a request.

Generate outputs for a request. This method is a coroutine. It adds the
Expand Down Expand Up @@ -126,13 +128,15 @@ async def generate(
prompt,
sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time)
arrival_time=arrival_time,
**kwargs)
else:
self.engine.add_request(request_id,
prompt,
sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time)
arrival_time=arrival_time,
**kwargs)

# The vLLM engine does not have a background loop that keeps
# processing incoming requests. Therefore, we need to keep kicking
Expand Down
11 changes: 8 additions & 3 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def add_request(
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
**kwargs,
) -> None:
"""Add a request to the engine's request pool.

Expand Down Expand Up @@ -258,7 +259,8 @@ def add_request(
seqs: List[Sequence] = []
for _ in range(sampling_params.best_of):
seq_id = next(self.seq_counter)
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
**kwargs)
seqs.append(seq)

# Create the sequence group.
Expand Down Expand Up @@ -410,7 +412,7 @@ def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
self.tokenizer,
seq.output_tokens,
seq.get_last_token_id(),
skip_special_tokens=True,
skip_special_tokens=self.model_config.skip_special_tokens,
)
if new_token is not None:
seq.output_tokens.append(new_token)
Expand Down Expand Up @@ -447,7 +449,10 @@ def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
continue
# Check if the sequence has generated the EOS token.
if not sampling_params.ignore_eos:
if seq.get_last_token_id() == self.tokenizer.eos_token_id:
if seq.get_last_token_id(
) == self.tokenizer.eop_token_id if hasattr(
self.tokenizer,
"eop_token_id") else self.tokenizer.eos_token_id:
self.scheduler.free_seq(
seq, SequenceStatus.FINISHED_STOPPED)
continue
Expand Down
24 changes: 16 additions & 8 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def generate(
sampling_params: Optional[SamplingParams] = None,
prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True,
**kwargs,
) -> List[RequestOutput]:
"""Generates the completions for the input prompts.

Expand Down Expand Up @@ -120,24 +121,31 @@ def generate(
num_requests = len(prompts)
else:
num_requests = len(prompt_token_ids)
position_ids = kwargs.pop("position_ids", None)
attention_mask = kwargs.pop("attention_mask", None)
for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None
pos_ids = position_ids[i] if position_ids is not None else None
attn_mask = attention_mask[
i] if attention_mask is not None else None
if prompt_token_ids is None:
token_ids = None
else:
token_ids = prompt_token_ids[i]
self._add_request(prompt, sampling_params, token_ids)
self._add_request(prompt,
sampling_params,
token_ids,
position_ids=pos_ids,
attention_mask=attn_mask,
**kwargs)
return self._run_engine(use_tqdm)

def _add_request(
self,
prompt: Optional[str],
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]],
) -> None:
def _add_request(self, prompt: Optional[str],
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]], **kwargs) -> None:
request_id = str(next(self.request_counter))
self.llm_engine.add_request(request_id, prompt, sampling_params,
prompt_token_ids)
prompt_token_ids, **kwargs)

def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
# Initialize tqdm.
Expand Down
6 changes: 4 additions & 2 deletions vllm/model_executor/input_metadata.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, Union

import torch
from xformers.ops import AttentionBias
Expand Down Expand Up @@ -29,6 +29,7 @@ def __init__(
context_lens: torch.Tensor,
max_context_len: int,
block_tables: torch.Tensor,
**kwargs,
) -> None:
self.seq_groups = seq_groups
self.seq_data = seq_data
Expand All @@ -50,7 +51,8 @@ def __init__(
assert context_lens.shape[0] == self.num_generation_tokens

# Set during the execution of the first attention op.
self.attn_bias: List[AttentionBias] = []
self.attn_bias: List[Union[AttentionBias, torch.Tensor]] = []
self.custom_attention_masks = kwargs.get('custom_attention_masks', [])

def __repr__(self) -> str:
# Print only useful metadata.
Expand Down
19 changes: 16 additions & 3 deletions vllm/model_executor/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
import torch.nn as nn
import torch.nn.functional as F
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
LowerTriangularMaskWithTensorBias)
Expand Down Expand Up @@ -74,12 +75,24 @@ def __init__(self,
raise ValueError(f"head_size ({self.head_size}) is not supported. "
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")

def set_attn_bias(self, input_metadata: InputMetadata) -> None:
def set_attn_bias(self,
input_metadata: InputMetadata,
dtype=torch.float32) -> None:
if input_metadata.attn_bias:
# Already set by a previous layer.
return
prompt_lens = input_metadata.prompt_lens
attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
if input_metadata.custom_attention_masks:
attn_mask = torch.block_diag(
*input_metadata.custom_attention_masks)
pad_len = -attn_mask.stride(-2) % 8
attn_mask = F.pad(attn_mask, (0, pad_len, 0, pad_len))
attn_mask = attn_mask.repeat(1, self.head_size, 1, 1)
attn_bias = torch.finfo(dtype).min * (1.0 - attn_mask).cuda()
if pad_len != 0:
attn_bias = attn_bias[..., :-pad_len, :-pad_len]
else:
attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
input_metadata.attn_bias.append(attn_bias)

def multi_query_kv_attention(
Expand Down Expand Up @@ -198,7 +211,7 @@ def forward(
if num_prompt_tokens > 0:
# Prompt run.
assert input_metadata.num_generation_tokens == 0
self.set_attn_bias(input_metadata)
self.set_attn_bias(input_metadata, dtype=query.dtype)
self.multi_query_kv_attention(
output[:num_prompt_tokens],
query[:num_prompt_tokens],
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch
import torch.nn as nn
from transformers import PretrainedConfig

from vllm.config import ModelConfig
from vllm.model_executor.models import * # pylint: disable=wildcard-import
from vllm.model_executor.weight_utils import initialize_dummy_weights
Expand All @@ -27,6 +26,7 @@
"OPTForCausalLM": OPTForCausalLM,
"QWenLMHeadModel": QWenLMHeadModel,
"RWForCausalLM": FalconForCausalLM,
"GLMModel": GLMForCausalLM,
}


Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
BaichuanForCausalLM)
from vllm.model_executor.models.bloom import BloomForCausalLM
from vllm.model_executor.models.falcon import FalconForCausalLM
from vllm.model_executor.models.glm import GLMForCausalLM
from vllm.model_executor.models.gpt2 import GPT2LMHeadModel
from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM
from vllm.model_executor.models.gpt_j import GPTJForCausalLM
Expand All @@ -19,6 +20,7 @@
"BaichuanForCausalLM",
"BloomForCausalLM",
"FalconForCausalLM",
"GLMForCausalLM",
"GPT2LMHeadModel",
"GPTBigCodeForCausalLM",
"GPTJForCausalLM",
Expand Down
Loading