|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import pandas as pd |
| 5 | + |
| 6 | +from skllm import FewShotGPTClassifier |
| 7 | +from skllm.memory import AnnoyMemoryIndex |
| 8 | +from skllm.models.gpt_few_shot_clf import _TRAINING_SAMPLE_PROMPT_TEMPLATE |
| 9 | +from skllm.preprocessing import GPTVectorizer |
| 10 | +from skllm.prompts.builders import build_few_shot_prompt_slc |
| 11 | +from skllm.utils import to_numpy |
| 12 | + |
| 13 | + |
| 14 | +class DynamicFewShotGPTClassifier(FewShotGPTClassifier): |
| 15 | + """Dynamic few-shot single-label classifier. |
| 16 | +
|
| 17 | + Parameters |
| 18 | + ---------- |
| 19 | + n_examples : int, optional |
| 20 | + number of examples per class, by default 3 |
| 21 | + openai_key : Optional[str] , default : None |
| 22 | + Your OpenAI API key. If None, the key will be read from the SKLLM_CONFIG_OPENAI_KEY environment variable. |
| 23 | + openai_org : Optional[str] , default : None |
| 24 | + Your OpenAI organization. If None, the organization will be read from the SKLLM_CONFIG_OPENAI_ORG |
| 25 | + environment variable. |
| 26 | + openai_model : str , default : "gpt-3.5-turbo" |
| 27 | + The OpenAI model to use. See https://beta.openai.com/docs/api-reference/available-models for a list of |
| 28 | + available models. |
| 29 | + default_label : Optional[Union[List[str], str]] , default : 'Random' |
| 30 | + The default label to use if the LLM could not generate a response for a sample. If set to 'Random' a random |
| 31 | + label will be chosen based on probabilities from the training set. |
| 32 | + """ |
| 33 | + |
| 34 | + def __init__( |
| 35 | + self, |
| 36 | + n_examples: int = 3, |
| 37 | + openai_key: str | None = None, |
| 38 | + openai_org: str | None = None, |
| 39 | + openai_model: str = "gpt-3.5-turbo", |
| 40 | + default_label: str | None = "Random", |
| 41 | + ): |
| 42 | + super().__init__(openai_key, openai_org, openai_model, default_label) |
| 43 | + self.n_examples = n_examples |
| 44 | + |
| 45 | + def fit( |
| 46 | + self, |
| 47 | + X: np.ndarray | pd.Series | list[str], |
| 48 | + y: np.ndarray | pd.Series | list[str], |
| 49 | + ) -> DynamicFewShotGPTClassifier: |
| 50 | + """Fits the model to the given data. |
| 51 | +
|
| 52 | + Parameters |
| 53 | + ---------- |
| 54 | + X : Union[np.ndarray, pd.Series, List[str]] |
| 55 | + training data |
| 56 | + y : Union[np.ndarray, pd.Series, List[str]] |
| 57 | + training labels |
| 58 | +
|
| 59 | + Returns |
| 60 | + ------- |
| 61 | + DynamicFewShotGPTClassifier |
| 62 | + self |
| 63 | + """ |
| 64 | + X = to_numpy(X) |
| 65 | + y = to_numpy(y) |
| 66 | + self.embedding_model_ = GPTVectorizer().fit(X) |
| 67 | + self.classes_, self.probabilities_ = self._get_unique_targets(y) |
| 68 | + |
| 69 | + self.data_ = {} |
| 70 | + for cls in self.classes_: |
| 71 | + print(f"Building index for class `{cls}` ...") |
| 72 | + self.data_[cls] = {} |
| 73 | + partition = X[y == cls] |
| 74 | + self.data_[cls]["partition"] = partition |
| 75 | + embeddings = self.embedding_model_.transform(partition) |
| 76 | + index = AnnoyMemoryIndex(embeddings.shape[1]) |
| 77 | + for i, embedding in enumerate(embeddings): |
| 78 | + index.add(i, embedding) |
| 79 | + index.build() |
| 80 | + self.data_[cls]["index"] = index |
| 81 | + |
| 82 | + return self |
| 83 | + |
| 84 | + def _get_prompt(self, x: str) -> str: |
| 85 | + """Generates the prompt for the given input. |
| 86 | +
|
| 87 | + Parameters |
| 88 | + ---------- |
| 89 | + x : str |
| 90 | + sample to classify |
| 91 | +
|
| 92 | + Returns |
| 93 | + ------- |
| 94 | + str |
| 95 | + final prompt |
| 96 | + """ |
| 97 | + embedding = self.embedding_model_.transform([x]) |
| 98 | + training_data = [] |
| 99 | + for cls in self.classes_: |
| 100 | + index = self.data_[cls]["index"] |
| 101 | + partition = self.data_[cls]["partition"] |
| 102 | + neighbors = index.retrieve(embedding, min(self.n_examples, len(partition))) |
| 103 | + neighbors = [partition[i] for i in neighbors[0]] |
| 104 | + training_data.extend( |
| 105 | + [ |
| 106 | + _TRAINING_SAMPLE_PROMPT_TEMPLATE.format(x=neighbor, label=cls) |
| 107 | + for neighbor in neighbors |
| 108 | + ] |
| 109 | + ) |
| 110 | + |
| 111 | + training_data_str = "\n".join(training_data) |
| 112 | + |
| 113 | + return build_few_shot_prompt_slc( |
| 114 | + x=x, training_data=training_data_str, labels=repr(self.classes_) |
| 115 | + ) |
0 commit comments