Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 185 additions & 83 deletions check50/assertions/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,19 @@
import tokenize
import types, builtins
from io import StringIO
import ast
from dataclasses import dataclass
from collections import defaultdict

@dataclass(frozen=True)
class TokenSeq:
kind: int # e.g. OP, NAME, STRING
text: str # e.g. 'foo', '(', ')'

@dataclass(frozen=True)
class KeyPattern:
string: str # key's string representation
tokens: tuple[TokenSeq, ...] # normalized token sequence

def check50_assert(src, msg_or_exc=None, cond_type="unknown", left=None, right=None, context=None):
"""
Expand All @@ -20,7 +32,7 @@ def check50_assert(src, msg_or_exc=None, cond_type="unknown", left=None, right=N
Used for rewriting assertion statements in check files.

Note:
Exceptions from the check50 library are preferred, since they will be
Exceptions from the `check50` library are preferred, since they will be
handled gracefully and integrated into the check output. Native Python
exceptions are technically supported, but check50 will immediately
terminate on the user's end if the assertion fails.
Expand Down Expand Up @@ -63,121 +75,211 @@ def check50_assert(src, msg_or_exc=None, cond_type="unknown", left=None, right=N
caller_globals = caller_frame.f_globals
caller_locals = caller_frame.f_locals

# Evaluate all variables and functions within the context dict and generate
# a string of these values
context_str = None
if context or (left and right):
for expr_str in context:
try:
context[expr_str] = eval(expr_str, caller_globals, caller_locals)
except Exception as e:
context[expr_str] = f"[error evaluating: {e}]"

# filter out modules, functions, and built-ins, which is needed to avoid
# overwriting function definitions in evaluaton and avoid useless string
# output
def is_irrelevant_value(v):
return isinstance(v, (types.ModuleType, types.FunctionType, types.BuiltinFunctionType))

def is_builtin_name(name):
return name in dir(builtins)

filtered_context = {
k: v for k, v in context.items()
if not is_irrelevant_value(v) and not is_builtin_name(k.split("(")[0])
}

# produces a string like "var1 = ..., var2 = ..., foo() = ..."
context_str = ", ".join(f"{k} = {repr(v)}" for k, v in filtered_context.items())
else:
filtered_context = {}
# Build the list of candidate keys
candidate_keys = list(context.keys()) if context else []

# Since we've memoized the functions and variables once, now try and
# evaluate the conditional by substituting the function calls/vars with
# their results
eval_src, eval_context = substitute_expressions(src, filtered_context)
# Plan substitutions and learn which keys are actually used
eval_src, key_to_placeholder = substitute_expressions(src, candidate_keys)

# Merge globals with expression context for evaluation
eval_globals = caller_globals.copy()
eval_globals.update(eval_context)
# Only evaluate the keys that were actually matched
evaluated = {}
for expr_str in key_to_placeholder.keys():
try:
evaluated[expr_str] = eval(expr_str, caller_globals, caller_locals)
except Exception as e:
evaluated[expr_str] = f"[error evaluating: {e}]"

# Merge locals with expression context for evaluation
eval_locals = caller_locals.copy()
eval_locals.update(eval_context)
# Build the eval_context for placeholders
eval_context = {
placeholder: evaluated[key]
for key, placeholder in key_to_placeholder.items()
}

# Merge locals and globals with expression context for evaluation
eval_globals = caller_globals.copy(); eval_globals.update(eval_context)
eval_locals = caller_locals.copy(); eval_locals.update(eval_context)
cond = eval(eval_src, eval_globals, eval_locals)

# Finally, quit if the condition evaluated to True.
if cond:
return

# If `right` or `left` were evaluatable objects, their actual value will be stored in `context`.
# Otherwise, they're still just literals.
right = context.get(right) or right
left = context.get(left) or left
# Filter out modules, functions, and built-ins, which is needed to avoid
# overwriting function definitions in evaluaton and avoid useless string
# output
def is_irrelevant_value(v):
return isinstance(v, (
types.ModuleType,
types.FunctionType,
types.BuiltinFunctionType
))
def is_builtin_name(name):
name = name.split("(")[0] # grab `len` from `len(...)`
return name in dir(builtins)

# Since the condition didn't evaluate to True, now, we can raise special
# exceptions.
filtered_context = {
k: v for k, v in evaluated.items()
if not is_irrelevant_value(v) and not is_builtin_name(k)
}

# Produces a string like "var1 = ..., var2 = ..., foo() = ..."
context_str = ", ".join(f"{k} = {repr(v)}" for k, v in filtered_context.items()) or None

# If `right` or `left` were evaluatable objects, their actual
# value will be stored in `evaluated`.
if right in evaluated:
right = evaluated[right]
if left in evaluated:
left = evaluated[left]

# Raise check50-specific/user-passed exceptions.
if isinstance(msg_or_exc, str):
raise Failure(msg_or_exc)
elif isinstance(msg_or_exc, BaseException):
raise msg_or_exc
elif cond_type == 'eq' and left and right:
elif cond_type == 'eq' and left is not None and right is not None:
help_msg = f"checked: {src}"
help_msg += f"\n where {context_str}" if context_str else ""
raise Mismatch(right, left, help=help_msg)
elif cond_type == 'in' and left and right:
elif cond_type == 'in' and left is not None and right is not None:
help_msg = f"checked: {src}"
help_msg += f"\n where {context_str}" if context_str else ""
raise Missing(left, right, help=help_msg)
else:
help_msg = f"\n where {context_str}" if context_str else ""
raise Failure(f"check did not pass: {src}" + help_msg)

def substitute_expressions(src: str, context: dict) -> tuple[str, dict]:
def _tokenize_normalized(code: str):
"""
Rewrites `src` by replacing each key in `context` with a placeholder variable name,
and builds a new context dict where those names map to pre-evaluated values.
Tokenize and normalize:
- drop ENCODING, NL, NEWLINE, INDENT, DEDENT, ENDMARKER
- for STRING tokens, normalize to their Python value (so "'pwd'" == "\"pwd\"")
- return both normalized tokens and the original raw tokens (1:1 positions)

