goupilew commited on
Commit
96070f4
·
unverified ·
1 Parent(s): 29ea0af

feat: add support for multimodal in Vertex (#1338)

Browse files

* feat: add support for multimodal in Vertex

* Nit changes and remove tools if multimodal

* revert model name change

* Fix tools/multimodal condition

* chores(lint): fix formatting

---------
Co-authored-by: Thomas <[email protected]>
Co-authored-by: Nathan Sarrazin <[email protected]>

README.md CHANGED
@@ -775,21 +775,29 @@ MODELS=`[
775
  {
776
  "name": "gemini-1.5-pro",
777
  "displayName": "Vertex Gemini Pro 1.5",
 
778
  "endpoints" : [{
779
  "type": "vertex",
780
  "project": "abc-xyz",
781
  "location": "europe-west3",
782
  "model": "gemini-1.5-pro-preview-0409", // model-name
783
-
784
  // Optional
785
  "safetyThreshold": "BLOCK_MEDIUM_AND_ABOVE",
786
  "apiEndpoint": "", // alternative api endpoint url,
787
- // Optional
788
  "tools": [{
789
  "googleSearchRetrieval": {
790
  "disableAttribution": true
791
  }
792
- }]
 
 
 
 
 
 
 
 
 
793
  }]
794
  },
795
  ]`
 
775
  {
776
  "name": "gemini-1.5-pro",
777
  "displayName": "Vertex Gemini Pro 1.5",
778
+ "multimodal": true,
779
  "endpoints" : [{
780
  "type": "vertex",
781
  "project": "abc-xyz",
782
  "location": "europe-west3",
783
  "model": "gemini-1.5-pro-preview-0409", // model-name
 
784
  // Optional
785
  "safetyThreshold": "BLOCK_MEDIUM_AND_ABOVE",
786
  "apiEndpoint": "", // alternative api endpoint url,
 
787
  "tools": [{
788
  "googleSearchRetrieval": {
789
  "disableAttribution": true
790
  }
791
+ }],
792
+ "multimodal": {
793
+ "image": {
794
+ "supportedMimeTypes": ["image/png", "image/jpeg", "image/webp"],
795
+ "preferredMimeType": "image/png",
796
+ "maxSizeInMB": 5,
797
+ "maxWidth": 2000,
798
+ "maxHeight": 1000;
799
+ }
800
+ }
801
  }]
802
  },
803
  ]`
src/lib/server/endpoints/google/endpointVertex.ts CHANGED
@@ -9,6 +9,7 @@ import type { Endpoint } from "../endpoints";
9
  import { z } from "zod";
10
  import type { Message } from "$lib/types/Message";
11
  import type { TextGenerationStreamOutput } from "@huggingface/inference";
 
12
 
13
  export const endpointVertexParametersSchema = z.object({
14
  weight: z.number().int().positive().default(1),
@@ -27,10 +28,28 @@ export const endpointVertexParametersSchema = z.object({
27
  ])
28
  .optional(),
29
  tools: z.array(z.any()).optional(),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  });
31
 
32
  export function endpointVertex(input: z.input<typeof endpointVertexParametersSchema>): Endpoint {
33
- const { project, location, model, apiEndpoint, safetyThreshold, tools } =
34
  endpointVertexParametersSchema.parse(input);
35
 
36
  const vertex_ai = new VertexAI({
@@ -42,6 +61,8 @@ export function endpointVertex(input: z.input<typeof endpointVertexParametersSch
42
  return async ({ messages, preprompt, generateSettings }) => {
43
  const parameters = { ...model.parameters, ...generateSettings };
44
 
 
 
45
  const generativeModel = vertex_ai.getGenerativeModel({
46
  model: model.id ?? model.name,
47
  safetySettings: safetyThreshold
@@ -73,7 +94,8 @@ export function endpointVertex(input: z.input<typeof endpointVertexParametersSch
73
  stopSequences: parameters?.stop,
74
  temperature: parameters?.temperature ?? 1,
75
  },
76
- tools,
 
77
  });
78
 
79
  // Preprompt is the same as the first system message.
@@ -83,16 +105,30 @@ export function endpointVertex(input: z.input<typeof endpointVertexParametersSch
83
  messages.shift();
84
  }
85
 
86
- const vertexMessages = messages.map(({ from, content }: Omit<Message, "id">): Content => {
87
- return {
88
- role: from === "user" ? "user" : "model",
89
- parts: [
90
- {
91
- text: content,
92
- },
93
- ],
94
- };
95
- });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  const result = await generativeModel.generateContentStream({
98
  contents: vertexMessages,
 
9
  import { z } from "zod";
10
  import type { Message } from "$lib/types/Message";
11
  import type { TextGenerationStreamOutput } from "@huggingface/inference";
12
+ import { createImageProcessorOptionsValidator, makeImageProcessor } from "../images";
13
 
14
  export const endpointVertexParametersSchema = z.object({
15
  weight: z.number().int().positive().default(1),
 
28
  ])
29
  .optional(),
30
  tools: z.array(z.any()).optional(),
31
+ multimodal: z
32
+ .object({
33
+ image: createImageProcessorOptionsValidator({
34
+ supportedMimeTypes: [
35
+ "image/png",
36
+ "image/jpeg",
37
+ "image/webp",
38
+ "image/avif",
39
+ "image/tiff",
40
+ "image/gif",
41
+ ],
42
+ preferredMimeType: "image/webp",
43
+ maxSizeInMB: Infinity,
44
+ maxWidth: 4096,
45
+ maxHeight: 4096,
46
+ }),
47
+ })
48
+ .default({}),
49
  });
50
 
51
  export function endpointVertex(input: z.input<typeof endpointVertexParametersSchema>): Endpoint {
52
+ const { project, location, model, apiEndpoint, safetyThreshold, tools, multimodal } =
53
  endpointVertexParametersSchema.parse(input);
54
 
55
  const vertex_ai = new VertexAI({
 
61
  return async ({ messages, preprompt, generateSettings }) => {
62
  const parameters = { ...model.parameters, ...generateSettings };
63
 
64
+ const hasFiles = messages.some((message) => message.files && message.files.length > 0);
65
+
66
  const generativeModel = vertex_ai.getGenerativeModel({
67
  model: model.id ?? model.name,
68
  safetySettings: safetyThreshold
 
94
  stopSequences: parameters?.stop,
95
  temperature: parameters?.temperature ?? 1,
96
  },
97
+ // tools and multimodal are mutually exclusive
98
+ tools: !hasFiles ? tools : undefined,
99
  });
100
 
101
  // Preprompt is the same as the first system message.
 
105
  messages.shift();
106
  }
107
 
108
+ const vertexMessages = await Promise.all(
109
+ messages.map(async ({ from, content, files }: Omit<Message, "id">): Promise<Content> => {
110
+ const imageProcessor = makeImageProcessor(multimodal.image);
111
+ const processedFiles =
112
+ files && files.length > 0
113
+ ? await Promise.all(files.map(async (file) => imageProcessor(file)))
114
+ : [];
115
+
116
+ return {
117
+ role: from === "user" ? "user" : "model",
118
+ parts: [
119
+ ...processedFiles.map((processedFile) => ({
120
+ inlineData: {
121
+ data: processedFile.image.toString("base64"),
122
+ mimeType: processedFile.mime,
123
+ },
124
+ })),
125
+ {
126
+ text: content,
127
+ },
128
+ ],
129
+ };
130
+ })
131
+ );
132
 
133
  const result = await generativeModel.generateContentStream({
134
  contents: vertexMessages,