Skip to content

Commit f9e029e

Browse files
committed
GH-44910: [Swift] fix ipc stream reader and writer impl
1 parent c3601a9 commit f9e029e

File tree

4 files changed

+130
-20
lines changed

4 files changed

+130
-20
lines changed

Diff for: swift/Arrow/Sources/Arrow/ArrowReader.swift

+76-6
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import FlatBuffers
1919
import Foundation
2020

2121
let FILEMARKER = "ARROW1"
22-
let CONTINUATIONMARKER = -1
22+
let CONTINUATIONMARKER = UInt32(0xFFFFFFFF)
2323

2424
public class ArrowReader { // swiftlint:disable:this type_body_length
2525
private class RecordBatchData {
@@ -216,7 +216,77 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
216216
return .success(RecordBatch(arrowSchema, columns: columns))
217217
}
218218

219-
public func fromStream( // swiftlint:disable:this function_body_length
219+
/*
220+
The Memory stream format is for reading the arrow streaming protocol. This
221+
format is slightly different from the File format protocol as it doesn't contain
222+
a header and footer
223+
*/
224+
public func fromMemoryStream( // swiftlint:disable:this function_body_length
225+
_ fileData: Data,
226+
useUnalignedBuffers: Bool = false
227+
) -> Result<ArrowReaderResult, ArrowError> {
228+
let result = ArrowReaderResult()
229+
var offset: Int = 0
230+
var length = getUInt32(fileData, offset: offset)
231+
var streamData = fileData
232+
var schemaMessage: org_apache_arrow_flatbuf_Schema?
233+
while length != 0 {
234+
if length == CONTINUATIONMARKER {
235+
offset += Int(MemoryLayout<Int32>.size)
236+
length = getUInt32(fileData, offset: offset)
237+
if length == 0 {
238+
return .success(result)
239+
}
240+
}
241+
242+
offset += Int(MemoryLayout<Int32>.size)
243+
streamData = fileData[offset...]
244+
let dataBuffer = ByteBuffer(
245+
data: streamData,
246+
allowReadingUnalignedBuffers: true)
247+
let message = org_apache_arrow_flatbuf_Message.getRootAsMessage(bb: dataBuffer)
248+
switch message.headerType {
249+
case .recordbatch:
250+
do {
251+
let rbMessage = message.header(type: org_apache_arrow_flatbuf_RecordBatch.self)!
252+
offset += Int(message.bodyLength + Int64(length))
253+
let recordBatch = try loadRecordBatch(
254+
rbMessage,
255+
schema: schemaMessage!,
256+
arrowSchema: result.schema!,
257+
data: fileData,
258+
messageEndOffset: (message.bodyLength + Int64(length))).get()
259+
result.batches.append(recordBatch)
260+
length = getUInt32(fileData, offset: offset)
261+
} catch let error as ArrowError {
262+
return .failure(error)
263+
} catch {
264+
return .failure(.unknownError("Unexpected error: \(error)"))
265+
}
266+
case .schema:
267+
schemaMessage = message.header(type: org_apache_arrow_flatbuf_Schema.self)!
268+
let schemaResult = loadSchema(schemaMessage!)
269+
switch schemaResult {
270+
case .success(let schema):
271+
result.schema = schema
272+
case .failure(let error):
273+
return .failure(error)
274+
}
275+
offset += Int(message.bodyLength + Int64(length))
276+
length = getUInt32(fileData, offset: offset)
277+
default:
278+
return .failure(.unknownError("Unhandled header type: \(message.headerType)"))
279+
}
280+
}
281+
return .success(result)
282+
}
283+
284+
/*
285+
The File stream format supports random accessing the data. This format contains
286+
a header and footer around the streaming format.
287+
*/
288+
289+
public func fromFileStream( // swiftlint:disable:this function_body_length
220290
_ fileData: Data,
221291
useUnalignedBuffers: Bool = false
222292
) -> Result<ArrowReaderResult, ArrowError> {
@@ -242,7 +312,7 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
242312
for index in 0 ..< footer.recordBatchesCount {
243313
let recordBatch = footer.recordBatches(at: index)!
244314
var messageLength = fileData.withUnsafeBytes { rawBuffer in
245-
rawBuffer.loadUnaligned(fromByteOffset: Int(recordBatch.offset), as: Int32.self)
315+
rawBuffer.loadUnaligned(fromByteOffset: Int(recordBatch.offset), as: UInt32.self)
246316
}
247317

248318
var messageOffset: Int64 = 1
@@ -251,7 +321,7 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
251321
messageLength = fileData.withUnsafeBytes { rawBuffer in
252322
rawBuffer.loadUnaligned(
253323
fromByteOffset: Int(recordBatch.offset + Int64(MemoryLayout<Int32>.size)),
254-
as: Int32.self)
324+
as: UInt32.self)
255325
}
256326
}
257327

@@ -296,7 +366,7 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
296366
let markerLength = FILEMARKER.utf8.count
297367
let footerLengthEnd = Int(fileData.count - markerLength)
298368
let data = fileData[..<(footerLengthEnd)]
299-
return fromStream(data)
369+
return fromFileStream(data)
300370
} catch {
301371
return .failure(.unknownError("Error loading file: \(error)"))
302372
}
@@ -340,10 +410,10 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
340410
} catch {
341411
return .failure(.unknownError("Unexpected error: \(error)"))
342412
}
343-
344413
default:
345414
return .failure(.unknownError("Unhandled header type: \(message.headerType)"))
346415
}
347416
}
348417

