Skip to content

Commit ac808fb

Browse files
committed
Merge remote-tracking branch 'origin/main' into feature/ppca
2 parents 190a499 + e7a8c24 commit ac808fb

File tree

5 files changed

+40
-25
lines changed

5 files changed

+40
-25
lines changed

Package.resolved

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,6 @@
11
{
22
"object": {
33
"pins": [
4-
{
5-
"package": "Benchmark",
6-
"repositoryURL": "https://github.com/google/swift-benchmark.git",
7-
"state": {
8-
"branch": null,
9-
"revision": "8e0ef8bb7482ab97dcd2cd1d6855bd38921c345d",
10-
"version": "0.1.0"
11-
}
12-
},
134
{
145
"package": "CSV.swift",
156
"repositoryURL": "https://github.com/yaslab/CSV.swift.git",
@@ -46,6 +37,33 @@
4637
"version": "0.3.1"
4738
}
4839
},
40+
{
41+
"package": "Benchmark",
42+
"repositoryURL": "https://github.com/google/swift-benchmark.git",
43+
"state": {
44+
"branch": null,
45+
"revision": "8e0ef8bb7482ab97dcd2cd1d6855bd38921c345d",
46+
"version": "0.1.0"
47+
}
48+
},
49+
{
50+
"package": "swift-models",
51+
"repositoryURL": "https://github.com/tensorflow/swift-models.git",
52+
"state": {
53+
"branch": null,
54+
"revision": "b2fc0325bf9d476bf2d7a4cd0a09d36486c506e4",
55+
"version": null
56+
}
57+
},
58+
{
59+
"package": "SwiftProtobuf",
60+
"repositoryURL": "https://github.com/apple/swift-protobuf.git",
61+
"state": {
62+
"branch": null,
63+
"revision": "da9a52be9cd36c63993291ce3f1b65dafcd1e826",
64+
"version": "1.14.0"
65+
}
66+
},
4967
{
5068
"package": "swift-tools-support-core",
5169
"repositoryURL": "https://github.com/apple/swift-tools-support-core.git",

Sources/SwiftFusion/Core/TensorVector.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,11 +128,11 @@ extension TensorVector: Vector {
128128

129129
/// Returns the result of calling `body` on the scalars of `self`.
130130
public mutating func withUnsafeMutableBufferPointer<R>(
131-
_ body: (UnsafeMutableBufferPointer<Double>) throws -> R
131+
_ body: (inout UnsafeMutableBufferPointer<Double>) throws -> R
132132
) rethrows -> R {
133133
var scalars = self.tensor.scalars
134134
let r = try scalars.withUnsafeMutableBufferPointer { b in
135-
try body(b)
135+
try body(&b)
136136
}
137137
self.tensor = Tensor(shape: self.shape, scalars: scalars)
138138
return r

Sources/SwiftFusion/Core/Vector.swift

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ public protocol Vector: Differentiable where Self.TangentVector == Self {
7878
/// A default is provided that is correct for types that are represented as contiguous scalars
7979
/// in memory.
8080
mutating func withUnsafeMutableBufferPointer<R>(
81-
_ body: (UnsafeMutableBufferPointer<Double>) throws -> R
81+
_ body: (inout UnsafeMutableBufferPointer<Double>) throws -> R
8282
) rethrows -> R
8383
#endif
8484
}
@@ -233,14 +233,13 @@ extension Vector {
233233

234234
/// Returns the result of calling `body` on the scalars of `self`.
235235
public mutating func withUnsafeMutableBufferPointer<R>(
236-
_ body: (UnsafeMutableBufferPointer<Double>) throws -> R
236+
_ body: (inout UnsafeMutableBufferPointer<Double>) throws -> R
237237
) rethrows -> R {
238-
return try withUnsafeMutablePointer(to: &self) { [dimension = self.dimension] p in
239-
try body(
240-
UnsafeMutableBufferPointer<Double>(
241-
start: UnsafeMutableRawPointer(p)
242-
.assumingMemoryBound(to: Double.self),
243-
count: dimension))
238+
try withUnsafeMutablePointer(to: &self) { [dimension = self.dimension] p in
239+
var b = UnsafeMutableBufferPointer<Double>(
240+
start: UnsafeMutableRawPointer(p).assumingMemoryBound(to: Double.self),
241+
count: dimension)
242+
return try body(&b)
244243
}
245244
}
246245
}

Sources/SwiftFusion/Inference/AnyArrayBuffer+Vector.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,10 +180,10 @@ extension AnyVectorArrayBuffer: Vector {
180180

181181
/// Returns the result of calling `body` on the scalars of `self`.
182182
public mutating func withUnsafeMutableBufferPointer<R>(
183-
_ body: (UnsafeMutableBufferPointer<Double>) throws -> R
183+
_ body: (inout UnsafeMutableBufferPointer<Double>) throws -> R
184184
) rethrows -> R {
185185
var buffer = Array(scalars)
186-
let r = try buffer.withUnsafeMutableBufferPointer { try body($0) }
186+
let r = try buffer.withUnsafeMutableBufferPointer { try body(&$0) }
187187
scalars.assign(buffer)
188188
return r
189189
}

Tests/SwiftFusionTests/Core/VectorTests.swift

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,8 @@ extension Vector {
100100
}
101101

102102
mutableSelf = self
103-
mutableSelf.withUnsafeMutableBufferPointer { b in
104-
for (i, j) in zip(b.indices, distinctScalars.indices) {
105-
b[i] = distinctScalars[j]
106-
}
103+
_ = mutableSelf.withUnsafeMutableBufferPointer { b in
104+
b.assign(distinctScalars)
107105
}
108106
XCTAssertTrue(mutableSelf.scalars.elementsEqual(distinctScalars))
109107
}

0 commit comments

Comments
 (0)