From 7d9fb0e893f73f6d2d25a15237ee8ed1c46838be Mon Sep 17 00:00:00 2001 From: aacebo Date: Sun, 6 Oct 2024 16:34:54 -0400 Subject: [PATCH] cleanup --- any.go | 14 ++++++------ float.go | 24 +++++++++------------ int.go | 48 +++++++++++++++++++++++------------------ object.go | 24 +++++++++++++++------ string.go | 64 ++++++++++++++++++++++++++++++------------------------- time.go | 48 +++++++++++++++++++++++------------------ union.go | 16 +++++++++++++- 7 files changed, 137 insertions(+), 101 deletions(-) diff --git a/any.go b/any.go index 16a05da..74c8249 100644 --- a/any.go +++ b/any.go @@ -41,10 +41,10 @@ func (self *AnySchema) Rule(key string, value any, rule RuleFn) *AnySchema { func (self *AnySchema) Required() *AnySchema { return self.Rule("required", true, func(value reflect.Value) (any, error) { if !value.IsValid() { - return value, errors.New("required") + return nil, errors.New("required") } - return value, nil + return value.Interface(), nil }) } @@ -52,11 +52,11 @@ func (self *AnySchema) Enum(values ...any) *AnySchema { return self.Rule("enum", values, func(value reflect.Value) (any, error) { for _, v := range values { if value.Equal(reflect.Indirect(reflect.ValueOf(v))) { - return value, nil + return value.Interface(), nil } } - return value, fmt.Errorf("must be one of %v", values) + return nil, fmt.Errorf("must be one of %v", values) }) } @@ -71,7 +71,7 @@ func (self AnySchema) MarshalJSON() ([]byte, error) { } func (self AnySchema) Validate(value any) error { - return self.validate("", reflect.Indirect(reflect.ValueOf(value))) + return self.validate("", reflect.Indirect(reflect.ValueOf(value))) } func (self AnySchema) validate(key string, value reflect.Value) error { @@ -85,9 +85,7 @@ func (self AnySchema) validate(key string, value reflect.Value) error { continue } - if value.CanSet() { - value.Set(reflect.ValueOf(v)) - } + value = reflect.ValueOf(v) } if len(err.Errors) > 0 { diff --git a/float.go b/float.go index 83f130b..acc46f6 100644 --- a/float.go +++ b/float.go @@ -19,11 +19,11 @@ func Float() *FloatSchema { } if value.CanConvert(reflect.TypeFor[float64]()) { - value.Set(value.Convert(reflect.TypeFor[float64]())) + value = value.Convert(reflect.TypeFor[float64]()) } if value.Kind() != reflect.Float64 { - return nil, errors.New("must be a float") + return value.Interface(), errors.New("must be a float") } return value.Interface(), nil @@ -60,28 +60,28 @@ func (self *FloatSchema) Enum(values ...float64) *FloatSchema { func (self *FloatSchema) Min(min float64) *FloatSchema { return self.Rule("min", min, func(value reflect.Value) (any, error) { if !value.IsValid() { - return value, nil + return nil, nil } if value.Float() < min { - return value, fmt.Errorf("must have value of at least %f", min) + return value.Interface(), fmt.Errorf("must have value of at least %f", min) } - return value, nil + return value.Interface(), nil }) } func (self *FloatSchema) Max(max float64) *FloatSchema { return self.Rule("max", max, func(value reflect.Value) (any, error) { if !value.IsValid() { - return value, nil + return nil, nil } if value.Float() > max { - return value, fmt.Errorf("must have value of at most %f", max) + return value.Interface(), fmt.Errorf("must have value of at most %f", max) } - return value, nil + return value.Interface(), nil }) } @@ -90,13 +90,9 @@ func (self FloatSchema) MarshalJSON() ([]byte, error) { } func (self FloatSchema) Validate(value any) error { - return self.validate("", reflect.Indirect(reflect.ValueOf(value))) + return self.validate("", reflect.ValueOf(value)) } func (self FloatSchema) validate(key string, value reflect.Value) error { - if err := self.schema.validate(key, value); err != nil { - return err - } - - return nil + return self.schema.validate(key, value) } diff --git a/int.go b/int.go index fe0a852..5688388 100644 --- a/int.go +++ b/int.go @@ -2,6 +2,7 @@ package owl import ( "encoding/json" + "errors" "fmt" "reflect" ) @@ -11,7 +12,24 @@ type IntSchema struct { } func Int() *IntSchema { - return &IntSchema{Any()} + self := &IntSchema{Any()} + self.Rule("type", self.Type(), func(value reflect.Value) (any, error) { + if !value.IsValid() { + return nil, nil + } + + if value.CanConvert(reflect.TypeFor[int]()) { + value = value.Convert(reflect.TypeFor[int]()) + } + + if value.Kind() != reflect.Int { + return value.Interface(), errors.New("must be an int") + } + + return value.Interface(), nil + }) + + return self } func (self IntSchema) Type() string { @@ -42,28 +60,28 @@ func (self *IntSchema) Enum(values ...int) *IntSchema { func (self *IntSchema) Min(min int) *IntSchema { return self.Rule("min", min, func(value reflect.Value) (any, error) { if !value.IsValid() { - return value, nil + return nil, nil } if value.Int() < int64(min) { - return value, fmt.Errorf("must have value of at least %d", min) + return value.Interface(), fmt.Errorf("must have value of at least %d", min) } - return value, nil + return value.Interface(), nil }) } func (self *IntSchema) Max(max int) *IntSchema { return self.Rule("max", max, func(value reflect.Value) (any, error) { if !value.IsValid() { - return value, nil + return nil, nil } if value.Int() > int64(max) { - return value, fmt.Errorf("must have value of at most %d", max) + return value.Interface(), fmt.Errorf("must have value of at most %d", max) } - return value, nil + return value.Interface(), nil }) } @@ -72,21 +90,9 @@ func (self IntSchema) MarshalJSON() ([]byte, error) { } func (self IntSchema) Validate(value any) error { - return self.validate("", reflect.Indirect(reflect.ValueOf(value))) + return self.validate("", reflect.ValueOf(value)) } func (self IntSchema) validate(key string, value reflect.Value) error { - if value.IsValid() && value.CanConvert(reflect.TypeFor[int]()) { - value = value.Convert(reflect.TypeFor[int]()) - } - - if err := self.schema.validate(key, value); err != nil { - return err - } - - if value.IsValid() && value.Kind() != reflect.Int { - return newError(key, "must be an integer") - } - - return nil + return self.schema.validate(key, value) } diff --git a/object.go b/object.go index 3c6f973..ef67457 100644 --- a/object.go +++ b/object.go @@ -12,7 +12,20 @@ type ObjectSchema struct { } func Object() *ObjectSchema { - return &ObjectSchema{Any(), map[string]Schema{}} + self := &ObjectSchema{Any(), map[string]Schema{}} + self.Rule("type", self.Type(), func(value reflect.Value) (any, error) { + if !value.IsValid() { + return nil, nil + } + + if value.Kind() != reflect.Struct && value.Kind() != reflect.Map { + return value.Interface(), errors.New("must be an object") + } + + return value.Interface(), nil + }) + + return self } func (self ObjectSchema) Type() string { @@ -39,7 +52,7 @@ func (self ObjectSchema) MarshalJSON() ([]byte, error) { } func (self ObjectSchema) Validate(value any) error { - return self.validate("", reflect.Indirect(reflect.ValueOf(value))) + return self.validate("", reflect.Indirect(reflect.ValueOf(value))) } func (self ObjectSchema) validate(key string, value reflect.Value) error { @@ -55,14 +68,11 @@ func (self ObjectSchema) validate(key string, value reflect.Value) error { value = value.Elem() } - switch value.Kind() { - case reflect.Map: + if value.Kind() == reflect.Map { return self.validateMap(key, value) - case reflect.Struct: - return self.validateStruct(key, value) } - return newError(key, "must be an object") + return self.validateStruct(key, value) } func (self ObjectSchema) validateMap(key string, value reflect.Value) error { diff --git a/string.go b/string.go index dd6c2af..6066d89 100644 --- a/string.go +++ b/string.go @@ -2,6 +2,7 @@ package owl import ( "encoding/json" + "errors" "fmt" "net/mail" "net/url" @@ -14,7 +15,20 @@ type StringSchema struct { } func String() *StringSchema { - return &StringSchema{Any()} + self := &StringSchema{Any()} + self.Rule("type", self.Type(), func(value reflect.Value) (any, error) { + if !value.IsValid() { + return nil, nil + } + + if value.Kind() != reflect.String { + return value.Interface(), errors.New("must be a string") + } + + return value.Interface(), nil + }) + + return self } func (self StringSchema) Type() string { @@ -45,93 +59,93 @@ func (self *StringSchema) Enum(values ...string) *StringSchema { func (self *StringSchema) Min(min int) *StringSchema { return self.Rule("min", min, func(value reflect.Value) (any, error) { if !value.IsValid() { - return value, nil + return nil, nil } if value.Len() < min { - return value, fmt.Errorf("must have length of at least %d", min) + return value.Interface(), fmt.Errorf("must have length of at least %d", min) } - return value, nil + return value.Interface(), nil }) } func (self *StringSchema) Max(max int) *StringSchema { return self.Rule("max", max, func(value reflect.Value) (any, error) { if !value.IsValid() { - return value, nil + return nil, nil } if value.Len() > max { - return value, fmt.Errorf("must have length of at most %d", max) + return value.Interface(), fmt.Errorf("must have length of at most %d", max) } - return value, nil + return value.Interface(), nil }) } func (self *StringSchema) Regex(re *regexp.Regexp) *StringSchema { return self.Rule("regex", re.String(), func(value reflect.Value) (any, error) { if !value.IsValid() { - return value, nil + return nil, nil } if !re.MatchString(value.String()) { - return value, fmt.Errorf("must match regex %s", re.String()) + return value.Interface(), fmt.Errorf("must match regex %s", re.String()) } - return value, nil + return value.Interface(), nil }) } func (self *StringSchema) Email() *StringSchema { return self.Rule("email", true, func(value reflect.Value) (any, error) { if !value.IsValid() { - return value, nil + return nil, nil } if _, err := mail.ParseAddress(value.String()); err != nil { - return value, fmt.Errorf( + return value.Interface(), fmt.Errorf( `"%s" does not match email format`, value.String(), ) } - return value, nil + return value.Interface(), nil }) } func (self *StringSchema) UUID() *StringSchema { return self.Rule("uuid", true, func(value reflect.Value) (any, error) { if !value.IsValid() { - return value, nil + return nil, nil } if !uuid.MatchString(value.String()) { - return value, fmt.Errorf( + return value.Interface(), fmt.Errorf( `"%s" does not match uuid format`, value.String(), ) } - return value, nil + return value.Interface(), nil }) } func (self *StringSchema) URL() *StringSchema { return self.Rule("url", true, func(value reflect.Value) (any, error) { if !value.IsValid() { - return value, nil + return nil, nil } if _, err := url.ParseRequestURI(value.String()); err != nil { - return value, fmt.Errorf( + return value.Interface(), fmt.Errorf( `"%s" does not match url format`, value.String(), ) } - return value, nil + return value.Interface(), nil }) } @@ -140,17 +154,9 @@ func (self StringSchema) MarshalJSON() ([]byte, error) { } func (self StringSchema) Validate(value any) error { - return self.validate("", reflect.Indirect(reflect.ValueOf(value))) + return self.validate("", reflect.Indirect(reflect.ValueOf(value))) } func (self StringSchema) validate(key string, value reflect.Value) error { - if err := self.schema.validate(key, value); err != nil { - return err - } - - if value.IsValid() && value.Kind() != reflect.String { - return newError(key, "must be a string") - } - - return nil + return self.schema.validate(key, value) } diff --git a/time.go b/time.go index e554c70..8698fb2 100644 --- a/time.go +++ b/time.go @@ -2,6 +2,7 @@ package owl import ( "encoding/json" + "errors" "fmt" "reflect" "time" @@ -13,7 +14,30 @@ type TimeSchema struct { } func Time() *TimeSchema { - return &TimeSchema{Any(), time.RFC3339} + self := &TimeSchema{Any(), time.RFC3339} + self.Rule("type", self.Type(), func(value reflect.Value) (any, error) { + if !value.IsValid() { + return nil, nil + } + + if value.Kind() != reflect.String && value.Type() != reflect.TypeFor[time.Time]() { + return value.Interface(), errors.New("must be a string or time.Time") + } + + if value.Kind() == reflect.String { + parsed, err := time.Parse(self.layout, value.String()) + + if err != nil { + return value.Interface(), err + } + + value = reflect.ValueOf(parsed) + } + + return value.Interface(), nil + }) + + return self } func (self TimeSchema) Type() string { @@ -72,27 +96,9 @@ func (self TimeSchema) MarshalJSON() ([]byte, error) { } func (self TimeSchema) Validate(value any) error { - return self.validate("", reflect.Indirect(reflect.ValueOf(value))) + return self.validate("", reflect.Indirect(reflect.ValueOf(value))) } func (self TimeSchema) validate(key string, value reflect.Value) error { - if value.IsValid() && value.Kind() != reflect.String && value.Type() != reflect.TypeFor[time.Time]() { - return newError(key, "must be a string or time.Time") - } - - if value.IsValid() && value.Kind() == reflect.String { - parsed, err := time.Parse(self.layout, value.String()) - - if err != nil { - return newError(key, err.Error()) - } - - value = reflect.ValueOf(parsed) - } - - if err := self.schema.validate(key, value); err != nil { - return err - } - - return nil + return self.schema.validate(key, value) } diff --git a/union.go b/union.go index 036caa7..6876433 100644 --- a/union.go +++ b/union.go @@ -2,6 +2,7 @@ package owl import ( "encoding/json" + "errors" "fmt" "reflect" "strings" @@ -13,7 +14,20 @@ type UnionSchema struct { } func Union(anyOf ...Schema) *UnionSchema { - return &UnionSchema{Any(), anyOf} + self := &UnionSchema{Any(), anyOf} + self.Rule("type", self.Type(), func(value reflect.Value) (any, error) { + for _, schema := range self.anyOf { + e := schema.Validate(value.Interface()) + + if e == nil { + return value.Interface(), nil + } + } + + return value.Interface(), errors.New("must match one or more types in union") + }) + + return self } func (self UnionSchema) Type() string {