diff --git a/spec/payload_spec.cr b/spec/payload_spec.cr index c25d48c..2eda7f4 100644 --- a/spec/payload_spec.cr +++ b/spec/payload_spec.cr @@ -74,4 +74,29 @@ describe MQTT::Protocol::Payload do (one == two).should be_false end end + + describe "IOPayload" do + it "#to_slice should peek if possible" do + io = IO::Memory.new("foo".to_slice) + io.rewind + + obj = MQTT::Protocol::IOPayload.new(io, 3) + data = obj.to_slice + + obj.@data.should be_nil + end + + it "#to_io should not affect position" do + io = IO::Memory.new("foo".to_slice) + io.rewind + + obj = MQTT::Protocol::IOPayload.new(io, 3) + + dst = IO::Memory.new + obj.to_io(dst) + + obj.@data.should be_nil + obj.@io.pos.should eq(0) + end + end end diff --git a/src/mqtt/protocol/payload.cr b/src/mqtt/protocol/payload.cr index ef7ad45..e9722a1 100644 --- a/src/mqtt/protocol/payload.cr +++ b/src/mqtt/protocol/payload.cr @@ -47,6 +47,8 @@ module MQTT end struct IOPayload < Payload + alias IOWithPosition = ::IO::Memory | ::IO::FileDescriptor + getter bytesize : Int32 @data : Bytes? = nil @@ -62,7 +64,7 @@ module MQTT if peeked = @io.peek.try &.[0, bytesize]? return peeked end - return @data || begin + @data ||= begin data = Bytes.new(bytesize) @io.read(data) data @@ -70,11 +72,20 @@ module MQTT end def to_io(io, format : ::IO::ByteFormat = ::IO::ByteFormat::SystemEndian) + # Use data that has already been copied to memory if data = @data io.write data else - copied = ::IO.copy(@io, io, bytesize) - raise "Failed to copy payload" if copied != bytesize + # else try to copy + if @io.io.is_a?(IOWithPosition) + pos = @io.pos + copied = ::IO.copy(@io, io, bytesize) + raise "Failed to copy payload" if copied != bytesize + @io.pos = pos + else + # copy to memory and write + io.write to_slice + end end end end