Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion test/infinicore/framework/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .base import TestConfig, TestRunner, BaseOperatorTest
from .test_case import TestCase, TestResult
from .entities import TestCase
from .benchmark import BenchmarkUtils, BenchmarkResult
from .config import (
add_common_test_args,
Expand All @@ -11,6 +11,9 @@
from .devices import InfiniDeviceEnum, InfiniDeviceNames, torch_device_map
from .runner import GenericTestRunner
from .tensor import TensorSpec, TensorInitializer
from .types import TestTiming, OperatorTestResult, TestResult
from .driver import TestDriver
from .printer import ConsolePrinter
from .utils import (
compare_results,
create_test_comparator,
Expand Down Expand Up @@ -38,6 +41,10 @@
"TestResult",
"TestRunner",
"TestReporter",
"TestTiming",
"OperatorTestResult",
"TestDriver",
"ConsolePrinter",
# Core functions
"add_common_test_args",
"compare_results",
Expand Down
3 changes: 2 additions & 1 deletion test/infinicore/framework/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import traceback
from abc import ABC, abstractmethod

from .test_case import TestCase, TestResult
from .entities import TestCase
Copy link

Copilot AI Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'TestCase' is not used.

Suggested change
from .entities import TestCase

Copilot uses AI. Check for mistakes.
from .types import TestResult
from .datatypes import to_torch_dtype, to_infinicore_dtype
from .devices import InfiniDeviceNames, torch_device_map
from .tensor import TensorSpec, TensorInitializer
Expand Down
2 changes: 1 addition & 1 deletion test/infinicore/framework/datatypes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import infinicore

from dataclasses import dataclass, field
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删?


Comment on lines +3 to 4
Copy link

Copilot AI Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'dataclass' is not used.
Import of 'field' is not used.

Suggested change
from dataclasses import dataclass, field

Copilot uses AI. Check for mistakes.
def to_torch_dtype(infini_dtype):
"""Convert infinicore data type to PyTorch data type"""
Expand Down
105 changes: 105 additions & 0 deletions test/infinicore/framework/driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import sys
import importlib.util
from io import StringIO
from contextlib import contextmanager
from .types import OperatorTestResult, TestTiming
Copy link

Copilot AI Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'TestTiming' is not used.

Suggested change
from .types import OperatorTestResult, TestTiming
from .types import OperatorTestResult

Copilot uses AI. Check for mistakes.

@contextmanager
def capture_output():
"""Context manager: captures stdout and stderr."""
new_out, new_err = StringIO(), StringIO()
old_out, old_err = sys.stdout, sys.stderr
try:
sys.stdout, sys.stderr = new_out, new_err
yield new_out, new_err
finally:
sys.stdout, sys.stderr = old_out, old_err

class TestDriver:
def drive(self, file_path) -> OperatorTestResult:
result = OperatorTestResult(name=file_path.stem)

try:
# 1. Dynamically import the module
module = self._import_module(file_path)

# 2. Look for TestRunner
if not hasattr(module, "GenericTestRunner"):
raise ImportError("No GenericTestRunner found in module")

# 3. Look for TestClass (subclass of BaseOperatorTest)
test_class = self._find_test_class(module)
if not test_class:
raise ImportError("No BaseOperatorTest subclass found")

test_instance = test_class()
runner_class = module.GenericTestRunner
runner = runner_class(test_instance.__class__)

# 4. Execute and capture output
with capture_output() as (out, err):
success, internal_runner = runner.run()

# 5. Populate results
result.success = success
result.stdout = out.getvalue()
result.stderr = err.getvalue()

# Extract detailed results from internal_runner
test_results = internal_runner.get_test_results() if internal_runner else []
self._analyze_return_code(result, test_results)
self._extract_timing(result, test_results)

except Exception as e:
result.success = False
result.error_message = str(e)
result.stderr += f"\nExecutor Error: {str(e)}"
result.return_code = -1

return result

def _import_module(self, path):
module_name = f"op_test_{path.stem}"
spec = importlib.util.spec_from_file_location(module_name, path)
if not spec or not spec.loader:
raise ImportError(f"Could not load spec from {path}")
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module

def _find_test_class(self, module):
for attr_name in dir(module):
attr = getattr(module, attr_name)
if isinstance(attr, type) and hasattr(attr, "__bases__"):
# Simple check for base class name
if any("BaseOperatorTest" in str(b) for b in attr.__bases__):
return attr
return None

def _analyze_return_code(self, result, test_results):
# Logic consistent with original code: determine if all passed, partially passed, or skipped
if result.success:
result.return_code = 0
return

has_failures = any(r.return_code == -1 for r in test_results)
has_partial = any(r.return_code == -3 for r in test_results)
has_skipped = any(r.return_code == -2 for r in test_results)

if has_failures:
result.return_code = -1
elif has_partial:
result.return_code = -3
elif has_skipped:
result.return_code = -2
else:
result.return_code = -1

def _extract_timing(self, result, test_results):
# Accumulate timing
t = result.timing
t.torch_host = sum(r.torch_host_time for r in test_results)
t.torch_device = sum(r.torch_device_time for r in test_results)
t.infini_host = sum(r.infini_host_time for r in test_results)
t.infini_device = sum(r.infini_device_time for r in test_results)
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,6 @@
from .tensor import TensorSpec


@dataclass
class TestResult:
"""Test result data structure"""

success: bool
return_code: int # 0: success, -1: failure, -2: skipped, -3: partial
torch_host_time: float = 0.0
torch_device_time: float = 0.0
infini_host_time: float = 0.0
infini_device_time: float = 0.0
error_message: str = ""
test_case: Any = None
device: Any = None


class TestCase:
"""Test case with all configuration included"""

Expand Down
73 changes: 73 additions & 0 deletions test/infinicore/framework/loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from pathlib import Path

class TestDiscoverer:
def __init__(self, ops_dir_path=None):
self.ops_dir = self._resolve_dir(ops_dir_path)

def _resolve_dir(self, path):
if path:
p = Path(path)
if p.exists(): return p

# Default fallback logic: 'ops' directory under the parent of the current file's parent.
# Note: Since this file is in 'framework/', we look at parent.parent.
# It is recommended to pass an explicit path in run.py.
fallback = Path(__file__).parent.parent / "ops"
return fallback if fallback.exists() else None

def get_available_operators(self):
"""Returns a list of names of all available operators."""
if not self.ops_dir: return []
files = self.scan()
return sorted([f.stem for f in files])

def get_raw_python_files(self):
"""
Get all .py files in the directory (excluding run.py) without content validation.
Used for debugging: helps identify files that exist but failed validation.
"""
if not self.ops_dir or not self.ops_dir.exists():
return []

files = list(self.ops_dir.glob("*.py"))
# Exclude run.py itself and __init__.py
return [f.name for f in files if f.name != "run.py" and not f.name.startswith("__")]

def scan(self, specific_ops=None):
"""Scans and returns a list of Path objects that meet the criteria."""
if not self.ops_dir or not self.ops_dir.exists():
return []

# 1. Find all .py files
files = list(self.ops_dir.glob("*.py"))

target_ops_set = set(specific_ops) if specific_ops else None

# 2. Filter out non-test files (via content check)
valid_files = []
for f in files:
# A. Basic Name Filtering
if f.name.startswith("_") or f.name == "run.py":
continue

# B. Specific Ops Filtering
if target_ops_set and f.stem not in target_ops_set:
continue

# C. Content Check (Expensive I/O)
# Only perform this check if the file passed the name filters above.
if self._is_operator_test(f):
valid_files.append(f)

return valid_files

def _is_operator_test(self, file_path):
"""Checks if the file content contains operator test characteristics."""
try:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
return "infinicore" in content and (
"BaseOperatorTest" in content or "GenericTestRunner" in content
)
except:
Copy link

Copilot AI Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Except block directly handles BaseException.

Suggested change
except:
except Exception:

Copilot uses AI. Check for mistakes.
return False
Loading