Skip to content

GH-44910: [Swift] Fix IPC stream reader and writer impl #45029

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 75 additions & 6 deletions swift/Arrow/Sources/Arrow/ArrowReader.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import FlatBuffers
import Foundation

let FILEMARKER = "ARROW1"
let CONTINUATIONMARKER = -1
let CONTINUATIONMARKER = UInt32(0xFFFFFFFF)

public class ArrowReader { // swiftlint:disable:this type_body_length
private class RecordBatchData {
Expand Down Expand Up @@ -216,7 +216,76 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
return .success(RecordBatch(arrowSchema, columns: columns))
}

public func fromStream( // swiftlint:disable:this function_body_length
/*
The Memory stream format is for reading the arrow streaming protocol. This
format is slightly different from the File format protocol as it doesn't contain
a header and footer
*/
public func fromMemoryStream( // swiftlint:disable:this function_body_length
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment that explains the difference between fromMemoryStream and fromFileStream?

_ fileData: Data,
useUnalignedBuffers: Bool = false
) -> Result<ArrowReaderResult, ArrowError> {
let result = ArrowReaderResult()
var offset: Int = 0
var length = getUInt32(fileData, offset: offset)
var streamData = fileData
var schemaMessage: org_apache_arrow_flatbuf_Schema?
while length != 0 {
if length == CONTINUATIONMARKER {
offset += Int(MemoryLayout<Int32>.size)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the size of length data?
How about using UInt32 not Int32 because length data is UInt32 not Int32?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked at the length var and it is already UInt32. From a couple of lines above: var length = getUInt32(fileData, offset: offset). Please let me know if this matches what you are seeing.

length = getUInt32(fileData, offset: offset)
if length == 0 {
return .success(result)
}
}

offset += Int(MemoryLayout<Int32>.size)
streamData = fileData[offset...]
let dataBuffer = ByteBuffer(
data: streamData,
allowReadingUnalignedBuffers: true)
let message = org_apache_arrow_flatbuf_Message.getRootAsMessage(bb: dataBuffer)
switch message.headerType {
case .recordbatch:
do {
let rbMessage = message.header(type: org_apache_arrow_flatbuf_RecordBatch.self)!
offset += Int(message.bodyLength + Int64(length))
let recordBatch = try loadRecordBatch(
rbMessage,
schema: schemaMessage!,
arrowSchema: result.schema!,
data: fileData,
messageEndOffset: (message.bodyLength + Int64(length))).get()
result.batches.append(recordBatch)
length = getUInt32(fileData, offset: offset)
} catch let error as ArrowError {
return .failure(error)
} catch {
return .failure(.unknownError("Unexpected error: \(error)"))
}
case .schema:
schemaMessage = message.header(type: org_apache_arrow_flatbuf_Schema.self)!
let schemaResult = loadSchema(schemaMessage!)
switch schemaResult {
case .success(let schema):
result.schema = schema
case .failure(let error):
return .failure(error)
}
offset += Int(message.bodyLength + Int64(length))
length = getUInt32(fileData, offset: offset)
default:
return .failure(.unknownError("Unhandled header type: \(message.headerType)"))
}
}
return .success(result)
}

/*
The File stream format supports random accessing the data. This format contains
a header and footer around the streaming format.
*/
public func fromFileStream( // swiftlint:disable:this function_body_length
_ fileData: Data,
useUnalignedBuffers: Bool = false
) -> Result<ArrowReaderResult, ArrowError> {
Expand All @@ -242,7 +311,7 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
for index in 0 ..< footer.recordBatchesCount {
let recordBatch = footer.recordBatches(at: index)!
var messageLength = fileData.withUnsafeBytes { rawBuffer in
rawBuffer.loadUnaligned(fromByteOffset: Int(recordBatch.offset), as: Int32.self)
rawBuffer.loadUnaligned(fromByteOffset: Int(recordBatch.offset), as: UInt32.self)
}

var messageOffset: Int64 = 1
Expand All @@ -251,7 +320,7 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
messageLength = fileData.withUnsafeBytes { rawBuffer in
rawBuffer.loadUnaligned(
fromByteOffset: Int(recordBatch.offset + Int64(MemoryLayout<Int32>.size)),
as: Int32.self)
as: UInt32.self)
}
}

Expand Down Expand Up @@ -296,7 +365,7 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
let markerLength = FILEMARKER.utf8.count
let footerLengthEnd = Int(fileData.count - markerLength)
let data = fileData[..<(footerLengthEnd)]
return fromStream(data)
return fromFileStream(data)
} catch {
return .failure(.unknownError("Error loading file: \(error)"))
}
Expand Down Expand Up @@ -340,10 +409,10 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
} catch {
return .failure(.unknownError("Unexpected error: \(error)"))
}

