Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 42 additions & 3 deletions lib/typeprof/core/ast/sig_type.rb
Original file line number Diff line number Diff line change
Expand Up @@ -487,16 +487,55 @@ def show
end

class SigTyRecordNode < SigTyNode
def initialize(raw_decl, lenv)
super(raw_decl, lenv)
@fields = raw_decl.fields.transform_values { |val| AST.create_rbs_type(val, lenv) }
end

attr_reader :fields
def subnodes = { fields: }

def covariant_vertex0(genv, changes, vtx, subst)
raise NotImplementedError
field_vertices = {}
@fields.each do |key, field_node|
field_vertices[key] = field_node.covariant_vertex(genv, changes, subst)
end

# Create base Hash type for Record
key_vtx = Source.new(genv.symbol_type)
# Create union of all field values for the Hash value type
val_vtx = changes.new_covariant_vertex(genv, [self, :union])
field_vertices.each_value do |field_vtx|
changes.add_edge(genv, field_vtx, val_vtx)
end
base_hash_type = genv.gen_hash_type(key_vtx, val_vtx)

changes.add_edge(genv, Source.new(Type::Record.new(genv, field_vertices, base_hash_type)), vtx)
end

def contravariant_vertex0(genv, changes, vtx, subst)
raise NotImplementedError
field_vertices = {}
@fields.each do |key, field_node|
field_vertices[key] = field_node.contravariant_vertex(genv, changes, subst)
end

# Create base Hash type for Record
key_vtx = Source.new(genv.symbol_type)
# Create union of all field values for the Hash value type
val_vtx = changes.new_contravariant_vertex(genv, [self, :union])
field_vertices.each_value do |field_vtx|
changes.add_edge(genv, field_vtx, val_vtx)
end
base_hash_type = genv.gen_hash_type(key_vtx, val_vtx)

changes.add_edge(genv, Source.new(Type::Record.new(genv, field_vertices, base_hash_type)), vtx)
end

def show
"(...record...)"
field_strs = @fields.map do |key, field_node|
"#{ key }: #{ field_node.show }"
end
"{ #{ field_strs.join(", ") } }"
end
end

Expand Down
10 changes: 8 additions & 2 deletions lib/typeprof/core/builtin.rb
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,16 @@ def array_push(changes, node, ty, a_args, ret)
def hash_aref(changes, node, ty, a_args, ret)
if a_args.positionals.size == 1
case ty
when Type::Hash
when Type::Hash, Type::Record
idx = node.positional_args[0]
idx = idx.is_a?(AST::SymbolNode) ? idx.lit : nil
changes.add_edge(@genv, ty.get_value(idx), ret)
value = ty.get_value(idx)
if value
changes.add_edge(@genv, value, ret)
else
# Return untyped for unknown fields
changes.add_edge(@genv, Source.new(), ret)
end
true
else
false
Expand Down
90 changes: 89 additions & 1 deletion lib/typeprof/core/type.rb
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,28 @@ def base_type(genv)
end

def check_match(genv, changes, vtx)
# TODO: implement
vtx.each_type do |other_ty|
case other_ty
when Record
# Hash can match with Record if Hash has Symbol keys
# and Hash value type can accept all Record field values
key_ty = get_key
val_ty = get_value

# Check if this Hash has Symbol keys
return false unless key_ty && Source.new(genv.symbol_type).check_match(genv, changes, key_ty)

# Check if Hash value type contains all required types for Record fields
other_ty.fields.each do |_key, field_val_vtx|
# For each Record field, check if field type can match with Hash value type
return false unless val_ty && field_val_vtx.check_match(genv, changes, val_ty)
end

return true
end
end

# Fall back to base_type check for other cases
@base_type.check_match(genv, changes, vtx)
end

Expand Down Expand Up @@ -348,5 +369,72 @@ def show
"var[#{ @name }]"
end
end

class Record < Type
#: (GlobalEnv, ::Hash[Symbol, Vertex], Instance) -> void
def initialize(genv, fields, base_type)
@fields = fields
@base_type = base_type
raise unless base_type.is_a?(Instance)
end

attr_reader :fields

def get_value(key = nil)
if key
# Return specific field value if it exists
@fields[key]
elsif @fields.empty?
# Empty record has no values
nil
else
# Return union of all field values if no specific key
@base_type.args[1]
end
end

def base_type(genv)
@base_type
end

def check_match(genv, changes, vtx)
vtx.each_type do |other_ty|
case other_ty
when Record
# Check if all fields match
return false unless @fields.size == other_ty.fields.size
@fields.each do |key, val_vtx|
other_val_vtx = other_ty.fields[key]
return false unless other_val_vtx
return false unless val_vtx.check_match(genv, changes, other_val_vtx)
end
return true
when Hash
# Record can match with Hash only if the Hash has Symbol keys
# and all record values can match with the Hash value type
key_vtx = other_ty.get_key
val_vtx = other_ty.get_value

