1+ import contextlib
12import json
23import shutil
3- from typing import Any , Callable , Dict , Optional , Union
4+ from contextvars import ContextVar
5+ from typing import Any , Callable , Dict , Generator , Optional , Union
46
57import requests
68
79import audiostack
810from audiostack .helpers .request_types import RequestTypes
911
12+ _current_trace_id : ContextVar [Optional [str ]] = ContextVar (
13+ "current_trace_id" , default = None
14+ )
15+
1016
1117def remove_empty (data : Any ) -> Any :
1218 if not (isinstance (data , dict ) or isinstance (data , list )):
@@ -32,14 +38,19 @@ def __init__(self, family: str) -> None:
3238 self .family = family
3339
3440 @staticmethod
35- def make_header () -> dict :
36- header = {
41+ def make_header (headers : Optional [ dict ] = None ) -> dict :
42+ new_headers = {
3743 "x-api-key" : audiostack .api_key ,
3844 "x-python-sdk-version" : audiostack .sdk_version ,
3945 }
46+ current_trace_id = _current_trace_id .get ()
47+ if current_trace_id is not None :
48+ new_headers ["x-customer-trace-id" ] = current_trace_id
4049 if audiostack .assume_org_id :
41- header ["x-assume-org" ] = audiostack .assume_org_id
42- return header
50+ new_headers ["x-assume-org" ] = audiostack .assume_org_id
51+ if headers :
52+ new_headers .update (headers )
53+ return new_headers
4354
4455 def resolve_response (self , r : Any ) -> dict :
4556 if self .DEBUG_PRINT :
@@ -82,6 +93,7 @@ def send_request(
8293 path_parameters : Optional [Union [dict , str ]] = None ,
8394 query_parameters : Optional [Union [dict , str ]] = None ,
8495 overwrite_base_url : Optional [str ] = None ,
96+ headers : Optional [dict ] = None ,
8597 ) -> Any :
8698 if overwrite_base_url :
8799 url = overwrite_base_url
@@ -111,15 +123,15 @@ def send_request(
111123 }
112124
113125 return self .resolve_response (
114- FUNC_MAP [rtype ](url = url , json = json , headers = self .make_header ())
126+ FUNC_MAP [rtype ](url = url , json = json , headers = self .make_header (headers ))
115127 )
116128 elif rtype == RequestTypes .GET :
117129 if path_parameters :
118130 url = f"{ url } /{ path_parameters } "
119131
120132 return self .resolve_response (
121133 requests .get (
122- url = url , params = query_parameters , headers = self .make_header ()
134+ url = url , params = query_parameters , headers = self .make_header (headers )
123135 )
124136 )
125137 elif rtype == RequestTypes .DELETE :
@@ -128,7 +140,7 @@ def send_request(
128140
129141 return self .resolve_response (
130142 requests .delete (
131- url = url , params = query_parameters , headers = self .make_header ()
143+ url = url , params = query_parameters , headers = self .make_header (headers )
132144 )
133145 )
134146
@@ -142,3 +154,12 @@ def download_url(cls, url: str, name: str, destination: str) -> None:
142154 local_filename = f"{ destination } /{ name } "
143155 with open (local_filename , "wb" ) as f :
144156 shutil .copyfileobj (r .raw , f )
157+
158+
159+ @contextlib .contextmanager
160+ def use_trace (trace_id : str ) -> Generator [None , None , None ]:
161+ token = _current_trace_id .set (trace_id )
162+ try :
163+ yield
164+ finally :
165+ _current_trace_id .reset (token )
0 commit comments