33from __future__ import annotations
44
55import os
6- from typing import Any , Union , Mapping
7- from typing_extensions import Self , override
6+ from typing import Any , Dict , Union , Mapping , cast
7+ from typing_extensions import Self , Literal , override
88
99import httpx
1010
3030 AsyncAPIClient ,
3131)
3232
33- __all__ = ["Timeout" , "Transport" , "ProxiesTypes" , "RequestOptions" , "Kernel" , "AsyncKernel" , "Client" , "AsyncClient" ]
33+ __all__ = [
34+ "ENVIRONMENTS" ,
35+ "Timeout" ,
36+ "Transport" ,
37+ "ProxiesTypes" ,
38+ "RequestOptions" ,
39+ "Kernel" ,
40+ "AsyncKernel" ,
41+ "Client" ,
42+ "AsyncClient" ,
43+ ]
44+
45+ ENVIRONMENTS : Dict [str , str ] = {
46+ "production" : "https://api.onkernel.com/" ,
47+ "development" : "https://localhost:3001/" ,
48+ }
3449
3550
3651class Kernel (SyncAPIClient ):
@@ -42,11 +57,14 @@ class Kernel(SyncAPIClient):
4257 # client options
4358 api_key : str
4459
60+ _environment : Literal ["production" , "development" ] | NotGiven
61+
4562 def __init__ (
4663 self ,
4764 * ,
4865 api_key : str | None = None ,
49- base_url : str | httpx .URL | None = None ,
66+ environment : Literal ["production" , "development" ] | NotGiven = NOT_GIVEN ,
67+ base_url : str | httpx .URL | None | NotGiven = NOT_GIVEN ,
5068 timeout : Union [float , Timeout , None , NotGiven ] = NOT_GIVEN ,
5169 max_retries : int = DEFAULT_MAX_RETRIES ,
5270 default_headers : Mapping [str , str ] | None = None ,
@@ -77,10 +95,31 @@ def __init__(
7795 )
7896 self .api_key = api_key
7997
80- if base_url is None :
81- base_url = os .environ .get ("KERNEL_BASE_URL" )
82- if base_url is None :
83- base_url = f"http://localhost:3001"
98+ self ._environment = environment
99+
100+ base_url_env = os .environ .get ("KERNEL_BASE_URL" )
101+ if is_given (base_url ) and base_url is not None :
102+ # cast required because mypy doesn't understand the type narrowing
103+ base_url = cast ("str | httpx.URL" , base_url ) # pyright: ignore[reportUnnecessaryCast]
104+ elif is_given (environment ):
105+ if base_url_env and base_url is not None :
106+ raise ValueError (
107+ "Ambiguous URL; The `KERNEL_BASE_URL` env var and the `environment` argument are given. If you want to use the environment, you must pass base_url=None" ,
108+ )
109+
110+ try :
111+ base_url = ENVIRONMENTS [environment ]
112+ except KeyError as exc :
113+ raise ValueError (f"Unknown environment: { environment } " ) from exc
114+ elif base_url_env is not None :
115+ base_url = base_url_env
116+ else :
117+ self ._environment = environment = "production"
118+
119+ try :
120+ base_url = ENVIRONMENTS [environment ]
121+ except KeyError as exc :
122+ raise ValueError (f"Unknown environment: { environment } " ) from exc
84123
85124 super ().__init__ (
86125 version = __version__ ,
@@ -122,6 +161,7 @@ def copy(
122161 self ,
123162 * ,
124163 api_key : str | None = None ,
164+ environment : Literal ["production" , "development" ] | None = None ,
125165 base_url : str | httpx .URL | None = None ,
126166 timeout : float | Timeout | None | NotGiven = NOT_GIVEN ,
127167 http_client : httpx .Client | None = None ,
@@ -157,6 +197,7 @@ def copy(
157197 return self .__class__ (
158198 api_key = api_key or self .api_key ,
159199 base_url = base_url or self .base_url ,
200+ environment = environment or self ._environment ,
160201 timeout = self .timeout if isinstance (timeout , NotGiven ) else timeout ,
161202 http_client = http_client ,
162203 max_retries = max_retries if is_given (max_retries ) else self .max_retries ,
@@ -212,11 +253,14 @@ class AsyncKernel(AsyncAPIClient):
212253 # client options
213254 api_key : str
214255
256+ _environment : Literal ["production" , "development" ] | NotGiven
257+
215258 def __init__ (
216259 self ,
217260 * ,
218261 api_key : str | None = None ,
219- base_url : str | httpx .URL | None = None ,
262+ environment : Literal ["production" , "development" ] | NotGiven = NOT_GIVEN ,
263+ base_url : str | httpx .URL | None | NotGiven = NOT_GIVEN ,
220264 timeout : Union [float , Timeout , None , NotGiven ] = NOT_GIVEN ,
221265 max_retries : int = DEFAULT_MAX_RETRIES ,
222266 default_headers : Mapping [str , str ] | None = None ,
@@ -247,10 +291,31 @@ def __init__(
247291 )
248292 self .api_key = api_key
249293
250- if base_url is None :
251- base_url = os .environ .get ("KERNEL_BASE_URL" )
252- if base_url is None :
253- base_url = f"http://localhost:3001"
294+ self ._environment = environment
295+
296+ base_url_env = os .environ .get ("KERNEL_BASE_URL" )
297+ if is_given (base_url ) and base_url is not None :
298+ # cast required because mypy doesn't understand the type narrowing
299+ base_url = cast ("str | httpx.URL" , base_url ) # pyright: ignore[reportUnnecessaryCast]
300+ elif is_given (environment ):
301+ if base_url_env and base_url is not None :
302+ raise ValueError (
303+ "Ambiguous URL; The `KERNEL_BASE_URL` env var and the `environment` argument are given. If you want to use the environment, you must pass base_url=None" ,
304+ )
305+
306+ try :
307+ base_url = ENVIRONMENTS [environment ]
308+ except KeyError as exc :
309+ raise ValueError (f"Unknown environment: { environment } " ) from exc
310+ elif base_url_env is not None :
311+ base_url = base_url_env
312+ else :
313+ self ._environment = environment = "production"
314+
315+ try :
316+ base_url = ENVIRONMENTS [environment ]
317+ except KeyError as exc :
318+ raise ValueError (f"Unknown environment: { environment } " ) from exc
254319
255320 super ().__init__ (
256321 version = __version__ ,
@@ -292,6 +357,7 @@ def copy(
292357 self ,
293358 * ,
294359 api_key : str | None = None ,
360+ environment : Literal ["production" , "development" ] | None = None ,
295361 base_url : str | httpx .URL | None = None ,
296362 timeout : float | Timeout | None | NotGiven = NOT_GIVEN ,
297363 http_client : httpx .AsyncClient | None = None ,
@@ -327,6 +393,7 @@ def copy(
327393 return self .__class__ (
328394 api_key = api_key or self .api_key ,
329395 base_url = base_url or self .base_url ,
396+ environment = environment or self ._environment ,
330397 timeout = self .timeout if isinstance (timeout , NotGiven ) else timeout ,
331398 http_client = http_client ,
332399 max_retries = max_retries if is_given (max_retries ) else self .max_retries ,
0 commit comments