Skip to content

Commit

Permalink
Fix TextUnmarshaler and BytesUnmarshaler
Browse files Browse the repository at this point in the history
  • Loading branch information
goccy committed Jun 23, 2020
1 parent cab0430 commit 3b54bfb
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 40 deletions.
57 changes: 36 additions & 21 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,35 @@ func (d *Decoder) lastNode(node ast.Node) ast.Node {
return node
}

func (d *Decoder) unmarshalableDocument(node ast.Node) []byte {
node = d.resolveAlias(node)
doc := node.String()
last := d.lastNode(node)
if last != nil && last.Type() == ast.LiteralType {
doc += "\n"
}
return []byte(doc)
}

func (d *Decoder) unmarshalableText(node ast.Node) ([]byte, bool) {
node = d.resolveAlias(node)
if node.Type() == ast.AnchorType {
node = node.(*ast.AnchorNode).Value
}
switch n := node.(type) {
case *ast.StringNode:
return []byte(n.Value), true
case *ast.LiteralNode:
return []byte(n.Value.GetToken().Value), true
default:
scalar, ok := n.(ast.ScalarNode)
if ok {
return []byte(fmt.Sprint(scalar.GetValue())), true
}
}
return nil, false
}

func (d *Decoder) decodeValue(dst reflect.Value, src ast.Node) error {
if src.Type() == ast.AnchorType {
anchorName := src.(*ast.AnchorNode).Name.GetToken().Value
Expand All @@ -444,18 +473,7 @@ func (d *Decoder) decodeValue(dst reflect.Value, src ast.Node) error {
}
valueType := dst.Type()
if unmarshaler, ok := dst.Addr().Interface().(BytesUnmarshaler); ok {
src = d.resolveAlias(src)
var b string
if scalar, isScalar := src.(ast.ScalarNode); isScalar {
b = fmt.Sprint(scalar.GetValue())
} else {
b = src.String()
}
last := d.lastNode(src)
if last != nil && last.Type() == ast.LiteralType {
b += "\n"
}
if err := unmarshaler.UnmarshalYAML([]byte(b)); err != nil {
if err := unmarshaler.UnmarshalYAML(d.unmarshalableDocument(src)); err != nil {
return errors.Wrapf(err, "failed to UnmarshalYAML")
}
return nil
Expand All @@ -476,16 +494,13 @@ func (d *Decoder) decodeValue(dst reflect.Value, src ast.Node) error {
} else if _, ok := dst.Addr().Interface().(*time.Time); ok {
return d.decodeTime(dst, src)
} else if unmarshaler, isText := dst.Addr().Interface().(encoding.TextUnmarshaler); isText {
var b string
if scalar, isScalar := src.(ast.ScalarNode); isScalar {
b = scalar.GetValue().(string)
} else {
b = src.String()
}
if err := unmarshaler.UnmarshalText([]byte(b)); err != nil {
return errors.Wrapf(err, "failed to UnmarshalText")
b, ok := d.unmarshalableText(src)
if ok {
if err := unmarshaler.UnmarshalText(b); err != nil {
return errors.Wrapf(err, "failed to UnmarshalText")
}
return nil
}
return nil
}
switch valueType.Kind() {
case reflect.Ptr:
Expand Down
143 changes: 124 additions & 19 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"io"
"log"
"math"
"net"
"reflect"
"strconv"
"strings"
Expand Down Expand Up @@ -1654,52 +1655,152 @@ complecated: string
// ^
}

type unmarshalableStringValue string
type unmarshalableYAMLStringValue string

func (v *unmarshalableStringValue) UnmarshalYAML(raw []byte) error {
*v = unmarshalableStringValue(string(raw))
func (v *unmarshalableYAMLStringValue) UnmarshalYAML(b []byte) error {
var s string
if err := yaml.Unmarshal(b, &s); err != nil {
return err
}
*v = unmarshalableYAMLStringValue(s)
return nil
}

type unmarshalableTextStringValue string

func (v *unmarshalableTextStringValue) UnmarshalText(b []byte) error {
*v = unmarshalableTextStringValue(string(b))
return nil
}

type unmarshalableStringContainer struct {
V unmarshalableStringValue `yaml:"value" json:"value"`
A unmarshalableYAMLStringValue `yaml:"a"`
B unmarshalableTextStringValue `yaml:"b"`
}

func TestUnmarshalableString(t *testing.T) {
t.Run("empty string", func(t *testing.T) {
t.Parallel()
yml := `
a: ""
b: ""
`
var container unmarshalableStringContainer
if err := yaml.Unmarshal([]byte(`value: ""`), &container); err != nil {
if err := yaml.Unmarshal([]byte(yml), &container); err != nil {
t.Fatalf("failed to unmarshal %v", err)
}
if container.V != "" {
t.Fatalf("expected empty string, but %q is set", container.V)
if container.A != "" {
t.Fatalf("expected empty string, but %q is set", container.A)
}
if container.B != "" {
t.Fatalf("expected empty string, but %q is set", container.B)
}
})
t.Run("filled string", func(t *testing.T) {
t.Parallel()
yml := `
a: "aaa"
b: "bbb"
`
var container unmarshalableStringContainer
if err := yaml.Unmarshal([]byte(`value: "aaa"`), &container); err != nil {
if err := yaml.Unmarshal([]byte(yml), &container); err != nil {
t.Fatalf("failed to unmarshal %v", err)
}
if container.V != "aaa" {
t.Fatalf("expected \"aaa\", but %q is set", container.V)
if container.A != "aaa" {
t.Fatalf("expected \"aaa\", but %q is set", container.A)
}
if container.B != "bbb" {
t.Fatalf("expected \"bbb\", but %q is set", container.B)
}
})
t.Run("single-quoted string", func(t *testing.T) {
t.Parallel()
yml := `
a: 'aaa'
b: 'bbb'
`
var container unmarshalableStringContainer
if err := yaml.Unmarshal([]byte(yml), &container); err != nil {
t.Fatalf("failed to unmarshal %v", err)
}
if container.A != "aaa" {
t.Fatalf("expected \"aaa\", but %q is set", container.A)
}
if container.B != "bbb" {
t.Fatalf("expected \"aaa\", but %q is set", container.B)
}
})
t.Run("literal", func(t *testing.T) {
t.Parallel()
yml := `
a: |
a
b
c
b: |
a
b
c
`
var container unmarshalableStringContainer
if err := yaml.Unmarshal([]byte(`value: 'aaa'`), &container); err != nil {
if err := yaml.Unmarshal([]byte(yml), &container); err != nil {
t.Fatalf("failed to unmarshal %v", err)
}
if container.V != "aaa" {
t.Fatalf("expected \"aaa\", but %q is set", container.V)
if container.A != "a\nb\nc\n" {
t.Fatalf("expected \"a\nb\nc\n\", but %q is set", container.A)
}
if container.B != "a\nb\nc\n" {
t.Fatalf("expected \"a\nb\nc\n\", but %q is set", container.B)
}
})
t.Run("anchor/alias", func(t *testing.T) {
yml := `
a: &x 1
b: *x
c: &y hello
d: *y
`
var v struct {
A, B, C, D unmarshalableTextStringValue
}
if err := yaml.Unmarshal([]byte(yml), &v); err != nil {
t.Fatal(err)
}
if v.A != "1" {
t.Fatal("failed to unmarshal")
}
if v.B != "1" {
t.Fatal("failed to unmarshal")
}
if v.C != "hello" {
t.Fatal("failed to unmarshal")
}
if v.D != "hello" {
t.Fatal("failed to unmarshal")
}
})
t.Run("net.IP", func(t *testing.T) {
yml := `
a: &a 127.0.0.1
b: *a
`
var v struct {
A, B net.IP
}
if err := yaml.Unmarshal([]byte(yml), &v); err != nil {
t.Fatal(err)
}
if v.A.String() != net.IPv4(127, 0, 0, 1).String() {
t.Fatal("failed to unmarshal")
}
if v.B.String() != net.IPv4(127, 0, 0, 1).String() {
t.Fatal("failed to unmarshal")
}
})
}

type unmarshalablePtrStringContainer struct {
V *string `yaml:"value" json:"value"`
V *string `yaml:"value"`
}

func TestUnmarshalablePtrString(t *testing.T) {
Expand All @@ -1709,7 +1810,7 @@ func TestUnmarshalablePtrString(t *testing.T) {
if err := yaml.Unmarshal([]byte(`value: ""`), &container); err != nil {
t.Fatalf("failed to unmarshal %v", err)
}
if *container.V != "" {
if container.V == nil || *container.V != "" {
t.Fatalf("expected empty string, but %q is set", *container.V)
}
})
Expand Down Expand Up @@ -1738,7 +1839,7 @@ func (v *unmarshalableIntValue) UnmarshalYAML(raw []byte) error {
}

type unmarshalableIntContainer struct {
V unmarshalableIntValue `yaml:"value" json:"value"`
V unmarshalableIntValue `yaml:"value"`
}

func TestUnmarshalableInt(t *testing.T) {
Expand Down Expand Up @@ -1775,7 +1876,7 @@ func TestUnmarshalableInt(t *testing.T) {
}

type unmarshalablePtrIntContainer struct {
V *int `yaml:"value" json:"value"`
V *int `yaml:"value"`
}

func TestUnmarshalablePtrInt(t *testing.T) {
Expand All @@ -1785,7 +1886,7 @@ func TestUnmarshalablePtrInt(t *testing.T) {
if err := yaml.Unmarshal([]byte(`value: 0`), &container); err != nil {
t.Fatalf("failed to unmarshal %v", err)
}
if *container.V != 0 {
if container.V == nil || *container.V != 0 {
t.Fatalf("expected 0, but %q is set", *container.V)
}
})
Expand Down Expand Up @@ -1900,7 +2001,11 @@ k: l
type unmarshalYAMLWithAliasString string

func (v *unmarshalYAMLWithAliasString) UnmarshalYAML(b []byte) error {
*v = unmarshalYAMLWithAliasString(string(b))
var s string
if err := yaml.Unmarshal(b, &s); err != nil {
return err
}
*v = unmarshalYAMLWithAliasString(s)
return nil
}

Expand Down

0 comments on commit 3b54bfb

Please sign in to comment.