6
6
* LICENSE file in the root directory of this source tree.
7
7
*/
8
8
9
+ import ExecuTorchLLM
9
10
import SwiftUI
10
11
import UniformTypeIdentifiers
11
12
12
- import LLaMARunner
13
-
14
13
class RunnerHolder : ObservableObject {
15
- var llamaRunner : LLaMARunner ?
16
- var llavaRunner : LLaVARunner ?
14
+ var textRunner : TextRunner ?
15
+ var multimodalRunner : MultimodalRunner ?
17
16
}
18
17
19
18
extension UIImage {
@@ -347,15 +346,34 @@ struct ContentView: View {
347
346
348
347
switch modelType {
349
348
case . llama, . qwen3, . phi4:
350
- runnerHolder. llamaRunner = runnerHolder. llamaRunner ?? LLaMARunner ( modelPath: modelPath, tokenizerPath: tokenizerPath)
349
+ runnerHolder. textRunner = runnerHolder. textRunner ?? TextRunner (
350
+ modelPath: modelPath,
351
+ tokenizerPath: tokenizerPath,
352
+ specialTokens: [
353
+ " <|begin_of_text|> " ,
354
+ " <|end_of_text|> " ,
355
+ " <|reserved_special_token_0|> " ,
356
+ " <|reserved_special_token_1|> " ,
357
+ " <|finetune_right_pad_id|> " ,
358
+ " <|step_id|> " ,
359
+ " <|start_header_id|> " ,
360
+ " <|end_header_id|> " ,
361
+ " <|eom_id|> " ,
362
+ " <|eot_id|> " ,
363
+ " <|python_tag|> "
364
+ ] + ( 2 ..< 256 ) . map { " <|reserved_special_token_ \( $0) |> " }
365
+ )
351
366
case . llava:
352
- runnerHolder. llavaRunner = runnerHolder. llavaRunner ?? LLaVARunner ( modelPath: modelPath, tokenizerPath: tokenizerPath)
367
+ runnerHolder. multimodalRunner = runnerHolder. multimodalRunner ?? MultimodalRunner (
368
+ modelPath: modelPath,
369
+ tokenizerPath: tokenizerPath
370
+ )
353
371
}
354
372
355
373
guard !shouldStopGenerating else { return }
356
374
switch modelType {
357
375
case . llama, . qwen3, . phi4:
358
- if let runner = runnerHolder. llamaRunner , !runner. isLoaded ( ) {
376
+ if let runner = runnerHolder. textRunner , !runner. isLoaded ( ) {
359
377
var error : Error ?
360
378
let startLoadTime = Date ( )
361
379
do {
@@ -385,7 +403,7 @@ struct ContentView: View {
385
403
}
386
404
}
387
405
case . llava:
388
- if let runner = runnerHolder. llavaRunner , !runner. isLoaded ( ) {
406
+ if let runner = runnerHolder. multimodalRunner , !runner. isLoaded ( ) {
389
407
var error : Error ?
390
408
let startLoadTime = Date ( )
391
409
do {
@@ -426,25 +444,21 @@ struct ContentView: View {
426
444
}
427
445
do {
428
446
var tokens : [ String ] = [ ]
429
- var rgbArray : [ UInt8 ] ?
430
- let MAX_WIDTH = 336.0
431
- var newHeight = 0.0
432
- var imageBuffer : UnsafeMutableRawPointer ?
433
447
434
448
if let img = selectedImage {
435
449
let llava_prompt = " \( text) ASSISTANT "
436
-
437
- newHeight = MAX_WIDTH * img. size. height / img. size. width
450
+ let MAX_WIDTH = 336.0
451
+ let newHeight = MAX_WIDTH * img. size. height / img. size. width
438
452
let resizedImage = img. resized ( to: CGSize ( width: MAX_WIDTH, height: newHeight) )
439
- rgbArray = resizedImage. toRGBArray ( )
440
- imageBuffer = UnsafeMutableRawPointer ( mutating: rgbArray)
441
-
442
- try runnerHolder. llavaRunner? . generate ( imageBuffer!, width: MAX_WIDTH, height: newHeight, prompt: llava_prompt, sequenceLength: seq_len) { token in
443
453
454
+ try runnerHolder. multimodalRunner? . generate ( [
455
+ MultimodalInput ( Image ( data: Data ( resizedImage. toRGBArray ( ) ?? [ ] ) , width: Int ( MAX_WIDTH) , height: Int ( newHeight. rounded ( ) ) , channels: 3 ) ) ,
456
+ MultimodalInput ( llava_prompt) ,
457
+ ] , sequenceLength: seq_len) { token in
444
458
if token != llava_prompt {
445
459
if token == " </s> " {
446
460
shouldStopGenerating = true
447
- runnerHolder. llavaRunner ? . stop ( )
461
+ runnerHolder. multimodalRunner ? . stop ( )
448
462
} else {
449
463
tokens. append ( token)
450
464
if tokens. count > 2 {
@@ -460,7 +474,7 @@ struct ContentView: View {
460
474
}
461
475
}
462
476
if shouldStopGenerating {
463
- runnerHolder. llavaRunner ? . stop ( )
477
+ runnerHolder. multimodalRunner ? . stop ( )
464
478
}
465
479
}
466
480
}
@@ -481,7 +495,7 @@ struct ContentView: View {
481
495
prompt = String ( format: Constants . phi4PromptTemplate, text)
482
496
}
483
497
484
- try runnerHolder. llamaRunner ? . generate ( prompt, sequenceLength: seq_len) { token in
498
+ try runnerHolder. textRunner ? . generate ( prompt, sequenceLength: seq_len) { token in
485
499
486
500
if token != prompt {
487
501
if token == " <|eot_id|> " {
@@ -534,7 +548,7 @@ struct ContentView: View {
534
548
}
535
549
}
536
550
if shouldStopGenerating {
537
- runnerHolder. llamaRunner ? . stop ( )
551
+ runnerHolder. textRunner ? . stop ( )
538
552
}
539
553
}
540
554
}
@@ -577,8 +591,8 @@ struct ContentView: View {
577
591
return
578
592
}
579
593
runnerQueue. async {
580
- runnerHolder. llamaRunner = nil
581
- runnerHolder. llavaRunner = nil
594
+ runnerHolder. textRunner = nil
595
+ runnerHolder. multimodalRunner = nil
582
596
}
583
597
switch pickerType {
584
598
case . model:
0 commit comments