Skip to content
Merged
29 changes: 29 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,35 @@ deduplicated_dataframe = pd.DataFrame(deduplicated_records)

</details>

<details>
<summary> Initializing from embeddings </summary>
<br>
You can also initialize SemHash from pre-computed embeddings. The following code snippet shows how to do this:

```python
from datasets import load_dataset
from model2vec import StaticModel
from semhash import SemHash

# Load a dataset
texts = load_dataset("ag_news", split="train")["text"]

# Load an embedding model
model = StaticModel.from_pretrained("minishlab/potion-base-8M")

# Create embeddings
embeddings = model.encode(texts)

# Initialize SemHash from embeddings
semhash = SemHash.from_embeddings(embeddings=embeddings, records=texts, model=model)

# Deduplicate, filter outliers, and find representative samples
deduplicated_texts = semhash.self_deduplicate().selected
filtered_texts = semhash.self_filter_outliers().selected
representative_texts = semhash.self_find_representative().selected
```
</details>

NOTE: By default, we use the ANN (approximate-nearest neighbors) backend for deduplication. We recommend keeping this since the recall for smaller datasets is ~100%, and it's needed for larger datasets (>1M samples) since these will take too long to deduplicate without ANN. If you want to use the flat/exact-matching backend, you can set `use_ann=False` in the SemHash constructor:

