Skip to content

Commit 0ef0da9

Browse files
committed
Abstract file tracking methods to struct
1 parent d243234 commit 0ef0da9

File tree

1 file changed

+55
-46
lines changed

1 file changed

+55
-46
lines changed

api_routine_test.go

+55-46
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
package routine
22

33
import (
4+
"fmt"
45
"io"
56
"os"
6-
"path"
7-
"strconv"
87
"strings"
98
"sync"
109
"testing"
@@ -138,8 +137,9 @@ func TestWrapTask_HasContext(t *testing.T) {
138137
}
139138

140139
func TestWrapTask_Complete_ThenFail(t *testing.T) {
141-
newStdout, oldStdout := captureStdout()
142-
defer restoreStdout(newStdout, oldStdout)
140+
tracker := NewFileTracker(&os.Stdout)
141+
tracker.Begin()
142+
defer tracker.End()
143143
//
144144
run := false
145145
wg := &sync.WaitGroup{}
@@ -167,8 +167,7 @@ func TestWrapTask_Complete_ThenFail(t *testing.T) {
167167
assert.True(t, run)
168168
//
169169
time.Sleep(10 * time.Millisecond)
170-
output := readAll(newStdout)
171-
assert.Equal(t, "", output)
170+
assert.Equal(t, "", tracker.Value())
172171
}
173172

174173
func TestWrapWaitTask_NoContext(t *testing.T) {
@@ -344,8 +343,9 @@ func TestWrapWaitTask_HasContext_Cancel(t *testing.T) {
344343
}
345344

346345
func TestWrapWaitTask_Complete_ThenFail(t *testing.T) {
347-
newStdout, oldStdout := captureStdout()
348-
defer restoreStdout(newStdout, oldStdout)
346+
tracker := NewFileTracker(&os.Stdout)
347+
tracker.Begin()
348+
defer tracker.End()
349349
//
350350
run := false
351351
wg := &sync.WaitGroup{}
@@ -373,8 +373,7 @@ func TestWrapWaitTask_Complete_ThenFail(t *testing.T) {
373373
assert.True(t, run)
374374
//
375375
time.Sleep(10 * time.Millisecond)
376-
output := readAll(newStdout)
377-
assert.Equal(t, "", output)
376+
assert.Equal(t, "", tracker.Value())
378377
}
379378

380379
func TestWrapWaitResultTask_NoContext(t *testing.T) {
@@ -554,8 +553,9 @@ func TestWrapWaitResultTask_HasContext_Cancel(t *testing.T) {
554553
}
555554

556555
func TestWrapWaitResultTask_Complete_ThenFail(t *testing.T) {
557-
newStdout, oldStdout := captureStdout()
558-
defer restoreStdout(newStdout, oldStdout)
556+
tracker := NewFileTracker(&os.Stdout)
557+
tracker.Begin()
558+
defer tracker.End()
559559
//
560560
run := false
561561
wg := &sync.WaitGroup{}
@@ -583,13 +583,13 @@ func TestWrapWaitResultTask_Complete_ThenFail(t *testing.T) {
583583
assert.True(t, run)
584584
//
585585
time.Sleep(10 * time.Millisecond)
586-
output := readAll(newStdout)
587-
assert.Equal(t, "", output)
586+
assert.Equal(t, "", tracker.Value())
588587
}
589588

590589
func TestGo_Error(t *testing.T) {
591-
newStdout, oldStdout := captureStdout()
592-
defer restoreStdout(newStdout, oldStdout)
590+
tracker := NewFileTracker(&os.Stdout)
591+
tracker.Begin()
592+
defer tracker.End()
593593
//
594594
run := false
595595
assert.NotPanics(t, func() {
@@ -605,8 +605,7 @@ func TestGo_Error(t *testing.T) {
605605
assert.True(t, run)
606606
//
607607
time.Sleep(10 * time.Millisecond)
608-
output := readAll(newStdout)
609-
lines := strings.Split(output, newLine)
608+
lines := strings.Split(tracker.Value(), newLine)
610609
assert.Equal(t, 7, len(lines))
611610
//
612611
line := lines[0]
@@ -728,7 +727,7 @@ func TestGoWait_Error(t *testing.T) {
728727
//
729728
line = lines[1]
730729
assert.True(t, strings.HasPrefix(line, " at github.com/timandy/routine.TestGoWait_Error."))
731-
assert.True(t, strings.HasSuffix(line, "api_routine_test.go:711"))
730+
assert.True(t, strings.HasSuffix(line, "api_routine_test.go:710"))
732731
//
733732
line = lines[2]
734733
assert.True(t, strings.HasPrefix(line, " at github.com/timandy/routine.inheritedWaitTask.run()"))
@@ -834,7 +833,7 @@ func TestGoWaitResult_Error(t *testing.T) {
834833
//
835834
line = lines[1]
836835
assert.True(t, strings.HasPrefix(line, " at github.com/timandy/routine.TestGoWaitResult_Error."))
837-
assert.True(t, strings.HasSuffix(line, "api_routine_test.go:815"))
836+
assert.True(t, strings.HasSuffix(line, "api_routine_test.go:814"))
838837
//
839838
line = lines[2]
840839
assert.True(t, strings.HasPrefix(line, " at github.com/timandy/routine.inheritedWaitResultTask[...].run()"))
@@ -925,44 +924,54 @@ func TestGoWaitResult_Cross(t *testing.T) {
925924
assert.Equal(t, "", result)
926925
}
927926

928-
func captureStdout() (newStdout, oldStdout *os.File) {
929-
oldStdout = os.Stdout
930-
fileName := path.Join(os.TempDir(), "go_test_"+strconv.FormatInt(time.Now().UnixNano(), 10)+".txt")
931-
file, err := os.Create(fileName)
927+
//===
928+
929+
type FileTracker struct {
930+
target **os.File
931+
oldValue *os.File
932+
tempValue *os.File
933+
}
934+
935+
func NewFileTracker(target **os.File) *FileTracker {
936+
return &FileTracker{target: target, oldValue: *target}
937+
}
938+
939+
func (f *FileTracker) Begin() {
940+
file, err := os.CreateTemp("", "go_test_*.txt")
932941
if err != nil {
933942
panic(err)
934943
}
935-
os.Stdout = file
936-
newStdout = file
937-
return
944+
*f.target = file
945+
f.tempValue = file
938946
}
939947

940-
func restoreStdout(newStdout, oldStdout *os.File) {
941-
os.Stdout = oldStdout
942-
if err := newStdout.Close(); err != nil {
948+
func (f *FileTracker) End() {
949+
*f.target = f.oldValue
950+
if err := f.tempValue.Close(); err != nil {
943951
panic(err)
944952
}
945-
if err := os.Remove(newStdout.Name()); err != nil {
953+
if err := os.Remove(f.tempValue.Name()); err != nil {
946954
panic(err)
947955
}
948956
}
949957

950-
func readAll(rs io.ReadSeeker) string {
951-
if _, err := rs.Seek(0, io.SeekStart); err != nil {
958+
func (f *FileTracker) Value() string {
959+
if _, err := f.tempValue.Seek(0, io.SeekStart); err != nil {
952960
panic(err)
953961
}
954-
b := make([]byte, 0, 512)
955-
for {
956-
if len(b) == cap(b) {
957-
b = append(b, 0)[:len(b)]
958-
}
959-
n, err := rs.Read(b[len(b):cap(b)])
960-
b = b[:len(b)+n]
961-
if err != nil {
962-
if err == io.EOF {
963-
return string(b)
964-
}
965-
panic(err)
966-
}
962+
buff, err := io.ReadAll(f.tempValue)
963+
if err != nil {
964+
panic(err)
967965
}
966+
return string(buff)
967+
}
968+
969+
func TestFileTracker(t *testing.T) {
970+
origin := os.Stdout
971+
tracker := NewFileTracker(&os.Stdout)
972+
tracker.Begin()
973+
fmt.Println("hello world")
974+
assert.Equal(t, "hello world\n", tracker.Value())
975+
tracker.End()
976+
assert.Same(t, origin, os.Stdout)
968977
}

0 commit comments

Comments
 (0)