Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion s3torchconnector/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ lightning = [

lightning-tests = [
"s3torchconnector[lightning]",
"s3fs"
"s3fs",
"torchmetrics != 1.7.0",
]

dcp = [
Expand Down
59 changes: 48 additions & 11 deletions s3torchconnector/src/s3torchconnector/_s3client/_s3client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import logging
import os
import gc
import threading
from functools import partial
from typing import Optional, Any
Expand All @@ -16,13 +17,14 @@
HeadObjectResult,
ListObjectStream,
GetObjectStream,
join_all_managed_threads,
)

from s3torchconnector._user_agent import UserAgent

"""
_s3client.py
Internal client wrapper class on top of S3 client implementation
Internal client wrapper class on top of S3 client implementation
with multi-process support.
"""

Expand All @@ -35,6 +37,32 @@ def _identity(obj: Any) -> Any:


_client_lock = threading.Lock()
NATIVE_S3_CLIENT = None


def _before_fork_handler():
"""Handler that cleans up CRT resources before fork operations."""
global NATIVE_S3_CLIENT
try:
if NATIVE_S3_CLIENT is not None:
# Release the client before fork as it's not fork-safe
NATIVE_S3_CLIENT = None
gc.collect()
# Wait for native background threads to complete joining (0.5 sec timeout)
join_all_managed_threads(0.5)
except Exception as e:
print(
"Warning: Failed to properly clean up native background threads before fork. "
f"Error: {e}\n"
"Your subprocess may crash or hang. To prevent this:\n"
"1. Ensure no active S3 client usage during fork operations\n"
"2. Use multiprocessing with 'spawn' or 'forkserver' start method instead"
)


# register the handler to release the S3 client and wait for background threads to join before fork happens
# As fork will not inherit any background threads. Wait for them to join to avoid crashes or hangs.
os.register_at_fork(before=_before_fork_handler)


class S3Client:
Expand All @@ -48,25 +76,34 @@ def __init__(
):
self._region = region
self._endpoint = endpoint
self._real_client: Optional[MountpointS3Client] = None
self._client_pid: Optional[int] = None
user_agent = user_agent or UserAgent()
self._user_agent_prefix = user_agent.prefix
self._s3client_config = s3client_config or S3ClientConfig()
self._client_pid: Optional[int] = None
global NATIVE_S3_CLIENT
NATIVE_S3_CLIENT = None

@property
def _client(self) -> MountpointS3Client:
# This is a fast check to avoid acquiring the lock unnecessarily.
if self._client_pid is None or self._client_pid != os.getpid():
global NATIVE_S3_CLIENT
if (
self._client_pid is None
or self._client_pid != os.getpid()
or NATIVE_S3_CLIENT is None
):
# Acquire the lock to ensure thread-safety when creating the client.
with _client_lock:
# This double-check ensures that the client is only created once.
if self._client_pid is None or self._client_pid != os.getpid():
# `MountpointS3Client` does not survive forking, so re-create it if the PID has changed.
self._real_client = self._client_builder()
if (
self._client_pid is None
or self._client_pid != os.getpid()
or NATIVE_S3_CLIENT is None
):
# This double-check ensures that the client is only created once.
NATIVE_S3_CLIENT = self._client_builder()
self._client_pid = os.getpid()
assert self._real_client is not None
return self._real_client

assert NATIVE_S3_CLIENT is not None
return NATIVE_S3_CLIENT

@property
def region(self) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,19 @@
from s3torchconnector._s3client import S3Client
from s3torchconnectorclient._mountpoint_s3_client import MountpointS3Client

NATIVE_S3_CLIENT = None


class S3ClientWithoutLock(S3Client):
@property
def _client(self) -> MountpointS3Client:
global NATIVE_S3_CLIENT
if self._client_pid is None or self._client_pid != os.getpid():
self._client_pid = os.getpid()
# `MountpointS3Client` does not survive forking, so re-create it if the PID has changed.
self._real_client = self._client_builder()
assert self._real_client is not None
return self._real_client
NATIVE_S3_CLIENT = self._client_builder()
assert NATIVE_S3_CLIENT is not None
return NATIVE_S3_CLIENT

def _client_builder(self):
time.sleep(1)
Expand Down
5 changes: 3 additions & 2 deletions s3torchconnectorclient/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions s3torchconnectorclient/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ built = "0.7"
pyo3 = "0.22.4"
futures = "0.3.28"
mountpoint-s3-client = { version = "0.13.0", features = ["mock"] }
mountpoint-s3-crt-sys = { version = "0.12.1" }
log = "0.4.20"
tracing = { version = "0.1.40", default-features = false, features = ["std", "log"] }
tracing-subscriber = { version = "0.3.18", features = ["fmt", "env-filter"]}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,5 @@ class S3Exception(Exception):
pass

__version__: str

def join_all_managed_threads(timeout_secs: float) -> None: ...
8 changes: 7 additions & 1 deletion s3torchconnectorclient/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@ use crate::exception::S3Exception;
use crate::get_object_stream::GetObjectStream;
use crate::list_object_stream::ListObjectStream;
use crate::mock_client::PyMockClient;
use crate::mountpoint_s3_client::join_all_managed_threads;
use crate::mountpoint_s3_client::MountpointS3Client;
use crate::put_object_stream::PutObjectStream;
use crate::python_structs::py_head_object_result::PyHeadObjectResult;
use crate::python_structs::py_list_object_result::PyListObjectResult;
use crate::python_structs::py_object_info::PyObjectInfo;
use crate::python_structs::py_head_object_result::PyHeadObjectResult;
use crate::python_structs::py_restore_status::PyRestoreStatus;
use pyo3::prelude::*;

Expand Down Expand Up @@ -41,5 +42,10 @@ fn make_lib(py: Python, mountpoint_s3_client: &Bound<'_, PyModule>) -> PyResult<
mountpoint_s3_client.add_class::<PyRestoreStatus>()?;
mountpoint_s3_client.add("S3Exception", py.get_type_bound::<S3Exception>())?;
mountpoint_s3_client.add("__version__", build_info::FULL_VERSION)?;
mountpoint_s3_client.add_function(wrap_pyfunction!(
join_all_managed_threads,
mountpoint_s3_client
)?)?;

Ok(())
}
65 changes: 46 additions & 19 deletions s3torchconnectorclient/rust/src/mountpoint_s3_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
use mountpoint_s3_client::config::{
AddressingStyle, EndpointConfig, S3ClientAuthConfig, S3ClientConfig,
};
use mountpoint_s3_client::config::{Allocator, Uri};
use mountpoint_s3_client::types::{GetObjectParams, HeadObjectParams, PutObjectParams};
use mountpoint_s3_client::user_agent::UserAgent;
use mountpoint_s3_client::{ObjectClient, S3CrtClient};
use mountpoint_s3_client::config::{Allocator, Uri};
use nix::unistd::Pid;
use mountpoint_s3_crt_sys::{aws_thread_join_all_managed, aws_thread_set_managed_join_timeout_ns};
use pyo3::marker::Python;
use pyo3::types::PyTuple;
use pyo3::{pyclass, pymethods, Bound, PyRef, PyResult, ToPyObject};
use pyo3::{pyclass, pyfunction, pymethods, Bound, PyErr, PyRef, PyResult, ToPyObject};
use std::sync::Arc;

use crate::exception::python_exception;
Expand Down Expand Up @@ -48,8 +49,48 @@ pub struct MountpointS3Client {
user_agent_prefix: String,
#[pyo3(get)]
endpoint: Option<String>,
}

/// Waits for all managed CRT threads to complete, with a specified timeout.
///
/// This function blocks the calling thread until all CRT-managed threads have
/// completed execution or until the timeout expires.
///
/// Args:
/// timeout_secs (float): Maximum time to wait for threads to join, in seconds.
/// Use 0.0 for no timeout.
///
/// Returns:
/// None: On successful completion when all threads have joined.
///
/// Raises:
/// RuntimeError: If threads failed to join within the timeout period.
///
/// Note:
/// This function must only be called from the main thread or a non-managed thread.
/// Calling it from a managed thread may result in deadlock or other undefined behavior.
///
/// Example:
/// >>> join_all_managed_threads(0.5) # Wait up to 0.5 seconds for threads to join
#[pyfunction]
pub fn join_all_managed_threads(py: Python<'_>, timeout_secs: f64) -> PyResult<()> {
unsafe {
// Convert seconds to nanoseconds (1 second = 1_000_000_000 nanoseconds)
let timeout_ns = (timeout_secs * 1_000_000_000.0) as u64;

owner_pid: Pid,
aws_thread_set_managed_join_timeout_ns(timeout_ns);

// Release the GIL while waiting for other threads to join, which may acquire GIL, to avoid deadlock
let result = py.allow_threads(|| aws_thread_join_all_managed());

if result != 0 {
return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
"Failed to join managed threads in {} secs",
timeout_secs
)));
}
}
Ok(())
}

