@@ -3,12 +3,16 @@ package testrunner
33import (
44 "bytes"
55 "errors"
6+ "fmt"
67 "go/ast"
78 "go/format"
9+ "go/importer"
810 "go/parser"
911 "go/printer"
1012 "go/token"
13+ "go/types"
1114 "log"
15+ "path/filepath"
1216 "regexp"
1317 "strconv"
1418 "strings"
@@ -34,6 +38,7 @@ type rootLevelTest struct {
3438 fileName string
3539 code string
3640 taskID uint64
41+ pkgName string
3742}
3843
3944// FindAllRootLevelTests parses the test file and extracts the name,
@@ -65,6 +70,7 @@ func FindAllRootLevelTests(fileName string) []rootLevelTest {
6570 fileName : fileName ,
6671 code : buf .String (),
6772 taskID : taskID ,
73+ pkgName : file .Name .Name ,
6874 })
6975 }
7076 }
@@ -100,16 +106,19 @@ func findTaskID(doc *ast.CommentGroup) uint64 {
100106}
101107
102108// generate simplified test code corresponding to a subtest
103- func getSubCode (test string , sub string , code string , file string ) string {
109+ func getSubCode (test string , sub string , code string , file string , pkgName string ) string {
110+ pkgLine := fmt .Sprintf ("package %s\n " , pkgName )
104111 fset := token .NewFileSet ()
105112 f , err := parser .ParseFile (
106- fset , file , "package main \n " + code , parser .ParseComments ,
113+ fset , file , pkgLine + code , parser .ParseComments ,
107114 )
108115 if err != nil {
109116 log .Printf ("warning: '%s' not parsed from '%s': %s" , test , file , err )
110117 return ""
111118 }
112119
120+ typeInfo := resolveTestData (fset , f , file )
121+
113122 fAST , ok := f .Decls [0 ].(* ast.FuncDecl )
114123 if ! ok {
115124 log .Println ("warning: first subtest declaration must be a function" )
@@ -118,14 +127,14 @@ func getSubCode(test string, sub string, code string, file string) string {
118127
119128 fbAST := fAST .Body .List // f.Decls[0].Body.List
120129
121- astInfo , err := findTestDataAndRange (fbAST )
130+ astInfo , err := findTestDataAndRange (fbAST , fset , typeInfo )
122131 if err != nil {
123132 log .Printf ("warning: could not find test table and/or range: %v\n " , err )
124133 return ""
125134 }
126135
127136 // process the test data assignment
128- metadata , ok := processTestDataAssgn (sub , astInfo .testDataAst )
137+ metadata , ok := processTestDataAssgn (sub , astInfo .testDataAst , typeInfo )
129138 if ! ok {
130139 return ""
131140 }
@@ -151,36 +160,33 @@ func getSubCode(test string, sub string, code string, file string) string {
151160 log .Println ("warning: failed to format extracted AST for subtest" )
152161 return ""
153162 }
154- return strings .TrimSpace (strings .TrimPrefix (buf .String (), "package main" ))
163+ if astInfo .testDataAstIdx != - 1 { // testDataAst is already in the test function
164+ return strings .TrimSpace (strings .TrimPrefix (buf .String (), pkgLine ))
165+ }
166+ return insertTestDataASTIntoFunc (fset , astInfo .testDataAst , fAST .Body , buf .Bytes (), pkgLine )
155167}
156168
157- func findTestDataAndRange (stmtList []ast.Stmt ) (subTestAstInfo , error ) {
169+ func findTestDataAndRange (stmtList []ast.Stmt , fset * token. FileSet , info * types. Info ) (subTestAstInfo , error ) {
158170 result := subTestAstInfo {}
159-
171+ posToIndex := make ( map [token. Position ] int )
160172 for i := range stmtList {
161- assignCandidate , ok := stmtList [i ].(* ast.AssignStmt )
162- if ok && result .testDataAst == nil {
163- result .testDataAst = assignCandidate
164- result .testDataAstIdx = i
165- } else if ok {
166- identifier , isIdentifier := assignCandidate .Lhs [0 ].(* ast.Ident )
167- if ! isIdentifier {
168- continue
169- }
170- // Overwrite the assignment we already found in case there is an
171- // assignment to a "tests" variable.
172- if identifier .Name == "tests" {
173+ posToIndex [fset .Position (stmtList [i ].Pos ())] = i
174+ if rangeCandidate , ok := stmtList [i ].(* ast.RangeStmt ); ok {
175+ assignCandidate := getTestDataAssignFromRange (rangeCandidate , info )
176+ if assignCandidate != nil {
177+ // check if assignCandidate is in the same function with rangeCandidate
178+ if idx , ok := posToIndex [fset .Position (assignCandidate .Pos ())]; ok &&
179+ fset .File (assignCandidate .Pos ()).Name () == fset .File (rangeCandidate .Pos ()).Name () {
180+ result .testDataAstIdx = idx
181+ } else {
182+ result .testDataAstIdx = - 1
183+ }
173184 result .testDataAst = assignCandidate
174- result .testDataAstIdx = i
185+ result .rangeAst = rangeCandidate
186+ result .rangeAstIdx = i
187+ return result , nil
175188 }
176- }
177-
178- rangeCandidate , ok := stmtList [i ].(* ast.RangeStmt )
179- // If we found a range after we already found an assignment, we are good to go.
180- if ok && result .testDataAst != nil {
181- result .rangeAst = rangeCandidate
182- result .rangeAstIdx = i
183- return result , nil
189+ return subTestAstInfo {}, errors .New ("failed to find assignment in sub-test" )
184190 }
185191 }
186192
@@ -190,17 +196,66 @@ func findTestDataAndRange(stmtList []ast.Stmt) (subTestAstInfo, error) {
190196
191197 return subTestAstInfo {}, errors .New ("failed to find range statement in sub-test" )
192198}
199+ func getTestDataAssignFromRange (rangeAst * ast.RangeStmt , info * types.Info ) * ast.AssignStmt {
200+ // Get the identifier being ranged over
201+ ident , ok := rangeAst .X .(* ast.Ident )
202+ if ! ok {
203+ return nil
204+ }
205+
206+ // Look up the object this identifier refers to
207+ obj := info .Uses [ident ]
208+ if obj == nil {
209+ // If not in Uses, check Defs (in case it's defined in the same expression)
210+ obj = info .Defs [ident ]
211+ }
212+ if obj == nil {
213+ return nil
214+ }
215+
216+ // Find the declaration AST node by looking through Defs
217+ for id , def := range info .Defs {
218+ if def == obj {
219+ // Found the defining identifier, now get its declaration
220+ if id .Obj != nil && id .Obj .Decl != nil {
221+ spec := id .Obj .Decl
222+ if assignStmt , ok := spec .(* ast.AssignStmt ); ok {
223+ return assignStmt
224+ }
225+ if valueSpec , ok := spec .(* ast.ValueSpec ); ok {
226+ lhs := make ([]ast.Expr , len (valueSpec .Names ))
227+ for i , name := range valueSpec .Names {
228+ lhs [i ] = name
229+ }
230+ return & ast.AssignStmt {
231+ Lhs : lhs ,
232+ Tok : token .DEFINE ,
233+ Rhs : valueSpec .Values ,
234+ }
235+ }
236+ }
237+ }
238+ }
239+ return nil
240+ }
193241
194242// validate the test data assignment and return the associated metadata
195- func processTestDataAssgn (sub string , assgn * ast.AssignStmt ) (* subTData , bool ) {
243+ func processTestDataAssgn (sub string , assgn * ast.AssignStmt , info * types. Info ) (* subTData , bool ) {
196244 lhs1 , ok := assgn .Lhs [0 ].(* ast.Ident ) // f.Decls[0].Body.List[0].Lhs[0]
197245 if ! ok {
198246 log .Println ("warning: test data assignment not found" )
199247 return nil , false
200248 }
201- if ast .Var != lhs1 .Obj .Kind {
202- log .Println ("warning: test data assignment must be a var" )
203- return nil , false
249+ // Check if this is a variable using type information
250+ obj := info .Defs [lhs1 ]
251+ if obj == nil {
252+ obj = info .Uses [lhs1 ]
253+ }
254+ if obj != nil {
255+ if _ , ok := obj .(* types.Var ); ! ok {
256+ log .Println ("warning: test data assignment must be a var" )
257+ return nil , false
258+ }
204259 }
205260 metadata := subTData {origTDName : lhs1 .Name }
206261
@@ -315,3 +370,60 @@ func processRange(metadata *subTData, rastmt *ast.RangeStmt) bool {
315370 metadata .subTest = body
316371 return true
317372}
373+
374+ // resolveTestData resolves test data variable declared in cases_test.go (if exists)
375+ // and returns type information for identifier resolution
376+ func resolveTestData (fset * token.FileSet , f * ast.File , file string ) * types.Info {
377+ filedata := filepath .Join (filepath .Dir (file ), "cases_test.go" )
378+ fdata , _ := parser .ParseFile (fset , filedata , nil , parser .ParseComments )
379+
380+ // Prepare files for type checking
381+ files := []* ast.File {f }
382+ if fdata != nil {
383+ files = append (files , fdata )
384+ }
385+
386+ // Configure type checker
387+ conf := types.Config {
388+ Importer : importer .Default (),
389+ Error : func (err error ) {}, // Ignore type errors - we only need identifier resolution
390+ }
391+
392+ // Type check the package
393+ info := & types.Info {
394+ Defs : make (map [* ast.Ident ]types.Object ),
395+ Uses : make (map [* ast.Ident ]types.Object ),
396+ }
397+
398+ // Type check - ignore errors since files may have missing imports
399+ _ , _ = conf .Check ("" , fset , files , info )
400+
401+ return info
402+ }
403+
404+ // insertTestDataASTIntoFunc inserts testDataAst into the first line of fbAST function's body
405+ func insertTestDataASTIntoFunc (fset * token.FileSet , testDataAst * ast.AssignStmt , fbAST * ast.BlockStmt , fileText []byte , pkgLine string ) string {
406+ buf := bytes.Buffer {}
407+
408+ p := fset .Position (fbAST .Lbrace ).Offset + 1
409+
410+ // write the beginning of fileText to func (...) {
411+ buf .Write (fileText [:p + 1 ])
412+
413+ // write test data assign stmt
414+ if err := format .Node (& buf , fset , testDataAst ); err != nil {
415+ log .Println ("warning: failed to format extracted AST for subtest" )
416+ return ""
417+ }
418+ // write the rest of fileText
419+ buf .Write (fileText [p + 1 :])
420+
421+ // because assign stmt is extracted from different file, its indentation is different from fileText
422+ // so need to reformat
423+ src , err := format .Source ((buf .Bytes ()))
424+ if err != nil {
425+ log .Println ("warning: failed to format extracted AST for subtest" )
426+ return ""
427+ }
428+ return strings .TrimSpace (strings .TrimPrefix (string (src ), pkgLine ))
429+ }
0 commit comments