Skip to content

Commit 66de4ab

Browse files
authored
Merge pull request #33 from iryna-kondr/feature_memory
Dynamic few shot classifier
2 parents 26dd797 + 7bae44e commit 66de4ab

File tree

10 files changed

+420
-29
lines changed

10 files changed

+420
-29
lines changed

.pre-commit-config.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ repos:
1313
- id: check-executables-have-shebangs
1414
- id: check-case-conflict
1515
- id: check-added-large-files
16-
- id: detect-aws-credentials
1716
- id: detect-private-key
1817
# Formatter for Json and Yaml files
1918
- repo: https://github.com/pre-commit/mirrors-prettier

README.md

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ ZeroShotGPTClassifier(openai_model="gpt4all::ggml-gpt4all-j-v1.3-groovy")
6262

6363
When running for the first time, the model file will be downloaded automatially.
6464

65-
At the moment only the following estimators support gpt4all as a backend:
65+
At the moment only the following estimators support gpt4all as a backend:
66+
6667
- `ZeroShotGPTClassifier`
6768
- `MultiLabelZeroShotGPTClassifier`
6869
- `FewShotGPTClassifier`
@@ -179,6 +180,27 @@ While the api remains the same as for the zero shot classifier, there are a few
179180

180181
Note: as the model is not being re-trained, but uses the training data during inference, one could say that this is still a (different) zero-shot approach.
181182

