Skip to content

Commit b497a08

Browse files
committed
Add custom options to client bulkWrite.
1 parent 54bab6d commit b497a08

File tree

3 files changed

+46
-2
lines changed

3 files changed

+46
-2
lines changed

mongo/client.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -963,6 +963,16 @@ func (c *Client) BulkWrite(ctx context.Context, writes []ClientBulkWrite,
963963
op.rawData = &rawData
964964
}
965965
}
966+
if bypassEmptyTsReplacementOpt := optionsutil.Value(bwo.Internal, "bypassEmptyTsReplacement"); bypassEmptyTsReplacementOpt != nil {
967+
if bypassEmptyTsReplacement, ok := bypassEmptyTsReplacementOpt.(bool); ok {
968+
op.bypassEmptyTsReplacement = &bypassEmptyTsReplacement
969+
}
970+
}
971+
if commandCallbackOpt := optionsutil.Value(bwo.Internal, "commandCallback"); commandCallbackOpt != nil {
972+
if commandCallback, ok := commandCallbackOpt.(func([]byte, description.SelectedServer) ([]byte, error)); ok {
973+
op.commandCallback = commandCallback
974+
}
975+
}
966976
if bwo.VerboseResults == nil || !(*bwo.VerboseResults) {
967977
op.errorsOnly = true
968978
} else if !acknowledged {

mongo/client_bulk_write.go

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ type clientBulkWrite struct {
4545
selector description.ServerSelector
4646
writeConcern *writeconcern.WriteConcern
4747
rawData *bool
48+
bypassEmptyTsReplacement *bool
49+
commandCallback func([]byte, description.SelectedServer) ([]byte, error)
4850

4951
result ClientBulkWriteResult
5052
}
@@ -122,7 +124,8 @@ func (bw *clientBulkWrite) execute(ctx context.Context) error {
122124
}
123125

124126
func (bw *clientBulkWrite) newCommand() func([]byte, description.SelectedServer) ([]byte, error) {
125-
return func(dst []byte, desc description.SelectedServer) ([]byte, error) {
127+
return func(cmd []byte, desc description.SelectedServer) ([]byte, error) {
128+
var dst []byte
126129
dst = bsoncore.AppendInt32Element(dst, "bulkWrite", 1)
127130

128131
dst = bsoncore.AppendBooleanElement(dst, "errorsOnly", bw.errorsOnly)
@@ -148,7 +151,19 @@ func (bw *clientBulkWrite) newCommand() func([]byte, description.SelectedServer)
148151
if bw.rawData != nil && desc.WireVersion != nil && driverutil.VersionRangeIncludes(*desc.WireVersion, 27) {
149152
dst = bsoncore.AppendBooleanElement(dst, "rawData", *bw.rawData)
150153
}
151-
return dst, nil
154+
if bw.bypassEmptyTsReplacement != nil {
155+
dst = bsoncore.AppendBooleanElement(dst, "bypassEmptyTsReplacement", *bw.rawData)
156+
}
157+
if bw.commandCallback != nil {
158+
var err error
159+
dst, err = bw.commandCallback(dst, desc)
160+
if err != nil {
161+
return nil, err
162+
}
163+
}
164+
165+
cmd = append(cmd, dst...)
166+
return cmd, nil
152167
}
153168
}
154169

x/mongo/driver/xoptions/options.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"go.mongodb.org/mongo-driver/v2/internal/optionsutil"
1313
"go.mongodb.org/mongo-driver/v2/mongo/options"
1414
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
15+
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
1516
)
1617

1718
// SetInternalClientOptions sets internal options for ClientOptions.
@@ -101,6 +102,24 @@ func SetInternalClientBulkWriteOptions(a *options.ClientBulkWriteOptionsBuilder,
101102
opts.Internal = optionsutil.WithValue(opts.Internal, key, b)
102103
return nil
103104
})
105+
case "bypassEmptyTsReplacement":
106+
b, ok := option.(bool)
107+
if !ok {
108+
return typeErrFunc("bool")
109+
}
110+
a.Opts = append(a.Opts, func(opts *options.ClientBulkWriteOptions) error {
111+
opts.Internal = optionsutil.WithValue(opts.Internal, key, b)
112+
return nil
113+
})
114+
case "commandCallback":
115+
cb, ok := option.(func([]byte, description.SelectedServer) ([]byte, error))
116+
if !ok {
117+
return typeErrFunc("func([]byte, description.SelectedServer) ([]byte, error)")
118+
}
119+
a.Opts = append(a.Opts, func(opts *options.ClientBulkWriteOptions) error {
120+
opts.Internal = optionsutil.WithValue(opts.Internal, key, cb)
121+
return nil
122+
})
104123
default:
105124
return fmt.Errorf("unsupported option: %q", key)
106125
}

0 commit comments

Comments
 (0)