Skip to content

Commit 805b037

Browse files
committed
Add Xet support for faster downloads
1 parent ef91dd1 commit 805b037

File tree

4 files changed

+283
-8
lines changed

4 files changed

+283
-8
lines changed

Package.swift

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,15 @@ let package = Package(
2020
)
2121
],
2222
dependencies: [
23-
.package(url: "https://github.com/mattt/EventSource.git", from: "1.0.0")
23+
.package(url: "https://github.com/mattt/EventSource.git", from: "1.0.0"),
24+
.package(url: "https://github.com/mattt/swift-xet.git", branch: "main"),
2425
],
2526
targets: [
2627
.target(
2728
name: "HuggingFace",
2829
dependencies: [
29-
.product(name: "EventSource", package: "EventSource")
30+
.product(name: "EventSource", package: "EventSource"),
31+
.product(name: "Xet", package: "swift-xet"),
3032
],
3133
path: "Sources/HuggingFace"
3234
),

Sources/HuggingFace/Hub/HubClient+Files.swift

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ import UniformTypeIdentifiers
66
import FoundationNetworking
77
#endif
88

9+
#if canImport(Xet)
10+
import Xet
11+
#endif
12+
913
// MARK: - Upload Operations
1014

1115
public extension HubClient {
@@ -183,6 +187,33 @@ public extension HubClient {
183187
useRaw: Bool = false,
184188
cachePolicy: URLRequest.CachePolicy = .useProtocolCachePolicy
185189
) async throws -> Data {
190+
#if canImport(Xet)
191+
if isXetEnabled {
192+
do {
193+
let tempDirectory = FileManager.default.temporaryDirectory
194+
.appendingPathComponent(UUID().uuidString, isDirectory: true)
195+
let tempFile = tempDirectory.appendingPathComponent(UUID().uuidString)
196+
try FileManager.default.createDirectory(at: tempDirectory, withIntermediateDirectories: true)
197+
defer { try? FileManager.default.removeItem(at: tempDirectory) }
198+
199+
if try await downloadFileWithXet(
200+
repoPath: repoPath,
201+
repo: repo,
202+
revision: revision,
203+
destination: tempFile,
204+
progress: nil
205+
) != nil {
206+
return try Data(contentsOf: tempFile)
207+
} else {
208+
print("⚠️ Xet returned nil for \(repoPath), falling back to LFS")
209+
}
210+
} catch {
211+
print("⚠️ Xet failed for \(repoPath): \(error), falling back to LFS")
212+
}
213+
}
214+
#endif
215+
216+
// Fallback to existing LFS download method
186217
let endpoint = useRaw ? "raw" : "resolve"
187218
let urlPath = "/\(repo)/\(endpoint)/\(revision)/\(repoPath)"
188219
var request = try httpClient.createRequest(.get, urlPath)
@@ -215,6 +246,27 @@ public extension HubClient {
215246
cachePolicy: URLRequest.CachePolicy = .useProtocolCachePolicy,
216247
progress: Progress? = nil
217248
) async throws -> URL {
249+
#if canImport(Xet)
250+
if isXetEnabled {
251+
do {
252+
if let downloaded = try await downloadFileWithXet(
253+
repoPath: repoPath,
254+
repo: repo,
255+
revision: revision,
256+
destination: destination,
257+
progress: progress
258+
) {
259+
return downloaded
260+
} else {
261+
print("⚠️ Xet returned nil for \(repoPath), falling back to LFS")
262+
}
263+
} catch {
264+
print("⚠️ Xet failed for \(repoPath): \(error), falling back to LFS")
265+
}
266+
}
267+
#endif
268+
269+
// Fallback to existing LFS download method
218270
let endpoint = useRaw ? "raw" : "resolve"
219271
let urlPath = "/\(repo)/\(endpoint)/\(revision)/\(repoPath)"
220272
var request = try httpClient.createRequest(.get, urlPath)
@@ -542,6 +594,81 @@ public extension HubClient {
542594
}
543595
}
544596

597+
#if canImport(Xet)
598+
private extension HubClient {
599+
/// Downloads a file using Xet's content-addressable storage system.
600+
///
601+
/// This method uses a cached XetClient instance and JWT tokens to maximize
602+
/// download performance through connection reuse and reduced API overhead.
603+
///
604+
/// Performance optimizations:
605+
/// - Reuses a single XetClient across all downloads for HTTP/TLS connection pooling
606+
/// - Caches CAS JWT tokens per (repo, revision) to avoid redundant API calls
607+
/// - Leverages Xet's parallel chunk downloading (configurable via XET_MAX_PARALLEL_FILES env var)
608+
@discardableResult
609+
func downloadFileWithXet(
610+
repoPath: String,
611+
repo: Repo.ID,
612+
revision: String,
613+
destination: URL,
614+
progress: Progress?
615+
) async throws -> URL? {
616+
guard isXetEnabled else {
617+
return nil
618+
}
619+
620+
let xetClient = try getXetClient()
621+
622+
guard
623+
let fileInfo = try xetClient.getFileInfo(
624+
repo: repo.rawValue,
625+
path: repoPath,
626+
revision: revision
627+
)
628+
else {
629+
return nil
630+
}
631+
632+
let jwt = try getCachedJwt(
633+
xetClient: xetClient,
634+
repo: repo.rawValue,
635+
revision: revision,
636+
isUpload: false
637+
)
638+
639+
let destinationDirectory = destination.deletingLastPathComponent()
640+
try FileManager.default.createDirectory(
641+
at: destinationDirectory,
642+
withIntermediateDirectories: true
643+
)
644+
645+
let downloads = try xetClient.downloadFiles(
646+
fileInfos: [fileInfo],
647+
destinationDir: destinationDirectory.path,
648+
jwtInfo: jwt
649+
)
650+
651+
guard let downloadedPath = downloads.first else {
652+
return nil
653+
}
654+
655+
let downloadedURL = URL(fileURLWithPath: downloadedPath)
656+
657+
if downloadedURL.standardizedFileURL != destination.standardizedFileURL {
658+
if FileManager.default.fileExists(atPath: destination.path) {
659+
try FileManager.default.removeItem(at: destination)
660+
}
661+
try FileManager.default.moveItem(at: downloadedURL, to: destination)
662+
}
663+
664+
progress?.totalUnitCount = 100
665+
progress?.completedUnitCount = 100
666+
667+
return destination
668+
}
669+
}
670+
#endif
671+
545672
// MARK: - Metadata Helpers
546673

547674
extension HubClient {

Sources/HuggingFace/Hub/HubClient.swift

Lines changed: 139 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ import Foundation
44
import FoundationNetworking
55
#endif
66

7+
#if canImport(Xet)
8+
import Xet
9+
#endif
10+
711
/// A Hugging Face Hub API client.
812
///
913
/// This client provides methods to interact with the Hugging Face Hub API,
@@ -32,8 +36,19 @@ public final class HubClient: Sendable {
3236
/// environment variable (defaults to https://huggingface.co).
3337
public static let `default` = HubClient()
3438

39+
/// Indicates whether Xet acceleration is enabled for this client.
40+
public let isXetEnabled: Bool
41+
3542
/// The underlying HTTP client.
3643
internal let httpClient: HTTPClient
44+
45+
#if canImport(Xet)
46+
/// Xet client instance for connection reuse (created once during initialization)
47+
private let xetClient: XetClient?
48+
49+
/// Thread-safe JWT cache for CAS access tokens
50+
private let jwtCache: JwtCache
51+
#endif
3752

3853
/// The host URL for requests made by the client.
3954
public var host: URL {
@@ -67,13 +82,15 @@ public final class HubClient: Sendable {
6782
/// - userAgent: The value for the `User-Agent` header sent in requests, if any. Defaults to `nil`.
6883
public convenience init(
6984
session: URLSession = URLSession(configuration: .default),
70-
userAgent: String? = nil
85+
userAgent: String? = nil,
86+
enableXet: Bool = HubClient.isXetSupported
7187
) {
7288
self.init(
7389
session: session,
7490
host: Self.detectHost(),
7591
userAgent: userAgent,
76-
tokenProvider: .environment
92+
tokenProvider: .environment,
93+
enableXet: enableXet
7794
)
7895
}
7996

@@ -88,13 +105,15 @@ public final class HubClient: Sendable {
88105
session: URLSession = URLSession(configuration: .default),
89106
host: URL,
90107
userAgent: String? = nil,
91-
bearerToken: String? = nil
108+
bearerToken: String? = nil,
109+
enableXet: Bool = HubClient.isXetSupported
92110
) {
93111
self.init(
94112
session: session,
95113
host: host,
96114
userAgent: userAgent,
97-
tokenProvider: bearerToken.map { .fixed(token: $0) } ?? .none
115+
tokenProvider: bearerToken.map { .fixed(token: $0) } ?? .none,
116+
enableXet: enableXet
98117
)
99118
}
100119

@@ -109,14 +128,28 @@ public final class HubClient: Sendable {
109128
session: URLSession = URLSession(configuration: .default),
110129
host: URL,
111130
userAgent: String? = nil,
112-
tokenProvider: TokenProvider
131+
tokenProvider: TokenProvider,
132+
enableXet: Bool = HubClient.isXetSupported
113133
) {
134+
self.isXetEnabled = enableXet && HubClient.isXetSupported
114135
self.httpClient = HTTPClient(
115136
host: host,
116137
userAgent: userAgent,
117138
tokenProvider: tokenProvider,
118139
session: session
119140
)
141+
142+
#if canImport(Xet)
143+
self.jwtCache = JwtCache()
144+
145+
if self.isXetEnabled {
146+
// Create XetClient once during initialization
147+
let token = try? tokenProvider.getToken()
148+
self.xetClient = try? (token.map { try XetClient.withToken(token: $0) } ?? XetClient())
149+
} else {
150+
self.xetClient = nil
151+
}
152+
#endif
120153
}
121154

122155
// MARK: - Auto-detection
@@ -134,4 +167,104 @@ public final class HubClient: Sendable {
134167
}
135168
return defaultHost
136169
}
137-
}
170+
171+
public static var isXetSupported: Bool {
172+
#if canImport(Xet)
173+
return true
174+
#else
175+
return false
176+
#endif
177+
}
178+
179+
// MARK: - Xet Client
180+
181+
#if canImport(Xet)
182+
/// Thread-safe cache for CAS JWT tokens
183+
private final class JwtCache: @unchecked Sendable {
184+
private struct CacheKey: Hashable {
185+
let repo: String
186+
let revision: String
187+
}
188+
189+
private struct CachedJwt {
190+
let jwt: CasJwtInfo
191+
let expiresAt: Date
192+
193+
var isExpired: Bool {
194+
Date() >= expiresAt
195+
}
196+
}
197+
198+
private var cache: [CacheKey: CachedJwt] = [:]
199+
private let lock = NSLock()
200+
201+
func get(repo: String, revision: String) -> CasJwtInfo? {
202+
lock.lock()
203+
defer { lock.unlock() }
204+
205+
let key = CacheKey(repo: repo, revision: revision)
206+
if let cached = cache[key], !cached.isExpired {
207+
return cached.jwt
208+
}
209+
return nil
210+
}
211+
212+
func set(jwt: CasJwtInfo, repo: String, revision: String) {
213+
lock.lock()
214+
defer { lock.unlock() }
215+
216+
let key = CacheKey(repo: repo, revision: revision)
217+
// Cache with expiration (5 minutes before actual expiry for safety)
218+
let expiresAt = Date(timeIntervalSince1970: TimeInterval(jwt.exp())) - 300
219+
cache[key] = CachedJwt(jwt: jwt, expiresAt: expiresAt)
220+
}
221+
}
222+
223+
/// Returns the Xet client for faster downloads.
224+
///
225+
/// The client is created once during initialization and reused across downloads
226+
/// to enable connection pooling and avoid reinitialization overhead.
227+
///
228+
/// - Returns: A Xet client instance.
229+
internal func getXetClient() throws -> XetClient {
230+
guard isXetEnabled, let client = xetClient else {
231+
throw HTTPClientError.requestError("Xet support is disabled for this client.")
232+
}
233+
return client
234+
}
235+
236+
/// Gets or fetches a CAS JWT for the given repository and revision.
237+
///
238+
/// JWTs are cached to avoid redundant API calls.
239+
///
240+
/// - Parameters:
241+
/// - xetClient: The Xet client to use for fetching the JWT
242+
/// - repo: Repository identifier
243+
/// - revision: Git revision
244+
/// - isUpload: Whether this JWT is for upload (true) or download (false)
245+
/// - Returns: A CAS JWT info object
246+
internal func getCachedJwt(
247+
xetClient: XetClient,
248+
repo: String,
249+
revision: String,
250+
isUpload: Bool
251+
) throws -> CasJwtInfo {
252+
// Check cache first
253+
if let cached = jwtCache.get(repo: repo, revision: revision) {
254+
return cached
255+
}
256+
257+
// Fetch a new JWT
258+
let jwt = try xetClient.getCasJwt(
259+
repo: repo,
260+
revision: revision,
261+
isUpload: isUpload
262+
)
263+
264+
// Cache it
265+
jwtCache.set(jwt: jwt, repo: repo, revision: revision)
266+
267+
return jwt
268+
}
269+
#endif
270+
}

Tests/HuggingFaceTests/HubTests/HubClientTests.swift

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,17 @@ struct HubClientTests {
3838

3939
#expect(client.host.path.hasSuffix("/"))
4040
}
41+
42+
@Test("Xet configuration can be toggled per client")
43+
func testXetConfigurationToggle() throws {
44+
let host = URL(string: "https://huggingface.co")!
45+
46+
try #require(HubClient.isXetSupported, "Xet is not supported on this platform")
47+
48+
let disabledClient = HubClient(host: host, enableXet: false)
49+
#expect(disabledClient.isXetEnabled == false)
50+
51+
let enabledClient = HubClient(host: host, enableXet: true)
52+
#expect(enabledClient.isXetEnabled)
53+
}
4154
}

0 commit comments

Comments
 (0)