# Check if Hash key type is Symbol
return false unless key_vtx && Source.new(genv.symbol_type).check_match(genv, changes, key_vtx)

# Check if all record field values can match with Hash value type
@fields.each do |_key, field_val_vtx|
return false unless field_val_vtx.check_match(genv, changes, val_vtx)
end

return true
end
end
return false
end

def show
field_strs = @fields.map do |key, val_vtx|
"#{ key }: #{ Type.strip_parens(val_vtx.show) }"
end
"{ #{ field_strs.join(", ") } }"
end
end
end
end
17 changes: 17 additions & 0 deletions scenario/rbs/record-arrays.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
## update: test.rbs
class RecordArrays
def get_users: -> Array[{ id: Integer, name: String }]
end

## update: test.rb
class RecordArrays
def first_user
users = get_users
users.first
end
end

## assert: test.rb
class RecordArrays
def first_user: -> { id: Integer, name: String }
end
16 changes: 16 additions & 0 deletions scenario/rbs/record-basic.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
## update: test.rbs
class BasicRecord
def simple_record: -> { name: String, age: Integer }
end

## update: test.rb
class BasicRecord
def get_record
simple_record
end
end

## assert: test.rb
class BasicRecord
def get_record: -> { name: String, age: Integer }
end
16 changes: 16 additions & 0 deletions scenario/rbs/record-empty.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
## update: test.rbs
class EmptyRecord
def empty_record: -> { }
end

## update: test.rb
class EmptyRecord
def get_empty
empty_record
end
end

## assert: test.rb
class EmptyRecord
def get_empty: -> { }
end
17 changes: 17 additions & 0 deletions scenario/rbs/record-field-access.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
## update: test.rbs
class RecordFieldAccess
def get_person: -> { name: String, age: Integer }
end

## update: test.rb
class RecordFieldAccess
def get_name
person = get_person
person[:name]
end
end

## assert: test.rb
class RecordFieldAccess
def get_name: -> String
end
17 changes: 17 additions & 0 deletions scenario/rbs/record-field-error.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
## update: test.rbs
class RecordFieldError
def get_person: -> { name: String, age: Integer }
end

## update: test.rb
class RecordFieldError
def get_unknown_field
person = get_person
person[:unknown] # Access non-existent field
end
end

## assert: test.rb
class RecordFieldError
def get_unknown_field: -> untyped
end
18 changes: 18 additions & 0 deletions scenario/rbs/record-hash-compat.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
## update: test.rbs
class RecordHashCompat
def get_symbol_hash: -> Hash[Symbol, String | Integer]
def accept_record: ({ name: String, age: Integer }) -> void
end

## update: test.rb
class RecordHashCompat
def test_hash_to_record
hash_data = get_symbol_hash
accept_record(hash_data)
end
end

## assert: test.rb
class RecordHashCompat
def test_hash_to_record: -> Object
end
17 changes: 17 additions & 0 deletions scenario/rbs/record-nested.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
## update: test.rbs
class NestedRecord
def get_company: -> { name: String, owner: { name: String, age: Integer } }
end

## update: test.rb
class NestedRecord
def get_owner_name
company = get_company
company[:owner][:name]
end
end

## assert: test.rb
class NestedRecord
def get_owner_name: -> String
end
16 changes: 16 additions & 0 deletions scenario/rbs/record-optional.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
## update: test.rbs
class OptionalRecord
def maybe_person: -> { name: String, age: Integer }?
end

## update: test.rb
class OptionalRecord
def check_optional
maybe_person
end
end

## assert: test.rb
class OptionalRecord
def check_optional: -> { name: String, age: Integer }?
end
28 changes: 28 additions & 0 deletions scenario/rbs/record-type-checking.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
## update: test.rbs
class RecordTypeChecking
def create_person: -> { name: String, age: Integer }
def accept_person: ({ name: String, age: Integer }) -> String
def process_user: ({ id: Integer, name: String, active: bool }) -> String
end

## update: test.rb
class RecordTypeChecking
# Test case: Exact type matching
# Record type created by method matches the parameter type exactly
def test_exact_match
person = create_person
accept_person(person)
end

# Test case: Untyped parameter
# Parameter without type annotation is passed to method expecting Record type
def test_untyped_param(user_data)
process_user(user_data)
end
end

## assert: test.rb
class RecordTypeChecking
def test_exact_match: -> String
def test_untyped_param: (untyped) -> String
end