diff --git a/README.md b/README.md index 2ea8afe1..b72f86fb 100644 --- a/README.md +++ b/README.md @@ -308,27 +308,8 @@ import ( "fmt" "github.com/goccy/go-yaml" - "github.com/goccy/go-yaml/parser" - "github.com/goccy/go-yaml/printer" ) -func yamlSourceByPath(originalSource string, pathStr string) (string, error) { - file, err := parser.ParseBytes([]byte(originalSource), 0) - if err != nil { - return "", err - } - path, err := yaml.PathString(pathStr) - if err != nil { - return "", err - } - node, err := path.FilterFile(file) - if err != nil { - return "", err - } - var p printer.Printer - return p.PrintErrorToken(node.GetToken(), true), nil -} - func main() { yml := ` a: 1 @@ -343,19 +324,20 @@ b: "hello" } if v.A != 2 { // output error with YAML source - source, err := yamlSourceByPath(yml, "$.a") + path, err := yaml.PathString("$.a") if err != nil { panic(err) } - fmt.Printf("a value expected 2 but actual %d:\n%s\n", v.A, source) + source, err := path.AnnotateSource([]byte(yml), true) + if err != nil { + panic(err) + } + fmt.Printf("a value expected 2 but actual %d:\n%s\n", v.A, string(source)) } } ``` -`printer.PrintErrorToken` can output YAML source with error point, -and you can get `token.Token` of error point by `yaml.Path` . - -output result is following +output result is the following. diff --git a/ast/ast.go b/ast/ast.go index a7834bce..c36ea786 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -2,6 +2,7 @@ package ast import ( "fmt" + "io" "math" "strconv" "strings" @@ -11,7 +12,9 @@ import ( ) var ( - ErrInvalidTokenType = xerrors.New("invalid token type") + ErrInvalidTokenType = xerrors.New("invalid token type") + ErrInvalidAnchorName = xerrors.New("invalid anchor name") + ErrInvalidAliasName = xerrors.New("invalid alias name") ) // NodeType type identifier of node @@ -109,6 +112,7 @@ func (t NodeType) String() string { // Node type of node type Node interface { + io.Reader // String node to text String() string // GetToken returns token instance @@ -121,83 +125,72 @@ type Node interface { SetComment(*token.Token) error // Comment returns comment token instance GetComment() *token.Token + // already read length + readLen() int + // append read length + addReadLen(int) } -// File contains all documents in YAML file -type File struct { - Name string - Docs []*Document +// ScalarNode type for scalar node +type ScalarNode interface { + Node + GetValue() interface{} } -// String all documents to text -func (f *File) String() string { - docs := []string{} - for _, doc := range f.Docs { - docs = append(docs, doc.String()) - } - return strings.Join(docs, "\n") +type BaseNode struct { + Comment *token.Token + read int } -// Document type of Document -type Document struct { - Comment *token.Token // position of Comment ( `#comment` ) - Start *token.Token // position of DocumentHeader ( `---` ) - End *token.Token // position of DocumentEnd ( `...` ) - Body Node +func (n *BaseNode) readLen() int { + return n.read } -// GetToken returns token instance -func (d *Document) GetToken() *token.Token { - return d.Body.GetToken() +func (n *BaseNode) addReadLen(len int) { + n.read += len } // GetComment returns comment token instance -func (d *Document) GetComment() *token.Token { - return d.Comment -} - -// AddColumn add column number to child nodes recursively -func (d *Document) AddColumn(col int) { - if d.Body != nil { - d.Body.AddColumn(col) - } +func (n *BaseNode) GetComment() *token.Token { + return n.Comment } // SetComment set comment token -func (d *Document) SetComment(tk *token.Token) error { +func (n *BaseNode) SetComment(tk *token.Token) error { if tk.Type != token.CommentType { return ErrInvalidTokenType } - d.Comment = tk + n.Comment = tk return nil } -// Type returns DocumentType -func (d *Document) Type() NodeType { return DocumentType } - -// String document to text -func (d *Document) String() string { - doc := []string{} - if d.Start != nil { - doc = append(doc, d.Start.Value) +func min(a, b int) int { + if a < b { + return a } - doc = append(doc, d.Body.String()) - if d.End != nil { - doc = append(doc, d.End.Value) - } - return strings.Join(doc, "\n") + return b } -// ScalarNode type for scalar node -type ScalarNode interface { - Node - GetValue() interface{} +func readNode(p []byte, node Node) (int, error) { + s := node.String() + readLen := node.readLen() + remain := len(s) - readLen + if remain == 0 { + return 0, io.EOF + } + size := min(remain, len(p)) + for idx, b := range s[readLen : readLen+size] { + p[idx] = byte(b) + } + node.addReadLen(size) + return size, nil } // Null create node for null value func Null(tk *token.Token) Node { return &NullNode{ - Token: tk, + BaseNode: &BaseNode{}, + Token: tk, } } @@ -205,15 +198,12 @@ func Null(tk *token.Token) Node { func Bool(tk *token.Token) Node { b, _ := strconv.ParseBool(tk.Value) return &BoolNode{ - Token: tk, - Value: b, + BaseNode: &BaseNode{}, + Token: tk, + Value: b, } } -func removeUnderScoreFromNumber(num string) string { - return strings.ReplaceAll(num, "_", "") -} - // Integer create node for integer value func Integer(tk *token.Token) Node { value := removeUnderScoreFromNumber(tk.Value) @@ -228,10 +218,18 @@ func Integer(tk *token.Token) Node { } if len(negativePrefix) > 0 { i, _ := strconv.ParseInt(negativePrefix+value[skipCharacterNum:], 2, 64) - return &IntegerNode{Token: tk, Value: i} + return &IntegerNode{ + BaseNode: &BaseNode{}, + Token: tk, + Value: i, + } } i, _ := strconv.ParseUint(negativePrefix+value[skipCharacterNum:], 2, 64) - return &IntegerNode{Token: tk, Value: i} + return &IntegerNode{ + BaseNode: &BaseNode{}, + Token: tk, + Value: i, + } case token.OctetIntegerType: // octet token starts with '0o' or '-0o' or '0' or '-0' skipCharacterNum := 1 @@ -249,10 +247,18 @@ func Integer(tk *token.Token) Node { } if len(negativePrefix) > 0 { i, _ := strconv.ParseInt(negativePrefix+value[skipCharacterNum:], 8, 64) - return &IntegerNode{Token: tk, Value: i} + return &IntegerNode{ + BaseNode: &BaseNode{}, + Token: tk, + Value: i, + } } i, _ := strconv.ParseUint(value[skipCharacterNum:], 8, 64) - return &IntegerNode{Token: tk, Value: i} + return &IntegerNode{ + BaseNode: &BaseNode{}, + Token: tk, + Value: i, + } case token.HexIntegerType: // hex token starts with '0x' or '-0x' skipCharacterNum := 2 @@ -263,32 +269,50 @@ func Integer(tk *token.Token) Node { } if len(negativePrefix) > 0 { i, _ := strconv.ParseInt(negativePrefix+value[skipCharacterNum:], 16, 64) - return &IntegerNode{Token: tk, Value: i} + return &IntegerNode{ + BaseNode: &BaseNode{}, + Token: tk, + Value: i, + } } i, _ := strconv.ParseUint(value[skipCharacterNum:], 16, 64) - return &IntegerNode{Token: tk, Value: i} + return &IntegerNode{ + BaseNode: &BaseNode{}, + Token: tk, + Value: i, + } } if value[0] == '-' || value[0] == '+' { i, _ := strconv.ParseInt(value, 10, 64) - return &IntegerNode{Token: tk, Value: i} + return &IntegerNode{ + BaseNode: &BaseNode{}, + Token: tk, + Value: i, + } } i, _ := strconv.ParseUint(value, 10, 64) - return &IntegerNode{Token: tk, Value: i} + return &IntegerNode{ + BaseNode: &BaseNode{}, + Token: tk, + Value: i, + } } // Float create node for float value func Float(tk *token.Token) Node { f, _ := strconv.ParseFloat(removeUnderScoreFromNumber(tk.Value), 64) return &FloatNode{ - Token: tk, - Value: f, + BaseNode: &BaseNode{}, + Token: tk, + Value: f, } } // Infinity create node for .inf or -.inf value -func Infinity(tk *token.Token) Node { +func Infinity(tk *token.Token) *InfinityNode { node := &InfinityNode{ - Token: tk, + BaseNode: &BaseNode{}, + Token: tk, } switch tk.Value { case ".inf", ".Inf", ".INF": @@ -300,62 +324,204 @@ func Infinity(tk *token.Token) Node { } // Nan create node for .nan value -func Nan(tk *token.Token) Node { - return &NanNode{Token: tk} +func Nan(tk *token.Token) *NanNode { + return &NanNode{ + BaseNode: &BaseNode{}, + Token: tk, + } } // String create node for string value -func String(tk *token.Token) Node { +func String(tk *token.Token) *StringNode { return &StringNode{ - Token: tk, - Value: tk.Value, + BaseNode: &BaseNode{}, + Token: tk, + Value: tk.Value, } } // Comment create node for comment -func Comment(tk *token.Token) Node { - return &CommentNode{Comment: tk} +func Comment(tk *token.Token) *CommentNode { + return &CommentNode{ + BaseNode: &BaseNode{Comment: tk}, + } } // MergeKey create node for merge key ( << ) -func MergeKey(tk *token.Token) Node { +func MergeKey(tk *token.Token) *MergeKeyNode { return &MergeKeyNode{ - Token: tk, + BaseNode: &BaseNode{}, + Token: tk, } } // Mapping create node for map -func Mapping(tk *token.Token, isFlowStyle bool) *MappingNode { - return &MappingNode{ +func Mapping(tk *token.Token, isFlowStyle bool, values ...*MappingValueNode) *MappingNode { + node := &MappingNode{ + BaseNode: &BaseNode{}, Start: tk, IsFlowStyle: isFlowStyle, Values: []*MappingValueNode{}, } + node.Values = append(node.Values, values...) + return node +} + +// MappingValue create node for mapping value +func MappingValue(tk *token.Token, key Node, value Node) *MappingValueNode { + return &MappingValueNode{ + BaseNode: &BaseNode{}, + Start: tk, + Key: key, + Value: value, + } } // MappingKey create node for map key ( '?' ). func MappingKey(tk *token.Token) *MappingKeyNode { return &MappingKeyNode{ - Start: tk, + BaseNode: &BaseNode{}, + Start: tk, } } // Sequence create node for sequence func Sequence(tk *token.Token, isFlowStyle bool) *SequenceNode { return &SequenceNode{ + BaseNode: &BaseNode{}, Start: tk, IsFlowStyle: isFlowStyle, Values: []Node{}, } } +func Anchor(tk *token.Token) *AnchorNode { + return &AnchorNode{ + BaseNode: &BaseNode{}, + Start: tk, + } +} + +func Alias(tk *token.Token) *AliasNode { + return &AliasNode{ + BaseNode: &BaseNode{}, + Start: tk, + } +} + +func Document(tk *token.Token, body Node) *DocumentNode { + return &DocumentNode{ + BaseNode: &BaseNode{}, + Start: tk, + Body: body, + } +} + +func Directive(tk *token.Token) *DirectiveNode { + return &DirectiveNode{ + BaseNode: &BaseNode{}, + Start: tk, + } +} + +func Literal(tk *token.Token) *LiteralNode { + return &LiteralNode{ + BaseNode: &BaseNode{}, + Start: tk, + } +} + +func Tag(tk *token.Token) *TagNode { + return &TagNode{ + BaseNode: &BaseNode{}, + Start: tk, + } +} + +// File contains all documents in YAML file +type File struct { + Name string + Docs []*DocumentNode +} + +// Read implements (io.Reader).Read +func (f *File) Read(p []byte) (int, error) { + for _, doc := range f.Docs { + n, err := doc.Read(p) + if err == io.EOF { + continue + } + return n, nil + } + return 0, io.EOF +} + +// String all documents to text +func (f *File) String() string { + docs := []string{} + for _, doc := range f.Docs { + docs = append(docs, doc.String()) + } + return strings.Join(docs, "\n") +} + +// DocumentNode type of Document +type DocumentNode struct { + *BaseNode + Start *token.Token // position of DocumentHeader ( `---` ) + End *token.Token // position of DocumentEnd ( `...` ) + Body Node +} + +// Read implements (io.Reader).Read +func (d *DocumentNode) Read(p []byte) (int, error) { + return readNode(p, d) +} + +// Type returns DocumentNodeType +func (d *DocumentNode) Type() NodeType { return DocumentType } + +// GetToken returns token instance +func (d *DocumentNode) GetToken() *token.Token { + return d.Body.GetToken() +} + +// AddColumn add column number to child nodes recursively +func (d *DocumentNode) AddColumn(col int) { + if d.Body != nil { + d.Body.AddColumn(col) + } +} + +// String document to text +func (d *DocumentNode) String() string { + doc := []string{} + if d.Start != nil { + doc = append(doc, d.Start.Value) + } + doc = append(doc, d.Body.String()) + if d.End != nil { + doc = append(doc, d.End.Value) + } + return strings.Join(doc, "\n") +} + +func removeUnderScoreFromNumber(num string) string { + return strings.ReplaceAll(num, "_", "") +} + // NullNode type of null node type NullNode struct { - ScalarNode + *BaseNode Comment *token.Token // position of Comment ( `#comment` ) Token *token.Token } +// Read implements (io.Reader).Read +func (n *NullNode) Read(p []byte) (int, error) { + return readNode(p, n) +} + // Type returns NullType func (n *NullNode) Type() NodeType { return NullType } @@ -364,11 +530,6 @@ func (n *NullNode) GetToken() *token.Token { return n.Token } -// GetComment returns comment token instance -func (n *NullNode) GetComment() *token.Token { - return n.Comment -} - // AddColumn add column number to child nodes recursively func (n *NullNode) AddColumn(col int) { n.Token.AddColumn(col) @@ -395,10 +556,14 @@ func (n *NullNode) String() string { // IntegerNode type of integer node type IntegerNode struct { - ScalarNode - Comment *token.Token // position of Comment ( `#comment` ) - Token *token.Token - Value interface{} // int64 or uint64 value + *BaseNode + Token *token.Token + Value interface{} // int64 or uint64 value +} + +// Read implements (io.Reader).Read +func (n *IntegerNode) Read(p []byte) (int, error) { + return readNode(p, n) } // Type returns IntegerType @@ -409,25 +574,11 @@ func (n *IntegerNode) GetToken() *token.Token { return n.Token } -// GetComment returns comment token instance -func (n *IntegerNode) GetComment() *token.Token { - return n.Comment -} - // AddColumn add column number to child nodes recursively func (n *IntegerNode) AddColumn(col int) { n.Token.AddColumn(col) } -// SetComment set comment token -func (n *IntegerNode) SetComment(tk *token.Token) error { - if tk.Type != token.CommentType { - return ErrInvalidTokenType - } - n.Comment = tk - return nil -} - // GetValue returns int64 value func (n *IntegerNode) GetValue() interface{} { return n.Value @@ -440,13 +591,17 @@ func (n *IntegerNode) String() string { // FloatNode type of float node type FloatNode struct { - ScalarNode - Comment *token.Token // position of Comment ( `#comment` ) + *BaseNode Token *token.Token Precision int Value float64 } +// Read implements (io.Reader).Read +func (n *FloatNode) Read(p []byte) (int, error) { + return readNode(p, n) +} + // Type returns FloatType func (n *FloatNode) Type() NodeType { return FloatType } @@ -455,25 +610,11 @@ func (n *FloatNode) GetToken() *token.Token { return n.Token } -// GetComment returns comment token instance -func (n *FloatNode) GetComment() *token.Token { - return n.Comment -} - // AddColumn add column number to child nodes recursively func (n *FloatNode) AddColumn(col int) { n.Token.AddColumn(col) } -// SetComment set comment token -func (n *FloatNode) SetComment(tk *token.Token) error { - if tk.Type != token.CommentType { - return ErrInvalidTokenType - } - n.Comment = tk - return nil -} - // GetValue returns float64 value func (n *FloatNode) GetValue() interface{} { return n.Value @@ -486,10 +627,14 @@ func (n *FloatNode) String() string { // StringNode type of string node type StringNode struct { - ScalarNode - Comment *token.Token // position of Comment ( `#comment` ) - Token *token.Token - Value string + *BaseNode + Token *token.Token + Value string +} + +// Read implements (io.Reader).Read +func (n *StringNode) Read(p []byte) (int, error) { + return readNode(p, n) } // Type returns StringType @@ -500,25 +645,11 @@ func (n *StringNode) GetToken() *token.Token { return n.Token } -// GetComment returns comment token instance -func (n *StringNode) GetComment() *token.Token { - return n.Comment -} - // AddColumn add column number to child nodes recursively func (n *StringNode) AddColumn(col int) { n.Token.AddColumn(col) } -// SetComment set comment token -func (n *StringNode) SetComment(tk *token.Token) error { - if tk.Type != token.CommentType { - return ErrInvalidTokenType - } - n.Comment = tk - return nil -} - // GetValue returns string value func (n *StringNode) GetValue() interface{} { return n.Value @@ -553,10 +684,14 @@ func (n *StringNode) String() string { // LiteralNode type of literal node type LiteralNode struct { - ScalarNode - Comment *token.Token // position of Comment ( `#comment` ) - Start *token.Token - Value *StringNode + *BaseNode + Start *token.Token + Value *StringNode +} + +// Read implements (io.Reader).Read +func (n *LiteralNode) Read(p []byte) (int, error) { + return readNode(p, n) } // Type returns LiteralType @@ -567,11 +702,6 @@ func (n *LiteralNode) GetToken() *token.Token { return n.Start } -// GetComment returns comment token instance -func (n *LiteralNode) GetComment() *token.Token { - return n.Comment -} - // AddColumn add column number to child nodes recursively func (n *LiteralNode) AddColumn(col int) { n.Start.AddColumn(col) @@ -580,15 +710,6 @@ func (n *LiteralNode) AddColumn(col int) { } } -// SetComment set comment token -func (n *LiteralNode) SetComment(tk *token.Token) error { - if tk.Type != token.CommentType { - return ErrInvalidTokenType - } - n.Comment = tk - return nil -} - // GetValue returns string value func (n *LiteralNode) GetValue() interface{} { return n.String() @@ -602,9 +723,13 @@ func (n *LiteralNode) String() string { // MergeKeyNode type of merge key node type MergeKeyNode struct { - ScalarNode - Comment *token.Token // position of Comment ( `#comment` ) - Token *token.Token + *BaseNode + Token *token.Token +} + +// Read implements (io.Reader).Read +func (n *MergeKeyNode) Read(p []byte) (int, error) { + return readNode(p, n) } // Type returns MergeKeyType @@ -615,11 +740,6 @@ func (n *MergeKeyNode) GetToken() *token.Token { return n.Token } -// GetComment returns comment token instance -func (n *MergeKeyNode) GetComment() *token.Token { - return n.Comment -} - // GetValue returns '<<' value func (n *MergeKeyNode) GetValue() interface{} { return n.Token.Value @@ -635,21 +755,16 @@ func (n *MergeKeyNode) AddColumn(col int) { n.Token.AddColumn(col) } -// SetComment set comment token -func (n *MergeKeyNode) SetComment(tk *token.Token) error { - if tk.Type != token.CommentType { - return ErrInvalidTokenType - } - n.Comment = tk - return nil -} - // BoolNode type of boolean node type BoolNode struct { - ScalarNode - Comment *token.Token // position of Comment ( `#comment` ) - Token *token.Token - Value bool + *BaseNode + Token *token.Token + Value bool +} + +// Read implements (io.Reader).Read +func (n *BoolNode) Read(p []byte) (int, error) { + return readNode(p, n) } // Type returns BoolType @@ -660,25 +775,11 @@ func (n *BoolNode) GetToken() *token.Token { return n.Token } -// GetComment returns comment token instance -func (n *BoolNode) GetComment() *token.Token { - return n.Comment -} - // AddColumn add column number to child nodes recursively func (n *BoolNode) AddColumn(col int) { n.Token.AddColumn(col) } -// SetComment set comment token -func (n *BoolNode) SetComment(tk *token.Token) error { - if tk.Type != token.CommentType { - return ErrInvalidTokenType - } - n.Comment = tk - return nil -} - // GetValue returns boolean value func (n *BoolNode) GetValue() interface{} { return n.Value @@ -691,10 +792,14 @@ func (n *BoolNode) String() string { // InfinityNode type of infinity node type InfinityNode struct { - ScalarNode - Comment *token.Token // position of Comment ( `#comment` ) - Token *token.Token - Value float64 + *BaseNode + Token *token.Token + Value float64 +} + +// Read implements (io.Reader).Read +func (n *InfinityNode) Read(p []byte) (int, error) { + return readNode(p, n) } // Type returns InfinityType @@ -705,25 +810,11 @@ func (n *InfinityNode) GetToken() *token.Token { return n.Token } -// GetComment returns comment token instance -func (n *InfinityNode) GetComment() *token.Token { - return n.Comment -} - // AddColumn add column number to child nodes recursively func (n *InfinityNode) AddColumn(col int) { n.Token.AddColumn(col) } -// SetComment set comment token -func (n *InfinityNode) SetComment(tk *token.Token) error { - if tk.Type != token.CommentType { - return ErrInvalidTokenType - } - n.Comment = tk - return nil -} - // GetValue returns math.Inf(0) or math.Inf(-1) func (n *InfinityNode) GetValue() interface{} { return n.Value @@ -736,9 +827,13 @@ func (n *InfinityNode) String() string { // NanNode type of nan node type NanNode struct { - ScalarNode - Comment *token.Token // position of Comment ( `#comment` ) - Token *token.Token + *BaseNode + Token *token.Token +} + +// Read implements (io.Reader).Read +func (n *NanNode) Read(p []byte) (int, error) { + return readNode(p, n) } // Type returns NanType @@ -749,25 +844,11 @@ func (n *NanNode) GetToken() *token.Token { return n.Token } -// GetComment returns comment token instance -func (n *NanNode) GetComment() *token.Token { - return n.Comment -} - // AddColumn add column number to child nodes recursively func (n *NanNode) AddColumn(col int) { n.Token.AddColumn(col) } -// SetComment set comment token -func (n *NanNode) SetComment(tk *token.Token) error { - if tk.Type != token.CommentType { - return ErrInvalidTokenType - } - n.Comment = tk - return nil -} - // GetValue returns math.NaN() func (n *NanNode) GetValue() interface{} { return math.NaN() @@ -813,13 +894,44 @@ func (m *MapNodeIter) Value() Node { // MappingNode type of mapping node type MappingNode struct { - Comment *token.Token // position of Comment ( `#comment` ) + *BaseNode Start *token.Token End *token.Token IsFlowStyle bool Values []*MappingValueNode } +func (n *MappingNode) startPos() *token.Position { + if len(n.Values) == 0 { + return n.Start.Position + } + return n.Values[0].Key.GetToken().Position +} + +// Merge merge key/value of map. +func (n *MappingNode) Merge(target *MappingNode) { + keyToMapValueMap := map[string]*MappingValueNode{} + for _, value := range n.Values { + key := value.Key.String() + keyToMapValueMap[key] = value + } + column := n.startPos().Column - target.startPos().Column + target.AddColumn(column) + for _, value := range target.Values { + mapValue, exists := keyToMapValueMap[value.Key.String()] + if exists { + mapValue.Value = value.Value + } else { + n.Values = append(n.Values, value) + } + } +} + +// Read implements (io.Reader).Read +func (n *MappingNode) Read(p []byte) (int, error) { + return readNode(p, n) +} + // Type returns MappingType func (n *MappingNode) Type() NodeType { return MappingType } @@ -828,11 +940,6 @@ func (n *MappingNode) GetToken() *token.Token { return n.Start } -// GetComment returns comment token instance -func (n *MappingNode) GetComment() *token.Token { - return n.Comment -} - // AddColumn add column number to child nodes recursively func (n *MappingNode) AddColumn(col int) { n.Start.AddColumn(col) @@ -842,15 +949,6 @@ func (n *MappingNode) AddColumn(col int) { } } -// SetComment set comment token -func (n *MappingNode) SetComment(tk *token.Token) error { - if tk.Type != token.CommentType { - return ErrInvalidTokenType - } - n.Comment = tk - return nil -} - func (n *MappingNode) flowStyleString() string { if len(n.Values) == 0 { return "{}" @@ -891,9 +989,14 @@ func (n *MappingNode) MapRange() *MapNodeIter { // MappingKeyNode type of tag node type MappingKeyNode struct { - Comment *token.Token // position of Comment ( `#comment` ) - Start *token.Token - Value Node + *BaseNode + Start *token.Token + Value Node +} + +// Read implements (io.Reader).Read +func (n *MappingKeyNode) Read(p []byte) (int, error) { + return readNode(p, n) } // Type returns MappingKeyType @@ -904,11 +1007,6 @@ func (n *MappingKeyNode) GetToken() *token.Token { return n.Start } -// GetComment returns comment token instance -func (n *MappingKeyNode) GetComment() *token.Token { - return n.Comment -} - // AddColumn add column number to child nodes recursively func (n *MappingKeyNode) AddColumn(col int) { n.Start.AddColumn(col) @@ -917,15 +1015,6 @@ func (n *MappingKeyNode) AddColumn(col int) { } } -// SetComment set comment token -func (n *MappingKeyNode) SetComment(tk *token.Token) error { - if tk.Type != token.CommentType { - return ErrInvalidTokenType - } - n.Comment = tk - return nil -} - // String tag to text func (n *MappingKeyNode) String() string { return fmt.Sprintf("%s %s", n.Start.Value, n.Value.String()) @@ -933,10 +1022,23 @@ func (n *MappingKeyNode) String() string { // MappingValueNode type of mapping value type MappingValueNode struct { - Comment *token.Token // position of Comment ( `#comment` ) - Start *token.Token - Key Node - Value Node + *BaseNode + Start *token.Token + Key Node + Value Node +} + +// Replace replace value node. +func (n *MappingValueNode) Replace(value Node) error { + column := n.Value.GetToken().Position.Column - value.GetToken().Position.Column + value.AddColumn(column) + n.Value = value + return nil +} + +// Read implements (io.Reader).Read +func (n *MappingValueNode) Read(p []byte) (int, error) { + return readNode(p, n) } // Type returns MappingValueType @@ -947,11 +1049,6 @@ func (n *MappingValueNode) GetToken() *token.Token { return n.Start } -// GetComment returns comment token instance -func (n *MappingValueNode) GetComment() *token.Token { - return n.Comment -} - // AddColumn add column number to child nodes recursively func (n *MappingValueNode) AddColumn(col int) { n.Start.AddColumn(col) @@ -963,15 +1060,6 @@ func (n *MappingValueNode) AddColumn(col int) { } } -// SetComment set comment token -func (n *MappingValueNode) SetComment(tk *token.Token) error { - if tk.Type != token.CommentType { - return ErrInvalidTokenType - } - n.Comment = tk - return nil -} - // String mapping value to text func (n *MappingValueNode) String() string { space := strings.Repeat(" ", n.Key.GetToken().Position.Column-1) @@ -1032,13 +1120,41 @@ func (m *ArrayNodeIter) Len() int { // SequenceNode type of sequence node type SequenceNode struct { - Comment *token.Token // position of Comment ( `#comment` ) + *BaseNode Start *token.Token End *token.Token IsFlowStyle bool Values []Node } +// Replace replace value node. +func (n *SequenceNode) Replace(idx int, value Node) error { + if len(n.Values) <= idx { + return xerrors.Errorf( + "invalid index for sequence: sequence length is %d, but specified %d index", + len(n.Values), idx, + ) + } + column := n.Values[idx].GetToken().Position.Column - value.GetToken().Position.Column + value.AddColumn(column) + n.Values[idx] = value + return nil +} + +// Merge merge sequence value. +func (n *SequenceNode) Merge(target *SequenceNode) { + column := n.Start.Position.Column - target.Start.Position.Column + target.AddColumn(column) + for _, value := range target.Values { + n.Values = append(n.Values, value) + } +} + +// Read implements (io.Reader).Read +func (n *SequenceNode) Read(p []byte) (int, error) { + return readNode(p, n) +} + // Type returns SequenceType func (n *SequenceNode) Type() NodeType { return SequenceType } @@ -1047,11 +1163,6 @@ func (n *SequenceNode) GetToken() *token.Token { return n.Start } -// GetComment returns comment token instance -func (n *SequenceNode) GetComment() *token.Token { - return n.Comment -} - // AddColumn add column number to child nodes recursively func (n *SequenceNode) AddColumn(col int) { n.Start.AddColumn(col) @@ -1061,15 +1172,6 @@ func (n *SequenceNode) AddColumn(col int) { } } -// SetComment set comment token -func (n *SequenceNode) SetComment(tk *token.Token) error { - if tk.Type != token.CommentType { - return ErrInvalidTokenType - } - n.Comment = tk - return nil -} - func (n *SequenceNode) flowStyleString() string { values := []string{} for _, value := range n.Values { @@ -1120,10 +1222,27 @@ func (n *SequenceNode) ArrayRange() *ArrayNodeIter { // AnchorNode type of anchor node type AnchorNode struct { - Comment *token.Token // position of Comment ( `#comment` ) - Start *token.Token - Name Node - Value Node + *BaseNode + Start *token.Token + Name Node + Value Node +} + +func (n *AnchorNode) SetName(name string) error { + if n.Name == nil { + return ErrInvalidAnchorName + } + s, ok := n.Name.(*StringNode) + if !ok { + return ErrInvalidAnchorName + } + s.Value = name + return nil +} + +// Read implements (io.Reader).Read +func (n *AnchorNode) Read(p []byte) (int, error) { + return readNode(p, n) } // Type returns AnchorType @@ -1134,11 +1253,6 @@ func (n *AnchorNode) GetToken() *token.Token { return n.Start } -// GetComment returns comment token instance -func (n *AnchorNode) GetComment() *token.Token { - return n.Comment -} - // AddColumn add column number to child nodes recursively func (n *AnchorNode) AddColumn(col int) { n.Start.AddColumn(col) @@ -1150,15 +1264,6 @@ func (n *AnchorNode) AddColumn(col int) { } } -// SetComment set comment token -func (n *AnchorNode) SetComment(tk *token.Token) error { - if tk.Type != token.CommentType { - return ErrInvalidTokenType - } - n.Comment = tk - return nil -} - // String anchor to text func (n *AnchorNode) String() string { value := n.Value.String() @@ -1174,9 +1279,26 @@ func (n *AnchorNode) String() string { // AliasNode type of alias node type AliasNode struct { - Comment *token.Token // position of Comment ( `#comment` ) - Start *token.Token - Value Node + *BaseNode + Start *token.Token + Value Node +} + +func (n *AliasNode) SetName(name string) error { + if n.Value == nil { + return ErrInvalidAliasName + } + s, ok := n.Value.(*StringNode) + if !ok { + return ErrInvalidAliasName + } + s.Value = name + return nil +} + +// Read implements (io.Reader).Read +func (n *AliasNode) Read(p []byte) (int, error) { + return readNode(p, n) } // Type returns AliasType @@ -1187,11 +1309,6 @@ func (n *AliasNode) GetToken() *token.Token { return n.Start } -// GetComment returns comment token instance -func (n *AliasNode) GetComment() *token.Token { - return n.Comment -} - // AddColumn add column number to child nodes recursively func (n *AliasNode) AddColumn(col int) { n.Start.AddColumn(col) @@ -1200,15 +1317,6 @@ func (n *AliasNode) AddColumn(col int) { } } -// SetComment set comment token -func (n *AliasNode) SetComment(tk *token.Token) error { - if tk.Type != token.CommentType { - return ErrInvalidTokenType - } - n.Comment = tk - return nil -} - // String alias to text func (n *AliasNode) String() string { return fmt.Sprintf("*%s", n.Value.String()) @@ -1216,9 +1324,14 @@ func (n *AliasNode) String() string { // DirectiveNode type of directive node type DirectiveNode struct { - Comment *token.Token // position of Comment ( `#comment` ) - Start *token.Token - Value Node + *BaseNode + Start *token.Token + Value Node +} + +// Read implements (io.Reader).Read +func (n *DirectiveNode) Read(p []byte) (int, error) { + return readNode(p, n) } // Type returns DirectiveType @@ -1229,11 +1342,6 @@ func (n *DirectiveNode) GetToken() *token.Token { return n.Start } -// GetComment returns comment token instance -func (n *DirectiveNode) GetComment() *token.Token { - return n.Comment -} - // AddColumn add column number to child nodes recursively func (n *DirectiveNode) AddColumn(col int) { if n.Value != nil { @@ -1241,15 +1349,6 @@ func (n *DirectiveNode) AddColumn(col int) { } } -// SetComment set comment token -func (n *DirectiveNode) SetComment(tk *token.Token) error { - if tk.Type != token.CommentType { - return ErrInvalidTokenType - } - n.Comment = tk - return nil -} - // String directive to text func (n *DirectiveNode) String() string { return fmt.Sprintf("%s%s", n.Start.Value, n.Value.String()) @@ -1257,9 +1356,14 @@ func (n *DirectiveNode) String() string { // TagNode type of tag node type TagNode struct { - Comment *token.Token // position of Comment ( `#comment` ) - Start *token.Token - Value Node + *BaseNode + Start *token.Token + Value Node +} + +// Read implements (io.Reader).Read +func (n *TagNode) Read(p []byte) (int, error) { + return readNode(p, n) } // Type returns TagType @@ -1270,11 +1374,6 @@ func (n *TagNode) GetToken() *token.Token { return n.Start } -// GetComment returns comment token instance -func (n *TagNode) GetComment() *token.Token { - return n.Comment -} - // AddColumn add column number to child nodes recursively func (n *TagNode) AddColumn(col int) { n.Start.AddColumn(col) @@ -1283,15 +1382,6 @@ func (n *TagNode) AddColumn(col int) { } } -// SetComment set comment token -func (n *TagNode) SetComment(tk *token.Token) error { - if tk.Type != token.CommentType { - return ErrInvalidTokenType - } - n.Comment = tk - return nil -} - // String tag to text func (n *TagNode) String() string { return fmt.Sprintf("%s %s", n.Start.Value, n.Value.String()) @@ -1299,7 +1389,12 @@ func (n *TagNode) String() string { // CommentNode type of comment node type CommentNode struct { - Comment *token.Token // position of Comment ( `#comment` ) + *BaseNode +} + +// Read implements (io.Reader).Read +func (n *CommentNode) Read(p []byte) (int, error) { + return readNode(p, n) } // Type returns TagType @@ -1308,23 +1403,11 @@ func (n *CommentNode) Type() NodeType { return CommentType } // GetToken returns token instance func (n *CommentNode) GetToken() *token.Token { return n.Comment } -// GetComment returns comment token instance -func (n *CommentNode) GetComment() *token.Token { return n.Comment } - // AddColumn add column number to child nodes recursively func (n *CommentNode) AddColumn(col int) { n.Comment.AddColumn(col) } -// SetComment set comment token -func (n *CommentNode) SetComment(tk *token.Token) error { - if tk.Type != token.CommentType { - return ErrInvalidTokenType - } - n.Comment = tk - return nil -} - // String comment to text func (n *CommentNode) String() string { return n.Comment.Value @@ -1357,6 +1440,8 @@ func Walk(v Visitor, node Node) { case *NanNode: case *TagNode: Walk(v, n.Value) + case *DocumentNode: + Walk(v, n.Body) case *MappingNode: for _, value := range n.Values { Walk(v, value) @@ -1377,3 +1462,72 @@ func Walk(v Visitor, node Node) { Walk(v, n.Value) } } + +type filterWalker struct { + typ NodeType + results []Node +} + +func (v *filterWalker) Visit(n Node) Visitor { + if v.typ == n.Type() { + v.results = append(v.results, n) + } + return v +} + +// Filter returns a list of nodes that match the given type. +func Filter(typ NodeType, node Node) []Node { + walker := &filterWalker{typ: typ} + Walk(walker, node) + return walker.results +} + +// FilterFile returns a list of nodes that match the given type. +func FilterFile(typ NodeType, file *File) []Node { + results := []Node{} + for _, doc := range file.Docs { + walker := &filterWalker{typ: typ} + Walk(walker, doc) + results = append(results, walker.results...) + } + return results +} + +type ErrInvalidMergeType struct { + dst Node + src Node +} + +func (e *ErrInvalidMergeType) Error() string { + return fmt.Sprintf("cannot merge %s into %s", e.src.Type(), e.dst.Type()) +} + +// Merge merge document, map, sequence node. +func Merge(dst Node, src Node) error { + if doc, ok := src.(*DocumentNode); ok { + src = doc.Body + } + err := &ErrInvalidMergeType{dst: dst, src: src} + switch dst.Type() { + case DocumentType: + node := dst.(*DocumentNode) + return Merge(node.Body, src) + case MappingType: + node := dst.(*MappingNode) + target, ok := src.(*MappingNode) + if !ok { + return err + } + node.Merge(target) + return nil + case SequenceType: + node := dst.(*SequenceNode) + target, ok := src.(*SequenceNode) + if !ok { + return err + } + node.Merge(target) + return nil + } + return err +} diff --git a/decode.go b/decode.go index 4a55a5ca..8a560ca0 100644 --- a/decode.go +++ b/decode.go @@ -861,10 +861,8 @@ func (d *Decoder) decodeStruct(dst reflect.Value, src ast.Node) error { } mapNode := ast.Mapping(nil, false) for k, v := range keyToNodeMap { - mapNode.Values = append(mapNode.Values, &ast.MappingValueNode{ - Key: &ast.StringNode{Value: k}, - Value: v, - }) + key := &ast.StringNode{BaseNode: &ast.BaseNode{}, Value: k} + mapNode.Values = append(mapNode.Values, ast.MappingValue(nil, key, v)) } newFieldValue, err := d.createDecodedNewValue(fieldValue.Type(), mapNode) if d.disallowUnknownField { diff --git a/decode_test.go b/decode_test.go index cb1efcec..f4b8d6f1 100644 --- a/decode_test.go +++ b/decode_test.go @@ -14,6 +14,7 @@ import ( "time" "github.com/goccy/go-yaml" + "github.com/goccy/go-yaml/parser" "golang.org/x/xerrors" ) @@ -2203,3 +2204,27 @@ func TestDecoder_Canonical(t *testing.T) { t.Fatalf("failed to decode canonical yaml: %+v", m) } } + +func TestDecoder_DecodeFromFile(t *testing.T) { + yml := ` +a: b +c: d +` + file, err := parser.ParseBytes([]byte(yml), 0) + if err != nil { + t.Fatal(err) + } + var v map[string]string + if err := yaml.NewDecoder(file).Decode(&v); err != nil { + t.Fatal(err) + } + if len(v) != 2 { + t.Fatal("failed to decode from ast.File") + } + if v["a"] != "b" { + t.Fatal("failed to decode from ast.File") + } + if v["c"] != "d" { + t.Fatal("failed to decode from ast.File") + } +} diff --git a/encode.go b/encode.go index 20ac90be..e3fef008 100644 --- a/encode.go +++ b/encode.go @@ -68,18 +68,27 @@ func (e *Encoder) Close() error { // // See the documentation for Marshal for details about the conversion of Go values to YAML. func (e *Encoder) Encode(v interface{}) error { + node, err := e.EncodeToNode(v) + if err != nil { + return errors.Wrapf(err, "failed to encode to node") + } + var p printer.Printer + e.writer.Write(p.PrintNode(node)) + return nil +} + +// EncodeToNode convert v to ast.Node. +func (e *Encoder) EncodeToNode(v interface{}) (ast.Node, error) { for _, opt := range e.opts { if err := opt(e); err != nil { - return errors.Wrapf(err, "failed to run option for encoder") + return nil, errors.Wrapf(err, "failed to run option for encoder") } } node, err := e.encodeValue(reflect.ValueOf(v), 1) if err != nil { - return errors.Wrapf(err, "failed to encode value") + return nil, errors.Wrapf(err, "failed to encode value") } - var p printer.Printer - e.writer.Write(p.PrintNode(node)) - return nil + return node, nil } func (e *Encoder) encodeDocument(doc []byte) (ast.Node, error) { @@ -155,10 +164,9 @@ func (e *Encoder) encodeValue(v reflect.Value, column int) (ast.Node, error) { anchorName := e.anchorPtrToNameMap[v.Pointer()] if anchorName != "" { aliasName := anchorName - return &ast.AliasNode{ - Start: token.New("*", "*", e.pos(column)), - Value: ast.String(token.New(aliasName, aliasName, e.pos(column))), - }, nil + alias := ast.Alias(token.New("*", "*", e.pos(column))) + alias.Value = ast.String(token.New(aliasName, aliasName, e.pos(column))) + return alias, nil } return e.encodeValue(v.Elem(), column) case reflect.Interface: @@ -276,11 +284,11 @@ func (e *Encoder) encodeMapItem(item MapItem, column int) (*ast.MappingValueNode if m, ok := value.(*ast.MappingNode); ok { m.AddColumn(e.indent) } - return &ast.MappingValueNode{ - Start: token.New("", "", e.pos(column)), - Key: e.encodeString(k.Interface().(string), column), - Value: value, - }, nil + return ast.MappingValue( + token.New("", "", e.pos(column)), + e.encodeString(k.Interface().(string), column), + value, + ), nil } func (e *Encoder) encodeMapSlice(value MapSlice, column int) (ast.Node, error) { @@ -312,10 +320,11 @@ func (e *Encoder) encodeMap(value reflect.Value, column int) ast.Node { if m, ok := value.(*ast.MappingNode); ok { m.AddColumn(e.indent) } - node.Values = append(node.Values, &ast.MappingValueNode{ - Key: e.encodeString(k.Interface().(string), column), - Value: value, - }) + node.Values = append(node.Values, ast.MappingValue( + nil, + e.encodeString(k.Interface().(string), column), + value, + )) } return node } @@ -376,11 +385,9 @@ func (e *Encoder) encodeTime(v time.Time, column int) ast.Node { } func (e *Encoder) encodeAnchor(anchorName string, value ast.Node, fieldValue reflect.Value, column int) (ast.Node, error) { - anchorNode := &ast.AnchorNode{ - Start: token.New("&", "&", e.pos(column)), - Name: ast.String(token.New(anchorName, anchorName, e.pos(column))), - Value: value, - } + anchorNode := ast.Anchor(token.New("&", "&", e.pos(column))) + anchorNode.Name = ast.String(token.New(anchorName, anchorName, e.pos(column))) + anchorNode.Value = value if e.anchorCallback != nil { if err := e.anchorCallback(anchorNode, fieldValue.Interface()); err != nil { return nil, errors.Wrapf(err, "failed to marshal anchor") @@ -451,20 +458,18 @@ func (e *Encoder) encodeStruct(value reflect.Value, column int) (ast.Node, error ) } aliasName := anchorName - value = &ast.AliasNode{ - Start: token.New("*", "*", e.pos(column)), - Value: ast.String(token.New(aliasName, aliasName, e.pos(column))), - } + alias := ast.Alias(token.New("*", "*", e.pos(column))) + alias.Value = ast.String(token.New(aliasName, aliasName, e.pos(column))) + value = alias if structField.IsInline { // if both used alias and inline, output `<<: *alias` key = ast.MergeKey(token.New("<<", "<<", e.pos(column))) } case structField.AliasName != "": aliasName := structField.AliasName - value = &ast.AliasNode{ - Start: token.New("*", "*", e.pos(column)), - Value: ast.String(token.New(aliasName, aliasName, e.pos(column))), - } + alias := ast.Alias(token.New("*", "*", e.pos(column))) + alias.Value = ast.String(token.New(aliasName, aliasName, e.pos(column))) + value = alias if structField.IsInline { // if both used alias and inline, output `<<: *alias` key = ast.MergeKey(token.New("<<", "<<", e.pos(column))) @@ -492,10 +497,7 @@ func (e *Encoder) encodeStruct(value reflect.Value, column int) (ast.Node, error } key.AddColumn(-e.indent) value.AddColumn(-e.indent) - node.Values = append(node.Values, &ast.MappingValueNode{ - Key: key, - Value: value, - }) + node.Values = append(node.Values, ast.MappingValue(nil, key, value)) } continue case structField.IsAutoAnchor: @@ -505,19 +507,14 @@ func (e *Encoder) encodeStruct(value reflect.Value, column int) (ast.Node, error } value = anchorNode } - node.Values = append(node.Values, &ast.MappingValueNode{ - Key: key, - Value: value, - }) + node.Values = append(node.Values, ast.MappingValue(nil, key, value)) } if hasInlineAnchorField { node.AddColumn(e.indent) anchorName := "anchor" - anchorNode := &ast.AnchorNode{ - Start: token.New("&", "&", e.pos(column)), - Name: ast.String(token.New(anchorName, anchorName, e.pos(column))), - Value: node, - } + anchorNode := ast.Anchor(token.New("&", "&", e.pos(column))) + anchorNode.Name = ast.String(token.New(anchorName, anchorName, e.pos(column))) + anchorNode.Value = node if e.anchorCallback != nil { if err := e.anchorCallback(anchorNode, value.Addr().Interface()); err != nil { return nil, errors.Wrapf(err, "failed to marshal anchor") diff --git a/parser/parser.go b/parser/parser.go index f4a370fb..bd107ebc 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -66,7 +66,7 @@ func (p *parser) parseSequence(ctx *context) (ast.Node, error) { func (p *parser) parseTag(ctx *context) (ast.Node, error) { tagToken := ctx.currentToken() - node := &ast.TagNode{Start: tagToken} + node := ast.Tag(tagToken) ctx.progress(1) // skip tag token var ( value ast.Node @@ -213,8 +213,8 @@ func (p *parser) parseMappingValue(ctx *context) (ast.Node, error) { return nil, errors.Wrapf(err, "failed to validate map value") } - mvnode := &ast.MappingValueNode{Start: tk, Key: key, Value: value} - node := &ast.MappingNode{Start: tk, Values: []*ast.MappingValueNode{mvnode}} + mvnode := ast.MappingValue(tk, key, value) + node := ast.Mapping(tk, false, mvnode) ntk := ctx.nextNotCommentToken() antk := ctx.afterNextNotCommentToken() @@ -253,10 +253,7 @@ func (p *parser) parseMappingValue(ctx *context) (ast.Node, error) { func (p *parser) parseSequenceEntry(ctx *context) (ast.Node, error) { tk := ctx.currentToken() - sequenceNode := &ast.SequenceNode{ - Start: tk, - Values: []ast.Node{}, - } + sequenceNode := ast.Sequence(tk, false) curColumn := tk.Position.Column for tk.Type == token.SequenceEntryType { ctx.progress(1) // skip sequence token @@ -282,7 +279,7 @@ func (p *parser) parseSequenceEntry(ctx *context) (ast.Node, error) { func (p *parser) parseAnchor(ctx *context) (ast.Node, error) { tk := ctx.currentToken() - anchor := &ast.AnchorNode{Start: tk} + anchor := ast.Anchor(tk) ntk := ctx.nextToken() if ntk == nil { return nil, errors.ErrSyntax("unexpected anchor. anchor name is undefined", tk) @@ -308,7 +305,7 @@ func (p *parser) parseAnchor(ctx *context) (ast.Node, error) { func (p *parser) parseAlias(ctx *context) (ast.Node, error) { tk := ctx.currentToken() - alias := &ast.AliasNode{Start: tk} + alias := ast.Alias(tk) ntk := ctx.nextToken() if ntk == nil { return nil, errors.ErrSyntax("unexpected alias. alias name is undefined", tk) @@ -385,7 +382,7 @@ func (p *parser) parseScalarValue(tk *token.Token) ast.Node { } func (p *parser) parseDirective(ctx *context) (ast.Node, error) { - node := &ast.DirectiveNode{Start: ctx.currentToken()} + node := ast.Directive(ctx.currentToken()) ctx.progress(1) // skip directive token value, err := p.parseToken(ctx, ctx.currentToken()) if err != nil { @@ -400,7 +397,7 @@ func (p *parser) parseDirective(ctx *context) (ast.Node, error) { } func (p *parser) parseLiteral(ctx *context) (ast.Node, error) { - node := &ast.LiteralNode{Start: ctx.currentToken()} + node := ast.Literal(ctx.currentToken()) ctx.progress(1) // skip literal/folded token value, err := p.parseToken(ctx, ctx.currentToken()) if err != nil { @@ -435,14 +432,14 @@ func (p *parser) setSameLineCommentIfExists(ctx *context, node ast.Node) error { return nil } -func (p *parser) parseDocument(ctx *context) (*ast.Document, error) { - node := &ast.Document{Start: ctx.currentToken()} +func (p *parser) parseDocument(ctx *context) (*ast.DocumentNode, error) { + startTk := ctx.currentToken() ctx.progress(1) // skip document header token body, err := p.parseToken(ctx, ctx.currentToken()) if err != nil { return nil, errors.Wrapf(err, "failed to parse document body") } - node.Body = body + node := ast.Document(startTk, body) if ntk := ctx.nextToken(); ntk != nil && ntk.Type == token.DocumentEndType { node.End = ntk ctx.progress(1) @@ -541,7 +538,7 @@ func (p *parser) parseToken(ctx *context, tk *token.Token) (ast.Node, error) { func (p *parser) parse(tokens token.Tokens, mode Mode) (*ast.File, error) { ctx := newContext(tokens, mode) - file := &ast.File{Docs: []*ast.Document{}} + file := &ast.File{Docs: []*ast.DocumentNode{}} for ctx.next() { node, err := p.parseToken(ctx, ctx.currentToken()) if err != nil { @@ -551,10 +548,10 @@ func (p *parser) parse(tokens token.Tokens, mode Mode) (*ast.File, error) { if node == nil { continue } - if doc, ok := node.(*ast.Document); ok { + if doc, ok := node.(*ast.DocumentNode); ok { file.Docs = append(file.Docs, doc) } else { - file.Docs = append(file.Docs, &ast.Document{Body: node}) + file.Docs = append(file.Docs, ast.Document(nil, node)) } } return file, nil diff --git a/path.go b/path.go index dfe802a2..cec46ded 100644 --- a/path.go +++ b/path.go @@ -9,6 +9,7 @@ import ( "github.com/goccy/go-yaml/ast" "github.com/goccy/go-yaml/internal/errors" "github.com/goccy/go-yaml/parser" + "github.com/goccy/go-yaml/printer" "golang.org/x/xerrors" ) @@ -25,7 +26,7 @@ var ( // . : child operator // .. : recursive descent // [num] : object/element of array by number -// [*] : all objects/elements for array +// [*] : all objects/elements for array. func PathString(s string) (*Path, error) { buf := []rune(s) length := len(buf) @@ -148,17 +149,17 @@ func parsePathIndex(b *PathBuilder, buf []rune, cursor int) (*PathBuilder, int, return nil, 0, errors.Wrapf(ErrInvalidPathString, "invalid character %s at %d", c, cursor) } -// Path represent YAMLPath ( like a JSONPath ) +// Path represent YAMLPath ( like a JSONPath ). type Path struct { node pathNode } -// String path to text +// String path to text. func (p *Path) String() string { return p.node.String() } -// Read decode from r and set extracted value by YAMLPath to v +// Read decode from r and set extracted value by YAMLPath to v. func (p *Path) Read(r io.Reader, v interface{}) error { node, err := p.ReadNode(r) if err != nil { @@ -170,7 +171,7 @@ func (p *Path) Read(r io.Reader, v interface{}) error { return nil } -// ReadNode create AST from r and extract node by YAMLPath +// ReadNode create AST from r and extract node by YAMLPath. func (p *Path) ReadNode(r io.Reader) (ast.Node, error) { if p.node == nil { return nil, ErrInvalidPath @@ -190,7 +191,7 @@ func (p *Path) ReadNode(r io.Reader) (ast.Node, error) { return node, nil } -// Filter filter from target by YAMLPath and set it to v +// Filter filter from target by YAMLPath and set it to v. func (p *Path) Filter(target, v interface{}) error { b, err := Marshal(target) if err != nil { @@ -202,7 +203,7 @@ func (p *Path) Filter(target, v interface{}) error { return nil } -// FilterFile filter from ast.File by YAMLPath +// FilterFile filter from ast.File by YAMLPath. func (p *Path) FilterFile(f *ast.File) (ast.Node, error) { for _, doc := range f.Docs { node, err := p.FilterNode(doc.Body) @@ -216,7 +217,7 @@ func (p *Path) FilterFile(f *ast.File) (ast.Node, error) { return nil, nil } -// FilterNode filter from node by YAMLPath +// FilterNode filter from node by YAMLPath. func (p *Path) FilterNode(node ast.Node) (ast.Node, error) { n, err := p.node.filter(node) if err != nil { @@ -225,43 +226,138 @@ func (p *Path) FilterNode(node ast.Node) (ast.Node, error) { return n, nil } -// PathBuilder represent builder for YAMLPath +// MergeFromReader merge YAML text into ast.File. +func (p *Path) MergeFromReader(dst *ast.File, src io.Reader) error { + var buf bytes.Buffer + if _, err := io.Copy(&buf, src); err != nil { + return errors.Wrapf(err, "failed to copy from reader") + } + file, err := parser.ParseBytes(buf.Bytes(), 0) + if err != nil { + return errors.Wrapf(err, "failed to parse") + } + if err := p.MergeFromFile(dst, file); err != nil { + return errors.Wrapf(err, "failed to merge file") + } + return nil +} + +// MergeFromFile merge ast.File into ast.File. +func (p *Path) MergeFromFile(dst *ast.File, src *ast.File) error { + base, err := p.FilterFile(dst) + if err != nil { + return errors.Wrapf(err, "failed to filter file") + } + for _, doc := range src.Docs { + if err := ast.Merge(base, doc); err != nil { + return errors.Wrapf(err, "failed to merge") + } + } + return nil +} + +// MergeFromNode merge ast.Node into ast.File. +func (p *Path) MergeFromNode(dst *ast.File, src ast.Node) error { + base, err := p.FilterFile(dst) + if err != nil { + return errors.Wrapf(err, "failed to filter file") + } + if err := ast.Merge(base, src); err != nil { + return errors.Wrapf(err, "failed to merge") + } + return nil +} + +// ReplaceWithReader replace ast.File with io.Reader. +func (p *Path) ReplaceWithReader(dst *ast.File, src io.Reader) error { + var buf bytes.Buffer + if _, err := io.Copy(&buf, src); err != nil { + return errors.Wrapf(err, "failed to copy from reader") + } + file, err := parser.ParseBytes(buf.Bytes(), 0) + if err != nil { + return errors.Wrapf(err, "failed to parse") + } + if err := p.ReplaceWithFile(dst, file); err != nil { + return errors.Wrapf(err, "failed to replace file") + } + return nil +} + +// ReplaceWithFile replace ast.File with ast.File. +func (p *Path) ReplaceWithFile(dst *ast.File, src *ast.File) error { + for _, doc := range src.Docs { + if err := p.ReplaceWithNode(dst, doc); err != nil { + return errors.Wrapf(err, "failed to replace file by path ( %s )", p.node) + } + } + return nil +} + +// ReplaceNode replace ast.File with ast.Node. +func (p *Path) ReplaceWithNode(dst *ast.File, node ast.Node) error { + for _, doc := range dst.Docs { + if node.Type() == ast.DocumentType { + node = node.(*ast.DocumentNode).Body + } + if err := p.node.replace(doc.Body, node); err != nil { + return errors.Wrapf(err, "failed to replace node by path ( %s )", p.node) + } + } + return nil +} + +// AnnotateSource add annotation to passed source ( see section 5.1 in README.md ). +func (p *Path) AnnotateSource(source []byte, colored bool) ([]byte, error) { + file, err := parser.ParseBytes([]byte(source), 0) + if err != nil { + return nil, err + } + node, err := p.FilterFile(file) + if err != nil { + return nil, err + } + var pp printer.Printer + return []byte(pp.PrintErrorToken(node.GetToken(), colored)), nil +} + +// PathBuilder represent builder for YAMLPath. type PathBuilder struct { root *rootNode node pathNode } -// Root add '$' to current path +// Root add '$' to current path. func (b *PathBuilder) Root() *PathBuilder { root := newRootNode() return &PathBuilder{root: root, node: root} } -// IndexAll add '[*]' to current path +// IndexAll add '[*]' to current path. func (b *PathBuilder) IndexAll() *PathBuilder { b.node = b.node.chain(newIndexAllNode()) return b } -// Recursive add '..selector' to current path +// Recursive add '..selector' to current path. func (b *PathBuilder) Recursive(selector string) *PathBuilder { b.node = b.node.chain(newRecursiveNode(selector)) return b } -// Child add '.name' to current path +// Child add '.name' to current path. func (b *PathBuilder) Child(name string) *PathBuilder { b.node = b.node.chain(newSelectorNode(name)) return b } -// Index add '[idx]' to current path +// Index add '[idx]' to current path. func (b *PathBuilder) Index(idx uint) *PathBuilder { b.node = b.node.chain(newIndexNode(idx)) return b } -// Build build YAMLPath +// Build build YAMLPath. func (b *PathBuilder) Build() *Path { return &Path{node: b.root} } @@ -270,6 +366,7 @@ type pathNode interface { fmt.Stringer chain(pathNode) pathNode filter(ast.Node) (ast.Node, error) + replace(ast.Node, ast.Node) error } type basePathNode struct { @@ -308,6 +405,16 @@ func (n *rootNode) filter(node ast.Node) (ast.Node, error) { return filtered, nil } +func (n *rootNode) replace(node ast.Node, target ast.Node) error { + if n.child == nil { + return nil + } + if err := n.child.replace(node, target); err != nil { + return errors.Wrapf(err, "failed to replace") + } + return nil +} + type selectorNode struct { *basePathNode selector string @@ -355,6 +462,42 @@ func (n *selectorNode) filter(node ast.Node) (ast.Node, error) { return nil, nil } +func (n *selectorNode) replaceMapValue(value *ast.MappingValueNode, target ast.Node) error { + key := value.Key.GetToken().Value + if key != n.selector { + return nil + } + if n.child == nil { + if err := value.Replace(target); err != nil { + return errors.Wrapf(err, "failed to replace") + } + } else { + if err := n.child.replace(value.Value, target); err != nil { + return errors.Wrapf(err, "failed to replace") + } + } + return nil +} + +func (n *selectorNode) replace(node ast.Node, target ast.Node) error { + switch node.Type() { + case ast.MappingType: + for _, value := range node.(*ast.MappingNode).Values { + if err := n.replaceMapValue(value, target); err != nil { + return errors.Wrapf(err, "failed to replace map value") + } + } + case ast.MappingValueType: + value := node.(*ast.MappingValueNode) + if err := n.replaceMapValue(value, target); err != nil { + return errors.Wrapf(err, "failed to replace map value") + } + default: + return errors.Wrapf(ErrInvalidQuery, "expected node type is map or map value. but got %s", node.Type()) + } + return nil +} + func (n *selectorNode) String() string { s := fmt.Sprintf(".%s", n.selector) if n.child != nil { @@ -394,6 +537,26 @@ func (n *indexNode) filter(node ast.Node) (ast.Node, error) { return filtered, nil } +func (n *indexNode) replace(node ast.Node, target ast.Node) error { + if node.Type() != ast.SequenceType { + return errors.Wrapf(ErrInvalidQuery, "expected sequence type node. but got %s", node.Type()) + } + sequence := node.(*ast.SequenceNode) + if n.selector >= uint(len(sequence.Values)) { + return errors.Wrapf(ErrInvalidQuery, "expected index is %d. but got sequences has %d items", n.selector, sequence.Values) + } + if n.child == nil { + if err := sequence.Replace(int(n.selector), target); err != nil { + return errors.Wrapf(err, "failed to replace") + } + return nil + } + if err := n.child.replace(sequence.Values[n.selector], target); err != nil { + return errors.Wrapf(err, "failed to replace") + } + return nil +} + func (n *indexNode) String() string { s := fmt.Sprintf("[%d]", n.selector) if n.child != nil { @@ -440,6 +603,27 @@ func (n *indexAllNode) filter(node ast.Node) (ast.Node, error) { return &out, nil } +func (n *indexAllNode) replace(node ast.Node, target ast.Node) error { + if node.Type() != ast.SequenceType { + return errors.Wrapf(ErrInvalidQuery, "expected sequence type node. but got %s", node.Type()) + } + sequence := node.(*ast.SequenceNode) + if n.child == nil { + for idx := range sequence.Values { + if err := sequence.Replace(idx, target); err != nil { + return errors.Wrapf(err, "failed to replace") + } + } + return nil + } + for _, value := range sequence.Values { + if err := n.child.replace(value, target); err != nil { + return errors.Wrapf(err, "failed to replace") + } + } + return nil +} + type recursiveNode struct { *basePathNode selector string @@ -501,3 +685,38 @@ func (n *recursiveNode) filter(node ast.Node) (ast.Node, error) { sequence.Start = node.GetToken() return sequence, nil } + +func (n *recursiveNode) replaceNode(node ast.Node, target ast.Node) error { + switch typedNode := node.(type) { + case *ast.MappingNode: + for _, value := range typedNode.Values { + if err := n.replaceNode(value, target); err != nil { + return errors.Wrapf(err, "failed to replace") + } + } + case *ast.MappingValueNode: + key := typedNode.Key.GetToken().Value + if n.selector == key { + if err := typedNode.Replace(target); err != nil { + return errors.Wrapf(err, "failed to replace") + } + } + if err := n.replaceNode(typedNode.Value, target); err != nil { + return errors.Wrapf(err, "failed to replace") + } + case *ast.SequenceNode: + for _, value := range typedNode.Values { + if err := n.replaceNode(value, target); err != nil { + return errors.Wrapf(err, "failed to replace") + } + } + } + return nil +} + +func (n *recursiveNode) replace(node ast.Node, target ast.Node) error { + if err := n.replaceNode(node, target); err != nil { + return errors.Wrapf(err, "failed to replace") + } + return nil +} diff --git a/path_test.go b/path_test.go index 1e5be825..02cd329d 100644 --- a/path_test.go +++ b/path_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/goccy/go-yaml" + "github.com/goccy/go-yaml/parser" ) func builder() *yaml.PathBuilder { return &yaml.PathBuilder{} } @@ -114,7 +115,336 @@ store: }) } -func Example_YAMLPath() { +func TestPath_Merge(t *testing.T) { + tests := []struct { + path string + dst string + src string + expected string + }{ + { + "$.c", + ` +a: 1 +b: 2 +c: + d: 3 + e: 4 +`, + ` +f: 5 +g: 6 +`, + ` +a: 1 +b: 2 +c: + d: 3 + e: 4 + f: 5 + g: 6 +`, + }, + { + "$.a.b", + ` +a: + b: + - 1 + - 2 +`, + ` +- 3 +- map: + - 4 + - 5 +`, + ` +a: + b: + - 1 + - 2 + - 3 + - map: + - 4 + - 5 +`, + }, + } + for _, test := range tests { + t.Run(test.path, func(t *testing.T) { + path, err := yaml.PathString(test.path) + if err != nil { + t.Fatalf("%+v", err) + } + t.Run("FromReader", func(t *testing.T) { + file, err := parser.ParseBytes([]byte(test.dst), 0) + if err != nil { + t.Fatalf("%+v", err) + } + if err := path.MergeFromReader(file, strings.NewReader(test.src)); err != nil { + t.Fatalf("%+v", err) + } + actual := "\n" + file.String() + "\n" + if test.expected != actual { + t.Fatalf("expected: %q. but got %q", test.expected, actual) + } + }) + t.Run("FromFile", func(t *testing.T) { + file, err := parser.ParseBytes([]byte(test.dst), 0) + if err != nil { + t.Fatalf("%+v", err) + } + src, err := parser.ParseBytes([]byte(test.src), 0) + if err != nil { + t.Fatalf("%+v", err) + } + if err := path.MergeFromFile(file, src); err != nil { + t.Fatalf("%+v", err) + } + actual := "\n" + file.String() + "\n" + if test.expected != actual { + t.Fatalf("expected: %q. but got %q", test.expected, actual) + } + }) + t.Run("FromNode", func(t *testing.T) { + file, err := parser.ParseBytes([]byte(test.dst), 0) + if err != nil { + t.Fatalf("%+v", err) + } + src, err := parser.ParseBytes([]byte(test.src), 0) + if err != nil { + t.Fatalf("%+v", err) + } + if len(src.Docs) == 0 { + t.Fatalf("failed to parse") + } + if err := path.MergeFromNode(file, src.Docs[0]); err != nil { + t.Fatalf("%+v", err) + } + actual := "\n" + file.String() + "\n" + if test.expected != actual { + t.Fatalf("expected: %q. but got %q", test.expected, actual) + } + }) + }) + } +} + +func TestPath_Replace(t *testing.T) { + tests := []struct { + path string + dst string + src string + expected string + }{ + { + "$.a", + ` +a: 1 +b: 2 +`, + `3`, + ` +a: 3 +b: 2 +`, + }, + { + "$.b", + ` +b: 1 +c: 2 +`, + ` +d: e +f: + g: h + i: j +`, + ` +b: + d: e + f: + g: h + i: j +c: 2 +`, + }, + { + "$.a.b[0]", + ` +a: + b: + - hello +c: 2 +`, + `world`, + ` +a: + b: + - world +c: 2 +`, + }, + + { + "$.books[*].author", + ` +books: + - name: book_a + author: none + - name: book_b + author: none +pictures: + - name: picture_a + author: none + - name: picture_b + author: none +building: + author: none +`, + `ken`, + ` +books: + - name: book_a + author: ken + - name: book_b + author: ken +pictures: + - name: picture_a + author: none + - name: picture_b + author: none +building: + author: none +`, + }, + { + "$..author", + ` +books: + - name: book_a + author: none + - name: book_b + author: none +pictures: + - name: picture_a + author: none + - name: picture_b + author: none +building: + author: none +`, + `ken`, + ` +books: + - name: book_a + author: ken + - name: book_b + author: ken +pictures: + - name: picture_a + author: ken + - name: picture_b + author: ken +building: + author: ken +`, + }, + } + for _, test := range tests { + t.Run(test.path, func(t *testing.T) { + path, err := yaml.PathString(test.path) + if err != nil { + t.Fatalf("%+v", err) + } + t.Run("WithReader", func(t *testing.T) { + file, err := parser.ParseBytes([]byte(test.dst), 0) + if err != nil { + t.Fatalf("%+v", err) + } + if err := path.ReplaceWithReader(file, strings.NewReader(test.src)); err != nil { + t.Fatalf("%+v", err) + } + actual := "\n" + file.String() + "\n" + if test.expected != actual { + t.Fatalf("expected: %q. but got %q", test.expected, actual) + } + }) + t.Run("WithFile", func(t *testing.T) { + file, err := parser.ParseBytes([]byte(test.dst), 0) + if err != nil { + t.Fatalf("%+v", err) + } + src, err := parser.ParseBytes([]byte(test.src), 0) + if err != nil { + t.Fatalf("%+v", err) + } + if err := path.ReplaceWithFile(file, src); err != nil { + t.Fatalf("%+v", err) + } + actual := "\n" + file.String() + "\n" + if test.expected != actual { + t.Fatalf("expected: %q. but got %q", test.expected, actual) + } + }) + t.Run("WithNode", func(t *testing.T) { + file, err := parser.ParseBytes([]byte(test.dst), 0) + if err != nil { + t.Fatalf("%+v", err) + } + src, err := parser.ParseBytes([]byte(test.src), 0) + if err != nil { + t.Fatalf("%+v", err) + } + if len(src.Docs) == 0 { + t.Fatalf("failed to parse") + } + if err := path.ReplaceWithNode(file, src.Docs[0]); err != nil { + t.Fatalf("%+v", err) + } + actual := "\n" + file.String() + "\n" + if test.expected != actual { + t.Fatalf("expected: %q. but got %q", test.expected, actual) + } + }) + }) + } +} + +func ExamplePath_AnnotateSource() { + yml := ` +a: 1 +b: "hello" +` + var v struct { + A int + B string + } + if err := yaml.Unmarshal([]byte(yml), &v); err != nil { + panic(err) + } + if v.A != 2 { + // output error with YAML source + path, err := yaml.PathString("$.a") + if err != nil { + log.Fatal(err) + } + source, err := path.AnnotateSource([]byte(yml), false) + if err != nil { + log.Fatal(err) + } + fmt.Printf("a value expected 2 but actual %d:\n%s\n", v.A, string(source)) + } + // OUTPUT: + // a value expected 2 but actual 1: + // > 2 | a: 1 + // ^ + // 3 | b: "hello" +} + +func ExamplePath_PathString() { yml := ` store: book: diff --git a/yaml.go b/yaml.go index c582567e..b4e76b48 100644 --- a/yaml.go +++ b/yaml.go @@ -4,6 +4,7 @@ import ( "bytes" "io" + "github.com/goccy/go-yaml/ast" "github.com/goccy/go-yaml/internal/errors" "golang.org/x/xerrors" ) @@ -43,6 +44,15 @@ type MapItem struct { // The order of keys is preserved when encoding and decoding. type MapSlice []MapItem +// ToMap convert to map[interface{}]interface{}. +func (s MapSlice) ToMap() map[interface{}]interface{} { + v := map[interface{}]interface{}{} + for _, item := range s { + v[item.Key] = item.Value + } + return v +} + // Marshal serializes the value provided into a YAML document. The structure // of the generated document will reflect the structure of the value itself. // Maps and pointers (to struct, string, int, etc) are accepted as the in value. @@ -94,14 +104,29 @@ type MapSlice []MapItem // yaml.Marshal(&T{F: 1}) // Returns "a: 1\nb: 0\n" // func Marshal(v interface{}) ([]byte, error) { + return MarshalWithOptions(v) +} + +// MarshalWithOptions serializes the value provided into a YAML document with EncodeOptions. +func MarshalWithOptions(v interface{}, opts ...EncodeOption) ([]byte, error) { var buf bytes.Buffer - enc := NewEncoder(&buf) + enc := NewEncoder(&buf, opts...) if err := enc.Encode(v); err != nil { return nil, errors.Wrapf(err, "failed to marshal") } return buf.Bytes(), nil } +// ValueToNode convert from value to ast.Node. +func ValueToNode(v interface{}, opts ...EncodeOption) (ast.Node, error) { + var buf bytes.Buffer + node, err := NewEncoder(&buf, opts...).EncodeToNode(v) + if err != nil { + return nil, errors.Wrapf(err, "failed to convert value to node") + } + return node, nil +} + // Unmarshal decodes the first document found within the in byte slice // and assigns decoded values into the out value. // @@ -126,7 +151,13 @@ func Marshal(v interface{}) ([]byte, error) { // supported tag options. // func Unmarshal(data []byte, v interface{}) error { - dec := NewDecoder(bytes.NewBuffer(data)) + return UnmarshalWithOptions(data, v) +} + +// UnmarshalWithOptions decodes with DecodeOptions the first document found within the in byte slice +// and assigns decoded values into the out value. +func UnmarshalWithOptions(data []byte, v interface{}, opts ...DecodeOption) error { + dec := NewDecoder(bytes.NewBuffer(data), opts...) if err := dec.Decode(v); err != nil { if err == io.EOF { return nil @@ -151,5 +182,4 @@ func FormatError(e error, colored, inclSource bool) string { } return e.Error() - } diff --git a/yaml_test.go b/yaml_test.go index e3993742..b2c34d67 100644 --- a/yaml_test.go +++ b/yaml_test.go @@ -263,11 +263,11 @@ collection: ` var v rootObject if err := yaml.Unmarshal([]byte(yml), &v); err != nil { - panic(err) + t.Fatal(err) } opt := yaml.MarshalAnchor(func(anchor *ast.AnchorNode, value interface{}) error { if o, ok := value.(*ObjectDecl); ok { - anchor.Name.(*ast.StringNode).Value = o.Name + return anchor.SetName(o.Name) } return nil }) @@ -280,3 +280,67 @@ collection: t.Fatalf("failed to marshal: expected:[%s] actual:[%s]", yml, actual) } } + +func TestMapSlice_Map(t *testing.T) { + yml := ` +a: b +c: d +` + var v yaml.MapSlice + if err := yaml.Unmarshal([]byte(yml), &v); err != nil { + t.Fatal(err) + } + m := v.ToMap() + if len(m) != 2 { + t.Fatal("failed to convert MapSlice to map") + } + if m["a"] != "b" { + t.Fatal("failed to convert MapSlice to map") + } + if m["c"] != "d" { + t.Fatal("failed to convert MapSlice to map") + } +} + +func TestMarshalWithModifiedAnchorAlias(t *testing.T) { + yml := ` +a: &a 1 +b: *a +` + var v struct { + A *int `yaml:"a,anchor"` + B *int `yaml:"b,alias"` + } + if err := yaml.Unmarshal([]byte(yml), &v); err != nil { + t.Fatal(err) + } + node, err := yaml.ValueToNode(v) + if err != nil { + t.Fatal(err) + } + anchors := ast.Filter(ast.AnchorType, node) + if len(anchors) != 1 { + t.Fatal("failed to filter node") + } + anchor := anchors[0].(*ast.AnchorNode) + if err := anchor.SetName("b"); err != nil { + t.Fatal(err) + } + aliases := ast.Filter(ast.AliasType, node) + if len(anchors) != 1 { + t.Fatal("failed to filter node") + } + alias := aliases[0].(*ast.AliasNode) + if err := alias.SetName("b"); err != nil { + t.Fatal(err) + } + + expected := ` +a: &b 1 +b: *b` + + actual := "\n" + node.String() + if expected != actual { + t.Fatalf("failed to marshal: expected:[%q] but got [%q]", expected, actual) + } +}