Skip to content

Commit 6a59ec0

Browse files
authored
Merge pull request #3324 from dmvolod/issue-3322
✨ Able to set WithContextFunc in WebhookBuilder
2 parents 7c50567 + 5c9496f commit 6a59ec0

File tree

2 files changed

+65
-0
lines changed

2 files changed

+65
-0
lines changed

pkg/builder/webhook.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
package builder
1818

1919
import (
20+
"context"
2021
"errors"
2122
"net/http"
2223
"net/url"
@@ -49,6 +50,7 @@ type WebhookBuilder struct {
4950
config *rest.Config
5051
recoverPanic *bool
5152
logConstructor func(base logr.Logger, req *admission.Request) logr.Logger
53+
contextFunc func(context.Context, *http.Request) context.Context
5254
err error
5355
}
5456

@@ -90,6 +92,12 @@ func (blder *WebhookBuilder) WithLogConstructor(logConstructor func(base logr.Lo
9092
return blder
9193
}
9294

95+
// WithContextFunc overrides the webhook's WithContextFunc.
96+
func (blder *WebhookBuilder) WithContextFunc(contextFunc func(context.Context, *http.Request) context.Context) *WebhookBuilder {
97+
blder.contextFunc = contextFunc
98+
return blder
99+
}
100+
93101
// RecoverPanic indicates whether panics caused by the webhook should be recovered.
94102
// Defaults to true.
95103
func (blder *WebhookBuilder) RecoverPanic(recoverPanic bool) *WebhookBuilder {
@@ -205,6 +213,7 @@ func (blder *WebhookBuilder) registerDefaultingWebhook() error {
205213
mwh := blder.getDefaultingWebhook()
206214
if mwh != nil {
207215
mwh.LogConstructor = blder.logConstructor
216+
mwh.WithContextFunc = blder.contextFunc
208217
path := generateMutatePath(blder.gvk)
209218
if blder.customDefaulterCustomPath != "" {
210219
generatedCustomPath, err := generateCustomPath(blder.customDefaulterCustomPath)
@@ -243,6 +252,7 @@ func (blder *WebhookBuilder) registerValidatingWebhook() error {
243252
vwh := blder.getValidatingWebhook()
244253
if vwh != nil {
245254
vwh.LogConstructor = blder.logConstructor
255+
vwh.WithContextFunc = blder.contextFunc
246256
path := generateValidatePath(blder.gvk)
247257
if blder.customValidatorCustomPath != "" {
248258
generatedCustomPath, err := generateCustomPath(blder.customValidatorCustomPath)

pkg/builder/webhook_test.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,14 @@ const (
4949
svcBaseAddr = "http://svc-name.svc-ns.svc"
5050

5151
customPath = "/custom-path"
52+
53+
userAgentHeader = "User-Agent"
54+
userAgentCtxKey agentCtxKey = "UserAgent"
55+
userAgentValue = "test"
5256
)
5357

58+
type agentCtxKey string
59+
5460
var _ = Describe("webhook", func() {
5561
Describe("New", func() {
5662
Context("v1 AdmissionReview", func() {
@@ -315,6 +321,9 @@ func runTests(admissionReviewVersion string) {
315321
WithLogConstructor(func(base logr.Logger, req *admission.Request) logr.Logger {
316322
return admission.DefaultLogConstructor(testingLogger, req)
317323
}).
324+
WithContextFunc(func(ctx context.Context, request *http.Request) context.Context {
325+
return context.WithValue(ctx, userAgentCtxKey, request.Header.Get(userAgentHeader))
326+
}).
318327
Complete()
319328
ExpectWithOffset(1, err).NotTo(HaveOccurred())
320329
svr := m.GetWebhookServer()
@@ -344,6 +353,30 @@ func runTests(admissionReviewVersion string) {
344353
}
345354
}
346355
}`)
356+
readerWithCxt := strings.NewReader(admissionReviewGV + admissionReviewVersion + `",
357+
"request":{
358+
"uid":"07e52e8d-4513-11e9-a716-42010a800270",
359+
"kind":{
360+
"group":"foo.test.org",
361+
"version":"v1",
362+
"kind":"TestValidator"
363+
},
364+
"resource":{
365+
"group":"foo.test.org",
366+
"version":"v1",
367+
"resource":"testvalidator"
368+
},
369+
"namespace":"default",
370+
"name":"foo",
371+
"operation":"UPDATE",
372+
"object":{
373+
"replica":1
374+
},
375+
"oldObject":{
376+
"replica":1
377+
}
378+
}
379+
}`)
347380

348381
ctx, cancel := context.WithCancel(specCtx)
349382
cancel()
@@ -373,6 +406,20 @@ func runTests(admissionReviewVersion string) {
373406
ExpectWithOffset(1, w.Body).To(ContainSubstring(`"allowed":false`))
374407
ExpectWithOffset(1, w.Body).To(ContainSubstring(`"code":403`))
375408
EventuallyWithOffset(1, logBuffer).Should(gbytes.Say(`"msg":"Validating object","object":{"name":"foo","namespace":"default"},"namespace":"default","name":"foo","resource":{"group":"foo.test.org","version":"v1","resource":"testvalidator"},"user":"","requestID":"07e52e8d-4513-11e9-a716-42010a800270"`))
409+
410+
By("sending a request to a validating webhook with context header validation")
411+
path = generateValidatePath(testValidatorGVK)
412+
_, err = readerWithCxt.Seek(0, 0)
413+
ExpectWithOffset(1, err).NotTo(HaveOccurred())
414+
req = httptest.NewRequest("POST", svcBaseAddr+path, readerWithCxt)
415+
req.Header.Add("Content-Type", "application/json")
416+
req.Header.Add(userAgentHeader, userAgentValue)
417+
w = httptest.NewRecorder()
418+
svr.WebhookMux().ServeHTTP(w, req)
419+
ExpectWithOffset(1, w.Code).To(Equal(http.StatusOK))
420+
By("sanity checking the response contains reasonable field")
421+
ExpectWithOffset(1, w.Body).To(ContainSubstring(`"allowed":true`))
422+
ExpectWithOffset(1, w.Body).To(ContainSubstring(`"code":200`))
376423
})
377424

378425
It("should scaffold a custom validating webhook with a custom path", func(specCtx SpecContext) {
@@ -1009,6 +1056,7 @@ func (*TestCustomDefaulter) Default(ctx context.Context, obj runtime.Object) err
10091056
if d.Replica < 2 {
10101057
d.Replica = 2
10111058
}
1059+
10121060
return nil
10131061
}
10141062

@@ -1035,6 +1083,7 @@ func (*TestCustomValidator) ValidateCreate(ctx context.Context, obj runtime.Obje
10351083
if v.Replica < 0 {
10361084
return nil, errors.New("number of replica should be greater than or equal to 0")
10371085
}
1086+
10381087
return nil, nil
10391088
}
10401089

@@ -1056,6 +1105,12 @@ func (*TestCustomValidator) ValidateUpdate(ctx context.Context, oldObj, newObj r
10561105
if v.Replica < old.Replica {
10571106
return nil, fmt.Errorf("new replica %v should not be fewer than old replica %v", v.Replica, old.Replica)
10581107
}
1108+
1109+
userAgent, ok := ctx.Value(userAgentCtxKey).(string)
1110+
if ok && userAgent != userAgentValue {
1111+
return nil, fmt.Errorf("expected %s value is %q in TestCustomValidator got %q", userAgentCtxKey, userAgentValue, userAgent)
1112+
}
1113+
10591114
return nil, nil
10601115
}
10611116

0 commit comments

Comments
 (0)