Outputs a normalized and raw tokenization (raw, excluding dropped) of the
code.

For instance, given a `src`:
For instance, the code input "foo.bar()" might output a `norm` of `TokenSeq`s:
```
check50.run('pwd').stdout() == actual
[
TokenSeq(NAME, "foo"), TokenSeq(OP, "."), TokenSeq(NAME, "bar"),
TokenSeq(OP, "("), TokenSeq(OP, ")")
]
```
it will create a new `eval_src` as
In this case, there were no strings to normalize, so `raw` would
output the same thing.
"""
drop = {
tokenize.ENCODING, tokenize.NL, tokenize.NEWLINE,
tokenize.INDENT, tokenize.DEDENT, tokenize.ENDMARKER
}

norm, raw = [], []
for tok in tokenize.generate_tokens(StringIO(code).readline):
# Extract type and string representation from token
tok_type, tok_string, *_ = tok

# Ignore certain encoding types
if tok_type in drop:
continue

raw.append(TokenSeq(tok_type, tok_string))

# Normalize STRING tokens to their Python value
if tok_type == tokenize.STRING:
try:
val = ast.literal_eval(tok_string)
norm.append(TokenSeq(tok_type, repr(val)))
except Exception:
norm.append(TokenSeq(tok_type, tok_string))
else:
norm.append(TokenSeq(tok_type, tok_string))

return norm, raw


def substitute_expressions(src: str, keys: list[str]) -> tuple[str, dict]:
"""
Rewrites `src` by replacing known `keys` (from `context`) with a placeholder
variable name, and builds a new context dict where those names map to
pre-evaluated values.

For instance, let `src` be the string representation of
```
__expr0 == __expr1
assert check50.run("./foo.c").stdout() == "OK"
```
and use the given context to define these variables:
The `keys` might look like
```
eval_context = {
'__expr0': context['check50.run('pwd').stdout()'],
'__expr1': context['actual']
}
["check50.run("./foo.c")", "check50.run("./foo.c").stdout()"]
```
We would want to find the longest match from these keys and output:
```
expr_str: assert __expr0 == "OK"
key_to_placeholder: { "check50.run("./foo.c").stdout()": "__expr0" }
```
"""
# Parse the src into a stream of tokens
tokens = tokenize.generate_tokens(StringIO(src).readline)

new_tokens = []
new_context = {}
placeholder_map = {} # used for duplicates in src (i.e. x == x => __expr0 == __expr0)
counter = 0

for tok_type, tok_string, start, end, line in tokens:
if tok_string in context:
if tok_string not in placeholder_map:
placeholder = f"__expr{counter}"
placeholder_map[tok_string] = placeholder
new_context[placeholder] = context[tok_string]
counter += 1
else:
# Avoid creating a new __expr{i} variable if it has already been seen
placeholder = placeholder_map[tok_string]
new_tokens.append((tok_type, placeholder))
# Tokenize/normalize the source once
src_norm, src_raw = _tokenize_normalized(src)

# Store a list of KeyPatterns
patterns = []
for key in keys:
key_norm, _ = _tokenize_normalized(key)
if key_norm:
patterns.append(KeyPattern(key, tuple(key_norm)))

# Stores a TokenSeq and every KeyPattern that starts with that TokenSeq
patterns_by_start_token = defaultdict(list)
for pattern in patterns:
start_token = pattern.tokens[0]
patterns_by_start_token[start_token].append(pattern)

# Prefer longest matches first (e.g. foo.bar.baz() is preferred over foo.bar)
for candidates in patterns_by_start_token.values():
candidates.sort(key=lambda p: len(p.tokens), reverse=True)

key_to_placeholder = {}
def get_placeholder(key_str):
"""Return a placeholder `__expr{i}` for a given key."""
if key_str not in key_to_placeholder:
key_to_placeholder[key_str] = f"__expr{len(key_to_placeholder)}"
return key_to_placeholder[key_str]

def longest_match_at(i):
"""Return the longest KeyPattern that matches `src_norm` starting at `i`"""
if i >= len(src_norm):
return None

candidates = patterns_by_start_token.get(src_norm[i], [])

# Iterate through the possible candidates for the longest match
for pattern in candidates:
L = len(pattern.tokens)

# Skip if i + L would run past the end and then check for match
if i + L <= len(src_norm) and tuple(src_norm[i:i+L]) == pattern.tokens:
return pattern

# No match
return None

output = []
i = 0
while i < len(src_norm):
# Find a longest pattern, if exists
pattern = longest_match_at(i)
if pattern is not None:
# Create a placeholder var for this specific match
placeholder = get_placeholder(pattern.string)
output.append((tokenize.NAME, placeholder))
# Move forward by the number of tokens in this pattern
i += len(pattern.tokens)
else:
# Anything not found in the context dictionary is placed here,
# including keywords, whitespace, operators, etc.
new_tokens.append((tok_type, tok_string))
# Preserve original lex for unmatched regions
token = src_raw[i]
output.append((token.kind, token.text))
# Move forward by 1 token
i += 1

eval_src = tokenize.untokenize(new_tokens)
return eval_src, new_context
eval_src = tokenize.untokenize(output)
return eval_src, key_to_placeholder