Skip to content

Commit cc8c0b6

Browse files
committed
use abc for provider
protocol is awesome for typing, but pycharm tooling doesn't fully understand it && behavior difference between 3.9 and 3.11 makes it cumbersome
1 parent c2d5b9b commit cc8c0b6

File tree

8 files changed

+16
-31
lines changed

8 files changed

+16
-31
lines changed

src/shelloracle/providers/__init__.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from __future__ import annotations
22

3+
import abc
34
from abc import abstractmethod
4-
from typing import TYPE_CHECKING, Generic, Protocol, TypeVar
5+
from typing import TYPE_CHECKING, Generic, TypeVar
56

67
if TYPE_CHECKING:
78
from collections.abc import AsyncIterator
@@ -22,7 +23,7 @@ class ProviderError(Exception):
2223
"""LLM providers raise this error to gracefully indicate something has gone wrong."""
2324

2425

25-
class Provider(Protocol):
26+
class Provider(abc.ABC):
2627
"""
2728
LLM Provider Protocol
2829
@@ -38,6 +39,7 @@ def __init__(self, config: Configuration) -> None:
3839
:param config: the configuration object
3940
:return: none
4041
"""
42+
self.config = config
4143

4244
@abstractmethod
4345
def generate(self, prompt: str) -> AsyncIterator[str]:

src/shelloracle/providers/deepseek.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,15 @@
99
if TYPE_CHECKING:
1010
from collections.abc import AsyncIterator
1111

12-
from shelloracle.config import Configuration
13-
1412

1513
class Deepseek(Provider):
1614
name = "Deepseek"
1715

1816
api_key = Setting(default="")
1917
model = Setting(default="deepseek-chat")
2018

21-
def __init__(self, config: Configuration) -> None:
22-
self.config = config
19+
def __init__(self, *args, **kwargs) -> None:
20+
super().__init__(*args, **kwargs)
2321
if not self.api_key:
2422
msg = "No API key provided"
2523
raise ProviderError(msg)

src/shelloracle/providers/google.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,15 @@
99
if TYPE_CHECKING:
1010
from collections.abc import AsyncIterator
1111

12-
from shelloracle.config import Configuration
13-
1412

1513
class Google(Provider):
1614
name = "Google"
1715

1816
api_key = Setting(default="")
1917
model = Setting(default="gemini-2.0-flash") # Assuming a default model name
2018

21-
def __init__(self, config: Configuration) -> None:
22-
self.config = config
19+
def __init__(self, *args, **kwargs) -> None:
20+
super().__init__(*args, **kwargs)
2321
if not self.api_key:
2422
msg = "No API key provided"
2523
raise ProviderError(msg)

src/shelloracle/providers/localai.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
if TYPE_CHECKING:
1010
from collections.abc import AsyncIterator
1111

12-
from shelloracle.config import Configuration
13-
1412

1513
class LocalAI(Provider):
1614
name = "LocalAI"
@@ -23,8 +21,8 @@ class LocalAI(Provider):
2321
def endpoint(self) -> str:
2422
return f"http://{self.host}:{self.port}"
2523

26-
def __init__(self, config: Configuration) -> None:
27-
self.config = config
24+
def __init__(self, *args, **kwargs) -> None:
25+
super().__init__(*args, **kwargs)
2826
# Use a placeholder API key so the client will work
2927
self.client = AsyncOpenAI(api_key="sk-xxx", base_url=self.endpoint)
3028

src/shelloracle/providers/ollama.py

-5
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
if TYPE_CHECKING:
1212
from collections.abc import AsyncIterator
1313

14-
from shelloracle.config import Configuration
15-
1614

1715
def dataclass_to_json(obj: Any) -> dict[str, Any]:
1816
"""Convert dataclass to a json dict
@@ -60,9 +58,6 @@ class Ollama(Provider):
6058
port = Setting(default=11434)
6159
model = Setting(default="dolphin-mistral")
6260

63-
def __init__(self, config: Configuration) -> None:
64-
self.config = config
65-
6661
@property
6762
def endpoint(self) -> str:
6863
# computed property because python descriptors need to be bound to an instance before access

src/shelloracle/providers/openai.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,15 @@
99
if TYPE_CHECKING:
1010
from collections.abc import AsyncIterator
1111

12-
from shelloracle.config import Configuration
13-
1412

1513
class OpenAI(Provider):
1614
name = "OpenAI"
1715

1816
api_key = Setting(default="")
1917
model = Setting(default="gpt-3.5-turbo")
2018

21-
def __init__(self, config: Configuration) -> None:
22-
self.config = config
19+
def __init__(self, *args, **kwargs) -> None:
20+
super().__init__(*args, **kwargs)
2321
if not self.api_key:
2422
msg = "No API key provided"
2523
raise ProviderError(msg)

src/shelloracle/providers/openai_compat.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
if TYPE_CHECKING:
1010
from collections.abc import AsyncIterator
1111

12-
from shelloracle.config import Configuration
13-
1412

1513
class OpenAICompat(Provider):
1614
name = "OpenAICompat"
@@ -19,8 +17,8 @@ class OpenAICompat(Provider):
1917
api_key = Setting(default="")
2018
model = Setting(default="")
2119

22-
def __init__(self, config: Configuration) -> None:
23-
self.config = config
20+
def __init__(self, *args, **kwargs) -> None:
21+
super().__init__(*args, **kwargs)
2422
if not self.api_key:
2523
msg = "No API key provided. Use a dummy placeholder if no key is required"
2624
raise ProviderError(msg)

src/shelloracle/providers/xai.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,15 @@
99
if TYPE_CHECKING:
1010
from collections.abc import AsyncIterator
1111

12-
from shelloracle.config import Configuration
13-
1412

1513
class XAI(Provider):
1614
name = "XAI"
1715

1816
api_key = Setting(default="")
1917
model = Setting(default="grok-beta")
2018

21-
def __init__(self, config: Configuration) -> None:
22-
self.config = config
19+
def __init__(self, *args, **kwargs) -> None:
20+
super().__init__(*args, **kwargs)
2321
if not self.api_key:
2422
msg = "No API key provided"
2523
raise ProviderError(msg)

0 commit comments

Comments
 (0)