```python
Expand Down
7 changes: 2 additions & 5 deletions semhash/datamodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,11 @@
from collections.abc import Hashable, Sequence
from dataclasses import dataclass, field
from functools import cached_property
from typing import Any, Generic, TypeAlias, TypeVar
from typing import Generic

from frozendict import frozendict

from semhash.utils import to_frozendict

Record = TypeVar("Record", str, dict[str, Any])
DuplicateList: TypeAlias = list[tuple[Record, float]]
from semhash.utils import DuplicateList, Record, to_frozendict


@dataclass
Expand Down
167 changes: 82 additions & 85 deletions semhash/semhash.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,14 @@
from semhash.datamodels import DeduplicationResult, DuplicateRecord, FilterResult, Record
from semhash.index import Index
from semhash.records import add_scores_to_records, map_deduplication_result_to_strings
from semhash.utils import Encoder, compute_candidate_limit, to_frozendict
from semhash.utils import (
Encoder,
compute_candidate_limit,
featurize,
prepare_records,
remove_exact_duplicates,
to_frozendict,
)


class SemHash(Generic[Record]):
Expand All @@ -33,70 +40,6 @@ def __init__(self, index: Index, model: Encoder, columns: Sequence[str], was_str
self._was_string = was_string
self._ranking_cache: FilterResult | None = None

@staticmethod
def _featurize(
records: Sequence[dict[str, str]],
columns: Sequence[str],
model: Encoder,
) -> np.ndarray:
"""
Featurize a list of records using the model.

:param records: A list of records.
:param columns: Columns to featurize.
:param model: An Encoder model.
:return: The embeddings of the records.
"""
# Extract the embeddings for each column across all records
embeddings_per_col = []
for col in columns:
col_texts = [r[col] for r in records]
col_emb = model.encode(col_texts)
embeddings_per_col.append(np.asarray(col_emb))

return np.concatenate(embeddings_per_col, axis=1)

@classmethod
def _remove_exact_duplicates(
cls,
records: Sequence[dict[str, str]],
columns: Sequence[str],
reference_records: list[list[dict[str, str]]] | None = None,
) -> tuple[list[dict[str, str]], list[tuple[dict[str, str], list[dict[str, str]]]]]:
"""
Remove exact duplicates based on the unpacked string representation of each record.

If reference_records is None, the function will only check for duplicates within the records list.

:param records: A list of records to check for exact duplicates.
:param columns: Columns to unpack.
:param reference_records: A list of records to compare against. These are already unpacked
:return: A list of deduplicated records and a list of duplicates.
"""
deduplicated = []
duplicates = []

column_set = set(columns)
# Build a seen set from reference_records if provided
seen: defaultdict[frozendict[str, str], list[dict[str, str]]] = defaultdict(list)
if reference_records is not None:
for record_set in reference_records:
key = to_frozendict(record_set[0], column_set)
seen[key] = list(record_set)
in_one_set = reference_records is None

for record in records:
frozen_record = frozendict({k: v for k, v in record.items() if k in column_set})
if duplicated_records := seen.get(frozen_record):
duplicates.append((record, duplicated_records))
else:
deduplicated.append(record)
# Only add current documents to seen if no reference set is used
if in_one_set:
seen[frozen_record].append(record)

return deduplicated, duplicates

@classmethod
def from_records(
cls,
Expand All @@ -119,26 +62,16 @@ def from_records(
:param ann_backend: (Optional) The ANN backend to use if use_ann is True. Defaults to Backend.USEARCH.
:param **kwargs: Any additional keyword arguments to pass to the Vicinity index.
:return: A SemHash instance with a fitted vicinity index.
:raises ValueError: If columns are not provided for dictionary records.
"""
if columns is None and isinstance(records[0], dict):
raise ValueError("Columns must be specified when passing dictionaries.")

if isinstance(records[0], str):
# If records are strings, convert to dictionaries with a single column
columns = ["text"]
dict_records: list[dict[str, str]] = [{"text": record} for record in records]
was_string = True
else:
dict_records = list(records)
was_string = False
# Prepare and validate records
dict_records, columns, was_string = prepare_records(records, columns)

# If no model is provided, load the default model
if model is None:
model = StaticModel.from_pretrained("minishlab/potion-base-8M")

# Remove exact duplicates
deduplicated_records, duplicates = cls._remove_exact_duplicates(dict_records, columns)
deduplicated_records, duplicates = remove_exact_duplicates(dict_records, columns)

col_set = set(columns)
duplicate_map = defaultdict(list)
Expand All @@ -149,12 +82,12 @@ def from_records(
items: list[list[dict[str, str]]] = []
for record in deduplicated_records:
i = [record]
frozen_record = to_frozendict(record, set(columns))
frozen_record = to_frozendict(record, col_set)
i.extend(duplicate_map[frozen_record])
items.append(i)

# Create embeddings and unpack records
embeddings = cls._featurize(deduplicated_records, columns, model)
# Create embeddings for deduplicated records only
embeddings = featurize(deduplicated_records, columns, model)

# Build the Vicinity index
backend = ann_backend if use_ann else Backend.BASIC
Expand All @@ -167,6 +100,70 @@ def from_records(

return cls(index=index, columns=columns, model=model, was_string=was_string)

@classmethod
def from_embeddings(
cls,
embeddings: np.ndarray,
records: Sequence[Record],
model: Encoder,
columns: Sequence[str] | None = None,
use_ann: bool = True,
ann_backend: Backend | str = Backend.USEARCH,
**kwargs: Any,
) -> SemHash:
"""
Initialize a SemHash instance from pre-computed embeddings.

This removes exact duplicates and fits a vicinity index using the provided embeddings.

:param embeddings: Pre-computed embeddings as a numpy array of shape (n_records, embedding_dim).
:param records: A list of records (strings or dictionaries) corresponding to the embeddings.
:param model: The Encoder model used for creating the embeddings.
:param columns: Columns to use if records are dictionaries. If None and records are strings,
defaults to ["text"].
:param use_ann: Whether to use approximate nearest neighbors (True) or basic search (False). Default is True.
:param ann_backend: (Optional) The ANN backend to use if use_ann is True. Defaults to Backend.USEARCH.
:param **kwargs: Any additional keyword arguments to pass to the Vicinity index.
:return: A SemHash instance with a fitted vicinity index.
:raises ValueError: If the number of embeddings doesn't match the number of records.
:raises ValueError: If columns are not provided for dictionary records.
"""
if len(embeddings) != len(records):
raise ValueError(f"Number of embeddings ({len(embeddings)}) must match number of records ({len(records)})")

# Prepare and validate records
dict_records, columns, was_string = prepare_records(records, columns)

# Remove exact duplicates
deduplicated_records, exact_duplicates = remove_exact_duplicates(dict_records, columns)

# Build items list. Each item is a list of exact duplicates
items: list[list[dict[str, str]]] = [[record] for record in deduplicated_records]

# Add exact duplicates to their corresponding items
for duplicate_record, original_records in exact_duplicates:
for item in items:
if item[0] == original_records[0]:
item.append(duplicate_record)
break

# Build index mapping for embeddings (accounting for removed exact duplicates)
embedding_indices = []
for i, record in enumerate(dict_records):
if record in deduplicated_records:
embedding_indices.append(i)

# Select embeddings for non-exact-duplicate records
deduplicated_embeddings = embeddings[embedding_indices]

# Create the index
backend_type = ann_backend if use_ann else Backend.BASIC
index = Index.from_vectors_and_items(
vectors=deduplicated_embeddings, items=items, backend_type=backend_type, **kwargs
)

return cls(index=index, model=model, columns=columns, was_string=was_string)

def deduplicate(
self,
records: Sequence[Record],
Expand All @@ -186,7 +183,7 @@ def deduplicate(
dict_records = self._validate_if_strings(records)

# Remove exact duplicates before embedding
dict_records, exact_duplicates = self._remove_exact_duplicates(
dict_records, exact_duplicates = remove_exact_duplicates(
records=dict_records, columns=self.columns, reference_records=self.index.items
)
duplicate_records = []
Expand All @@ -202,7 +199,7 @@ def deduplicate(
)

# Compute embeddings for the new records
embeddings = self._featurize(records=dict_records, columns=self.columns, model=self.model)
embeddings = featurize(records=dict_records, columns=self.columns, model=self.model)
# Query the fitted index
results = self.index.query_threshold(embeddings, threshold=threshold)

Expand Down Expand Up @@ -459,7 +456,7 @@ def _rank_by_average_similarity(
:return: A FilterResult containing the ranking (records sorted and their average similarity scores).
"""
dict_records = self._validate_if_strings(records)
embeddings = self._featurize(records=dict_records, columns=self.columns, model=self.model)
embeddings = featurize(records=dict_records, columns=self.columns, model=self.model)
results = self.index.query_top_k(embeddings, k=100, vectors_are_in_index=False)

# Compute the average similarity for each record.
Expand Down Expand Up @@ -523,7 +520,7 @@ def _diversify(
if not candidates:
return FilterResult(selected=[], filtered=[], scores_selected=[], scores_filtered=[])

embeddings = self._featurize(records=candidates, columns=self.columns, model=self.model)
embeddings = featurize(records=candidates, columns=self.columns, model=self.model)
result = diversify(
embeddings=embeddings,
scores=np.array(relevance),
Expand Down
Loading