Spaces:
Running
Running
Add Google Gemini API Support (#1330)
Browse files* Supports Google Gemini API
* Fixed error when running svelte-check
* update: from the contribution of nsarrazin
docs/source/configuration/models/providers/google.md
CHANGED
@@ -44,6 +44,42 @@ MODELS=`[
|
|
44 |
}
|
45 |
}]
|
46 |
}]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
]`
|
49 |
```
|
|
|
44 |
}
|
45 |
}]
|
46 |
}]
|
47 |
+
}
|
48 |
+
]`
|
49 |
+
```
|
50 |
+
|
51 |
+
## GenAI
|
52 |
+
|
53 |
+
Or use the Gemini API API provider [from](https://github.com/google-gemini/generative-ai-js#readme):
|
54 |
+
|
55 |
+
> Make sure that you have an API key from Google Cloud Platform. To get an API key, follow the instructions [here](https://cloud.google.com/docs/authentication/api-keys).
|
56 |
+
|
57 |
+
```ini
|
58 |
+
MODELS=`[
|
59 |
+
{
|
60 |
+
"name": "gemini-1.5-flash",
|
61 |
+
"displayName": "Gemini Flash 1.5",
|
62 |
+
"multimodal": true,
|
63 |
+
"endpoints": [
|
64 |
+
{
|
65 |
+
"type": "genai",
|
66 |
+
"apiKey": "abc...xyz"
|
67 |
+
}
|
68 |
+
]
|
69 |
+
|
70 |
+
// Optional
|
71 |
+
"safetyThreshold": "BLOCK_MEDIUM_AND_ABOVE",
|
72 |
},
|
73 |
+
{
|
74 |
+
"name": "gemini-1.5-pro",
|
75 |
+
"displayName": "Gemini Pro 1.5",
|
76 |
+
"multimodal": false,
|
77 |
+
"endpoints": [
|
78 |
+
{
|
79 |
+
"type": "genai",
|
80 |
+
"apiKey": "abc...xyz"
|
81 |
+
}
|
82 |
+
]
|
83 |
+
}
|
84 |
]`
|
85 |
```
|
package-lock.json
CHANGED
@@ -96,6 +96,7 @@
|
|
96 |
"@anthropic-ai/sdk": "^0.17.1",
|
97 |
"@anthropic-ai/vertex-sdk": "^0.3.0",
|
98 |
"@google-cloud/vertexai": "^1.1.0",
|
|
|
99 |
"aws4fetch": "^1.0.17",
|
100 |
"cohere-ai": "^7.9.0",
|
101 |
"openai": "^4.44.0"
|
@@ -1360,6 +1361,15 @@
|
|
1360 |
"node": ">=18.0.0"
|
1361 |
}
|
1362 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1363 |
"node_modules/@gradio/client": {
|
1364 |
"version": "0.19.4",
|
1365 |
"resolved": "https://registry.npmjs.org/@gradio/client/-/client-0.19.4.tgz",
|
|
|
96 |
"@anthropic-ai/sdk": "^0.17.1",
|
97 |
"@anthropic-ai/vertex-sdk": "^0.3.0",
|
98 |
"@google-cloud/vertexai": "^1.1.0",
|
99 |
+
"@google/generative-ai": "^0.14.1",
|
100 |
"aws4fetch": "^1.0.17",
|
101 |
"cohere-ai": "^7.9.0",
|
102 |
"openai": "^4.44.0"
|
|
|
1361 |
"node": ">=18.0.0"
|
1362 |
}
|
1363 |
},
|
1364 |
+
"node_modules/@google/generative-ai": {
|
1365 |
+
"version": "0.14.1",
|
1366 |
+
"resolved": "https://registry.npmjs.org/@google/generative-ai/-/generative-ai-0.14.1.tgz",
|
1367 |
+
"integrity": "sha512-pevEyZCb0Oc+dYNlSberW8oZBm4ofeTD5wN01TowQMhTwdAbGAnJMtQzoklh6Blq2AKsx8Ox6FWa44KioZLZiA==",
|
1368 |
+
"optional": true,
|
1369 |
+
"engines": {
|
1370 |
+
"node": ">=18.0.0"
|
1371 |
+
}
|
1372 |
+
},
|
1373 |
"node_modules/@gradio/client": {
|
1374 |
"version": "0.19.4",
|
1375 |
"resolved": "https://registry.npmjs.org/@gradio/client/-/client-0.19.4.tgz",
|
package.json
CHANGED
@@ -106,6 +106,7 @@
|
|
106 |
"@anthropic-ai/sdk": "^0.17.1",
|
107 |
"@anthropic-ai/vertex-sdk": "^0.3.0",
|
108 |
"@google-cloud/vertexai": "^1.1.0",
|
|
|
109 |
"aws4fetch": "^1.0.17",
|
110 |
"cohere-ai": "^7.9.0",
|
111 |
"openai": "^4.44.0"
|
|
|
106 |
"@anthropic-ai/sdk": "^0.17.1",
|
107 |
"@anthropic-ai/vertex-sdk": "^0.3.0",
|
108 |
"@google-cloud/vertexai": "^1.1.0",
|
109 |
+
"@google/generative-ai": "^0.14.1",
|
110 |
"aws4fetch": "^1.0.17",
|
111 |
"cohere-ai": "^7.9.0",
|
112 |
"openai": "^4.44.0"
|
src/lib/server/endpoints/endpoints.ts
CHANGED
@@ -8,6 +8,7 @@ import { endpointOAIParametersSchema, endpointOai } from "./openai/endpointOai";
|
|
8 |
import endpointLlamacpp, { endpointLlamacppParametersSchema } from "./llamacpp/endpointLlamacpp";
|
9 |
import endpointOllama, { endpointOllamaParametersSchema } from "./ollama/endpointOllama";
|
10 |
import endpointVertex, { endpointVertexParametersSchema } from "./google/endpointVertex";
|
|
|
11 |
|
12 |
import {
|
13 |
endpointAnthropic,
|
@@ -65,6 +66,7 @@ export const endpoints = {
|
|
65 |
llamacpp: endpointLlamacpp,
|
66 |
ollama: endpointOllama,
|
67 |
vertex: endpointVertex,
|
|
|
68 |
cloudflare: endpointCloudflare,
|
69 |
cohere: endpointCohere,
|
70 |
langserve: endpointLangserve,
|
@@ -79,6 +81,7 @@ export const endpointSchema = z.discriminatedUnion("type", [
|
|
79 |
endpointLlamacppParametersSchema,
|
80 |
endpointOllamaParametersSchema,
|
81 |
endpointVertexParametersSchema,
|
|
|
82 |
endpointCloudflareParametersSchema,
|
83 |
endpointCohereParametersSchema,
|
84 |
endpointLangserveParametersSchema,
|
|
|
8 |
import endpointLlamacpp, { endpointLlamacppParametersSchema } from "./llamacpp/endpointLlamacpp";
|
9 |
import endpointOllama, { endpointOllamaParametersSchema } from "./ollama/endpointOllama";
|
10 |
import endpointVertex, { endpointVertexParametersSchema } from "./google/endpointVertex";
|
11 |
+
import endpointGenAI, { endpointGenAIParametersSchema } from "./google/endpointGenAI";
|
12 |
|
13 |
import {
|
14 |
endpointAnthropic,
|
|
|
66 |
llamacpp: endpointLlamacpp,
|
67 |
ollama: endpointOllama,
|
68 |
vertex: endpointVertex,
|
69 |
+
genai: endpointGenAI,
|
70 |
cloudflare: endpointCloudflare,
|
71 |
cohere: endpointCohere,
|
72 |
langserve: endpointLangserve,
|
|
|
81 |
endpointLlamacppParametersSchema,
|
82 |
endpointOllamaParametersSchema,
|
83 |
endpointVertexParametersSchema,
|
84 |
+
endpointGenAIParametersSchema,
|
85 |
endpointCloudflareParametersSchema,
|
86 |
endpointCohereParametersSchema,
|
87 |
endpointLangserveParametersSchema,
|
src/lib/server/endpoints/google/endpointGenAI.ts
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { GoogleGenerativeAI, HarmBlockThreshold, HarmCategory } from "@google/generative-ai";
|
2 |
+
import type { Content, Part, TextPart } from "@google/generative-ai";
|
3 |
+
import { z } from "zod";
|
4 |
+
import type { Message, MessageFile } from "$lib/types/Message";
|
5 |
+
import type { TextGenerationStreamOutput } from "@huggingface/inference";
|
6 |
+
import type { Endpoint } from "../endpoints";
|
7 |
+
import { createImageProcessorOptionsValidator, makeImageProcessor } from "../images";
|
8 |
+
import type { ImageProcessorOptions } from "../images";
|
9 |
+
|
10 |
+
export const endpointGenAIParametersSchema = z.object({
|
11 |
+
weight: z.number().int().positive().default(1),
|
12 |
+
model: z.any(),
|
13 |
+
type: z.literal("genai"),
|
14 |
+
apiKey: z.string(),
|
15 |
+
safetyThreshold: z
|
16 |
+
.enum([
|
17 |
+
HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED,
|
18 |
+
HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
|
19 |
+
HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
20 |
+
HarmBlockThreshold.BLOCK_NONE,
|
21 |
+
HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
22 |
+
])
|
23 |
+
.optional(),
|
24 |
+
multimodal: z
|
25 |
+
.object({
|
26 |
+
image: createImageProcessorOptionsValidator({
|
27 |
+
supportedMimeTypes: ["image/png", "image/jpeg", "image/webp"],
|
28 |
+
preferredMimeType: "image/webp",
|
29 |
+
// The 4 / 3 compensates for the 33% increase in size when converting to base64
|
30 |
+
maxSizeInMB: (5 / 4) * 3,
|
31 |
+
maxWidth: 4096,
|
32 |
+
maxHeight: 4096,
|
33 |
+
}),
|
34 |
+
})
|
35 |
+
.default({}),
|
36 |
+
});
|
37 |
+
|
38 |
+
export function endpointGenAI(input: z.input<typeof endpointGenAIParametersSchema>): Endpoint {
|
39 |
+
const { model, apiKey, safetyThreshold, multimodal } = endpointGenAIParametersSchema.parse(input);
|
40 |
+
|
41 |
+
const genAI = new GoogleGenerativeAI(apiKey);
|
42 |
+
|
43 |
+
return async ({ messages, preprompt, generateSettings }) => {
|
44 |
+
const parameters = { ...model.parameters, ...generateSettings };
|
45 |
+
|
46 |
+
const generativeModel = genAI.getGenerativeModel({
|
47 |
+
model: model.id ?? model.name,
|
48 |
+
safetySettings: safetyThreshold
|
49 |
+
? [
|
50 |
+
{
|
51 |
+
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
52 |
+
threshold: safetyThreshold,
|
53 |
+
},
|
54 |
+
{
|
55 |
+
category: HarmCategory.HARM_CATEGORY_HARASSMENT,
|
56 |
+
threshold: safetyThreshold,
|
57 |
+
},
|
58 |
+
{
|
59 |
+
category: HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
60 |
+
threshold: safetyThreshold,
|
61 |
+
},
|
62 |
+
{
|
63 |
+
category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
64 |
+
threshold: safetyThreshold,
|
65 |
+
},
|
66 |
+
{
|
67 |
+
category: HarmCategory.HARM_CATEGORY_UNSPECIFIED,
|
68 |
+
threshold: safetyThreshold,
|
69 |
+
},
|
70 |
+
]
|
71 |
+
: undefined,
|
72 |
+
generationConfig: {
|
73 |
+
maxOutputTokens: parameters?.max_new_tokens ?? 4096,
|
74 |
+
stopSequences: parameters?.stop,
|
75 |
+
temperature: parameters?.temperature ?? 1,
|
76 |
+
},
|
77 |
+
});
|
78 |
+
|
79 |
+
let systemMessage = preprompt;
|
80 |
+
if (messages[0].from === "system") {
|
81 |
+
systemMessage = messages[0].content;
|
82 |
+
messages.shift();
|
83 |
+
}
|
84 |
+
|
85 |
+
const genAIMessages = await Promise.all(
|
86 |
+
messages.map(async ({ from, content, files }: Omit<Message, "id">): Promise<Content> => {
|
87 |
+
return {
|
88 |
+
role: from === "user" ? "user" : "model",
|
89 |
+
parts: [
|
90 |
+
...(await Promise.all(
|
91 |
+
(files ?? []).map((file) => fileToImageBlock(file, multimodal.image))
|
92 |
+
)),
|
93 |
+
{ text: content },
|
94 |
+
],
|
95 |
+
};
|
96 |
+
})
|
97 |
+
);
|
98 |
+
|
99 |
+
const result = await generativeModel.generateContentStream({
|
100 |
+
contents: genAIMessages,
|
101 |
+
systemInstruction:
|
102 |
+
systemMessage && systemMessage.trim() !== ""
|
103 |
+
? {
|
104 |
+
role: "system",
|
105 |
+
parts: [{ text: systemMessage }],
|
106 |
+
}
|
107 |
+
: undefined,
|
108 |
+
});
|
109 |
+
|
110 |
+
let tokenId = 0;
|
111 |
+
return (async function* () {
|
112 |
+
let generatedText = "";
|
113 |
+
|
114 |
+
for await (const data of result.stream) {
|
115 |
+
if (!data?.candidates?.length) break; // Handle case where no candidates are present
|
116 |
+
|
117 |
+
const candidate = data.candidates[0];
|
118 |
+
if (!candidate.content?.parts?.length) continue; // Skip if no parts are present
|
119 |
+
|
120 |
+
const firstPart = candidate.content.parts.find((part) => "text" in part) as
|
121 |
+
| TextPart
|
122 |
+
| undefined;
|
123 |
+
if (!firstPart) continue; // Skip if no text part is found
|
124 |
+
|
125 |
+
const content = firstPart.text;
|
126 |
+
generatedText += content;
|
127 |
+
|
128 |
+
const output: TextGenerationStreamOutput = {
|
129 |
+
token: {
|
130 |
+
id: tokenId++,
|
131 |
+
text: content,
|
132 |
+
logprob: 0,
|
133 |
+
special: false,
|
134 |
+
},
|
135 |
+
generated_text: null,
|
136 |
+
details: null,
|
137 |
+
};
|
138 |
+
yield output;
|
139 |
+
}
|
140 |
+
|
141 |
+
const output: TextGenerationStreamOutput = {
|
142 |
+
token: {
|
143 |
+
id: tokenId++,
|
144 |
+
text: "",
|
145 |
+
logprob: 0,
|
146 |
+
special: true,
|
147 |
+
},
|
148 |
+
generated_text: generatedText,
|
149 |
+
details: null,
|
150 |
+
};
|
151 |
+
yield output;
|
152 |
+
})();
|
153 |
+
};
|
154 |
+
}
|
155 |
+
|
156 |
+
async function fileToImageBlock(
|
157 |
+
file: MessageFile,
|
158 |
+
opts: ImageProcessorOptions<"image/png" | "image/jpeg" | "image/webp">
|
159 |
+
): Promise<Part> {
|
160 |
+
const processor = makeImageProcessor(opts);
|
161 |
+
const { image, mime } = await processor(file);
|
162 |
+
|
163 |
+
return {
|
164 |
+
inlineData: {
|
165 |
+
mimeType: mime,
|
166 |
+
data: image.toString("base64"),
|
167 |
+
},
|
168 |
+
};
|
169 |
+
}
|
170 |
+
|
171 |
+
export default endpointGenAI;
|
src/lib/server/models.ts
CHANGED
@@ -230,6 +230,8 @@ const addEndpoint = (m: Awaited<ReturnType<typeof processModel>>) => ({
|
|
230 |
return endpoints.ollama(args);
|
231 |
case "vertex":
|
232 |
return await endpoints.vertex(args);
|
|
|
|
|
233 |
case "cloudflare":
|
234 |
return await endpoints.cloudflare(args);
|
235 |
case "cohere":
|
|
|
230 |
return endpoints.ollama(args);
|
231 |
case "vertex":
|
232 |
return await endpoints.vertex(args);
|
233 |
+
case "genai":
|
234 |
+
return await endpoints.genai(args);
|
235 |
case "cloudflare":
|
236 |
return await endpoints.cloudflare(args);
|
237 |
case "cohere":
|