Skip to content

Commit a149634

Browse files
authored
Changed from_preset file downloading to use GFile when able (#1665)
* Changed from_preset file downloading to use GFile when able * Formatted * Addressed Comments * Conditional TF import * More comment addressing * Renamed arg and changed import * Better handling of scheme detection
1 parent 5a734f9 commit a149634

File tree

1 file changed

+45
-9
lines changed

1 file changed

+45
-9
lines changed

keras_nlp/src/utils/preset_utils.py

+45-9
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,18 @@
2323
from packaging.version import parse
2424

2525
from keras_nlp.src.api_export import keras_nlp_export
26+
from keras_nlp.src.backend import config
2627
from keras_nlp.src.backend import config as backend_config
2728
from keras_nlp.src.backend import keras
29+
from keras_nlp.src.utils.keras_utils import print_msg
30+
31+
try:
32+
import tensorflow as tf
33+
except ImportError:
34+
raise ImportError(
35+
"To use `keras_nlp`, please install Tensorflow: `pip install tensorflow`. "
36+
"The TensorFlow package is required for data preprocessing with any backend."
37+
)
2838

2939
try:
3040
import kagglehub
@@ -43,6 +53,10 @@
4353
GS_PREFIX = "gs://"
4454
HF_PREFIX = "hf://"
4555

56+
KAGGLE_SCHEME = "kaggle"
57+
GS_SCHEME = "gs"
58+
HF_SCHEME = "hf"
59+
4660
TOKENIZER_ASSET_DIR = "assets/tokenizer"
4761

4862
# Config file names.
@@ -99,13 +113,18 @@ def get_file(preset, path):
99113
)
100114
if preset in BUILTIN_PRESETS:
101115
preset = BUILTIN_PRESETS[preset]["kaggle_handle"]
102-
if preset.startswith(KAGGLE_PREFIX):
116+
117+
scheme = None
118+
if "://" in preset:
119+
scheme = preset.split("://")[0].lower()
120+
121+
if scheme == KAGGLE_SCHEME:
103122
if kagglehub is None:
104123
raise ImportError(
105124
"`from_preset()` requires the `kagglehub` package. "
106125
"Please install with `pip install kagglehub`."
107126
)
108-
kaggle_handle = preset.removeprefix(KAGGLE_PREFIX)
127+
kaggle_handle = preset.removeprefix(KAGGLE_SCHEME + "://")
109128
num_segments = len(kaggle_handle.split("/"))
110129
if num_segments not in (4, 5):
111130
raise ValueError(
@@ -134,25 +153,23 @@ def get_file(preset, path):
134153
else:
135154
raise ValueError(message)
136155

137-
elif preset.startswith(GS_PREFIX):
156+
elif scheme in tf.io.gfile.get_registered_schemes():
138157
url = os.path.join(preset, path)
139-
url = url.replace(GS_PREFIX, "https://storage.googleapis.com/")
140-
subdir = preset.replace(GS_PREFIX, "gs_")
141-
subdir = subdir.replace("/", "_").replace("-", "_")
158+
subdir = preset.replace("://", "_").replace("-", "_")
142159
filename = os.path.basename(path)
143160
subdir = os.path.join(subdir, os.path.dirname(path))
144-
return keras.utils.get_file(
161+
return copy_gfile_to_cache(
145162
filename,
146163
url,
147164
cache_subdir=os.path.join("models", subdir),
148165
)
149-
elif preset.startswith(HF_PREFIX):
166+
elif scheme == HF_SCHEME:
150167
if huggingface_hub is None:
151168
raise ImportError(
152169
f"`from_preset()` requires the `huggingface_hub` package to load from '{preset}'. "
153170
"Please install with `pip install huggingface_hub`."
154171
)
155-
hf_handle = preset.removeprefix(HF_PREFIX)
172+
hf_handle = preset.removeprefix(HF_SCHEME + "://")
156173
try:
157174
return huggingface_hub.hf_hub_download(
158175
repo_id=hf_handle, filename=path
@@ -192,6 +209,25 @@ def get_file(preset, path):
192209
)
193210

194211

212+
def copy_gfile_to_cache(filename, url, cache_subdir):
213+
"""Much of this is adapted from get_file of keras core."""
214+
if cache_subdir is None:
215+
cache_dir = config.keras_home()
216+
217+
datadir_base = os.path.expanduser(cache_dir)
218+
if not os.access(datadir_base, os.W_OK):
219+
datadir_base = os.path.join("/tmp", ".keras")
220+
datadir = os.path.join(datadir_base, cache_subdir)
221+
os.makedirs(datadir, exist_ok=True)
222+
223+
fpath = os.path.join(datadir, filename)
224+
if not os.path.exists(fpath):
225+
print_msg(f"Downloading data from {url}")
226+
tf.io.gfile.copy(url, fpath)
227+
228+
return fpath
229+
230+
195231
def check_file_exists(preset, path):
196232
try:
197233
get_file(preset, path)

0 commit comments

Comments
 (0)