File size: 3,946 Bytes
987575f
 
 
a8a9533
486ffa7
987575f
 
 
 
 
a8a9533
 
987575f
 
 
 
 
 
b831f4b
 
 
 
 
 
987575f
8019701
987575f
 
 
 
 
 
 
 
 
8019701
 
987575f
 
 
8019701
 
 
 
 
987575f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7dd92b8
dc98038
987575f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import { z } from "zod";
import type { Endpoint } from "../endpoints";
import type { TextGenerationStreamOutput } from "@huggingface/inference";
import { env } from "$env/dynamic/private";
import { logger } from "$lib/server/logger";

export const endpointCloudflareParametersSchema = z.object({
	weight: z.number().int().positive().default(1),
	model: z.any(),
	type: z.literal("cloudflare"),
	accountId: z.string().default(env.CLOUDFLARE_ACCOUNT_ID),
	apiToken: z.string().default(env.CLOUDFLARE_API_TOKEN),
});

export async function endpointCloudflare(
	input: z.input<typeof endpointCloudflareParametersSchema>
): Promise<Endpoint> {
	const { accountId, apiToken, model } = endpointCloudflareParametersSchema.parse(input);

	if (!model.id.startsWith("@")) {
		model.id = "@hf/" + model.id;
	}

	const apiURL = `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/run/${model.id}`;

	return async ({ messages, preprompt, generateSettings }) => {
		let messagesFormatted = messages.map((message) => ({
			role: message.from,
			content: message.content,
		}));

		if (messagesFormatted?.[0]?.role !== "system") {
			messagesFormatted = [{ role: "system", content: preprompt ?? "" }, ...messagesFormatted];
		}

		const parameters = { ...model.parameters, ...generateSettings };

		const payload = JSON.stringify({
			messages: messagesFormatted,
			stream: true,
			max_tokens: parameters?.max_new_tokens,
			temperature: parameters?.temperature,
			top_p: parameters?.top_p,
			top_k: parameters?.top_k,
			repetition_penalty: parameters?.repetition_penalty,
		});

		const res = await fetch(apiURL, {
			method: "POST",
			headers: {
				Authorization: `Bearer ${apiToken}`,
				"Content-Type": "application/json",
			},
			body: payload,
		});

		if (!res.ok) {
			throw new Error(`Failed to generate text: ${await res.text()}`);
		}

		const encoder = new TextDecoderStream();
		const reader = res.body?.pipeThrough(encoder).getReader();

		return (async function* () {
			let stop = false;
			let generatedText = "";
			let tokenId = 0;
			let accumulatedData = ""; // Buffer to accumulate data chunks

			while (!stop) {
				const out = await reader?.read();

				// If it's done, we cancel
				if (out?.done) {
					reader?.cancel();
					return;
				}

				if (!out?.value) {
					return;
				}

				// Accumulate the data chunk
				accumulatedData += out.value;

				// Process each complete JSON object in the accumulated data
				while (accumulatedData.includes("\n")) {
					// Assuming each JSON object ends with a newline
					const endIndex = accumulatedData.indexOf("\n");
					let jsonString = accumulatedData.substring(0, endIndex).trim();

					// Remove the processed part from the buffer
					accumulatedData = accumulatedData.substring(endIndex + 1);

					if (jsonString.startsWith("data: ")) {
						jsonString = jsonString.slice(6);
						let data = null;

						if (jsonString === "[DONE]") {
							stop = true;

							yield {
								token: {
									id: tokenId++,
									text: "",
									logprob: 0,
									special: true,
								},
								generated_text: generatedText,
								details: null,
							} satisfies TextGenerationStreamOutput;
							reader?.cancel();

							continue;
						}

						try {
							data = JSON.parse(jsonString);
						} catch (e) {
							logger.error(e, "Failed to parse JSON");
							logger.error(jsonString, "Problematic JSON string:");
							continue; // Skip this iteration and try the next chunk
						}

						// Handle the parsed data
						if (data.response) {
							generatedText += data.response ?? "";
							const output: TextGenerationStreamOutput = {
								token: {
									id: tokenId++,
									text: data.response ?? "",
									logprob: 0,
									special: false,
								},
								generated_text: null,
								details: null,
							};
							yield output;
						}
					}
				}
			}
		})();
	};
}

export default endpointCloudflare;