Skip to content

Commit

Permalink
Fix hash_with_views error (#587)
Browse files Browse the repository at this point in the history
  • Loading branch information
bgoldberg-habana authored Dec 8, 2023
1 parent 67bd806 commit f6af3fe
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
5 changes: 3 additions & 2 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
model_on_meta,
write_checkpoints_json,
)
from optimum.habana.utils import check_habana_frameworks_min_version, check_optimum_habana_min_version, set_seed
from optimum.habana.utils import check_habana_frameworks_version, check_optimum_habana_min_version, set_seed


def override_print(enable):
Expand Down Expand Up @@ -132,7 +132,8 @@ def setup_model(args, model_dtype, model_kwargs, logger):
if args.use_hpu_graphs:
from habana_frameworks.torch.hpu import wrap_in_hpu_graph

if check_habana_frameworks_min_version("1.13.0"):
# TODO: remove the following check from SynapseAI v1.15
if check_habana_frameworks_version("1.13.0"):
if model.config.model_type == "falcon":
args.skip_hash_with_views = True
model = wrap_in_hpu_graph(model, hash_with_views=not args.skip_hash_with_views)
Expand Down
7 changes: 7 additions & 0 deletions optimum/habana/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,13 @@ def check_habana_frameworks_min_version(min_version):
return True


def check_habana_frameworks_version(req_version):
"""
Checks if the installed version of `habana_frameworks` is equal to `req_version`.
"""
return get_habana_frameworks_version() == version.parse(req_version)


def get_device_name():
"""
Returns the name of the current device: Gaudi or Gaudi2.
Expand Down

0 comments on commit f6af3fe

Please sign in to comment.