Skip to content

Commit b868f59

Browse files
WaVEVtimgraham
authored andcommitted
Replace decimal in vector index for integer
1 parent c74604e commit b868f59

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

django_mongodb_backend/indexes.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from django.core.checks import Error, Warning
55
from django.db import NotSupportedError
6-
from django.db.models import DecimalField, FloatField, Index
6+
from django.db.models import FloatField, Index, IntegerField
77
from django.db.models.lookups import BuiltinLookup
88
from django.db.models.sql.query import Query
99
from django.db.models.sql.where import AND, XOR, WhereNode
@@ -200,12 +200,12 @@ def check(self, model, connection):
200200
id=f"{self._error_id_prefix}.E002",
201201
)
202202
)
203-
if not isinstance(field_.base_field, FloatField | DecimalField):
203+
if not isinstance(field_.base_field, FloatField | IntegerField):
204204
errors.append(
205205
Error(
206206
"VectorSearchIndex requires the base field of "
207207
f"ArrayField '{field_.name}' to be FloatField or "
208-
"DecimalField but is "
208+
"IntegerField but is "
209209
f"{field_.base_field.get_internal_type()}.",
210210
obj=model,
211211
id=f"{self._error_id_prefix}.E003",

tests/indexes_/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,4 @@ class Article(models.Model):
1515
data = models.JSONField()
1616
embedded = EmbeddedModelField(Data)
1717
created_at = models.DateTimeField(auto_now=True)
18-
description_semantic = ArrayField(models.DecimalField(decimal_places=3, max_digits=10), size=10)
18+
description_semantic = ArrayField(models.IntegerField(), size=10)

tests/indexes_/test_checks.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ class Meta:
9797
[
9898
checks.Error(
9999
"VectorSearchIndex requires the base field of ArrayField "
100-
"'title_embedded' to be FloatField or DecimalField but is CharField.",
100+
"'title_embedded' to be FloatField or IntegerField but is CharField.",
101101
id="django_mongodb_backend.indexes.VectorSearchIndex.E003",
102102
obj=Article,
103103
)
@@ -125,7 +125,7 @@ class Meta:
125125

126126
def test_invalid_number_similarity_function_singular(self):
127127
class Article(models.Model):
128-
vector_data = ArrayField(models.DecimalField(), size=10)
128+
vector_data = ArrayField(models.FloatField(), size=10)
129129

130130
class Meta:
131131
indexes = [
@@ -150,8 +150,8 @@ class Meta:
150150

151151
def test_invalid_number_similarity_function_plural(self):
152152
class Article(models.Model):
153-
vector1 = ArrayField(models.DecimalField(), size=10)
154-
vector2 = ArrayField(models.DecimalField(), size=10)
153+
vector1 = ArrayField(models.FloatField(), size=10)
154+
vector2 = ArrayField(models.FloatField(), size=10)
155155

156156
class Meta:
157157
indexes = [
@@ -176,7 +176,7 @@ class Meta:
176176

177177
def test_simple(self):
178178
class Article(models.Model):
179-
vector_data = ArrayField(models.DecimalField(), size=10)
179+
vector_data = ArrayField(models.FloatField(), size=10)
180180

181181
class Meta:
182182
indexes = [VectorSearchIndex(fields=["vector_data"])]

0 commit comments

Comments
 (0)