Skip to content

refactor: implement native lock objects #14028

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ddtrace/_monkey.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import importlib
import os
import threading
from types import ModuleType
from typing import TYPE_CHECKING # noqa:F401
from typing import Union
Expand All @@ -9,6 +8,7 @@

from ddtrace.appsec._listeners import load_common_appsec_modules
from ddtrace.internal.telemetry.constants import TELEMETRY_NAMESPACE
from ddtrace.internal.threads import Lock
from ddtrace.settings._config import config
from ddtrace.settings.asm import config as asm_config
from ddtrace.vendor.debtcollector import deprecate
Expand Down Expand Up @@ -130,7 +130,7 @@
}


_LOCK = threading.Lock()
_LOCK = Lock()
_PATCHED_MODULES = set()

# Module names that need to be patched for a given integration. If the module
Expand Down
11 changes: 4 additions & 7 deletions ddtrace/_trace/context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import base64
import re
import threading
from typing import Any
from typing import Dict
from typing import List
Expand All @@ -19,6 +18,7 @@
from ddtrace.internal.constants import W3C_TRACEPARENT_KEY
from ddtrace.internal.constants import W3C_TRACESTATE_KEY
from ddtrace.internal.logger import get_logger
from ddtrace.internal.threads import RLock
from ddtrace.internal.utils.http import w3c_get_dd_list_member as _w3c_get_dd_list_member


