1717 ***********************************************************************/
1818import {
1919 containerEngine ,
20- provider ,
2120 type Webview ,
2221 type TelemetryLogger ,
23- type ImageInfo ,
2422 type ContainerInfo ,
2523 type ContainerInspectInfo ,
26- type ProviderContainerConnection ,
2724} from '@podman-desktop/api' ;
2825import type { ContainerRegistry } from '../../registries/ContainerRegistry' ;
2926import type { PodmanConnection } from '../podmanConnection' ;
3027import { beforeEach , expect , describe , test , vi } from 'vitest' ;
3128import { InferenceManager } from './inferenceManager' ;
3229import type { ModelsManager } from '../modelsManager' ;
33- import { LABEL_INFERENCE_SERVER , INFERENCE_SERVER_IMAGE } from '../../utils/inferenceUtils' ;
30+ import { LABEL_INFERENCE_SERVER } from '../../utils/inferenceUtils' ;
3431import type { InferenceServerConfig } from '@shared/src/models/InferenceServerConfig' ;
3532import type { TaskRegistry } from '../../registries/TaskRegistry' ;
3633import { Messages } from '@shared/Messages' ;
34+ import type { InferenceProviderRegistry } from '../../registries/InferenceProviderRegistry' ;
35+ import type { InferenceProvider } from '../../workers/provider/InferenceProvider' ;
3736
3837vi . mock ( '@podman-desktop/api' , async ( ) => {
3938 return {
4039 containerEngine : {
4140 startContainer : vi . fn ( ) ,
4241 stopContainer : vi . fn ( ) ,
43- listContainers : vi . fn ( ) ,
4442 inspectContainer : vi . fn ( ) ,
45- pullImage : vi . fn ( ) ,
46- listImages : vi . fn ( ) ,
47- createContainer : vi . fn ( ) ,
4843 deleteContainer : vi . fn ( ) ,
44+ listContainers : vi . fn ( ) ,
4945 } ,
5046 Disposable : {
5147 from : vi . fn ( ) ,
5248 create : vi . fn ( ) ,
5349 } ,
54- provider : {
55- getContainerConnections : vi . fn ( ) ,
56- } ,
5750 } ;
5851} ) ;
5952
@@ -87,6 +80,11 @@ const taskRegistryMock = {
8780 getTasksByLabels : vi . fn ( ) ,
8881} as unknown as TaskRegistry ;
8982
83+ const inferenceProviderRegistryMock = {
84+ getAll : vi . fn ( ) ,
85+ get : vi . fn ( ) ,
86+ } as unknown as InferenceProviderRegistry ;
87+
9088const getInitializedInferenceManager = async ( ) : Promise < InferenceManager > => {
9189 const manager = new InferenceManager (
9290 webviewMock ,
@@ -95,6 +93,7 @@ const getInitializedInferenceManager = async (): Promise<InferenceManager> => {
9593 modelsManager ,
9694 telemetryMock ,
9795 taskRegistryMock ,
96+ inferenceProviderRegistryMock ,
9897 ) ;
9998 manager . init ( ) ;
10099 await vi . waitUntil ( manager . isInitialize . bind ( manager ) , {
@@ -119,26 +118,6 @@ beforeEach(() => {
119118 Health : undefined ,
120119 } ,
121120 } as unknown as ContainerInspectInfo ) ;
122- vi . mocked ( provider . getContainerConnections ) . mockReturnValue ( [
123- {
124- providerId : 'test@providerId' ,
125- connection : {
126- type : 'podman' ,
127- name : 'test@connection' ,
128- status : ( ) => 'started' ,
129- } ,
130- } as unknown as ProviderContainerConnection ,
131- ] ) ;
132- vi . mocked ( containerEngine . listImages ) . mockResolvedValue ( [
133- {
134- Id : 'dummyImageId' ,
135- engineId : 'dummyEngineId' ,
136- RepoTags : [ INFERENCE_SERVER_IMAGE ] ,
137- } ,
138- ] as unknown as ImageInfo [ ] ) ;
139- vi . mocked ( containerEngine . createContainer ) . mockResolvedValue ( {
140- id : 'dummyCreatedContainerId' ,
141- } ) ;
142121 vi . mocked ( taskRegistryMock . getTasksByLabels ) . mockReturnValue ( [ ] ) ;
143122 vi . mocked ( modelsManager . getLocalModelPath ) . mockReturnValue ( '/local/model.guff' ) ;
144123 vi . mocked ( modelsManager . uploadModelToPodmanMachine ) . mockResolvedValue ( '/mnt/path/model.guff' ) ;
@@ -233,119 +212,59 @@ describe('init Inference Manager', () => {
233212 * Testing the creation logic
234213 */
235214describe ( 'Create Inference Server' , ( ) => {
236- test ( 'unknown providerId' , async ( ) => {
237- const inferenceManager = await getInitializedInferenceManager ( ) ;
238- await expect (
239- inferenceManager . createInferenceServer (
240- {
241- providerId : 'unknown' ,
242- } as unknown as InferenceServerConfig ,
243- 'dummyTrackingId' ,
244- ) ,
245- ) . rejects . toThrowError ( 'cannot find any started container provider.' ) ;
215+ test ( 'no provider available should throw an error' , async ( ) => {
216+ vi . mocked ( inferenceProviderRegistryMock . getAll ) . mockReturnValue ( [ ] ) ;
246217
247- expect ( provider . getContainerConnections ) . toHaveBeenCalled ( ) ;
248- } ) ;
249-
250- test ( 'unknown imageId' , async ( ) => {
251218 const inferenceManager = await getInitializedInferenceManager ( ) ;
252219 await expect (
253- inferenceManager . createInferenceServer (
254- {
255- providerId : 'test@providerId' ,
256- image : 'unknown' ,
257- } as unknown as InferenceServerConfig ,
258- 'dummyTrackingId' ,
259- ) ,
260- ) . rejects . toThrowError ( 'image unknown not found.' ) ;
261-
262- expect ( containerEngine . listImages ) . toHaveBeenCalled ( ) ;
220+ inferenceManager . createInferenceServer ( {
221+ inferenceProvider : undefined ,
222+ labels : { } ,
223+ modelsInfo : [ ] ,
224+ port : 8888 ,
225+ } ) ,
226+ ) . rejects . toThrowError ( 'no enabled provider could be found.' ) ;
263227 } ) ;
264228
265- test ( 'empty modelsInfo' , async ( ) => {
229+ test ( 'inference provider provided should use get from InferenceProviderRegistry' , async ( ) => {
230+ vi . mocked ( inferenceProviderRegistryMock . get ) . mockReturnValue ( {
231+ enabled : ( ) => false ,
232+ } as unknown as InferenceProvider ) ;
233+
266234 const inferenceManager = await getInitializedInferenceManager ( ) ;
267235 await expect (
268- inferenceManager . createInferenceServer (
269- {
270- providerId : 'test@providerId' ,
271- image : INFERENCE_SERVER_IMAGE ,
272- modelsInfo : [ ] ,
273- } as unknown as InferenceServerConfig ,
274- 'dummyTrackingId' ,
275- ) ,
276- ) . rejects . toThrowError ( 'Need at least one model info to start an inference server.' ) ;
236+ inferenceManager . createInferenceServer ( {
237+ inferenceProvider : 'dummy-inference-provider' ,
238+ labels : { } ,
239+ modelsInfo : [ ] ,
240+ port : 8888 ,
241+ } ) ,
242+ ) . rejects . toThrowError ( 'provider requested is not enabled.' ) ;
243+ expect ( inferenceProviderRegistryMock . get ) . toHaveBeenCalledWith ( 'dummy-inference-provider' ) ;
277244 } ) ;
278245
279- test ( 'valid InferenceServerConfig' , async ( ) => {
246+ test ( 'selected inference provider should receive config' , async ( ) => {
247+ const provider : InferenceProvider = {
248+ enabled : ( ) => true ,
249+ name : 'dummy-inference-provider' ,
250+ dispose : ( ) => { } ,
251+ perform : vi . fn ( ) . mockResolvedValue ( { id : 'dummy-container-id' , engineId : 'dummy-engine-id' } ) ,
252+ } as unknown as InferenceProvider ;
253+ vi . mocked ( inferenceProviderRegistryMock . get ) . mockReturnValue ( provider ) ;
254+
280255 const inferenceManager = await getInitializedInferenceManager ( ) ;
281- await inferenceManager . createInferenceServer (
282- {
283- port : 8888 ,
284- providerId : 'test@providerId' ,
285- image : INFERENCE_SERVER_IMAGE ,
286- modelsInfo : [
287- {
288- id : 'dummyModelId' ,
289- file : {
290- file : 'model.guff' ,
291- path : '/mnt/path' ,
292- } ,
293- } ,
294- ] ,
295- } as unknown as InferenceServerConfig ,
296- 'dummyTrackingId' ,
297- ) ;
298256
299- expect ( modelsManager . uploadModelToPodmanMachine ) . toHaveBeenCalledWith (
300- {
301- id : 'dummyModelId' ,
302- file : {
303- file : 'model.guff' ,
304- path : '/mnt/path' ,
305- } ,
306- } ,
307- {
308- trackingId : 'dummyTrackingId' ,
309- } ,
310- ) ;
311- expect ( taskRegistryMock . createTask ) . toHaveBeenNthCalledWith (
312- 1 ,
313- expect . stringContaining (
314- 'Pulling ghcr.io/containers/podman-desktop-extension-ai-lab-playground-images/ai-lab-playground-chat:' ,
315- ) ,
316- 'loading' ,
317- {
318- trackingId : 'dummyTrackingId' ,
319- } ,
320- ) ;
321- expect ( taskRegistryMock . createTask ) . toHaveBeenNthCalledWith ( 2 , 'Creating container.' , 'loading' , {
322- trackingId : 'dummyTrackingId' ,
323- } ) ;
324- expect ( taskRegistryMock . updateTask ) . toHaveBeenLastCalledWith ( {
325- state : 'success' ,
326- } ) ;
327- expect ( containerEngine . createContainer ) . toHaveBeenCalled ( ) ;
328- expect ( inferenceManager . getServers ( ) ) . toStrictEqual ( [
329- {
330- connection : {
331- port : 8888 ,
332- } ,
333- container : {
334- containerId : 'dummyCreatedContainerId' ,
335- engineId : 'dummyEngineId' ,
336- } ,
337- models : [
338- {
339- file : {
340- file : 'model.guff' ,
341- path : '/mnt/path' ,
342- } ,
343- id : 'dummyModelId' ,
344- } ,
345- ] ,
346- status : 'running' ,
347- } ,
348- ] ) ;
257+ const config : InferenceServerConfig = {
258+ inferenceProvider : 'dummy-inference-provider' ,
259+ labels : { } ,
260+ modelsInfo : [ ] ,
261+ port : 8888 ,
262+ } ;
263+ const result = await inferenceManager . createInferenceServer ( config ) ;
264+
265+ expect ( provider . perform ) . toHaveBeenCalledWith ( config ) ;
266+
267+ expect ( result ) . toBe ( 'dummy-container-id' ) ;
349268 } ) ;
350269} ) ;
351270
@@ -511,33 +430,6 @@ describe('Request Create Inference Server', () => {
511430 trackingId : identifier ,
512431 } ) ;
513432 } ) ;
514-
515- test ( 'Pull image error should be reflected in task registry' , async ( ) => {
516- vi . mocked ( containerEngine . pullImage ) . mockRejectedValue ( new Error ( 'dummy pull image error' ) ) ;
517-
518- const inferenceManager = await getInitializedInferenceManager ( ) ;
519- inferenceManager . requestCreateInferenceServer ( {
520- port : 8888 ,
521- providerId : 'test@providerId' ,
522- image : 'quay.io/bootsy/playground:v0' ,
523- modelsInfo : [
524- {
525- id : 'dummyModelId' ,
526- file : {
527- file : 'dummyFile' ,
528- path : 'dummyPath' ,
529- } ,
530- } ,
531- ] ,
532- } as unknown as InferenceServerConfig ) ;
533-
534- await vi . waitFor ( ( ) => {
535- expect ( taskRegistryMock . updateTask ) . toHaveBeenLastCalledWith ( {
536- state : 'error' ,
537- error : 'Something went wrong while trying to create an inference server Error: dummy pull image error.' ,
538- } ) ;
539- } ) ;
540- } ) ;
541433} ) ;
542434
543435describe ( 'containerRegistry events' , ( ) => {
0 commit comments