Skip to content

Commit 9fb8019

Browse files
authored
ref: simplify cache dir creation and remove repeated parts (#568)
* ref: simplify cache dir creation and remove repeated parts * fix: use constant for index filename in cache index copy function
1 parent b1645b6 commit 9fb8019

File tree

1 file changed

+18
-14
lines changed

1 file changed

+18
-14
lines changed

src/litdata/utilities/dataset_utilities.py

+18-14
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,10 @@ def _read_updated_at(
148148
index_json_content = None
149149
assert isinstance(input_dir, Dir)
150150

151+
# Try to read index.json locally
151152
if input_dir.path is not None and os.path.exists(os.path.join(input_dir.path, _INDEX_FILENAME)):
152-
# read index.json file and read last_updation_timestamp
153153
index_json_content = load_index_file(input_dir.path)
154+
# Try to read index.json remotely
154155
elif input_dir.url is not None:
155156
assert input_dir.url is not None
156157
# download index.json file and read last_updation_timestamp
@@ -170,11 +171,14 @@ def _read_updated_at(
170171

171172

172173
def _clear_cache_dir_if_updated(input_dir_hash_filepath: str, updated_at_hash: str) -> None:
173-
"""Clear cache dir if it is updated.
174+
"""Clear the cache directory if it is outdated.
174175
175-
If last_updated has changed and /cache/chunks/{HASH(input_dir.url)} isn't empty, we remove all the files and then
176-
create the cache.
176+
If the directory at `input_dir_hash_filepath` exists and does not contain only a single subdirectory named
177+
`updated_at_hash`, the entire directory is deleted to prevent using stale or partial cache data.
177178
179+
Args:
180+
input_dir_hash_filepath (str): Path to the hashed cache directory (e.g., /cache/chunks/{HASH(input_dir.url)}).
181+
updated_at_hash (str): The expected hash or timestamp for the current dataset state.
178182
"""
179183
if os.path.exists(input_dir_hash_filepath):
180184
# check if it only contains one directory with updated_at_hash
@@ -189,24 +193,24 @@ def _try_create_cache_dir(
189193
storage_options: Optional[Dict] = {},
190194
index_path: Optional[str] = None,
191195
) -> Optional[str]:
196+
"""Prepare and return the cache directory for a dataset."""
192197
resolved_input_dir = _resolve_dir(input_dir)
193198
updated_at = _read_updated_at(resolved_input_dir, storage_options, index_path)
194199

200+
# Fallback to a hash of the input_dir if updated_at is "0"
195201
if updated_at == "0" and input_dir is not None:
196202
updated_at = hashlib.md5(input_dir.encode()).hexdigest() # noqa: S324
197203

198204
dir_url_hash = hashlib.md5((resolved_input_dir.url or "").encode()).hexdigest() # noqa: S324
199205

200-
if "LIGHTNING_CLUSTER_ID" not in os.environ or "LIGHTNING_CLOUD_PROJECT_ID" not in os.environ:
201-
input_dir_hash_filepath = os.path.join(cache_dir or _DEFAULT_CACHE_DIR, dir_url_hash)
202-
_clear_cache_dir_if_updated(input_dir_hash_filepath, updated_at)
203-
cache_dir = os.path.join(input_dir_hash_filepath, updated_at)
204-
os.makedirs(cache_dir, exist_ok=True)
205-
return cache_dir
206+
# Determine cache root based on environment
207+
is_lightning_cloud = "LIGHTNING_CLUSTER_ID" in os.environ and "LIGHTNING_CLOUD_PROJECT_ID" in os.environ
208+
default_cache_root = _DEFAULT_LIGHTNING_CACHE_DIR if is_lightning_cloud else _DEFAULT_CACHE_DIR
209+
cache_root = cache_dir or default_cache_root
206210

207-
input_dir_hash_filepath = os.path.join(cache_dir or _DEFAULT_LIGHTNING_CACHE_DIR, dir_url_hash)
208-
_clear_cache_dir_if_updated(input_dir_hash_filepath, updated_at)
209-
cache_dir = os.path.join(input_dir_hash_filepath, updated_at)
211+
input_dir_hash_path = os.path.join(cache_root, dir_url_hash)
212+
_clear_cache_dir_if_updated(input_dir_hash_path, updated_at)
213+
cache_dir = os.path.join(input_dir_hash_path, updated_at)
210214
os.makedirs(cache_dir, exist_ok=True)
211215
return cache_dir
212216

@@ -305,7 +309,7 @@ def copy_index_to_cache_index_filepath(index_path: str, cache_index_filepath: st
305309
"""Copy Index file from index_path to cache_index_filepath."""
306310
# If index_path is a directory, append "index.json"
307311
if os.path.isdir(index_path):
308-
index_path = os.path.join(index_path, "index.json")
312+
index_path = os.path.join(index_path, _INDEX_FILENAME)
309313
# Check if the file exists before copying
310314
if not os.path.isfile(index_path):
311315
raise FileNotFoundError(f"Index file not found: {index_path}")

0 commit comments

Comments
 (0)