Skip to content

[*] Typo fixes #273

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/mlperf/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def flush_queries(self):
self.accuracy_log.write(json.dumps(pred_outputs))
self.accuracy_log.flush()
self.accuracy_log.close()
log.info("Dumpped prediction outputs to accuracy log... ")
log.info("Dumped prediction outputs to accuracy log... ")

def __del__(self):
print("Finished destroying SUT.")
2 changes: 1 addition & 1 deletion benchmarks/tests/test_benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ async def test_benchmark(self):
disable_tqdm = True

async def mocked_decode_response():
"""Mocks decode reponse as an async generator."""
"""Mocks decode response as an async generator."""
responses = [
jetstream_pb2.DecodeResponse(
stream_content=jetstream_pb2.DecodeResponse.StreamContent(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Observability in JetStream Server

In JetStream Server, we use [Prometheus](https://prometheus.io/docs/introduction/overview/) to collect key metrics within JetStream orchestrator and engines. We implemented a [Prometheus client server](https://prometheus.github.io/client_python/exporting/http/) in JetStream `server_lib.py` and use `MetricsServerConfig` (by passing `prometheus_port` in server entrypoint) to gaurd the metrics observability feature.
In JetStream Server, we use [Prometheus](https://prometheus.io/docs/introduction/overview/) to collect key metrics within JetStream orchestrator and engines. We implemented a [Prometheus client server](https://prometheus.github.io/client_python/exporting/http/) in JetStream `server_lib.py` and use `MetricsServerConfig` (by passing `prometheus_port` in server entrypoint) to guard the metrics observability feature.

## Enable Prometheus server to observe Jetstream metrics

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@ def benchmark():
num_input_tokens = sum(map(lambda r: len(r.input_tokens), res_list))
num_output_tokens = sum(map(lambda r: len(r.generated_tokens), res_list))

print("Benchmarking result: ")
print(" Total requests:", len(dataset))
print(" Total input tokens:", num_input_tokens)
print(" Total output tokens:", num_output_tokens)
print(f" Input token thruput: {num_input_tokens/duration: .2f} tokens/sec")
print(f" Output token thruput: {num_output_tokens/duration: .2f} tokens/sec")
print("Benchmarking result:")
print(" Total requests: ", len(dataset))
print(" Total input tokens: ", num_input_tokens)
print(" Total output tokens: ", num_output_tokens)
print(f" Input token throughput: {num_input_tokens/duration: .2f} tokens/sec")
print(f" Output token throughput: {num_output_tokens/duration: .2f} tokens/sec")


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion experimental/jax/inference/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

https://www.apache.org/licenses/LICENSE-2.0

Unless reuired by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
Expand Down
2 changes: 1 addition & 1 deletion experimental/jax/inference/parallel/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


def reduce_scatter(operand, scatter_dimension, axis_names):
"""reduce-scatter sum operation via ppermute."""
"""reduce-scatter sum operation via permute."""
idx = get_partition_index(axis_names=axis_names)
num_partitions = get_num_partitions(axis_names=axis_names)
chunk_size = operand.shape[scatter_dimension] // num_partitions
Expand Down
2 changes: 1 addition & 1 deletion experimental/jax/inference/parallel/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ def pspec(a):
elif isinstance(a, int) or isinstance(a, float):
return P()
else:
raise ValueError(f"unknown parition spec for {a}")
raise ValueError(f"unknown partition spec for {a}")

return jax.tree_util.tree_map(pspec, sharded_pytree)
18 changes: 9 additions & 9 deletions experimental/jax/inference/runtime/batch_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,20 +110,20 @@ def schedule(
cur_prompt_chunk_len = (
total_len - next_prefill_req.chunk_idx * next_prefill_req.chunk_size
)
alloced_pages = self.kv_manager.alloc_prefill_hbm_pages(
allocated_pages = self.kv_manager.alloc_prefill_hbm_pages(
cur_prompt_chunk_len
)
if len(alloced_pages) == 0:
if len(allocated_pages) == 0:
# TODO: introduce priority for the request and better
# eviction algorithm.
raise NotImplementedError("Eviction is not supported yet")
else:
start_idx = (
next_prefill_req.chunk_idx * next_prefill_req.chunk_size
) // self.kv_manager.page_size
for i, page in enumerate(alloced_pages):
for i, page in enumerate(allocated_pages):
next_prefill_req.page_indices[start_idx + i] = page
prefill_pages_update = PrefillPagesUpdate(alloced_pages)
prefill_pages_update = PrefillPagesUpdate(allocated_pages)

# Schedule new generate reqs and allocate memory for all reqs.
with generate_state.map_mutex:
Expand All @@ -150,12 +150,12 @@ def schedule(
next_generate_reqs.append(gr)

# Check and alloc memory for generate.
alloced_pages = self.kv_manager.alloc_hbm_pages(
allocated_pages = self.kv_manager.alloc_hbm_pages(
len(generate_state.active_slot_req_map)
)
if (
len(generate_state.active_slot_req_map) != 0
and len(alloced_pages) == 0
and len(allocated_pages) == 0
):
raise NotImplementedError(
"Eviction isn't supported yet, please set a lower value for batch_size"
Expand All @@ -169,17 +169,17 @@ def schedule(
if idx >= len(req.page_indices):
continue

req.page_indices[idx] = alloced_pages[page_to_use]
req.page_indices[idx] = allocated_pages[page_to_use]
generate_state_page_updates.append(
GenerateStatePageUpdate(
slot=slot,
page_idx=idx,
mapped_idx=alloced_pages[page_to_use],
mapped_idx=allocated_pages[page_to_use],
)
)
page_to_use += 1

self.kv_manager.free_hbm_pages(alloced_pages[page_to_use:])
self.kv_manager.free_hbm_pages(allocated_pages[page_to_use:])

if len(generate_state.active_slot_req_map) == 0:
schedule_generate = False
Expand Down
2 changes: 1 addition & 1 deletion experimental/jax/inference/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def __init__(
)
print(" preprocess,", end="")
self._preprocess_queue: queue.Queue[Request] = queue.Queue()
# TODO: Seperate the running loop with the static inference model.
# TODO: Separate the running loop with the static inference model.
self._preprocess_thread = threading.Thread(
name="preprocess", target=self._preprocess
)
Expand Down
2 changes: 1 addition & 1 deletion experimental/jetstream-maxtext-stable-stack/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@ docker build --no-cache \
-t ${LOCAL_IMAGE_TAG} \
-f ./Dockerfile .

echo "********* Sucessfully built Stable Stack Image with tag $LOCAL_IMAGE_TAG *********"
echo "********* Successfully built Stable Stack Image with tag $LOCAL_IMAGE_TAG *********"
2 changes: 1 addition & 1 deletion jetstream/core/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# JetStream core Subpackage - Server and Library that support continuous batching serving.

Interleaved mode: Provide continuous batching to optimize inference. Uses JAX directy on single-host TPU.
Interleaved mode: Provide continuous batching to optimize inference. Uses JAX directly on single-host TPU.
10 changes: 5 additions & 5 deletions jetstream/core/lora/adapter_tensorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,9 @@ def _initialize_decoding_adapters_cache(self, adapter_weights):
"""
Create a new PyTree with zero tensors at the paths corresponding to
non-None leaves in the input PyTree. The zero tensors have an added
dimension of size `self.totol_slots`.
dimension of size `self.total_slots`.
Args:
adatper_weights: The input PyTree, whose structure will be mirrored.
adapter_weights: The input PyTree, whose structure will be mirrored.
Returns:
A new PyTree with zero Tensors or None values, mirroring the structure
of the input PyTree.
Expand Down Expand Up @@ -437,7 +437,7 @@ async def load_adapter(

# --- Handle LOADING state ---
if metadata.status == AdapterStatus.LOADING:
# Wait untill loading is done.
# Wait until loading is done.
logging.info(
"Adapter %s is already loading by another task, waiting...",
adapter_id,
Expand Down Expand Up @@ -655,7 +655,7 @@ async def get_lora_weights(
async def unload_adapter(self, adapter_id: str):
"""Unloads a LoRA adapter's weights and removes it from the TensorStore."""
if adapter_id not in self.adapter_registry:
raise ValueError(f"Adatper with ID '{adapter_id}' not found.")
raise ValueError(f"Adapter with ID '{adapter_id}' not found.")

event_to_wait_on: Optional[asyncio.Event] = None
async with self.lock:
Expand All @@ -677,7 +677,7 @@ async def unload_adapter(self, adapter_id: str):
self._unsafe_unload_adapter(adapter_id)

def list_adapters(self) -> Dict[str, AdapterMetadata]:
"""Lists all registered adatpers and their metadata."""
"""Lists all registered adapters and their metadata."""
return self.adapter_registry

def _evict(self, from_hbm: bool = True) -> bool:
Expand Down
8 changes: 4 additions & 4 deletions jetstream/core/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
on queues that don't have an ongoing activity (i.e. everything but the
generation queue) because we don't control to go back to those queues until
necessary. Blocking means that the GIL doesn't switch back to that thread,
wheras continual queue get operations 'chop' control and mean that we do not
whereas continual queue get operations 'chop' control and mean that we do not
achieve good throughput. This is okay on the prefill/transfer/detokenization
threads because we don't need to do anything other than react to the presence
of items on these queues, wheras the generation thread needs to also run a
Expand Down Expand Up @@ -811,9 +811,9 @@ def _prefill_thread(self, idx: int):

# Here we are applying the LoRA adapter params to the base params and
# them. In the interleaved mode, the prefill and generate shares the
# same params. But as long as prefill and decode happens sequentially,
# there is no issues. Issue will arrise if prefill and decode is running
# in parallel and sharing the same params. Issue arrise because prefill
# same params. But as long as prefill and decode happen sequentially,
# there are no issues. Issue will arise if prefill and decode are running
# in parallel and sharing the same params. Issues arise because prefill
# uses pre-merged weights and generate uses only base weights.
final_prefill_params = prefill_params
if adapter_id and adapter_tensorstore is not None:
Expand Down
4 changes: 2 additions & 2 deletions jetstream/core/server_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def create_driver(
config: A ServerConfig to config engine, model, device slices, etc.
devices: Device objects, will be used to get engine with proper slicing.
jax_padding: The flag to enable JAX padding during tokenization.
metrics_collector: The JetStream Promethus metric collector.
metrics_collector: The JetStream Prometheus metric collector.
enable_model_warmup: The flag to enable model server warmup.
multi_sampling: The flag to enable multi-sampling.
prefix_caching_config: Config to prefix caching. Disable if None.
Expand Down Expand Up @@ -291,7 +291,7 @@ def run(
threads: Number of RPC handlers worker threads. This should be at least
equal to the decoding batch size to fully saturate the decoding queue.
jax_padding: The flag to enable JAX padding during tokenization.
metrics_server_config: The config to enable Promethus metric server.
metrics_server_config: The config to enable Prometheus metric server.
enable_jax_profiler: The flag to enable JAX profiler server.
jax_profiler_port: The port JAX profiler server (default to 9999).
enable_model_warmup: The flag to enable model server warmup.
Expand Down
8 changes: 4 additions & 4 deletions jetstream/engine/engine_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,17 @@


# The model parameters - their partitioning will be unique for different prefill
# and decode topoologies.
# and decode topologies.
Params = Any
# The result of a prefill operation, often a batch size 1 KVCache.
Prefix = Any
# The inputs into a generation step, often a prefill and generate cache tuple.
DecodeState = Any
# Accelerator representation of tokens.
DeviceTokens = Any
# Cpus asscociated with the mesh.
# Cpus associated with the mesh.
CpuDevices = Any
# Tokenkizer used by the engine
# Tokenizer used by the engine
Tokenizer = Any
# PRNG key used for prefilling
PRNGKeyType = Any
Expand Down Expand Up @@ -264,7 +264,7 @@ def free_resource(
) -> Any:
"""Free cache and other decode resource for the slot.

This function is needed for advanced attetnion kenel like PageAttetion.
This function is needed for advanced attention kernel like PageAttention.
After finishing one request, the engine need to free all used page block
resource and reuse for coming requests.
"""
Expand Down
4 changes: 2 additions & 2 deletions jetstream/engine/mock_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def generate(
# TODO: Do we need a left aligned one to test spec sampling?
# Don't need the + 1 you normally would, because we don't provide a
# token from prefill in the dummy.
# This iota and masking is to allow for a cicular cache.
# This iota and masking is to allow for a circular cache.
length_mask = (
-(l_iota - generate_cache_index) % self.cache_length
) <= generate_lengths[:, None]
Expand Down Expand Up @@ -540,7 +540,7 @@ def colocated_cpus(self) -> None:

@property
def use_chunked_prefill(self) -> bool:
"""Wether to use chunked prefill."""
"""Whether to use chunked prefill."""
return self._use_chunked_prefill

@property
Expand Down
2 changes: 1 addition & 1 deletion jetstream/engine/mock_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class TestVocab(Vocabulary):
tokenizer: TestTokenizer = TestTokenizer()

def _encode(self, s: str) -> Sequence[int]:
"""Converts a string into a integer sequenc."""
"""Converts a string into a integer sequence."""
# 'We use array methods, not python iterables so we don't
# implement this method in the mock vocab.
raise NotImplementedError
Expand Down
6 changes: 3 additions & 3 deletions jetstream/engine/token_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def encode(
return tokens, true_length

def decode(self, token_ids: list[int], **kwargs) -> str:
"""Processess input token ids to generate a string.
"""Processes input token ids to generate a string.
Args:
token_ids: List of token ids.
**kwargs: Additional keyword arguments.
Expand Down Expand Up @@ -483,7 +483,7 @@ def encode(
return tokens, true_length

def decode(self, token_ids: list[int]) -> str:
"""Processess input token ids to generate a string.
"""Processes input token ids to generate a string.
Args:
token_ids: List of token ids.
Returns:
Expand Down Expand Up @@ -566,7 +566,7 @@ def encode(
return tokens, true_length

def decode(self, token_ids: list[int]) -> str:
"""Processess input token ids to generate a string.
"""Processes input token ids to generate a string.
Args:
token_ids: List of token ids.
Returns:
Expand Down
2 changes: 1 addition & 1 deletion jetstream/engine/tokenizer_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def encode(

@abc.abstractmethod
def decode(self, token_ids: list[int], **kwargs) -> str:
"""Processess input token ids to generate a string.
"""Processes input token ids to generate a string.
Args:
token_ids: List of token ids.
**kwargs: Additional keyword arguments.
Expand Down
2 changes: 1 addition & 1 deletion jetstream/engine/warmup_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def initialize_insert_generate_jit_cache(
generate_params: Any,
generate_idx: int,
):
"""Initialiszes jit cache for insert and generate.
"""Initializes jit cache for insert and generate.

Args:
generate_engine: A generate engine to be compiled for.
Expand Down
2 changes: 1 addition & 1 deletion jetstream/external_tokenizers/llama3/llama3_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def encode(
By default, setting disallowed_special=() encodes a string by ignoring
special tokens. Specifically:
- Setting `disallowed_special` to () will cause all text corresponding
to special tokens to be encoded as natural text (insteading of raising
to special tokens to be encoded as natural text (instead of raising
an error).
- Setting `allowed_special` to "all" will treat all text corresponding
to special tokens to be encoded as special tokens.
Expand Down
10 changes: 5 additions & 5 deletions jetstream/tests/engine/test_token_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


class SPTokenizer:
"""Tokenier used in original llama2 git"""
"""Tokenizer used in original llama2 git"""

def __init__(self, tokenizer_path: str):
self.tokenizer = SentencePieceProcessor()
Expand All @@ -40,7 +40,7 @@ def decode(self, t: List[int]) -> str:


class JetStreamTokenizer:
"""Tokenier used in JetStream before mix_token"""
"""Tokenizer used in JetStream before mix_token"""

def __init__(self, tokenizer_path: str):
metadata = tokenizer_pb2.TokenizerParameters(path=tokenizer_path)
Expand Down Expand Up @@ -91,13 +91,13 @@ def setup_hftoken(self):
def test_decode_vs_piece(self):
self.setup_sentencepiece()
tokens = [304, 13, 2266, 526, 777, 9590, 2020, 29901]
expeted_sp_output = []
expected_sp_output = []
jt_output = []
for t in tokens:
expeted_sp_output.append(self.sp_tokenizer.decode([t]))
expected_sp_output.append(self.sp_tokenizer.decode([t]))
jt_output.append(self.jt_tokenizer.decode(t))

self.assertNotEqual(jt_output, expeted_sp_output)
self.assertNotEqual(jt_output, expected_sp_output)

def test_sp_vs_seqio(self):
self.setup_sentencepiece()
Expand Down
2 changes: 1 addition & 1 deletion jetstream/tools/multi_lora_decode_requester.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def get_tokenizer(
model_id: str,
tokenizer_name: str,
) -> Any:
"""Return a tokenizer or a tokenizer placholder."""
"""Return a tokenizer or a tokenizer placeholder."""
if tokenizer_name == "test":
print("Using test tokenizer")
return "test"
Expand Down
Loading