default:
return .failure(.unknownError("Unhandled header type: \(message.headerType)"))
}
}

}
// swiftlint:disable:this file_length
7 changes: 7 additions & 0 deletions swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift
Original file line number Diff line number Diff line change
Expand Up @@ -289,3 +289,10 @@ func validateFileData(_ data: Data) -> Bool {
let endString = String(decoding: data[(data.count - markerLength)...], as: UTF8.self)
return startString == FILEMARKER && endString == FILEMARKER
}

func getUInt32(_ data: Data, offset: Int) -> UInt32 {
let token = data.withUnsafeBytes { rawBuffer in
rawBuffer.loadUnaligned(fromByteOffset: offset, as: UInt32.self)
}
return token
}
41 changes: 37 additions & 4 deletions swift/Arrow/Sources/Arrow/ArrowWriter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
let startIndex = writer.count
switch writeRecordBatch(batch: batch) {
case .success(let rbResult):
withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) {writer.append(Data($0))}
withUnsafeBytes(of: rbResult.1.o.littleEndian) {writer.append(Data($0))}
writer.append(rbResult.0)
switch writeRecordBatchData(&writer, batch: batch) {
Expand Down Expand Up @@ -232,7 +233,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
return .success(fbb.data)
}

private func writeStream(_ writer: inout DataWriter, info: ArrowWriter.Info) -> Result<Bool, ArrowError> {
private func writeFileStream(_ writer: inout DataWriter, info: ArrowWriter.Info) -> Result<Bool, ArrowError> {
var fbb: FlatBufferBuilder = FlatBufferBuilder()
switch writeSchema(&fbb, schema: info.schema) {
case .success(let schemaOffset):
Expand Down Expand Up @@ -264,9 +265,41 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
return .success(true)
}

public func toStream(_ info: ArrowWriter.Info) -> Result<Data, ArrowError> {
public func toMemoryStream(_ info: ArrowWriter.Info) -> Result<Data, ArrowError> {
let writer: any DataWriter = InMemDataWriter()
switch toMessage(info.schema) {
case .success(let schemaData):
withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) {writer.append(Data($0))}
withUnsafeBytes(of: UInt32(schemaData.count).littleEndian) {writer.append(Data($0))}
writer.append(schemaData)
case .failure(let error):
return .failure(error)
}

for batch in info.batches {
switch toMessage(batch) {
case .success(let batchData):
withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) {writer.append(Data($0))}
withUnsafeBytes(of: UInt32(batchData[0].count).littleEndian) {writer.append(Data($0))}
writer.append(batchData[0])
writer.append(batchData[1])
case .failure(let error):
return .failure(error)
}
}

withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) {writer.append(Data($0))}
withUnsafeBytes(of: UInt32(0).littleEndian) {writer.append(Data($0))}
if let memWriter = writer as? InMemDataWriter {
return .success(memWriter.data)
} else {
return .failure(.invalid("Unable to cast writer"))
}
}

