- thi
ng
diff --git a/internal/sanitizer/sanitizer.go b/internal/sanitizer/sanitizer.go index d35645051..741fd9e5f 100644 --- a/internal/sanitizer/sanitizer.go +++ b/internal/sanitizer/sanitizer.go @@ -34,12 +34,7 @@ func SanitizeBytes(b []byte) []byte { if err != nil { return []byte{} } - var keepNodes []*html.Node - for _, n := range nodes { - if sanitize(n) { - keepNodes = append(keepNodes, n) - } - } + keepNodes := sanitizeNodes(nodes) var buf bytes.Buffer for _, n := range keepNodes { html.Render(&buf, n) @@ -48,22 +43,28 @@ func SanitizeBytes(b []byte) []byte { } // sanitize sanitizes the attributes and children of n. -// It returns false if the entire node should be cut out. -func sanitize(n *html.Node) bool { +// It returns false if the node should be cut out, and a list +// of parent-less nodes the node should be replaced with. +func sanitize(n *html.Node) ([]*html.Node, bool) { switch n.Type { case html.CommentNode: - return false + return nil, false case html.DoctypeNode: - return false + return nil, false case html.TextNode: - return true // Assume text nodes are safe + return nil, true // Assume text nodes are safe case html.ElementNode: if n.Namespace != "" { - return false + return nil, false } n.Data = strings.ToLower(n.Data) if !allowElemsMap[n.Data] { - return false + switch n.Data { + case "frame", "frameset", "iframe", "noembed", "noframes", "noscript", "nostyle", "object", "script", "style", "title": + return nil, false + default: + return extractSanitizedChildren(n), false + } } keepAttr := []html.Attribute{} for _, attr := range n.Attr { @@ -93,22 +94,56 @@ func sanitize(n *html.Node) bool { keepAttr = append(keepAttr, attr) } if n.Data == "a" { + if len(keepAttr) == 0 { + return extractSanitizedChildren(n), false + } keepAttr = addRelNoFollow(keepAttr) } + if n.Data == "img" { + if len(keepAttr) == 0 { + return nil, false + } + } n.Attr = keepAttr - var removeChildren []*html.Node + replaceChildren := make(map[*html.Node][]*html.Node) for child := n.FirstChild; child != nil; child = child.NextSibling { - if !sanitize(child) { - removeChildren = append(removeChildren, child) + if replace, ok := sanitize(child); !ok { + replaceChildren[child] = replace } } - for _, child := range removeChildren { + for child, replace := range replaceChildren { + for _, r := range replace { + n.InsertBefore(r, child) + } n.RemoveChild(child) } - return true + return nil, true default: - return false + return extractSanitizedChildren(n), false + } +} + +func extractSanitizedChildren(node *html.Node) []*html.Node { + var children []*html.Node + for child := node.FirstChild; child != nil; child = child.NextSibling { + children = append(children, child) + } + for _, child := range children { + node.RemoveChild(child) + } + return sanitizeNodes(children) +} + +func sanitizeNodes(nodes []*html.Node) []*html.Node { + var keepNodes []*html.Node + for _, n := range nodes { + if replace, ok := sanitize(n); ok { + keepNodes = append(keepNodes, n) + } else { + keepNodes = append(keepNodes, replace...) + } } + return keepNodes } func addRelNoFollow(attrs []html.Attribute) []html.Attribute { diff --git a/internal/sanitizer/sanitizer_test.go b/internal/sanitizer/sanitizer_test.go index 9830e764c..00a3f06d5 100644 --- a/internal/sanitizer/sanitizer_test.go +++ b/internal/sanitizer/sanitizer_test.go @@ -16,13 +16,17 @@ func TestSanitizeBytes(t *testing.T) { "", "", }, + { + "", + "", + }, { `body`, - `body`, + `body`, }, { `body`, - `body`, + `body`, }, { `
`, @@ -52,6 +56,8 @@ func TestSanitizeBytes(t *testing.T) {AB
`, }, { ` @@ -135,6 +141,8 @@ func TestSanitizeBytes(t *testing.T) { `, }, + {` middle hello middle