Skip to content

Commit 0c1b278

Browse files
json: update cycle checking aspects from stdlib golang_encode_test.go
1 parent 8e958e9 commit 0c1b278

File tree

2 files changed

+105
-11
lines changed

2 files changed

+105
-11
lines changed

json/golang_encode_test.go

Lines changed: 75 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -136,21 +136,85 @@ func TestEncodeRenamedByteSlice(t *testing.T) {
136136
}
137137
}
138138

139-
var unsupportedValues = []any{
140-
math.NaN(),
141-
math.Inf(-1),
142-
math.Inf(1),
139+
type SamePointerNoCycle struct {
140+
Ptr1, Ptr2 *SamePointerNoCycle
141+
}
142+
143+
var samePointerNoCycle = &SamePointerNoCycle{}
144+
145+
type PointerCycle struct {
146+
Ptr *PointerCycle
147+
}
148+
149+
var pointerCycle = &PointerCycle{}
150+
151+
type PointerCycleIndirect struct {
152+
Ptrs []any
153+
}
154+
155+
type RecursiveSlice []RecursiveSlice
156+
157+
var (
158+
pointerCycleIndirect = &PointerCycleIndirect{}
159+
mapCycle = make(map[string]any)
160+
sliceCycle = []any{nil}
161+
sliceNoCycle = []any{nil, nil}
162+
recursiveSliceCycle = []RecursiveSlice{nil}
163+
)
164+
165+
func init() {
166+
ptr := &SamePointerNoCycle{}
167+
samePointerNoCycle.Ptr1 = ptr
168+
samePointerNoCycle.Ptr2 = ptr
169+
170+
pointerCycle.Ptr = pointerCycle
171+
pointerCycleIndirect.Ptrs = []any{pointerCycleIndirect}
172+
173+
mapCycle["x"] = mapCycle
174+
sliceCycle[0] = sliceCycle
175+
sliceNoCycle[1] = sliceNoCycle[:1]
176+
for i := startDetectingCyclesAfter; i > 0; i-- {
177+
sliceNoCycle = []any{sliceNoCycle}
178+
}
179+
recursiveSliceCycle[0] = recursiveSliceCycle
180+
}
181+
182+
func TestSamePointerNoCycle(t *testing.T) {
183+
if _, err := Marshal(samePointerNoCycle); err != nil {
184+
t.Fatalf("Marshal error: %v", err)
185+
}
186+
}
187+
188+
func TestSliceNoCycle(t *testing.T) {
189+
if _, err := Marshal(sliceNoCycle); err != nil {
190+
t.Fatalf("Marshal error: %v", err)
191+
}
143192
}
144193

145194
func TestUnsupportedValues(t *testing.T) {
146-
for _, v := range unsupportedValues {
147-
if _, err := Marshal(v); err != nil {
148-
if _, ok := err.(*UnsupportedValueError); !ok {
149-
t.Errorf("for %v, got %T want UnsupportedValueError", v, err)
195+
tests := []struct {
196+
CaseName
197+
in any
198+
}{
199+
{Name(""), math.NaN()},
200+
{Name(""), math.Inf(-1)},
201+
{Name(""), math.Inf(1)},
202+
{Name(""), pointerCycle},
203+
{Name(""), pointerCycleIndirect},
204+
{Name(""), mapCycle},
205+
{Name(""), sliceCycle},
206+
{Name(""), recursiveSliceCycle},
207+
}
208+
for _, tt := range tests {
209+
t.Run(tt.Name, func(t *testing.T) {
210+
if _, err := Marshal(tt.in); err != nil {
211+
if _, ok := err.(*UnsupportedValueError); !ok {
212+
t.Errorf("%s: Marshal error:\n\tgot: %T\n\twant: %T", tt.Where, err, new(UnsupportedValueError))
213+
}
214+
} else {
215+
t.Errorf("%s: Marshal error: got nil, want non-nil", tt.Where)
150216
}
151-
} else {
152-
t.Errorf("for %v, expected error", v)
153-
}
217+
})
154218
}
155219
}
156220

json/golang_shim_test.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@ package json
44

55
import (
66
"bytes"
7+
"fmt"
8+
"path"
79
"reflect"
10+
"runtime"
811
"sync"
912
"testing"
1013
)
@@ -68,3 +71,30 @@ func errorWithPrefixes(t *testing.T, prefixes []any, format string, elements ...
6871
}
6972
t.Errorf(fullFormat, allElements...)
7073
}
74+
75+
// =============================================================================
76+
// Copyright 2010 The Go Authors. All rights reserved.
77+
// Use of this source code is governed by a BSD-style
78+
// license that can be found in the LICENSE file.
79+
80+
// CaseName is a case name annotated with a file and line.
81+
type CaseName struct {
82+
Name string
83+
Where CasePos
84+
}
85+
86+
// Name annotates a case name with the file and line of the caller.
87+
func Name(s string) (c CaseName) {
88+
c.Name = s
89+
runtime.Callers(2, c.Where.pc[:])
90+
return c
91+
}
92+
93+
// CasePos represents a file and line number.
94+
type CasePos struct{ pc [1]uintptr }
95+
96+
func (pos CasePos) String() string {
97+
frames := runtime.CallersFrames(pos.pc[:])
98+
frame, _ := frames.Next()
99+
return fmt.Sprintf("%s:%d", path.Base(frame.File), frame.Line)
100+
}

0 commit comments

Comments
 (0)