Expand Down Expand Up @@ -65,7 +65,7 @@ def __init__(
sampling_priority: Optional[float] = None,
meta: Optional[_MetaDictType] = None,
metrics: Optional[_MetricDictType] = None,
lock: Optional[threading.RLock] = None,
lock: Optional[RLock] = None,
span_links: Optional[List[SpanLink]] = None,
baggage: Optional[Dict[str, Any]] = None,
is_remote: bool = True,
Expand All @@ -91,10 +91,7 @@ def __init__(
if lock is not None:
self._lock = lock
else:
# DEV: A `forksafe.RLock` is not necessary here since Contexts
# are recreated by the tracer after fork
# https://github.com/DataDog/dd-trace-py/blob/a1932e8ddb704d259ea8a3188d30bf542f59fd8d/ddtrace/tracer.py#L489-L508
self._lock = threading.RLock()
self._lock = RLock()

def __getstate__(self) -> _ContextState:
return (
Expand All @@ -121,7 +118,7 @@ def __setstate__(self, state: _ContextState) -> None:
self._reactivate,
) = state
# We cannot serialize and lock, so we must recreate it unless we already have one
self._lock = threading.RLock()
self._lock = RLock()

def __enter__(self) -> "Context":
self._lock.acquire()
Expand Down
4 changes: 2 additions & 2 deletions ddtrace/_trace/processor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import abc
from collections import defaultdict
from itertools import chain
from threading import RLock
from typing import Any
from typing import DefaultDict
from typing import Dict
Expand Down Expand Up @@ -29,6 +28,7 @@
from ddtrace.internal.service import ServiceStatusError
from ddtrace.internal.telemetry.constants import TELEMETRY_LOG_LEVEL
from ddtrace.internal.telemetry.constants import TELEMETRY_NAMESPACE
from ddtrace.internal.threads import RLock
from ddtrace.internal.writer import AgentResponse
from ddtrace.internal.writer import create_trace_writer
from ddtrace.settings._config import config
Expand Down Expand Up @@ -280,7 +280,7 @@ def __init__(
self.writer = create_trace_writer(response_callback=self._agent_response_callback)
# Initialize the trace buffer and lock
self._traces: DefaultDict[int, _Trace] = defaultdict(lambda: _Trace())
self._lock: RLock = RLock()
self._lock = RLock()
# Track telemetry span metrics by span api
# ex: otel api, opentracing api, datadog api
self._span_metrics: Dict[str, DefaultDict] = {
Expand Down
2 changes: 1 addition & 1 deletion ddtrace/_trace/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import logging
import os
from os import getpid
from threading import RLock
from typing import Any
from typing import Callable
from typing import Dict
Expand Down Expand Up @@ -53,6 +52,7 @@
from ddtrace.internal.processor.endpoint_call_counter import EndpointCallCounterProcessor
from ddtrace.internal.runtime import get_runtime_id
from ddtrace.internal.schema.processor import BaseServiceProcessor
from ddtrace.internal.threads import RLock
from ddtrace.internal.utils import _get_metas_to_propagate
from ddtrace.internal.utils.formats import format_trace_id
from ddtrace.internal.writer import AgentWriterInterface
Expand Down
5 changes: 3 additions & 2 deletions ddtrace/appsec/_iast/_overhead_control_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
limit. It will measure operations being executed in a request and it will deactivate detection
(and therefore reduce the overhead to nearly 0) if a certain threshold is reached.
"""

from ddtrace._trace.sampler import RateSampler
from ddtrace._trace.span import Span
from ddtrace.appsec._iast._utils import _is_iast_debug_enabled
from ddtrace.internal._unpatched import _threading as threading
from ddtrace.internal.logger import get_logger
from ddtrace.internal.threads import Lock
from ddtrace.settings.asm import config as asm_config


Expand All @@ -24,7 +25,7 @@ class OverheadControl(object):
The goal is to do sampling at different levels of the IAST analysis (per process, per request, etc)
"""

_lock = threading.Lock()
_lock = Lock()
_request_quota = asm_config._iast_max_concurrent_requests
_sampler = RateSampler(sample_rate=get_request_sampling_value() / 100.0)

Expand Down
2 changes: 1 addition & 1 deletion ddtrace/contrib/internal/subprocess/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from ddtrace.contrib.internal.subprocess.constants import COMMANDS
from ddtrace.ext import SpanTypes
from ddtrace.internal import core
from ddtrace.internal.forksafe import RLock
from ddtrace.internal.logger import get_logger
from ddtrace.internal.threads import RLock
from ddtrace.settings._config import config
from ddtrace.settings.asm import config as asm_config

Expand Down
4 changes: 2 additions & 2 deletions ddtrace/debugging/_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
from ddtrace.debugging._config import di_config
from ddtrace.debugging._signal.log import LogSignal
from ddtrace.debugging._signal.snapshot import Snapshot
from ddtrace.internal import forksafe
from ddtrace.internal._encoding import BufferFull
from ddtrace.internal.logger import get_logger
from ddtrace.internal.threads import Lock
from ddtrace.internal.utils.formats import format_trace_id


Expand Down Expand Up @@ -310,7 +310,7 @@ def __init__(
) -> None:
self._encoder = encoder
self._buffer = JsonBuffer(buffer_size)
self._lock = forksafe.Lock()
self._lock = Lock()
self._on_full = on_full
self.count = 0
self.max_size = buffer_size - self._buffer.size
Expand Down
4 changes: 2 additions & 2 deletions ddtrace/debugging/_probe/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from ddtrace.debugging._probe.model import Probe
from ddtrace.debugging._probe.model import ProbeLocationMixin
from ddtrace.debugging._probe.status import ProbeStatusLogger
from ddtrace.internal import forksafe
from ddtrace.internal.logger import get_logger
from ddtrace.internal.threads import RLock


logger = get_logger(__name__)
Expand Down Expand Up @@ -68,7 +68,7 @@ def __init__(self, status_logger: ProbeStatusLogger, *args: Any, **kwargs: Any)
# Used to keep track of probes pending installation
self._pending: Dict[str, List[Probe]] = defaultdict(list)

self._lock = forksafe.RLock()
self._lock = RLock()

def register(self, *probes: Probe) -> None:
"""Register a probe."""
Expand Down
9 changes: 5 additions & 4 deletions ddtrace/internal/_encoding.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ from libc cimport stdint
from libc.string cimport strlen

from json import dumps as json_dumps
import threading
from json import dumps as json_dumps

from ._utils cimport PyBytesLike_Check
Expand All @@ -26,6 +25,8 @@ from .constants import MAX_UINT_64BITS
from .._trace._limits import MAX_SPAN_META_VALUE_LEN
from .._trace._limits import TRUNCATED_SPAN_ATTRIBUTE_LEN
from ..settings._agent import config as agent_config
from ddtrace.internal.threads import Lock
from ddtrace.internal.threads import RLock


DEF MSGPACK_ARRAY_LENGTH_PREFIX_SIZE = 5
Expand Down Expand Up @@ -256,7 +257,7 @@ cdef class MsgpackStringTable(StringTable):
self.max_size = max_size
self.pk.length = MSGPACK_STRING_TABLE_LENGTH_PREFIX_SIZE
self._sp_len = 0
self._lock = threading.RLock()
self._lock = RLock()
super(MsgpackStringTable, self).__init__()

self.index(ORIGIN_KEY)
Expand Down Expand Up @@ -371,7 +372,7 @@ cdef class BufferedEncoder(object):
def __cinit__(self, size_t max_size, size_t max_item_size):
self.max_size = max_size
self.max_item_size = max_item_size
self._lock = threading.Lock()
self._lock = Lock()

# ---- Abstract methods ----

Expand Down Expand Up @@ -443,7 +444,7 @@ cdef class MsgpackEncoderBase(BufferedEncoder):
self.max_size = max_size
self.pk.buf_size = buf_size
self.max_item_size = max_item_size if max_item_size < max_size else max_size
self._lock = threading.RLock()
self._lock = RLock()
self._reset_buffer()

def __dealloc__(self):
Expand Down
26 changes: 26 additions & 0 deletions ddtrace/internal/_threads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ class PyRef
PyObject* _obj;
};

// ----------------------------------------------------------------------------

#include "_threads/lock.hpp"

// ----------------------------------------------------------------------------
class Event
{
Expand Down Expand Up @@ -511,6 +515,7 @@ static PyTypeObject PeriodicThreadType = {

// ----------------------------------------------------------------------------
static PyMethodDef _threads_methods[] = {
{ "reset_locks", (PyCFunction)lock_reset_locks, METH_NOARGS, "Reset all locks (generally after a fork)" },
{ NULL, NULL, 0, NULL } /* Sentinel */
};

Expand All @@ -533,6 +538,12 @@ PyInit__threads(void)
if (PyType_Ready(&PeriodicThreadType) < 0)
return NULL;

if (PyType_Ready(&LockType) < 0)
return NULL;

if (PyType_Ready(&RLockType) < 0)
return NULL;

_periodic_threads = PyDict_New();
if (_periodic_threads == NULL)
return NULL;
Expand All @@ -541,6 +552,7 @@ PyInit__threads(void)
if (m == NULL)
goto error;

// Periodic thread
Py_INCREF(&PeriodicThreadType);
if (PyModule_AddObject(m, "PeriodicThread", (PyObject*)&PeriodicThreadType) < 0) {
Py_DECREF(&PeriodicThreadType);
Expand All @@ -550,6 +562,20 @@ PyInit__threads(void)
if (PyModule_AddObject(m, "periodic_threads", _periodic_threads) < 0)
goto error;

// Lock
Py_INCREF(&LockType);
if (PyModule_AddObject(m, "Lock", (PyObject*)&LockType) < 0) {
Py_DECREF(&LockType);
goto error;
}

// RLock
Py_INCREF(&RLockType);
if (PyModule_AddObject(m, "RLock", (PyObject*)&RLockType) < 0) {
Py_DECREF(&RLockType);
goto error;
}

return m;

error:
Expand Down
15 changes: 15 additions & 0 deletions ddtrace/internal/_threads.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
import typing as t

class _BaseLock:
def __init__(self, reentrant: bool = False) -> None: ...
def acquire(self, timeout: t.Optional[float] = None) -> bool: ...
def release(self) -> None: ...
def locked(self) -> bool: ...
def __enter__(self) -> None: ...
def __exit__(self, exc_type, exc_value, traceback) -> t.Literal[False]: ...

class Lock(_BaseLock): ...
class RLock(_BaseLock): ...

def reset_locks() -> None: ...
def begin_reset_locks() -> None: ...
def end_reset_locks() -> None: ...

class PeriodicThread:
name: str
ident: int
Expand Down
Loading
Loading