Skip to content

use tag 28/29 for cyclic references #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 58 additions & 3 deletions src/CBOR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,34 @@ end
Base.:(==)(a::Tag, b::Tag) = a.id == b.id && a.data == b.data
Tag(id::Integer, data) = Tag(Int(id), data)

struct Decoder{IOType}
io::IOType
reference_cache::Vector{Any}
end
Decoder(io::IOType) where IOType <: IO = Decoder{IOType}(io, [])

Base.read(io::Decoder, T::Type) = read(io.io, T)
Base.read(io::Decoder, n::Integer) = read(io.io, n)
Base.skip(io::Decoder, amount) = read(io.io, amount)

struct Encoder{IOType}
io::IOType
encode_references::Bool
references::IdDict{Any, Int}
end
Base.write(io::Encoder, data) = write(io.io, data)

function Encoder(io::IOType, encode_references = false) where IOType <: IO
Encoder{IOType}(io, encode_references, IdDict{Any, Int}())
end

"""
A CBOR reference for CBOR Reference types (Tag 28/29)
"""
struct Reference
index::Int
end

include("constants.jl")
include("encode.jl")
include("decode.jl")
Expand All @@ -44,14 +72,41 @@ export encode
export decode, decode_with_iana
export Simple, Null, Undefined

function decode(data::Array{UInt8, 1})
replace_references!(refs, x) = x
replace_references!(refs, x::Reference) = refs[x.index]
replace_references!(refs, x::Vector) = map!(x-> replace_references!(refs, x), x, x)

function replace_references!(refs, dict::Dict)
for (k, v) in dict
if k isa Reference
delete!(dict, k)
k = replace_references!(refs, k)
end
dict[k] = replace_references!(refs, v)
end
dict
end


function decode(data::Vector{UInt8})
return decode(IOBuffer(data))
end

function encode(data)
function decode(io::IO)
dio = Decoder(io)
data = decode(dio)
return replace_references!(dio.reference_cache, data)
end


function encode(data; with_references = false)
io = IOBuffer()
encode(io, data)
encode(io, data; with_references = with_references)
return take!(io)
end

function encode(io::IO, data; with_references = false)
encode(Encoder(io, with_references), data)
end

end
6 changes: 6 additions & 0 deletions src/constants.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,9 @@ const CBOR_UNDEF_BYTE = UInt8(TYPE_7 | 23)


const CUSTOM_LANGUAGE_TYPE = 27

# 28 Mark shared value
const MARK_SHARED_VALUE = 28

# 29 Reference shared value
const REFERENCE_SHARED_VALUE = 29
48 changes: 31 additions & 17 deletions src/decode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,16 @@ function type_from_fields(::Type{T}, fields) where T
ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), T, fields, length(fields))
end

function peekbyte(io::IO)
mark(io)
function peekbyte(io::Decoder)
mark(io.io)
byte = read(io, UInt8)
reset(io)
reset(io.io)
return byte
end

struct UndefIter{IO, F}
struct UndefIter{Decoder, F}
f::F
io::IO
io::Decoder
end
Base.IteratorSize(::Type{<: UndefIter}) = Base.SizeUnknown()

Expand All @@ -42,7 +42,7 @@ function Base.iterate(x::UndefIter, state = nothing)
return x.f(x.io), nothing
end

function decode_ntimes(f, io::IO)
function decode_ntimes(f, io::Decoder)
first_byte = peekbyte(io)
if (first_byte & ADDNTL_INFO_MASK) == ADDNTL_INFO_INDEF
skip(io, 1) # skip first byte
Expand All @@ -52,7 +52,7 @@ function decode_ntimes(f, io::IO)
end
end

function decode_unsigned(io::IO)
function decode_unsigned(io::Decoder)
addntl_info = read(io, UInt8) & ADDNTL_INFO_MASK
if addntl_info < SINGLE_BYTE_UINT_PLUS_ONE
return addntl_info
Expand All @@ -71,9 +71,9 @@ end



decode(io::IO, ::Val{TYPE_0}) = decode_unsigned(io)
decode(io::Decoder, ::Val{TYPE_0}) = decode_unsigned(io)

function decode(io::IO, ::Val{TYPE_1})
function decode(io::Decoder, ::Val{TYPE_1})
data = signed(decode_unsigned(io))
if (i = Int128(data) + one(data)) > typemax(Int64)
return -i
Expand All @@ -85,7 +85,7 @@ end
"""
Decode Byte Array
"""
function decode(io::IO, ::Val{TYPE_2})
function decode(io::Decoder, ::Val{TYPE_2})
if (peekbyte(io) & ADDNTL_INFO_MASK) == ADDNTL_INFO_INDEF
skip(io, 1)
result = IOBuffer()
Expand All @@ -101,7 +101,7 @@ end
"""
Decode String
"""
function decode(io::IO, ::Val{TYPE_3})
function decode(io::Decoder, ::Val{TYPE_3})
if (peekbyte(io) & ADDNTL_INFO_MASK) == ADDNTL_INFO_INDEF
skip(io, 1)
result = IOBuffer()
Expand All @@ -117,23 +117,24 @@ end
"""
Decode Vector of arbitrary elements
"""
function decode(io::IO, ::Val{TYPE_4})
function decode(io::Decoder, ::Val{TYPE_4})
return map(identity, decode_ntimes(decode, io))
end

"""
Decode Dict
"""
function decode(io::IO, ::Val{TYPE_5})
function decode(io::Decoder, ::Val{TYPE_5})
return Dict(decode_ntimes(io) do io
decode(io) => decode(io)
end)
end


"""
Decode Tagged type
"""
function decode(io::IO, ::Val{TYPE_6})
function decode(io::Decoder, ::Val{TYPE_6})
tag = decode_unsigned(io)
data = decode(io)
if tag in (POS_BIG_INT_TAG, NEG_BIG_INT_TAG)
Expand All @@ -145,7 +146,20 @@ function decode(io::IO, ::Val{TYPE_6})
end
return big_int
end

if tag == MARK_SHARED_VALUE
push!(io.reference_cache, data)
return data
end
if tag == REFERENCE_SHARED_VALUE
if checkbounds(Bool, io.reference_cache, data + 1)
# if the index is in bounds, we already know the referenced object
# and can immediately return it
return io.reference_cache[data + 1]
else
# The reference points to an object that isn't constructed yet
return Reference(data + 1)
end
end
if tag == CUSTOM_LANGUAGE_TYPE # Type Tag
name = data[1]
object_serialized = data[2]
Expand All @@ -157,7 +171,7 @@ function decode(io::IO, ::Val{TYPE_6})
return Tag(tag, data)
end

function decode(io::IO, ::Val{TYPE_7})
function decode(io::Decoder, ::Val{TYPE_7})
first_byte = read(io, UInt8)
addntl_info = first_byte & ADDNTL_INFO_MASK
if addntl_info < SINGLE_BYTE_SIMPLE_PLUS_ONE + 1
Expand Down Expand Up @@ -190,7 +204,7 @@ function decode(io::IO, ::Val{TYPE_7})
end
end

function decode(io::IO)
function decode(io::Decoder)
# leave startbyte in io
first_byte = peekbyte(io)
typ = first_byte & TYPE_BITS_MASK
Expand Down
Loading