toandev commited on
Commit
1c84463
·
unverified ·
1 Parent(s): eb5499b

Add Google Gemini API Support (#1330)

Browse files

* Supports Google Gemini API

* Fixed error when running svelte-check

* update: from the contribution of nsarrazin

docs/source/configuration/models/providers/google.md CHANGED
@@ -44,6 +44,42 @@ MODELS=`[
44
  }
45
  }]
46
  }]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  },
 
 
 
 
 
 
 
 
 
 
 
48
  ]`
49
  ```
 
44
  }
45
  }]
46
  }]
47
+ }
48
+ ]`
49
+ ```
50
+
51
+ ## GenAI
52
+
53
+ Or use the Gemini API API provider [from](https://github.com/google-gemini/generative-ai-js#readme):
54
+
55
+ > Make sure that you have an API key from Google Cloud Platform. To get an API key, follow the instructions [here](https://cloud.google.com/docs/authentication/api-keys).
56
+
57
+ ```ini
58
+ MODELS=`[
59
+ {
60
+ "name": "gemini-1.5-flash",
61
+ "displayName": "Gemini Flash 1.5",
62
+ "multimodal": true,
63
+ "endpoints": [
64
+ {
65
+ "type": "genai",
66
+ "apiKey": "abc...xyz"
67
+ }
68
+ ]
69
+
70
+ // Optional
71
+ "safetyThreshold": "BLOCK_MEDIUM_AND_ABOVE",
72
  },
73
+ {
74
+ "name": "gemini-1.5-pro",
75
+ "displayName": "Gemini Pro 1.5",
76
+ "multimodal": false,
77
+ "endpoints": [
78
+ {
79
+ "type": "genai",
80
+ "apiKey": "abc...xyz"
81
+ }
82
+ ]
83
+ }
84
  ]`
85
  ```
package-lock.json CHANGED
@@ -96,6 +96,7 @@
96
  "@anthropic-ai/sdk": "^0.17.1",
97
  "@anthropic-ai/vertex-sdk": "^0.3.0",
98
  "@google-cloud/vertexai": "^1.1.0",
 
99
  "aws4fetch": "^1.0.17",
100
  "cohere-ai": "^7.9.0",
101
  "openai": "^4.44.0"
@@ -1360,6 +1361,15 @@
1360
  "node": ">=18.0.0"
1361
  }
1362
  },
 
 
 
 
 
 
 
 
 
1363
  "node_modules/@gradio/client": {
1364
  "version": "0.19.4",
1365
  "resolved": "https://registry.npmjs.org/@gradio/client/-/client-0.19.4.tgz",
 
96
  "@anthropic-ai/sdk": "^0.17.1",
97
  "@anthropic-ai/vertex-sdk": "^0.3.0",
98
  "@google-cloud/vertexai": "^1.1.0",
99
+ "@google/generative-ai": "^0.14.1",
100
  "aws4fetch": "^1.0.17",
101
  "cohere-ai": "^7.9.0",
102
  "openai": "^4.44.0"
 
1361
  "node": ">=18.0.0"
1362
  }
1363
  },
1364
+ "node_modules/@google/generative-ai": {
1365
+ "version": "0.14.1",
1366
+ "resolved": "https://registry.npmjs.org/@google/generative-ai/-/generative-ai-0.14.1.tgz",
1367
+ "integrity": "sha512-pevEyZCb0Oc+dYNlSberW8oZBm4ofeTD5wN01TowQMhTwdAbGAnJMtQzoklh6Blq2AKsx8Ox6FWa44KioZLZiA==",
1368
+ "optional": true,
1369
+ "engines": {
1370
+ "node": ">=18.0.0"
1371
+ }
1372
+ },
1373
  "node_modules/@gradio/client": {
1374
  "version": "0.19.4",
1375
  "resolved": "https://registry.npmjs.org/@gradio/client/-/client-0.19.4.tgz",
package.json CHANGED
@@ -106,6 +106,7 @@
106
  "@anthropic-ai/sdk": "^0.17.1",
107
  "@anthropic-ai/vertex-sdk": "^0.3.0",
108
  "@google-cloud/vertexai": "^1.1.0",
 
109
  "aws4fetch": "^1.0.17",
110
  "cohere-ai": "^7.9.0",
111
  "openai": "^4.44.0"
 
106
  "@anthropic-ai/sdk": "^0.17.1",
107
  "@anthropic-ai/vertex-sdk": "^0.3.0",
108
  "@google-cloud/vertexai": "^1.1.0",
109
+ "@google/generative-ai": "^0.14.1",
110
  "aws4fetch": "^1.0.17",
111
  "cohere-ai": "^7.9.0",
112
  "openai": "^4.44.0"
src/lib/server/endpoints/endpoints.ts CHANGED
@@ -8,6 +8,7 @@ import { endpointOAIParametersSchema, endpointOai } from "./openai/endpointOai";
8
  import endpointLlamacpp, { endpointLlamacppParametersSchema } from "./llamacpp/endpointLlamacpp";
9
  import endpointOllama, { endpointOllamaParametersSchema } from "./ollama/endpointOllama";
10
  import endpointVertex, { endpointVertexParametersSchema } from "./google/endpointVertex";
 
11
 
12
  import {
13
  endpointAnthropic,
@@ -65,6 +66,7 @@ export const endpoints = {
65
  llamacpp: endpointLlamacpp,
66
  ollama: endpointOllama,
67
  vertex: endpointVertex,
 
68
  cloudflare: endpointCloudflare,
69
  cohere: endpointCohere,
70
  langserve: endpointLangserve,
@@ -79,6 +81,7 @@ export const endpointSchema = z.discriminatedUnion("type", [
79
  endpointLlamacppParametersSchema,
80
  endpointOllamaParametersSchema,
81
  endpointVertexParametersSchema,
 
82
  endpointCloudflareParametersSchema,
83
  endpointCohereParametersSchema,
84
  endpointLangserveParametersSchema,
 
8
  import endpointLlamacpp, { endpointLlamacppParametersSchema } from "./llamacpp/endpointLlamacpp";
9
  import endpointOllama, { endpointOllamaParametersSchema } from "./ollama/endpointOllama";
10
  import endpointVertex, { endpointVertexParametersSchema } from "./google/endpointVertex";
11
+ import endpointGenAI, { endpointGenAIParametersSchema } from "./google/endpointGenAI";
12
 
13
  import {
14
  endpointAnthropic,
 
66
  llamacpp: endpointLlamacpp,
67
  ollama: endpointOllama,
68
  vertex: endpointVertex,
69
+ genai: endpointGenAI,
70
  cloudflare: endpointCloudflare,
71
  cohere: endpointCohere,
72
  langserve: endpointLangserve,
 
81
  endpointLlamacppParametersSchema,
82
  endpointOllamaParametersSchema,
83
  endpointVertexParametersSchema,
84
+ endpointGenAIParametersSchema,
85
  endpointCloudflareParametersSchema,
86
  endpointCohereParametersSchema,
87
  endpointLangserveParametersSchema,
src/lib/server/endpoints/google/endpointGenAI.ts ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { GoogleGenerativeAI, HarmBlockThreshold, HarmCategory } from "@google/generative-ai";
2
+ import type { Content, Part, TextPart } from "@google/generative-ai";
3
+ import { z } from "zod";
4
+ import type { Message, MessageFile } from "$lib/types/Message";
5
+ import type { TextGenerationStreamOutput } from "@huggingface/inference";
6
+ import type { Endpoint } from "../endpoints";
7
+ import { createImageProcessorOptionsValidator, makeImageProcessor } from "../images";
8
+ import type { ImageProcessorOptions } from "../images";
9
+
10
+ export const endpointGenAIParametersSchema = z.object({
11
+ weight: z.number().int().positive().default(1),
12
+ model: z.any(),
13
+ type: z.literal("genai"),
14
+ apiKey: z.string(),
15
+ safetyThreshold: z
16
+ .enum([
17
+ HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED,
18
+ HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
19
+ HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
20
+ HarmBlockThreshold.BLOCK_NONE,
21
+ HarmBlockThreshold.BLOCK_ONLY_HIGH,
22
+ ])
23
+ .optional(),
24
+ multimodal: z
25
+ .object({
26
+ image: createImageProcessorOptionsValidator({
27
+ supportedMimeTypes: ["image/png", "image/jpeg", "image/webp"],
28
+ preferredMimeType: "image/webp",
29
+ // The 4 / 3 compensates for the 33% increase in size when converting to base64
30
+ maxSizeInMB: (5 / 4) * 3,
31
+ maxWidth: 4096,
32
+ maxHeight: 4096,
33
+ }),
34
+ })
35
+ .default({}),
36
+ });
37
+
38
+ export function endpointGenAI(input: z.input<typeof endpointGenAIParametersSchema>): Endpoint {
39
+ const { model, apiKey, safetyThreshold, multimodal } = endpointGenAIParametersSchema.parse(input);
40
+
41
+ const genAI = new GoogleGenerativeAI(apiKey);
42
+
43
+ return async ({ messages, preprompt, generateSettings }) => {
44
+ const parameters = { ...model.parameters, ...generateSettings };
45
+
46
+ const generativeModel = genAI.getGenerativeModel({
47
+ model: model.id ?? model.name,
48
+ safetySettings: safetyThreshold
49
+ ? [
50
+ {
51
+ category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
52
+ threshold: safetyThreshold,
53
+ },
54
+ {
55
+ category: HarmCategory.HARM_CATEGORY_HARASSMENT,
56
+ threshold: safetyThreshold,
57
+ },
58
+ {
59
+ category: HarmCategory.HARM_CATEGORY_HATE_SPEECH,
60
+ threshold: safetyThreshold,
61
+ },
62
+ {
63
+ category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
64
+ threshold: safetyThreshold,
65
+ },
66
+ {
67
+ category: HarmCategory.HARM_CATEGORY_UNSPECIFIED,
68
+ threshold: safetyThreshold,
69
+ },
70
+ ]
71
+ : undefined,
72
+ generationConfig: {
73
+ maxOutputTokens: parameters?.max_new_tokens ?? 4096,
74
+ stopSequences: parameters?.stop,
75
+ temperature: parameters?.temperature ?? 1,
76
+ },
77
+ });
78
+
79
+ let systemMessage = preprompt;
80
+ if (messages[0].from === "system") {
81
+ systemMessage = messages[0].content;
82
+ messages.shift();
83
+ }
84
+
85
+ const genAIMessages = await Promise.all(
86
+ messages.map(async ({ from, content, files }: Omit<Message, "id">): Promise<Content> => {
87
+ return {
88
+ role: from === "user" ? "user" : "model",
89
+ parts: [
90
+ ...(await Promise.all(
91
+ (files ?? []).map((file) => fileToImageBlock(file, multimodal.image))
92
+ )),
93
+ { text: content },
94
+ ],
95
+ };
96
+ })
97
+ );
98
+
99
+ const result = await generativeModel.generateContentStream({
100
+ contents: genAIMessages,
101
+ systemInstruction:
102
+ systemMessage && systemMessage.trim() !== ""
103
+ ? {
104
+ role: "system",
105
+ parts: [{ text: systemMessage }],
106
+ }
107
+ : undefined,
108
+ });
109
+
110
+ let tokenId = 0;
111
+ return (async function* () {
112
+ let generatedText = "";
113
+
114
+ for await (const data of result.stream) {
115
+ if (!data?.candidates?.length) break; // Handle case where no candidates are present
116
+
117
+ const candidate = data.candidates[0];
118
+ if (!candidate.content?.parts?.length) continue; // Skip if no parts are present
119
+
120
+ const firstPart = candidate.content.parts.find((part) => "text" in part) as
121
+ | TextPart
122
+ | undefined;
123
+ if (!firstPart) continue; // Skip if no text part is found
124
+
125
+ const content = firstPart.text;
126
+ generatedText += content;
127
+
128
+ const output: TextGenerationStreamOutput = {
129
+ token: {
130
+ id: tokenId++,
131
+ text: content,
132
+ logprob: 0,
133
+ special: false,
134
+ },
135
+ generated_text: null,
136
+ details: null,
137
+ };
138
+ yield output;
139
+ }
140
+
141
+ const output: TextGenerationStreamOutput = {
142
+ token: {
143
+ id: tokenId++,
144
+ text: "",
145
+ logprob: 0,
146
+ special: true,
147
+ },
148
+ generated_text: generatedText,
149
+ details: null,
150
+ };
151
+ yield output;
152
+ })();
153
+ };
154
+ }
155
+
156
+ async function fileToImageBlock(
157
+ file: MessageFile,
158
+ opts: ImageProcessorOptions<"image/png" | "image/jpeg" | "image/webp">
159
+ ): Promise<Part> {
160
+ const processor = makeImageProcessor(opts);
161
+ const { image, mime } = await processor(file);
162
+
163
+ return {
164
+ inlineData: {
165
+ mimeType: mime,
166
+ data: image.toString("base64"),
167
+ },
168
+ };
169
+ }
170
+
171
+ export default endpointGenAI;
src/lib/server/models.ts CHANGED
@@ -230,6 +230,8 @@ const addEndpoint = (m: Awaited<ReturnType<typeof processModel>>) => ({
230
  return endpoints.ollama(args);
231
  case "vertex":
232
  return await endpoints.vertex(args);
 
 
233
  case "cloudflare":
234
  return await endpoints.cloudflare(args);
235
  case "cohere":
 
230
  return endpoints.ollama(args);
231
  case "vertex":
232
  return await endpoints.vertex(args);
233
+ case "genai":
234
+ return await endpoints.genai(args);
235
  case "cloudflare":
236
  return await endpoints.cloudflare(args);
237
  case "cohere":