|
4 | 4 | from django.core.checks import Error
|
5 | 5 | from django.db import NotSupportedError
|
6 | 6 | from django.db.models import (
|
7 |
| - BooleanField, |
8 |
| - CharField, |
9 |
| - DateField, |
10 |
| - DateTimeField, |
11 | 7 | DecimalField,
|
12 | 8 | FloatField,
|
13 | 9 | Index,
|
14 |
| - IntegerField, |
15 |
| - TextField, |
16 |
| - UUIDField, |
17 | 10 | )
|
18 | 11 | from django.db.models.lookups import BuiltinLookup
|
19 | 12 | from django.db.models.sql.query import Query
|
20 | 13 | from django.db.models.sql.where import AND, XOR, WhereNode
|
21 | 14 | from pymongo import ASCENDING, DESCENDING
|
22 | 15 | from pymongo.operations import IndexModel, SearchIndexModel
|
23 | 16 |
|
24 |
| -from django_mongodb_backend.fields import ArrayField, ObjectIdAutoField, ObjectIdField |
| 17 | +from django_mongodb_backend.fields import ArrayField |
25 | 18 |
|
26 | 19 | from .query_utils import process_rhs
|
27 | 20 |
|
@@ -161,7 +154,7 @@ def __init__(self, *expressions, similarities="cosine", **kwargs):
|
161 | 154 | # validate the similarities types
|
162 | 155 | self.similarities = similarities
|
163 | 156 |
|
164 |
| - def check(self, model): |
| 157 | + def check(self, model, connection): |
165 | 158 | errors = []
|
166 | 159 | error_id_prefix = "django_mongodb_backend.indexes.VectorSearchIndex"
|
167 | 160 | similarities = (
|
@@ -196,28 +189,20 @@ def check(self, model):
|
196 | 189 | id=f"{error_id_prefix}.E002",
|
197 | 190 | )
|
198 | 191 | )
|
199 |
| - # filter - for fields that contain boolean, date, objectId, |
200 |
| - # numeric, string, or UUID values. Reference: |
201 |
| - # https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/#atlas-vector-search-index-fields |
202 |
| - elif not isinstance( |
203 |
| - field_, |
204 |
| - BooleanField |
205 |
| - | IntegerField |
206 |
| - | DateField |
207 |
| - | DateTimeField |
208 |
| - | CharField |
209 |
| - | TextField |
210 |
| - | UUIDField |
211 |
| - | ObjectIdField |
212 |
| - | ObjectIdAutoField, |
213 |
| - ): |
214 |
| - errors.append( |
215 |
| - Error( |
216 |
| - f"Unsupported filter of type {field_.get_internal_type()}.", |
217 |
| - obj=self, |
218 |
| - id="django_mongodb_backend.indexes.VectorSearchIndex.E003", |
| 192 | + else: |
| 193 | + field_type = field_.db_type(connection) |
| 194 | + search_type = self.search_index_data_types(field_, field_type) |
| 195 | + # filter - for fields that contain boolean, date, objectId, |
| 196 | + # numeric, string, or UUID values. Reference: |
| 197 | + # https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/#atlas-vector-search-index-fields |
| 198 | + if search_type not in ("number", "string", "boolean", "objectId", "uuid", "date"): |
| 199 | + errors.append( |
| 200 | + Error( |
| 201 | + f"Unsupported filter of type {field_.get_internal_type()}.", |
| 202 | + obj=self, |
| 203 | + id="django_mongodb_backend.indexes.VectorSearchIndex.E003", |
| 204 | + ) |
219 | 205 | )
|
220 |
| - ) |
221 | 206 | return errors
|
222 | 207 |
|
223 | 208 | def deconstruct(self):
|
|
0 commit comments