#[pymethods]
Expand Down Expand Up @@ -147,7 +188,7 @@ impl MountpointS3Client {
pub fn head_object(
slf: PyRef<'_, Self>,
bucket: String,
key: String
key: String,
) -> PyResult<PyHeadObjectResult> {
let params = HeadObjectParams::default();
slf.client.head_object(slf.py(), bucket, key, params)
Expand Down Expand Up @@ -213,7 +254,6 @@ impl MountpointS3Client {
client: Arc::new(MountpointS3ClientInnerImpl::new(client)),
user_agent_prefix,
endpoint,
owner_pid: nix::unistd::getpid(),
}
}
}
Expand All @@ -227,16 +267,3 @@ fn auth_config(profile: Option<&str>, unsigned: bool) -> S3ClientAuthConfig {
S3ClientAuthConfig::Default
}
}

impl Drop for MountpointS3Client {
fn drop(&mut self) {
if nix::unistd::getpid() != self.owner_pid {
// We don't want to try to deallocate a client on a different process after a fork, as
// the threads the destructor is expecting to exist actually don't (they didn't survive
// the fork). So we intentionally leak the inner client by bumping its reference count
// and then forgetting it, so the reference count can never reach zero. It's a memory
// leak, but not a big one in practice given how long we expect clients to live.
std::mem::forget(Arc::clone(&self.client));
}
}
}