Skip to content

Commit 18b9e95

Browse files
fix: check for valid schema before querying by constrained field (#75)
**Description** This PR fixes #74. If querying (modusDB.Get) for a struct's constrained field before the schema contained that type, the worker would panic (the bug contains that trace). This change ensures the schema contains the type before launching the query. Note that querying by uid does not suffer from this issue as the discovery of the tablets (groups) are not involved in the query. Also added tests (as well as expanding existing unit tests) and changed the api file to use errors.New instead of the fmt package as no formatting was being applied. **Checklist** - [x] Code compiles correctly and linting passes locally - [ ] For all _code_ changes, an entry added to the `CHANGELOG.md` file describing and linking to this PR - [x] Tests added for new functionality, or regression tests for bug fixes added as applicable --------- Co-authored-by: Ryan Fox-Tyler <[email protected]>
1 parent 6d80353 commit 18b9e95

File tree

2 files changed

+54
-9
lines changed

2 files changed

+54
-9
lines changed

api.go

+20-9
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ package modusdb
77

88
import (
99
"context"
10-
"fmt"
10+
"errors"
11+
"reflect"
1112

1213
"github.com/hypermodeinc/dgraph/v24/dql"
1314
"github.com/hypermodeinc/dgraph/v24/schema"
@@ -20,7 +21,7 @@ func Create[T any](ctx context.Context, engine *Engine, object T,
2021
engine.mutex.Lock()
2122
defer engine.mutex.Unlock()
2223
if len(nsId) > 1 {
23-
return 0, object, fmt.Errorf("only one namespace is allowed")
24+
return 0, object, errors.New("only one namespace is allowed")
2425
}
2526
ctx, ns, err := getDefaultNamespace(ctx, engine, nsId...)
2627
if err != nil {
@@ -59,7 +60,7 @@ func Upsert[T any](ctx context.Context, engine *Engine, object T,
5960
engine.mutex.Lock()
6061
defer engine.mutex.Unlock()
6162
if len(nsId) > 1 {
62-
return 0, object, false, fmt.Errorf("only one namespace is allowed")
63+
return 0, object, false, errors.New("only one namespace is allowed")
6364
}
6465

6566
ctx, ns, err := getDefaultNamespace(ctx, engine, nsId...)
@@ -131,7 +132,7 @@ func Get[T any, R UniqueField](ctx context.Context, engine *Engine, uniqueField
131132
defer engine.mutex.Unlock()
132133
var obj T
133134
if len(nsId) > 1 {
134-
return 0, obj, fmt.Errorf("only one namespace is allowed")
135+
return 0, obj, errors.New("only one namespace is allowed")
135136
}
136137
ctx, ns, err := getDefaultNamespace(ctx, engine, nsId...)
137138
if err != nil {
@@ -142,18 +143,28 @@ func Get[T any, R UniqueField](ctx context.Context, engine *Engine, uniqueField
142143
}
143144

144145
if cf, ok := any(uniqueField).(ConstrainedField); ok {
145-
return getByConstrainedField[T](ctx, ns, cf)
146+
objType := reflect.TypeOf(obj)
147+
sch, err := getSchema(ctx, ns)
148+
if err != nil {
149+
return 0, obj, err
150+
}
151+
for _, t := range sch.Types {
152+
if t.Name == objType.Name() {
153+
return getByConstrainedField[T](ctx, ns, cf)
154+
}
155+
}
156+
return 0, obj, errors.New("type not found")
146157
}
147158

148-
return 0, obj, fmt.Errorf("invalid unique field type")
159+
return 0, obj, errors.New("invalid unique field type")
149160
}
150161

151162
func Query[T any](ctx context.Context, engine *Engine, queryParams QueryParams,
152163
nsId ...uint64) ([]uint64, []T, error) {
153164
engine.mutex.Lock()
154165
defer engine.mutex.Unlock()
155166
if len(nsId) > 1 {
156-
return nil, nil, fmt.Errorf("only one namespace is allowed")
167+
return nil, nil, errors.New("only one namespace is allowed")
157168
}
158169
ctx, ns, err := getDefaultNamespace(ctx, engine, nsId...)
159170
if err != nil {
@@ -169,7 +180,7 @@ func Delete[T any, R UniqueField](ctx context.Context, engine *Engine, uniqueFie
169180
defer engine.mutex.Unlock()
170181
var zeroObj T
171182
if len(nsId) > 1 {
172-
return 0, zeroObj, fmt.Errorf("only one namespace is allowed")
183+
return 0, zeroObj, errors.New("only one namespace is allowed")
173184
}
174185
ctx, ns, err := getDefaultNamespace(ctx, engine, nsId...)
175186
if err != nil {
@@ -207,5 +218,5 @@ func Delete[T any, R UniqueField](ctx context.Context, engine *Engine, uniqueFie
207218
return uid, obj, nil
208219
}
209220

210-
return 0, zeroObj, fmt.Errorf("invalid unique field type")
221+
return 0, zeroObj, errors.New("invalid unique field type")
211222
}

unit_test/api_test.go

+34
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,40 @@ func TestFirstTimeUser(t *testing.T) {
5959
require.Equal(t, "A", queriedUser2.Name)
6060
require.Equal(t, "123", queriedUser2.ClerkId)
6161

62+
// Search for a non-existent record
63+
_, _, err = modusdb.Get[User](context.Background(), engine, modusdb.ConstrainedField{
64+
Key: "clerk_id",
65+
Value: "456",
66+
})
67+
require.Error(t, err)
68+
require.Equal(t, "no object found", err.Error())
69+
6270
_, _, err = modusdb.Delete[User](context.Background(), engine, gid)
6371
require.NoError(t, err)
6472

6573
_, queriedUser3, err := modusdb.Get[User](context.Background(), engine, gid)
6674
require.Error(t, err)
6775
require.Equal(t, "no object found", err.Error())
6876
require.Equal(t, queriedUser3, User{})
77+
}
6978

79+
func TestGetBeforeObjectWrite(t *testing.T) {
80+
ctx := context.Background()
81+
engine, err := modusdb.NewEngine(modusdb.NewDefaultConfig(t.TempDir()))
82+
require.NoError(t, err)
83+
defer engine.Close()
84+
ns, err := engine.CreateNamespace()
85+
require.NoError(t, err)
86+
87+
_, _, err = modusdb.Get[User](ctx, engine, uint64(1), ns.ID())
88+
require.Error(t, err)
89+
90+
_, _, err = modusdb.Get[User](ctx, engine, modusdb.ConstrainedField{
91+
Key: "name",
92+
Value: "test",
93+
}, ns.ID())
94+
require.Error(t, err)
95+
require.Equal(t, "type not found", err.Error())
7096
}
7197

7298
func TestCreateApi(t *testing.T) {
@@ -125,6 +151,14 @@ func TestCreateApiWithNonStruct(t *testing.T) {
125151
_, _, err = modusdb.Create[*User](context.Background(), engine, &user, ns1.ID())
126152
require.Error(t, err)
127153
require.Equal(t, "expected struct, got ptr", err.Error())
154+
155+
_, _, err = modusdb.Create[[]string](context.Background(), engine, []string{"foo", "bar"}, ns1.ID())
156+
require.Error(t, err)
157+
require.Equal(t, "expected struct, got slice", err.Error())
158+
159+
_, _, err = modusdb.Create[float32](context.Background(), engine, 3.1415, ns1.ID())
160+
require.Error(t, err)
161+
require.Equal(t, "expected struct, got float32", err.Error())
128162
}
129163

130164
func TestGetApi(t *testing.T) {

0 commit comments

Comments
 (0)