183+
### Dynamic Few-Shot Text Classification
184+
185+
`DynamicFewShotGPTClassifier` dynamically selects N samples per class to include in the prompt. This allows the few-shot classifier to scale to datasets that are too large for the standard context window of LLMs.
186+
187+
*How does it work?*
188+
189+
During fitting, the whole dataset is partitioned by class, vectorized, and stored.
190+
191+
During inference, the [annoy](https://github.com/spotify/annoy) library is used for fast neighbor lookup, which allows including only the most similar examples in the prompt.
192+
193+
```python
194+
from skllm import DynamicFewShotGPTClassifier
195+
from skllm.datasets import get_classification_dataset
196+
197+
X, y = get_classification_dataset()
198+
199+
clf = DynamicFewShotGPTClassifier(n_examples=3)
200+
clf.fit(X, y)
201+
labels = clf.predict(X)
202+
```
203+
182204
### Text Vectorization
183205

184206
As an alternative to using GPT as a classifier, it can be used solely for data preprocessing. `GPTVectorizer` allows to embed a chunk of text of arbitrary length to a fixed-dimensional vector, that can be used with virtually any classification or regression model.

pyproject.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@ dependencies = [
88
"pandas>=1.5.0",
99
"openai>=0.27.0",
1010
"tqdm>=4.60.0",
11+
"annoy>=1.17.2",
1112
]
1213
name = "scikit-llm"
13-
version = "0.1.1"
14+
version = "0.2.0"
1415
authors = [
1516
{ name="Oleg Kostromin", email="[email protected]" },
1617
{ name="Iryna Kondrashchenko", email="[email protected]" },
@@ -79,12 +80,13 @@ target-version = ['py38', 'py39', 'py310', 'py311']
7980
profile = "black"
8081
filter_files = true
8182
known_first_party = ["skllm", "skllm.*"]
83+
skip = ["__init__.py"]
8284

8385
[tool.docformatter]
8486
close-quotes-on-newline = true # D209
8587

8688
[tool.interrogate]
87-
fail-under = 80
89+
fail-under = 65
8890
ignore-module = true
8991
ignore-nested-functions = true
9092
ignore-private = true

skllm/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
from skllm.models.gpt_few_shot_clf import FewShotGPTClassifier
1+
# ordering is important here to prevent circular imports
22
from skllm.models.gpt_zero_shot_clf import (
33
MultiLabelZeroShotGPTClassifier,
44
ZeroShotGPTClassifier,
55
)
6+
from skllm.models.gpt_few_shot_clf import FewShotGPTClassifier
7+
from skllm.models.gpt_dyn_few_shot_clf import DynamicFewShotGPTClassifier

skllm/memory/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from skllm.memory._annoy import AnnoyMemoryIndex

skllm/memory/_annoy.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import os
2+
import tempfile
3+
from typing import Any, List
4+
5+
from annoy import AnnoyIndex
6+
from numpy import ndarray
7+
8+
from skllm.memory.base import _BaseMemoryIndex
9+
10+
11+
class AnnoyMemoryIndex(_BaseMemoryIndex):
12+
"""Memory index using Annoy.
13+
14+
Parameters
15+
----------
16+
dim : int
17+
dimensionality of the vectors
18+
metric : str, optional
19+
metric to use, by default "euclidean"
20+
"""
21+
22+
def __init__(self, dim: int, metric: str = "euclidean", **kwargs: Any) -> None:
23+
self._index = AnnoyIndex(dim, metric)
24+
self.metric = metric
25+
self.dim = dim
26+
self.built = False
27+
28+
def add(self, id: int, vector: ndarray) -> None:
29+
"""Adds a vector to the index.
30+
31+
Parameters
32+
----------
33+
id : Any
34+
identifier for the vector
35+
vector : ndarray
36+
vector to add to the index
37+
"""
38+
if self.built:
39+
raise RuntimeError("Cannot add vectors after index is built.")
40+
self._index.add_item(id, vector)
41+
42+
def build(self) -> None:
43+
"""Builds the index.
44+
45+
No new vectors can be added after building.
46+
"""
47+
self._index.build(-1)
48+
self.built = True
49+
50+
def retrieve(self, vectors: ndarray, k: int) -> List[List[int]]:
51+
"""Retrieves the k nearest neighbors for each vector.
52+
53+
Parameters
54+
----------
55+
vectors : ndarray
56+
vectors to retrieve nearest neighbors for
57+
k : int
58+
number of nearest neighbors to retrieve
59+
60+
Returns
61+
-------
62+
List
63+
ids of retrieved nearest neighbors
64+
"""
65+
if not self.built:
66+
raise RuntimeError("Cannot retrieve vectors before the index is built.")
67+
return [
68+
self._index.get_nns_by_vector(v, k, search_k=-1, include_distances=False)
69+
for v in vectors
70+
]
71+
72+
def __getstate__(self) -> dict:
73+
"""Returns the state of the object. To store the actual annoy index, it
74+
has to be written to a temporary file.
75+
76+
Returns
77+
-------
78+
dict
79+
state of the object
80+
"""
81+
state = self.__dict__.copy()
82+
83+
# save index to temporary file
84+
with tempfile.NamedTemporaryFile(delete=False) as tmp:
85+
temp_filename = tmp.name
86+
self._index.save(temp_filename)
87+
88+
# read bytes from the file
89+
with open(temp_filename, "rb") as tmp:
90+
index_bytes = tmp.read()
91+
92+
# store bytes representation in state
93+
state["_index"] = index_bytes
94+
95+
# remove temporary file
96+
os.remove(temp_filename)
97+
98+
return state
99+
100+
def __setstate__(self, state: dict) -> None:
101+
"""Sets the state of the object. It restores the annoy index from the
102+
bytes representation.
103+
104+
Parameters
105+
----------
106+
state : dict
107+
state of the object
108+
"""
109+
self.__dict__.update(state)
110+
# restore index from bytes
111+
with tempfile.NamedTemporaryFile(delete=False) as tmp:
112+
temp_filename = tmp.name
113+
tmp.write(self._index)
114+
115+
self._index = AnnoyIndex(self.dim, self.metric)
116+
self._index.load(temp_filename)
117+
118+
# remove temporary file
119+
os.remove(temp_filename)

skllm/memory/base.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Any, List
3+
4+
from numpy import ndarray
5+
6+
7+
class _BaseMemoryIndex(ABC):
8+
@abstractmethod
9+
def add(self, id: Any, vector: ndarray):
10+
"""Adds a vector to the index.
11+
12+
Parameters
13+
----------
14+
id : Any
15+
identifier for the vector
16+
vector : ndarray
17+
vector to add to the index
18+
"""
19+
pass
20+
21+
@abstractmethod
22+
def retrieve(self, vectors: ndarray, k: int) -> List:
23+
"""Retrieves the k nearest neighbors for each vector.
24+
25+
Parameters
26+
----------
27+
vectors : ndarray
28+
vectors to retrieve nearest neighbors for
29+
k : int
30+
number of nearest neighbors to retrieve
31+
32+
Returns
33+
-------
34+
List
35+
ids of retrieved nearest neighbors
36+
"""
37+
pass
38+
39+
@abstractmethod
40+
def build(self) -> None:
41+
"""Builds the index.
42+
43+
All build parameters should be passed to the constructor.
44+
"""
45+
pass
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
from __future__ import annotations
2+
3+
import numpy as np
4+
import pandas as pd
5+
6+
from skllm import FewShotGPTClassifier
7+
from skllm.memory import AnnoyMemoryIndex
8+
from skllm.models.gpt_few_shot_clf import _TRAINING_SAMPLE_PROMPT_TEMPLATE
9+
from skllm.preprocessing import GPTVectorizer
10+
from skllm.prompts.builders import build_few_shot_prompt_slc
11+
from skllm.utils import to_numpy
12+
13+
14+
class DynamicFewShotGPTClassifier(FewShotGPTClassifier):
15+
"""Dynamic few-shot single-label classifier.
16+
17+
Parameters
18+
----------
19+
n_examples : int, optional
20+
number of examples per class, by default 3
21+
openai_key : Optional[str] , default : None
22+
Your OpenAI API key. If None, the key will be read from the SKLLM_CONFIG_OPENAI_KEY environment variable.
23+
openai_org : Optional[str] , default : None
24+
Your OpenAI organization. If None, the organization will be read from the SKLLM_CONFIG_OPENAI_ORG
25+
environment variable.
26+
openai_model : str , default : "gpt-3.5-turbo"
27+
The OpenAI model to use. See https://beta.openai.com/docs/api-reference/available-models for a list of
28+
available models.
29+
default_label : Optional[Union[List[str], str]] , default : 'Random'
30+
The default label to use if the LLM could not generate a response for a sample. If set to 'Random' a random
31+
label will be chosen based on probabilities from the training set.
32+
"""
33+
34+
def __init__(
35+
self,
36+
n_examples: int = 3,
37+
openai_key: str | None = None,
38+
openai_org: str | None = None,
39+
openai_model: str = "gpt-3.5-turbo",
40+
default_label: str | None = "Random",
41+
):
42+
super().__init__(openai_key, openai_org, openai_model, default_label)
43+
self.n_examples = n_examples
44+
45+
def fit(
46+
self,
47+
X: np.ndarray | pd.Series | list[str],
48+
y: np.ndarray | pd.Series | list[str],
49+
) -> DynamicFewShotGPTClassifier:
50+
"""Fits the model to the given data.
51+
52+
Parameters
53+
----------
54+
X : Union[np.ndarray, pd.Series, List[str]]
55+
training data
56+
y : Union[np.ndarray, pd.Series, List[str]]
57+
training labels
58+
59+
Returns
60+
-------
61+
DynamicFewShotGPTClassifier
62+
self
63+
"""
64+
X = to_numpy(X)
65+
y = to_numpy(y)
66+
self.embedding_model_ = GPTVectorizer().fit(X)
67+
self.classes_, self.probabilities_ = self._get_unique_targets(y)
68+
69+
self.data_ = {}
70+
for cls in self.classes_:
71+
print(f"Building index for class `{cls}` ...")
72+
self.data_[cls] = {}
73+
partition = X[y == cls]
74+
self.data_[cls]["partition"] = partition
75+
embeddings = self.embedding_model_.transform(partition)
76+
index = AnnoyMemoryIndex(embeddings.shape[1])
77+
for i, embedding in enumerate(embeddings):
78+
index.add(i, embedding)
79+
index.build()
80+
self.data_[cls]["index"] = index
81+
82+
return self
83+
84+
def _get_prompt(self, x: str) -> str:
85+
"""Generates the prompt for the given input.
86+
87+
Parameters
88+
----------
89+
x : str
90+
sample to classify
91+
92+
Returns
93+
-------
94+
str
95+
final prompt
96+
"""
97+
embedding = self.embedding_model_.transform([x])
98+
training_data = []
99+
for cls in self.classes_:
100+
index = self.data_[cls]["index"]
101+
partition = self.data_[cls]["partition"]
102+
neighbors = index.retrieve(embedding, min(self.n_examples, len(partition)))
103+
neighbors = [partition[i] for i in neighbors[0]]
104+
training_data.extend(
105+
[
106+
_TRAINING_SAMPLE_PROMPT_TEMPLATE.format(x=neighbor, label=cls)
107+
for neighbor in neighbors
108+
]
109+
)
110+
111+
training_data_str = "\n".join(training_data)
112+
113+
return build_few_shot_prompt_slc(
114+
x=x, training_data=training_data_str, labels=repr(self.classes_)
115+
)

0 commit comments

Comments
 (0)