349418
}
419+
// swiftlint:disable:this file_length

Diff for: swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift

+7
Original file line numberDiff line numberDiff line change
@@ -289,3 +289,10 @@ func validateFileData(_ data: Data) -> Bool {
289289
let endString = String(decoding: data[(data.count - markerLength)...], as: UTF8.self)
290290
return startString == FILEMARKER && endString == FILEMARKER
291291
}
292+
293+
func getUInt32(_ data: Data, offset: Int) -> UInt32 {
294+
let token = data.withUnsafeBytes { rawBuffer in
295+
rawBuffer.loadUnaligned(fromByteOffset: offset, as: UInt32.self)
296+
}
297+
return token
298+
}

Diff for: swift/Arrow/Sources/Arrow/ArrowWriter.swift

+37-4
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
123123
let startIndex = writer.count
124124
switch writeRecordBatch(batch: batch) {
125125
case .success(let rbResult):
126+
withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) {writer.append(Data($0))}
126127
withUnsafeBytes(of: rbResult.1.o.littleEndian) {writer.append(Data($0))}
127128
writer.append(rbResult.0)
128129
switch writeRecordBatchData(&writer, batch: batch) {
@@ -232,7 +233,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
232233
return .success(fbb.data)
233234
}
234235

235-
private func writeStream(_ writer: inout DataWriter, info: ArrowWriter.Info) -> Result<Bool, ArrowError> {
236+
private func writeFileStream(_ writer: inout DataWriter, info: ArrowWriter.Info) -> Result<Bool, ArrowError> {
236237
var fbb: FlatBufferBuilder = FlatBufferBuilder()
237238
switch writeSchema(&fbb, schema: info.schema) {
238239
case .success(let schemaOffset):
@@ -264,9 +265,41 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
264265
return .success(true)
265266
}
266267

267-
public func toStream(_ info: ArrowWriter.Info) -> Result<Data, ArrowError> {
268+
public func toMemoryStream(_ info: ArrowWriter.Info) -> Result<Data, ArrowError> {
269+
let writer: any DataWriter = InMemDataWriter()
270+
switch toMessage(info.schema) {
271+
case .success(let schemaData):
272+
withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) {writer.append(Data($0))}
273+
withUnsafeBytes(of: UInt32(schemaData.count).littleEndian) {writer.append(Data($0))}
274+
writer.append(schemaData)
275+
case .failure(let error):
276+
return .failure(error)
277+
}
278+
279+
for batch in info.batches {
280+
switch toMessage(batch) {
281+
case .success(let batchData):
282+
withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) {writer.append(Data($0))}
283+
withUnsafeBytes(of: UInt32(batchData[0].count).littleEndian) {writer.append(Data($0))}
284+
writer.append(batchData[0])
285+
writer.append(batchData[1])
286+
case .failure(let error):
287+
return .failure(error)
288+
}
289+
}
290+
291+
withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) {writer.append(Data($0))}
292+
withUnsafeBytes(of: UInt32(0).littleEndian) {writer.append(Data($0))}
293+
if let memWriter = writer as? InMemDataWriter {
294+
return .success(memWriter.data)
295+
} else {
296+
return .failure(.invalid("Unable to cast writer"))
297+
}
298+
}
299+
300+
public func toFileStream(_ info: ArrowWriter.Info) -> Result<Data, ArrowError> {
268301
var writer: any DataWriter = InMemDataWriter()
269-
switch writeStream(&writer, info: info) {
302+
switch writeFileStream(&writer, info: info) {
270303
case .success:
271304
if let memWriter = writer as? InMemDataWriter {
272305
return .success(memWriter.data)
@@ -293,7 +326,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
293326

294327
var writer: any DataWriter = FileDataWriter(fileHandle)
295328
writer.append(FILEMARKER.data(using: .utf8)!)
296-
switch writeStream(&writer, info: info) {
329+
switch writeFileStream(&writer, info: info) {
297330
case .success:
298331
writer.append(FILEMARKER.data(using: .utf8)!)
299332
case .failure(let error):

Diff for: swift/Arrow/Tests/ArrowTests/IPCTests.swift

+10-10
Original file line numberDiff line numberDiff line change
@@ -167,10 +167,10 @@ final class IPCFileReaderTests: XCTestCase {
167167
let arrowWriter = ArrowWriter()
168168
// write data from file to a stream
169169
let writerInfo = ArrowWriter.Info(.recordbatch, schema: fileRBs[0].schema, batches: fileRBs)
170-
switch arrowWriter.toStream(writerInfo) {
170+
switch arrowWriter.toFileStream(writerInfo) {
171171
case .success(let writeData):
172172
// read stream back into recordbatches
173-
try checkBoolRecordBatch(arrowReader.fromStream(writeData))
173+
try checkBoolRecordBatch(arrowReader.fromFileStream(writeData))
174174
case .failure(let error):
175175
throw error
176176
}
@@ -190,10 +190,10 @@ final class IPCFileReaderTests: XCTestCase {
190190
let recordBatch = try makeRecordBatch()
191191
let arrowWriter = ArrowWriter()
192192
let writerInfo = ArrowWriter.Info(.recordbatch, schema: schema, batches: [recordBatch])
193-
switch arrowWriter.toStream(writerInfo) {
193+
switch arrowWriter.toFileStream(writerInfo) {
194194
case .success(let writeData):
195195
let arrowReader = ArrowReader()
196-
switch arrowReader.fromStream(writeData) {
196+
switch arrowReader.fromFileStream(writeData) {
197197
case .success(let result):
198198
let recordBatches = result.batches
199199
XCTAssertEqual(recordBatches.count, 1)
@@ -242,10 +242,10 @@ final class IPCFileReaderTests: XCTestCase {
242242
let schema = makeSchema()
243243
let arrowWriter = ArrowWriter()
244244
let writerInfo = ArrowWriter.Info(.schema, schema: schema)
245-
switch arrowWriter.toStream(writerInfo) {
245+
switch arrowWriter.toFileStream(writerInfo) {
246246
case .success(let writeData):
247247
let arrowReader = ArrowReader()
248-
switch arrowReader.fromStream(writeData) {
248+
switch arrowReader.fromFileStream(writeData) {
249249
case .success(let result):
250250
XCTAssertNotNil(result.schema)
251251
let schema = result.schema!
@@ -325,10 +325,10 @@ final class IPCFileReaderTests: XCTestCase {
325325
let dataset = try makeBinaryDataset()
326326
let writerInfo = ArrowWriter.Info(.recordbatch, schema: dataset.0, batches: [dataset.1])
327327
let arrowWriter = ArrowWriter()
328-
switch arrowWriter.toStream(writerInfo) {
328+
switch arrowWriter.toFileStream(writerInfo) {
329329
case .success(let writeData):
330330
let arrowReader = ArrowReader()
331-
switch arrowReader.fromStream(writeData) {
331+
switch arrowReader.fromFileStream(writeData) {
332332
case .success(let result):
333333
XCTAssertNotNil(result.schema)
334334
let schema = result.schema!
@@ -354,10 +354,10 @@ final class IPCFileReaderTests: XCTestCase {
354354
let dataset = try makeTimeDataset()
355355
let writerInfo = ArrowWriter.Info(.recordbatch, schema: dataset.0, batches: [dataset.1])
356356
let arrowWriter = ArrowWriter()
357-
switch arrowWriter.toStream(writerInfo) {
357+
switch arrowWriter.toFileStream(writerInfo) {
358358
case .success(let writeData):
359359
let arrowReader = ArrowReader()
360-
switch arrowReader.fromStream(writeData) {
360+
switch arrowReader.fromFileStream(writeData) {
361361
case .success(let result):
362362
XCTAssertNotNil(result.schema)
363363
let schema = result.schema!

0 commit comments

Comments
 (0)