Skip to content
26 changes: 22 additions & 4 deletions logfire/_internal/instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from contextlib import AbstractContextManager, asynccontextmanager, contextmanager
from typing import TYPE_CHECKING, Any, Callable, TypeVar

from opentelemetry import trace
from opentelemetry.util import types as otel_types
from typing_extensions import LiteralString, ParamSpec

Expand Down Expand Up @@ -50,6 +51,7 @@ def instrument(
extract_args: bool | Iterable[str],
record_return: bool,
allow_generator: bool,
new_trace: bool,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
from .main import set_user_attributes_on_raw_span

Expand All @@ -61,7 +63,7 @@ def decorator(func: Callable[P, R]) -> Callable[P, R]:
)

attributes = get_attributes(func, msg_template, tags)
open_span = get_open_span(logfire, attributes, span_name, extract_args, func)
open_span = get_open_span(logfire, attributes, span_name, extract_args, func, new_trace)

if inspect.isgeneratorfunction(func):
if not allow_generator:
Expand Down Expand Up @@ -119,6 +121,7 @@ def get_open_span(
span_name: str | None,
extract_args: bool | Iterable[str],
func: Callable[P, R],
new_trace: bool,
) -> Callable[P, AbstractContextManager[Any]]:
final_span_name: str = span_name or attributes[ATTRIBUTES_MESSAGE_TEMPLATE_KEY] # type: ignore

Expand All @@ -136,9 +139,24 @@ def get_logfire():
def get_logfire():
return logfire

if new_trace:

def extra_span_kwargs() -> dict[str, Any]:
prev_context = trace.get_current_span().get_span_context()
if not prev_context.is_valid:
return {}
return {
'links': [trace.Link(prev_context)],
'context': trace.set_span_in_context(trace.INVALID_SPAN),
}
else:

def extra_span_kwargs() -> dict[str, Any]:
return {}

# This is the fast case for when there are no arguments to extract
def open_span(*_: P.args, **__: P.kwargs): # type: ignore
return get_logfire()._fast_span(final_span_name, attributes) # type: ignore
return get_logfire()._fast_span(final_span_name, attributes, **extra_span_kwargs()) # type: ignore

if extract_args is True:
sig = inspect.signature(func)
Expand All @@ -149,7 +167,7 @@ def open_span(*func_args: P.args, **func_kwargs: P.kwargs):
bound.apply_defaults()
args_dict = bound.arguments
return get_logfire()._instrument_span_with_args( # type: ignore
final_span_name, attributes, args_dict
final_span_name, attributes, args_dict, **extra_span_kwargs()
)

return open_span
Expand Down Expand Up @@ -180,7 +198,7 @@ def open_span(*func_args: P.args, **func_kwargs: P.kwargs):
args_dict = {k: args_dict[k] for k in extract_args_final}

return get_logfire()._instrument_span_with_args( # type: ignore
final_span_name, attributes, args_dict
final_span_name, attributes, args_dict, **extra_span_kwargs()
)

return open_span
Expand Down
16 changes: 11 additions & 5 deletions logfire/_internal/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,20 +246,20 @@ def _span(
log_internal_error()
return NoopSpan() # type: ignore

def _fast_span(self, name: str, attributes: otel_types.Attributes) -> FastLogfireSpan:
def _fast_span(self, name: str, attributes: otel_types.Attributes, **kwargs: Any) -> FastLogfireSpan:
"""A simple version of `_span` optimized for auto-tracing that doesn't support message formatting.

Returns a similarly simplified version of `LogfireSpan` which must immediately be used as a context manager.
"""
try:
span = self._spans_tracer.start_span(name=name, attributes=attributes)
span = self._spans_tracer.start_span(name=name, attributes=attributes, **kwargs)
return FastLogfireSpan(span)
except Exception: # pragma: no cover
log_internal_error()
return NoopSpan() # type: ignore

def _instrument_span_with_args(
self, name: str, attributes: dict[str, otel_types.AttributeValue], function_args: dict[str, Any]
self, name: str, attributes: dict[str, otel_types.AttributeValue], function_args: dict[str, Any], **kwargs: Any
) -> FastLogfireSpan:
"""A version of `_span` used by `@instrument` with `extract_args=True`.

Expand All @@ -272,7 +272,7 @@ def _instrument_span_with_args(
if json_schema_properties := attributes_json_schema_properties(function_args): # pragma: no branch
attributes[ATTRIBUTES_JSON_SCHEMA_KEY] = attributes_json_schema(json_schema_properties)
attributes.update(prepare_otlp_attributes(function_args))
return self._fast_span(name, attributes)
return self._fast_span(name, attributes, **kwargs)
except Exception: # pragma: no cover
log_internal_error()
return NoopSpan() # type: ignore
Expand Down Expand Up @@ -580,6 +580,7 @@ def instrument(
extract_args: bool | Iterable[str] = True,
record_return: bool = False,
allow_generator: bool = False,
new_trace: bool = False,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""Decorator for instrumenting a function as a span.

Expand All @@ -603,6 +604,8 @@ def my_function(a: int):
Ignored for generators.
allow_generator: Set to `True` to prevent a warning when instrumenting a generator function.
Read https://logfire.pydantic.dev/docs/guides/advanced/generators/#using-logfireinstrument first.
new_trace: Set to `True` to start a new trace with a span link to the current span
instead of creating a child of the current span.
"""

@overload
Expand All @@ -629,6 +632,7 @@ def instrument( # type: ignore[reportInconsistentOverload]
extract_args: bool | Iterable[str] = True,
record_return: bool = False,
allow_generator: bool = False,
new_trace: bool = False,
) -> Callable[[Callable[P, R]], Callable[P, R]] | Callable[P, R]:
"""Decorator for instrumenting a function as a span.

Expand All @@ -652,11 +656,13 @@ def my_function(a: int):
Ignored for generators.
allow_generator: Set to `True` to prevent a warning when instrumenting a generator function.
Read https://logfire.pydantic.dev/docs/guides/advanced/generators/#using-logfireinstrument first.
new_trace: Set to `True` to start a new trace with a span link to the current span
instead of creating a child of the current span.
"""
if callable(msg_template):
return self.instrument()(msg_template)
return instrument(
self, tuple(self._tags), msg_template, span_name, extract_args, record_return, allow_generator
self, tuple(self._tags), msg_template, span_name, extract_args, record_return, allow_generator, new_trace
)

def log(
Expand Down
Loading
Loading