public func toFileStream(_ info: ArrowWriter.Info) -> Result<Data, ArrowError> {
var writer: any DataWriter = InMemDataWriter()
switch writeStream(&writer, info: info) {
switch writeFileStream(&writer, info: info) {
case .success:
if let memWriter = writer as? InMemDataWriter {
return .success(memWriter.data)
Expand All @@ -293,7 +326,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length

var writer: any DataWriter = FileDataWriter(fileHandle)
writer.append(FILEMARKER.data(using: .utf8)!)
switch writeStream(&writer, info: info) {
switch writeFileStream(&writer, info: info) {
case .success:
writer.append(FILEMARKER.data(using: .utf8)!)
case .failure(let error):
Expand Down
20 changes: 10 additions & 10 deletions swift/Arrow/Tests/ArrowTests/IPCTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,10 @@ final class IPCFileReaderTests: XCTestCase {
let arrowWriter = ArrowWriter()
// write data from file to a stream
let writerInfo = ArrowWriter.Info(.recordbatch, schema: fileRBs[0].schema, batches: fileRBs)
switch arrowWriter.toStream(writerInfo) {
switch arrowWriter.toFileStream(writerInfo) {
case .success(let writeData):
// read stream back into recordbatches
try checkBoolRecordBatch(arrowReader.fromStream(writeData))
try checkBoolRecordBatch(arrowReader.fromFileStream(writeData))
case .failure(let error):
throw error
}
Expand All @@ -190,10 +190,10 @@ final class IPCFileReaderTests: XCTestCase {
let recordBatch = try makeRecordBatch()
let arrowWriter = ArrowWriter()
let writerInfo = ArrowWriter.Info(.recordbatch, schema: schema, batches: [recordBatch])
switch arrowWriter.toStream(writerInfo) {
switch arrowWriter.toFileStream(writerInfo) {
case .success(let writeData):
let arrowReader = ArrowReader()
switch arrowReader.fromStream(writeData) {
switch arrowReader.fromFileStream(writeData) {
case .success(let result):
let recordBatches = result.batches
XCTAssertEqual(recordBatches.count, 1)
Expand Down Expand Up @@ -242,10 +242,10 @@ final class IPCFileReaderTests: XCTestCase {
let schema = makeSchema()
let arrowWriter = ArrowWriter()
let writerInfo = ArrowWriter.Info(.schema, schema: schema)
switch arrowWriter.toStream(writerInfo) {
switch arrowWriter.toFileStream(writerInfo) {
case .success(let writeData):
let arrowReader = ArrowReader()
switch arrowReader.fromStream(writeData) {
switch arrowReader.fromFileStream(writeData) {
case .success(let result):
XCTAssertNotNil(result.schema)
let schema = result.schema!
Expand Down Expand Up @@ -325,10 +325,10 @@ final class IPCFileReaderTests: XCTestCase {
let dataset = try makeBinaryDataset()
let writerInfo = ArrowWriter.Info(.recordbatch, schema: dataset.0, batches: [dataset.1])
let arrowWriter = ArrowWriter()
switch arrowWriter.toStream(writerInfo) {
switch arrowWriter.toFileStream(writerInfo) {
case .success(let writeData):
let arrowReader = ArrowReader()
switch arrowReader.fromStream(writeData) {
switch arrowReader.fromFileStream(writeData) {
case .success(let result):
XCTAssertNotNil(result.schema)
let schema = result.schema!
Expand All @@ -354,10 +354,10 @@ final class IPCFileReaderTests: XCTestCase {
let dataset = try makeTimeDataset()
let writerInfo = ArrowWriter.Info(.recordbatch, schema: dataset.0, batches: [dataset.1])
let arrowWriter = ArrowWriter()
switch arrowWriter.toStream(writerInfo) {
switch arrowWriter.toFileStream(writerInfo) {
case .success(let writeData):
let arrowReader = ArrowReader()
switch arrowReader.fromStream(writeData) {
switch arrowReader.fromFileStream(writeData) {
case .success(let result):
XCTAssertNotNil(result.schema)
let schema = result.schema!
Expand Down
Loading