|
23 | 23 | from packaging.version import parse
|
24 | 24 |
|
25 | 25 | from keras_nlp.src.api_export import keras_nlp_export
|
| 26 | +from keras_nlp.src.backend import config |
26 | 27 | from keras_nlp.src.backend import config as backend_config
|
27 | 28 | from keras_nlp.src.backend import keras
|
| 29 | +from keras_nlp.src.utils.keras_utils import print_msg |
| 30 | + |
| 31 | +try: |
| 32 | + import tensorflow as tf |
| 33 | +except ImportError: |
| 34 | + raise ImportError( |
| 35 | + "To use `keras_nlp`, please install Tensorflow: `pip install tensorflow`. " |
| 36 | + "The TensorFlow package is required for data preprocessing with any backend." |
| 37 | + ) |
28 | 38 |
|
29 | 39 | try:
|
30 | 40 | import kagglehub
|
|
43 | 53 | GS_PREFIX = "gs://"
|
44 | 54 | HF_PREFIX = "hf://"
|
45 | 55 |
|
| 56 | +KAGGLE_SCHEME = "kaggle" |
| 57 | +GS_SCHEME = "gs" |
| 58 | +HF_SCHEME = "hf" |
| 59 | + |
46 | 60 | TOKENIZER_ASSET_DIR = "assets/tokenizer"
|
47 | 61 |
|
48 | 62 | # Config file names.
|
@@ -99,13 +113,18 @@ def get_file(preset, path):
|
99 | 113 | )
|
100 | 114 | if preset in BUILTIN_PRESETS:
|
101 | 115 | preset = BUILTIN_PRESETS[preset]["kaggle_handle"]
|
102 |
| - if preset.startswith(KAGGLE_PREFIX): |
| 116 | + |
| 117 | + scheme = None |
| 118 | + if "://" in preset: |
| 119 | + scheme = preset.split("://")[0].lower() |
| 120 | + |
| 121 | + if scheme == KAGGLE_SCHEME: |
103 | 122 | if kagglehub is None:
|
104 | 123 | raise ImportError(
|
105 | 124 | "`from_preset()` requires the `kagglehub` package. "
|
106 | 125 | "Please install with `pip install kagglehub`."
|
107 | 126 | )
|
108 |
| - kaggle_handle = preset.removeprefix(KAGGLE_PREFIX) |
| 127 | + kaggle_handle = preset.removeprefix(KAGGLE_SCHEME + "://") |
109 | 128 | num_segments = len(kaggle_handle.split("/"))
|
110 | 129 | if num_segments not in (4, 5):
|
111 | 130 | raise ValueError(
|
@@ -134,25 +153,23 @@ def get_file(preset, path):
|
134 | 153 | else:
|
135 | 154 | raise ValueError(message)
|
136 | 155 |
|
137 |
| - elif preset.startswith(GS_PREFIX): |
| 156 | + elif scheme in tf.io.gfile.get_registered_schemes(): |
138 | 157 | url = os.path.join(preset, path)
|
139 |
| - url = url.replace(GS_PREFIX, "https://storage.googleapis.com/") |
140 |
| - subdir = preset.replace(GS_PREFIX, "gs_") |
141 |
| - subdir = subdir.replace("/", "_").replace("-", "_") |
| 158 | + subdir = preset.replace("://", "_").replace("-", "_") |
142 | 159 | filename = os.path.basename(path)
|
143 | 160 | subdir = os.path.join(subdir, os.path.dirname(path))
|
144 |
| - return keras.utils.get_file( |
| 161 | + return copy_gfile_to_cache( |
145 | 162 | filename,
|
146 | 163 | url,
|
147 | 164 | cache_subdir=os.path.join("models", subdir),
|
148 | 165 | )
|
149 |
| - elif preset.startswith(HF_PREFIX): |
| 166 | + elif scheme == HF_SCHEME: |
150 | 167 | if huggingface_hub is None:
|
151 | 168 | raise ImportError(
|
152 | 169 | f"`from_preset()` requires the `huggingface_hub` package to load from '{preset}'. "
|
153 | 170 | "Please install with `pip install huggingface_hub`."
|
154 | 171 | )
|
155 |
| - hf_handle = preset.removeprefix(HF_PREFIX) |
| 172 | + hf_handle = preset.removeprefix(HF_SCHEME + "://") |
156 | 173 | try:
|
157 | 174 | return huggingface_hub.hf_hub_download(
|
158 | 175 | repo_id=hf_handle, filename=path
|
@@ -192,6 +209,25 @@ def get_file(preset, path):
|
192 | 209 | )
|
193 | 210 |
|
194 | 211 |
|
| 212 | +def copy_gfile_to_cache(filename, url, cache_subdir): |
| 213 | + """Much of this is adapted from get_file of keras core.""" |
| 214 | + if cache_subdir is None: |
| 215 | + cache_dir = config.keras_home() |
| 216 | + |
| 217 | + datadir_base = os.path.expanduser(cache_dir) |
| 218 | + if not os.access(datadir_base, os.W_OK): |
| 219 | + datadir_base = os.path.join("/tmp", ".keras") |
| 220 | + datadir = os.path.join(datadir_base, cache_subdir) |
| 221 | + os.makedirs(datadir, exist_ok=True) |
| 222 | + |
| 223 | + fpath = os.path.join(datadir, filename) |
| 224 | + if not os.path.exists(fpath): |
| 225 | + print_msg(f"Downloading data from {url}") |
| 226 | + tf.io.gfile.copy(url, fpath) |
| 227 | + |
| 228 | + return fpath |
| 229 | + |
| 230 | + |
195 | 231 | def check_file_exists(preset, path):
|
196 | 232 | try:
|
197 | 233 | get_file(preset, path)
|
|
0 commit comments