Skip to content

Commit

Permalink
Incorporate review suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
ssarkar2 committed Dec 10, 2023
1 parent 97df69b commit ffc19a2
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 10 deletions.
12 changes: 5 additions & 7 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,10 @@ def setup_parser(parser):
parser.add_argument("--verbose_workers", action="store_true", help="Enable output from non-master workers")
parser.add_argument(
"--simulate_dyn_prompt",
default="",
type=str,
help="If empty static prompt is used. If a comma separated list of integers are passed, we warmup and use those shapes for prompt length",
default=None,
type=int,
nargs="*",
help="If empty, static prompt is used. If a comma separated list of integers is passed, we warmup and use those shapes for prompt length.",
)
parser.add_argument(
"--reduce_recompile",
Expand Down Expand Up @@ -303,10 +304,7 @@ def generate(size=None, reduce_recompile=False):
HabanaProfile.disable()
# Compilation
logger.info("Graph compilation...")
if len(args.simulate_dyn_prompt) > 0:
dyn_prompt_lens = [int(k) for k in args.simulate_dyn_prompt.split(",")]
else:
dyn_prompt_lens = None
dyn_prompt_lens = args.simulate_dyn_prompt
t0 = time.perf_counter()
# The first three iterations take longer because of graph compilation
if dyn_prompt_lens is None or len(set(dyn_prompt_lens)) == 1:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,5 @@ def __init__(self, **kwargs):
self.limit_hpu_graphs = kwargs.get("limit_hpu_graphs", None)
self.reuse_cache = kwargs.get("reuse_cache", None)
self.bucket_size = kwargs.get("bucket_size", -1)
self.reduce_recompile = kwargs.get("reduce_recompile", False)
self.reduce_recompile = kwargs.get("reduce_recompile", None)
self.kv_cache_fp8 = kwargs.get("kv_cache_fp8", None)
4 changes: 2 additions & 2 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,8 +582,8 @@ def generate(
or self._get_generation_mode(generation_config, assistant_model) == GenerationMode.BEAM_SEARCH
)
model_kwargs["bucket_size"] = generation_config.bucket_size if generation_config.static_shapes else -1
model_kwargs["reduce_recompile"] = generation_config.reduce_recompile
if generation_config.reduce_recompile:
model_kwargs["reduce_recompile"] = (generation_config.reduce_recompile, False)[generation_config.reduce_recompile is None]
if model_kwargs["reduce_recompile"]:
assert generation_config.bucket_size
if generation_config.reuse_cache:
assert generation_config.bucket_size <= 0, "reuse_cache and bucketing flags set together"
Expand Down

0 comments on commit ffc19a2

Please sign in to comment.