Skip to content

Commit bc20c1a

Browse files
lstocchiLuca Stocchijeffmaury
authored
use inference server when model service image is equivalent to inference server image (#1503)
* feat: use inference server when model service image is equivalent to inference server image Signed-off-by: lstocchi <lstocchi@redhat.com> * fix: use backend to decide about inference server usage Signed-off-by: Luca Stocchi <luca@MacBook-Pro-di-Luca.local> * fix: fix failing unit tests Signed-off-by: Jeff MAURY <jmaury@redhat.com> * fix: refactor from @axel7083 review Signed-off-by: Jeff MAURY <jmaury@redhat.com> --------- Signed-off-by: lstocchi <lstocchi@redhat.com> Signed-off-by: Luca Stocchi <luca@MacBook-Pro-di-Luca.local> Signed-off-by: Jeff MAURY <jmaury@redhat.com> Co-authored-by: Luca Stocchi <luca@MacBook-Pro-di-Luca.local> Co-authored-by: Jeff MAURY <jmaury@redhat.com>
1 parent 157d4b0 commit bc20c1a

File tree

8 files changed

+169
-65
lines changed

8 files changed

+169
-65
lines changed

packages/backend/src/managers/application/applicationManager.spec.ts

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ beforeEach(() => {
126126
vi.resetAllMocks();
127127

128128
vi.mocked(webviewMock.postMessage).mockResolvedValue(true);
129-
vi.mocked(recipeManager.buildRecipe).mockResolvedValue([recipeImageInfoMock]);
129+
vi.mocked(recipeManager.buildRecipe).mockResolvedValue({ images: [recipeImageInfoMock] });
130130
vi.mocked(podManager.createPod).mockResolvedValue({ engineId: 'test-engine-id', Id: 'test-pod-id' });
131131
vi.mocked(podManager.getPod).mockResolvedValue({ engineId: 'test-engine-id', Id: 'test-pod-id' } as PodInfo);
132132
vi.mocked(podManager.getPodsWithLabels).mockResolvedValue([]);
@@ -312,7 +312,7 @@ describe('pullApplication', () => {
312312
'model-id': remoteModelMock.id,
313313
});
314314
// build the recipe
315-
expect(recipeManager.buildRecipe).toHaveBeenCalledWith(connectionMock, recipeMock, {
315+
expect(recipeManager.buildRecipe).toHaveBeenCalledWith(connectionMock, recipeMock, remoteModelMock, {
316316
'test-label': 'test-value',
317317
'recipe-id': recipeMock.id,
318318
'model-id': remoteModelMock.id,
@@ -374,18 +374,20 @@ describe('pullApplication', () => {
374374
test('qemu connection should have specific flag', async () => {
375375
vi.mocked(podManager.findPodByLabelsValues).mockResolvedValue(undefined);
376376

377-
vi.mocked(recipeManager.buildRecipe).mockResolvedValue([
378-
recipeImageInfoMock,
379-
{
380-
modelService: true,
381-
ports: ['8888'],
382-
name: 'llamacpp',
383-
id: 'llamacpp',
384-
appName: 'llamacpp',
385-
engineId: recipeImageInfoMock.engineId,
386-
recipeId: recipeMock.id,
387-
},
388-
]);
377+
vi.mocked(recipeManager.buildRecipe).mockResolvedValue({
378+
images: [
379+
recipeImageInfoMock,
380+
{
381+
modelService: true,
382+
ports: ['8888'],
383+
name: 'llamacpp',
384+
id: 'llamacpp',
385+
appName: 'llamacpp',
386+
engineId: recipeImageInfoMock.engineId,
387+
recipeId: recipeMock.id,
388+
},
389+
],
390+
});
389391

390392
await getInitializedApplicationManager().pullApplication(connectionMock, recipeMock, remoteModelMock);
391393

packages/backend/src/managers/application/applicationManager.ts

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
* SPDX-License-Identifier: Apache-2.0
1717
***********************************************************************/
1818

19-
import type { Recipe, RecipeImage } from '@shared/src/models/IRecipe';
19+
import type { Recipe, RecipeComponents, RecipeImage } from '@shared/src/models/IRecipe';
2020
import * as path from 'node:path';
2121
import { containerEngine, Disposable, window, ProgressLocation } from '@podman-desktop/api';
2222
import type {
@@ -187,7 +187,7 @@ export class ApplicationManager extends Publisher<ApplicationState[]> implements
187187
});
188188

189189
// build all images, one per container (for a basic sample we should have 2 containers = sample app + model service)
190-
const images = await this.recipeManager.buildRecipe(connection, recipe, {
190+
const recipeComponents = await this.recipeManager.buildRecipe(connection, recipe, model, {
191191
...labels,
192192
'recipe-id': recipe.id,
193193
'model-id': model.id,
@@ -199,7 +199,7 @@ export class ApplicationManager extends Publisher<ApplicationState[]> implements
199199
}
200200

201201
// create a pod containing all the containers to run the application
202-
return this.createApplicationPod(connection, recipe, model, images, modelPath, {
202+
return this.createApplicationPod(connection, recipe, model, recipeComponents, modelPath, {
203203
...labels,
204204
'recipe-id': recipe.id,
205205
'model-id': model.id,
@@ -253,7 +253,7 @@ export class ApplicationManager extends Publisher<ApplicationState[]> implements
253253
connection: ContainerProviderConnection,
254254
recipe: Recipe,
255255
model: ModelInfo,
256-
images: RecipeImage[],
256+
components: RecipeComponents,
257257
modelPath: string,
258258
labels?: { [key: string]: string },
259259
): Promise<PodInfo> {
@@ -262,7 +262,7 @@ export class ApplicationManager extends Publisher<ApplicationState[]> implements
262262
// create empty pod
263263
let podInfo: PodInfo;
264264
try {
265-
podInfo = await this.createPod(connection, recipe, model, images);
265+
podInfo = await this.createPod(connection, recipe, model, components.images);
266266
task.labels = {
267267
...task.labels,
268268
'pod-id': podInfo.Id,
@@ -277,7 +277,7 @@ export class ApplicationManager extends Publisher<ApplicationState[]> implements
277277
}
278278

279279
try {
280-
await this.createContainerAndAttachToPod(connection, podInfo, images, model, modelPath);
280+
await this.createContainerAndAttachToPod(connection, podInfo, components, model, modelPath);
281281
task.state = 'success';
282282
} catch (e) {
283283
console.error(`error when creating pod ${podInfo.Id}`, e);
@@ -294,14 +294,14 @@ export class ApplicationManager extends Publisher<ApplicationState[]> implements
294294
protected async createContainerAndAttachToPod(
295295
connection: ContainerProviderConnection,
296296
podInfo: PodInfo,
297-
images: RecipeImage[],
297+
components: RecipeComponents,
298298
modelInfo: ModelInfo,
299299
modelPath: string,
300300
): Promise<void> {
301301
const vmType = connection.vmType ?? VMType.UNKNOWN;
302302
// temporary check to set Z flag or not - to be removed when switching to podman 5
303303
await Promise.all(
304-
images.map(async image => {
304+
components.images.map(async image => {
305305
let hostConfig: HostConfig | undefined = undefined;
306306
let envs: string[] = [];
307307
let healthcheck: HealthConfig | undefined = undefined;
@@ -321,11 +321,15 @@ export class ApplicationManager extends Publisher<ApplicationState[]> implements
321321
envs = [`MODEL_PATH=/${modelName}`];
322322
envs.push(...getModelPropertiesForEnvironment(modelInfo));
323323
} else {
324-
// TODO: remove static port
325-
const modelService = images.find(image => image.modelService);
326-
if (modelService && modelService.ports.length > 0) {
327-
const endPoint = `http://localhost:${modelService.ports[0]}`;
324+
if (components.inferenceServer) {
325+
const endPoint = `http://host.containers.internal:${components.inferenceServer.connection.port}`;
328326
envs = [`MODEL_ENDPOINT=${endPoint}`];
327+
} else {
328+
const modelService = components.images.find(image => image.modelService);
329+
if (modelService && modelService.ports.length > 0) {
330+
const endPoint = `http://localhost:${modelService.ports[0]}`;
331+
envs = [`MODEL_ENDPOINT=${endPoint}`];
332+
}
329333
}
330334
}
331335
if (image.ports.length > 0) {

packages/backend/src/managers/inference/inferenceManager.ts

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,22 @@ export class InferenceManager extends Publisher<InferenceServer[]> implements Di
101101
return this.#servers.get(containerId);
102102
}
103103

104+
/**
105+
* return the first inference server which is using the specific model
106+
* it throws if the model backend is not currently supported
107+
*/
108+
public findServerByModel(model: ModelInfo): InferenceServer | undefined {
109+
// check if model backend is supported
110+
const backend: InferenceType = getInferenceType([model]);
111+
const providers: InferenceProvider[] = this.inferenceProviderRegistry
112+
.getByType(backend)
113+
.filter(provider => provider.enabled());
114+
if (providers.length === 0) {
115+
throw new Error('no enabled provider could be found.');
116+
}
117+
return this.getServers().find(s => s.models.some(m => m.id === model.id));
118+
}
119+
104120
/**
105121
* Creating an inference server can be heavy task (pulling image, uploading model to WSL etc.)
106122
* The frontend cannot wait endlessly, therefore we provide a method returning a tracking identifier

packages/backend/src/managers/recipes/RecipeManager.spec.ts

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ import { existsSync, statSync } from 'node:fs';
2828
import { AIConfigFormat, parseYamlFile } from '../../models/AIConfig';
2929
import { goarch } from '../../utils/arch';
3030
import { VMType } from '@shared/src/models/IPodman';
31+
import type { InferenceManager } from '../inference/inferenceManager';
32+
import type { ModelInfo } from '@shared/src/models/IModelInfo';
3133

3234
const taskRegistryMock = {
3335
createTask: vi.fn(),
@@ -46,6 +48,8 @@ const localRepositoriesMock = {
4648
register: vi.fn(),
4749
} as unknown as LocalRepositoryRegistry;
4850

51+
const inferenceManagerMock = {} as unknown as InferenceManager;
52+
4953
const recipeMock: Recipe = {
5054
id: 'recipe-test',
5155
name: 'Test Recipe',
@@ -60,6 +64,12 @@ const connectionMock: ContainerProviderConnection = {
6064
vmType: VMType.UNKNOWN,
6165
} as unknown as ContainerProviderConnection;
6266

67+
const modelInfoMock: ModelInfo = {
68+
id: 'modelId',
69+
name: 'Model',
70+
description: 'model to test',
71+
} as unknown as ModelInfo;
72+
6373
vi.mock('../../models/AIConfig', () => ({
6474
AIConfigFormat: {
6575
CURRENT: 'current',
@@ -123,6 +133,7 @@ async function getInitializedRecipeManager(): Promise<RecipeManager> {
123133
taskRegistryMock,
124134
builderManagerMock,
125135
localRepositoriesMock,
136+
inferenceManagerMock,
126137
);
127138
manager.init();
128139
return manager;
@@ -180,14 +191,14 @@ describe('buildRecipe', () => {
180191
const manager = await getInitializedRecipeManager();
181192

182193
await expect(() => {
183-
return manager.buildRecipe(connectionMock, recipeMock);
194+
return manager.buildRecipe(connectionMock, recipeMock, modelInfoMock);
184195
}).rejects.toThrowError('build error');
185196
});
186197

187198
test('labels should be propagated', async () => {
188199
const manager = await getInitializedRecipeManager();
189200

190-
await manager.buildRecipe(connectionMock, recipeMock, {
201+
await manager.buildRecipe(connectionMock, recipeMock, modelInfoMock, {
191202
'test-label': 'test-value',
192203
});
193204

packages/backend/src/managers/recipes/RecipeManager.ts

Lines changed: 69 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
***********************************************************************/
1818
import type { GitCloneInfo, GitManager } from '../gitManager';
1919
import type { TaskRegistry } from '../../registries/TaskRegistry';
20-
import type { Recipe, RecipeImage } from '@shared/src/models/IRecipe';
20+
import type { Recipe, RecipeComponents } from '@shared/src/models/IRecipe';
2121
import path from 'node:path';
2222
import type { Task } from '@shared/src/models/ITask';
2323
import type { LocalRepositoryRegistry } from '../../registries/LocalRepositoryRegistry';
@@ -28,6 +28,10 @@ import { goarch } from '../../utils/arch';
2828
import type { BuilderManager } from './BuilderManager';
2929
import type { ContainerProviderConnection, Disposable } from '@podman-desktop/api';
3030
import { CONFIG_FILENAME } from '../../utils/RecipeConstants';
31+
import type { InferenceManager } from '../inference/inferenceManager';
32+
import type { ModelInfo } from '@shared/src/models/IModelInfo';
33+
import { withDefaultConfiguration } from '../../utils/inferenceUtils';
34+
import type { InferenceServer } from '@shared/src/models/IInference';
3135

3236
export interface AIContainers {
3337
aiConfigFile: AIConfigFile;
@@ -41,6 +45,7 @@ export class RecipeManager implements Disposable {
4145
private taskRegistry: TaskRegistry,
4246
private builderManager: BuilderManager,
4347
private localRepositories: LocalRepositoryRegistry,
48+
private inferenceManager: InferenceManager,
4449
) {}
4550

4651
dispose(): void {}
@@ -94,17 +99,63 @@ export class RecipeManager implements Disposable {
9499
public async buildRecipe(
95100
connection: ContainerProviderConnection,
96101
recipe: Recipe,
102+
model: ModelInfo,
97103
labels?: { [key: string]: string },
98-
): Promise<RecipeImage[]> {
104+
): Promise<RecipeComponents> {
99105
const localFolder = path.join(this.appUserDirectory, recipe.id);
100106

107+
let inferenceServer: InferenceServer | undefined;
108+
// if the recipe has a defined backend, we gives priority to using an inference server
109+
if (recipe.backend && recipe.backend === model.backend) {
110+
let task: Task | undefined;
111+
try {
112+
inferenceServer = this.inferenceManager.findServerByModel(model);
113+
task = this.taskRegistry.createTask('Starting Inference server', 'loading', labels);
114+
if (!inferenceServer) {
115+
const inferenceContainerId = await this.inferenceManager.createInferenceServer(
116+
await withDefaultConfiguration({
117+
modelsInfo: [model],
118+
}),
119+
);
120+
inferenceServer = this.inferenceManager.get(inferenceContainerId);
121+
this.taskRegistry.updateTask({
122+
...task,
123+
labels: {
124+
...task.labels,
125+
containerId: inferenceContainerId,
126+
},
127+
});
128+
} else if (inferenceServer.status === 'stopped') {
129+
await this.inferenceManager.startInferenceServer(inferenceServer.container.containerId);
130+
}
131+
task.state = 'success';
132+
} catch (e) {
133+
// we only skip the task update if the error is that we do not support this backend.
134+
// If so, we build the image for the model service
135+
if (task && String(e) !== 'no enabled provider could be found.') {
136+
task.state = 'error';
137+
task.error = `Something went wrong while starting the inference server: ${String(e)}`;
138+
throw e;
139+
}
140+
} finally {
141+
if (task) {
142+
this.taskRegistry.updateTask(task);
143+
}
144+
}
145+
}
146+
101147
// load and parse the recipe configuration file and filter containers based on architecture
102-
const configAndFilteredContainers = this.getConfigAndFilterContainers(recipe.basedir, localFolder, {
103-
...labels,
104-
'recipe-id': recipe.id,
105-
});
148+
const configAndFilteredContainers = this.getConfigAndFilterContainers(
149+
recipe.basedir,
150+
localFolder,
151+
!!inferenceServer,
152+
{
153+
...labels,
154+
'recipe-id': recipe.id,
155+
},
156+
);
106157

107-
return await this.builderManager.build(
158+
const images = await this.builderManager.build(
108159
connection,
109160
recipe,
110161
configAndFilteredContainers.containers,
@@ -114,11 +165,17 @@ export class RecipeManager implements Disposable {
114165
'recipe-id': recipe.id,
115166
},
116167
);
168+
169+
return {
170+
images,
171+
inferenceServer,
172+
};
117173
}
118174

119175
private getConfigAndFilterContainers(
120176
recipeBaseDir: string | undefined,
121177
localFolder: string,
178+
useInferenceServer: boolean,
122179
labels?: { [key: string]: string },
123180
): AIContainers {
124181
// Adding loading configuration task
@@ -135,7 +192,11 @@ export class RecipeManager implements Disposable {
135192
}
136193

137194
// filter the containers based on architecture, gpu accelerator and backend (that define which model supports)
138-
const filteredContainers: ContainerConfig[] = this.filterContainers(aiConfigFile.aiConfig);
195+
let filteredContainers: ContainerConfig[] = this.filterContainers(aiConfigFile.aiConfig);
196+
// if we are using the inference server we can remove the model service
197+
if (useInferenceServer) {
198+
filteredContainers = filteredContainers.filter(c => !c.modelService);
199+
}
139200
if (filteredContainers.length > 0) {
140201
// Mark as success.
141202
task.state = 'success';

packages/backend/src/models/AIConfig.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ export interface ContainerConfig {
2828
gpu_env: string[];
2929
ports?: number[];
3030
image?: string;
31+
backend?: string[];
3132
}
3233

3334
export enum AIConfigFormat {
@@ -130,6 +131,7 @@ export function parseYamlFile(filepath: string, defaultArch: string): AIConfig {
130131
? container['ports'].map(port => parseInt(port))
131132
: [],
132133
image: 'image' in container && isString(container['image']) ? container['image'] : undefined,
134+
backend: 'backend' in container && Array.isArray(container['backend']) ? container['backend'] : undefined,
133135
};
134136
}),
135137
},

0 commit comments

Comments
 (0)