@huggingface/inference
Advanced tools
Comparing version 2.8.1 to 3.0.0
@@ -43,26 +43,99 @@ var __defProp = Object.defineProperty; | ||
// src/utils/pick.ts | ||
function pick(o, props) { | ||
return Object.assign( | ||
{}, | ||
...props.map((prop) => { | ||
if (o[prop] !== void 0) { | ||
return { [prop]: o[prop] }; | ||
} | ||
}) | ||
); | ||
} | ||
// src/config.ts | ||
var HF_HUB_URL = "https://huggingface.co"; | ||
var HF_INFERENCE_API_URL = "https://api-inference.huggingface.co"; | ||
// src/utils/typedInclude.ts | ||
function typedInclude(arr, v) { | ||
return arr.includes(v); | ||
} | ||
// src/providers/fal-ai.ts | ||
var FAL_AI_API_BASE_URL = "https://fal.run"; | ||
var FAL_AI_SUPPORTED_MODEL_IDS = { | ||
"text-to-image": { | ||
"black-forest-labs/FLUX.1-schnell": "fal-ai/flux/schnell", | ||
"black-forest-labs/FLUX.1-dev": "fal-ai/flux/dev" | ||
}, | ||
"automatic-speech-recognition": { | ||
"openai/whisper-large-v3": "fal-ai/whisper" | ||
} | ||
}; | ||
// src/utils/omit.ts | ||
function omit(o, props) { | ||
const propsArr = Array.isArray(props) ? props : [props]; | ||
const letsKeep = Object.keys(o).filter((prop) => !typedInclude(propsArr, prop)); | ||
return pick(o, letsKeep); | ||
} | ||
// src/providers/replicate.ts | ||
var REPLICATE_API_BASE_URL = "https://api.replicate.com"; | ||
var REPLICATE_SUPPORTED_MODEL_IDS = { | ||
"text-to-image": { | ||
"black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell", | ||
"ByteDance/SDXL-Lightning": "bytedance/sdxl-lightning-4step:5599ed30703defd1d160a25a63321b4dec97101d98b4674bcc56e41f62f35637" | ||
} | ||
// "text-to-speech": { | ||
// "SWivid/F5-TTS": "x-lance/f5-tts:87faf6dd7a692dd82043f662e76369cab126a2cf1937e25a9d41e0b834fd230e" | ||
// }, | ||
}; | ||
// src/providers/sambanova.ts | ||
var SAMBANOVA_API_BASE_URL = "https://api.sambanova.ai"; | ||
var SAMBANOVA_SUPPORTED_MODEL_IDS = { | ||
/** Chat completion / conversational */ | ||
conversational: { | ||
"Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct", | ||
"Qwen/Qwen2.5-72B-Instruct": "Qwen2.5-72B-Instruct", | ||
"Qwen/QwQ-32B-Preview": "QwQ-32B-Preview", | ||
"meta-llama/Llama-3.3-70B-Instruct": "Meta-Llama-3.3-70B-Instruct", | ||
"meta-llama/Llama-3.2-1B-Instruct": "Meta-Llama-3.2-1B-Instruct", | ||
"meta-llama/Llama-3.2-3B-Instruct": "Meta-Llama-3.2-3B-Instruct", | ||
"meta-llama/Llama-3.2-11B-Vision-Instruct": "Llama-3.2-11B-Vision-Instruct", | ||
"meta-llama/Llama-3.2-90B-Vision-Instruct": "Llama-3.2-90B-Vision-Instruct", | ||
"meta-llama/Llama-3.1-8B-Instruct": "Meta-Llama-3.1-8B-Instruct", | ||
"meta-llama/Llama-3.1-70B-Instruct": "Meta-Llama-3.1-70B-Instruct", | ||
"meta-llama/Llama-3.1-405B-Instruct": "Meta-Llama-3.1-405B-Instruct", | ||
"meta-llama/Llama-Guard-3-8B": "Meta-Llama-Guard-3-8B" | ||
} | ||
}; | ||
// src/providers/together.ts | ||
var TOGETHER_API_BASE_URL = "https://api.together.xyz"; | ||
var TOGETHER_SUPPORTED_MODEL_IDS = { | ||
"text-to-image": { | ||
"black-forest-labs/FLUX.1-Canny-dev": "black-forest-labs/FLUX.1-canny", | ||
"black-forest-labs/FLUX.1-Depth-dev": "black-forest-labs/FLUX.1-depth", | ||
"black-forest-labs/FLUX.1-dev": "black-forest-labs/FLUX.1-dev", | ||
"black-forest-labs/FLUX.1-Redux-dev": "black-forest-labs/FLUX.1-redux", | ||
"black-forest-labs/FLUX.1-schnell": "black-forest-labs/FLUX.1-pro", | ||
"stabilityai/stable-diffusion-xl-base-1.0": "stabilityai/stable-diffusion-xl-base-1.0" | ||
}, | ||
conversational: { | ||
"databricks/dbrx-instruct": "databricks/dbrx-instruct", | ||
"deepseek-ai/deepseek-llm-67b-chat": "deepseek-ai/deepseek-llm-67b-chat", | ||
"google/gemma-2-9b-it": "google/gemma-2-9b-it", | ||
"google/gemma-2b-it": "google/gemma-2-27b-it", | ||
"llava-hf/llava-v1.6-mistral-7b-hf": "llava-hf/llava-v1.6-mistral-7b-hf", | ||
"meta-llama/Llama-2-13b-chat-hf": "meta-llama/Llama-2-13b-chat-hf", | ||
"meta-llama/Llama-2-70b-hf": "meta-llama/Llama-2-70b-hf", | ||
"meta-llama/Llama-2-7b-chat-hf": "meta-llama/Llama-2-7b-chat-hf", | ||
"meta-llama/Llama-3.2-11B-Vision-Instruct": "meta-llama/Llama-Vision-Free", | ||
"meta-llama/Llama-3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct-Turbo", | ||
"meta-llama/Llama-3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo", | ||
"meta-llama/Llama-3.3-70B-Instruct": "meta-llama/Llama-3.3-70B-Instruct-Turbo", | ||
"meta-llama/Meta-Llama-3-70B-Instruct": "meta-llama/Llama-3-70b-chat-hf", | ||
"meta-llama/Meta-Llama-3-8B-Instruct": "togethercomputer/Llama-3-8b-chat-hf-int4", | ||
"meta-llama/Meta-Llama-3.1-405B-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo", | ||
"meta-llama/Meta-Llama-3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", | ||
"meta-llama/Meta-Llama-3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo-128K", | ||
"microsoft/WizardLM-2-8x22B": "microsoft/WizardLM-2-8x22B", | ||
"mistralai/Mistral-7B-Instruct-v0.3": "mistralai/Mistral-7B-Instruct-v0.3", | ||
"mistralai/Mixtral-8x22B-Instruct-v0.1": "mistralai/Mixtral-8x22B-Instruct-v0.1", | ||
"mistralai/Mixtral-8x7B-Instruct-v0.1": "mistralai/Mixtral-8x7B-Instruct-v0.1", | ||
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", | ||
"nvidia/Llama-3.1-Nemotron-70B-Instruct-HF": "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF", | ||
"Qwen/Qwen2-72B-Instruct": "Qwen/Qwen2-72B-Instruct", | ||
"Qwen/Qwen2.5-72B-Instruct": "Qwen/Qwen2.5-72B-Instruct-Turbo", | ||
"Qwen/Qwen2.5-7B-Instruct": "Qwen/Qwen2.5-7B-Instruct-Turbo", | ||
"Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen/Qwen2.5-Coder-32B-Instruct", | ||
"Qwen/QwQ-32B-Preview": "Qwen/QwQ-32B-Preview", | ||
"scb10x/llama-3-typhoon-v1.5-8b-instruct": "scb10x/scb10x-llama3-typhoon-v1-5-8b-instruct", | ||
"scb10x/llama-3-typhoon-v1.5x-70b-instruct-awq": "scb10x/scb10x-llama3-typhoon-v1-5x-4f316" | ||
}, | ||
"text-generation": { | ||
"meta-llama/Meta-Llama-3-8B": "meta-llama/Meta-Llama-3-8B", | ||
"mistralai/Mixtral-8x7B-v0.1": "mistralai/Mixtral-8x7B-v0.1" | ||
} | ||
}; | ||
// src/lib/isUrl.ts | ||
@@ -73,66 +146,40 @@ function isUrl(modelOrUrl) { | ||
// src/lib/getDefaultTask.ts | ||
var taskCache = /* @__PURE__ */ new Map(); | ||
var CACHE_DURATION = 10 * 60 * 1e3; | ||
var MAX_CACHE_ITEMS = 1e3; | ||
var HF_HUB_URL = "https://huggingface.co"; | ||
async function getDefaultTask(model, accessToken, options) { | ||
if (isUrl(model)) { | ||
return null; | ||
} | ||
const key = `${model}:${accessToken}`; | ||
let cachedTask = taskCache.get(key); | ||
if (cachedTask && cachedTask.date < new Date(Date.now() - CACHE_DURATION)) { | ||
taskCache.delete(key); | ||
cachedTask = void 0; | ||
} | ||
if (cachedTask === void 0) { | ||
const modelTask = await (options?.fetch ?? fetch)(`${HF_HUB_URL}/api/models/${model}?expand[]=pipeline_tag`, { | ||
headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {} | ||
}).then((resp) => resp.json()).then((json) => json.pipeline_tag).catch(() => null); | ||
if (!modelTask) { | ||
return null; | ||
} | ||
cachedTask = { task: modelTask, date: /* @__PURE__ */ new Date() }; | ||
taskCache.set(key, { task: modelTask, date: /* @__PURE__ */ new Date() }); | ||
if (taskCache.size > MAX_CACHE_ITEMS) { | ||
taskCache.delete(taskCache.keys().next().value); | ||
} | ||
} | ||
return cachedTask.task; | ||
} | ||
// src/lib/makeRequestOptions.ts | ||
var HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co"; | ||
var HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_HUB_URL}/api/inference-proxy/{{PROVIDER}}`; | ||
var tasks = null; | ||
async function makeRequestOptions(args, options) { | ||
const { accessToken, endpointUrl, ...otherArgs } = args; | ||
let { model } = args; | ||
const { | ||
forceTask: task, | ||
includeCredentials, | ||
taskHint, | ||
wait_for_model, | ||
use_cache, | ||
dont_load_model, | ||
chatCompletion: chatCompletion2 | ||
} = options ?? {}; | ||
const headers = {}; | ||
if (accessToken) { | ||
headers["Authorization"] = `Bearer ${accessToken}`; | ||
const { accessToken, endpointUrl, provider: maybeProvider, model: maybeModel, ...otherArgs } = args; | ||
const provider = maybeProvider ?? "hf-inference"; | ||
const { forceTask, includeCredentials, taskHint, wait_for_model, use_cache, dont_load_model, chatCompletion: chatCompletion2 } = options ?? {}; | ||
if (endpointUrl && provider !== "hf-inference") { | ||
throw new Error(`Cannot use endpointUrl with a third-party provider.`); | ||
} | ||
if (!model && !tasks && taskHint) { | ||
const res = await fetch(`${HF_HUB_URL}/api/tasks`); | ||
if (res.ok) { | ||
tasks = await res.json(); | ||
} | ||
if (forceTask && provider !== "hf-inference") { | ||
throw new Error(`Cannot use forceTask with a third-party provider.`); | ||
} | ||
if (!model && tasks && taskHint) { | ||
const taskInfo = tasks[taskHint]; | ||
if (taskInfo) { | ||
model = taskInfo.models[0].id; | ||
if (maybeModel && isUrl(maybeModel)) { | ||
throw new Error(`Model URLs are no longer supported. Use endpointUrl instead.`); | ||
} | ||
let model; | ||
if (!maybeModel) { | ||
if (taskHint) { | ||
model = mapModel({ model: await loadDefaultModel(taskHint), provider, taskHint, chatCompletion: chatCompletion2 }); | ||
} else { | ||
throw new Error("No model provided, and no default model found for this task"); | ||
} | ||
} else { | ||
model = mapModel({ model: maybeModel, provider, taskHint, chatCompletion: chatCompletion2 }); | ||
} | ||
if (!model) { | ||
throw new Error("No model provided, and no default model found for this task"); | ||
const authMethod = accessToken ? accessToken.startsWith("hf_") ? "hf-token" : "provider-key" : includeCredentials === "include" ? "credentials-include" : "none"; | ||
const url = endpointUrl ? chatCompletion2 ? endpointUrl + `/v1/chat/completions` : endpointUrl : makeUrl({ | ||
authMethod, | ||
chatCompletion: chatCompletion2 ?? false, | ||
forceTask, | ||
model, | ||
provider: provider ?? "hf-inference", | ||
taskHint | ||
}); | ||
const headers = {}; | ||
if (accessToken) { | ||
headers["Authorization"] = provider === "fal-ai" && authMethod === "provider-key" ? `Key ${accessToken}` : `Bearer ${accessToken}`; | ||
} | ||
@@ -143,30 +190,16 @@ const binary = "data" in args && !!args.data; | ||
} | ||
if (wait_for_model) { | ||
headers["X-Wait-For-Model"] = "true"; | ||
} | ||
if (use_cache === false) { | ||
headers["X-Use-Cache"] = "false"; | ||
} | ||
if (dont_load_model) { | ||
headers["X-Load-Model"] = "0"; | ||
} | ||
let url = (() => { | ||
if (endpointUrl && isUrl(model)) { | ||
throw new TypeError("Both model and endpointUrl cannot be URLs"); | ||
if (provider === "hf-inference") { | ||
if (wait_for_model) { | ||
headers["X-Wait-For-Model"] = "true"; | ||
} | ||
if (isUrl(model)) { | ||
console.warn("Using a model URL is deprecated, please use the `endpointUrl` parameter instead"); | ||
return model; | ||
if (use_cache === false) { | ||
headers["X-Use-Cache"] = "false"; | ||
} | ||
if (endpointUrl) { | ||
return endpointUrl; | ||
if (dont_load_model) { | ||
headers["X-Load-Model"] = "0"; | ||
} | ||
if (task) { | ||
return `${HF_INFERENCE_API_BASE_URL}/pipeline/${task}/${model}`; | ||
} | ||
return `${HF_INFERENCE_API_BASE_URL}/models/${model}`; | ||
})(); | ||
if (chatCompletion2 && !url.endsWith("/chat/completions")) { | ||
url += "/v1/chat/completions"; | ||
} | ||
if (provider === "replicate") { | ||
headers["Prefer"] = "wait"; | ||
} | ||
let credentials; | ||
@@ -178,2 +211,6 @@ if (typeof includeCredentials === "string") { | ||
} | ||
if (provider === "replicate" && model.includes(":")) { | ||
const version = model.split(":")[1]; | ||
otherArgs.version = version; | ||
} | ||
const info = { | ||
@@ -183,5 +220,6 @@ headers, | ||
body: binary ? args.data : JSON.stringify({ | ||
...otherArgs.model && isUrl(otherArgs.model) ? omit(otherArgs, "model") : otherArgs | ||
...otherArgs, | ||
...chatCompletion2 || provider === "together" ? { model } : void 0 | ||
}), | ||
...credentials && { credentials }, | ||
...credentials ? { credentials } : void 0, | ||
signal: options?.signal | ||
@@ -191,2 +229,90 @@ }; | ||
} | ||
function mapModel(params) { | ||
if (params.provider === "hf-inference") { | ||
return params.model; | ||
} | ||
if (!params.taskHint) { | ||
throw new Error("taskHint must be specified when using a third-party provider"); | ||
} | ||
const task = params.taskHint === "text-generation" && params.chatCompletion ? "conversational" : params.taskHint; | ||
const model = (() => { | ||
switch (params.provider) { | ||
case "fal-ai": | ||
return FAL_AI_SUPPORTED_MODEL_IDS[task]?.[params.model]; | ||
case "replicate": | ||
return REPLICATE_SUPPORTED_MODEL_IDS[task]?.[params.model]; | ||
case "sambanova": | ||
return SAMBANOVA_SUPPORTED_MODEL_IDS[task]?.[params.model]; | ||
case "together": | ||
return TOGETHER_SUPPORTED_MODEL_IDS[task]?.[params.model]; | ||
} | ||
})(); | ||
if (!model) { | ||
throw new Error(`Model ${params.model} is not supported for task ${task} and provider ${params.provider}`); | ||
} | ||
return model; | ||
} | ||
function makeUrl(params) { | ||
if (params.authMethod === "none" && params.provider !== "hf-inference") { | ||
throw new Error("Authentication is required when requesting a third-party provider. Please provide accessToken"); | ||
} | ||
const shouldProxy = params.provider !== "hf-inference" && params.authMethod !== "provider-key"; | ||
switch (params.provider) { | ||
case "fal-ai": { | ||
const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : FAL_AI_API_BASE_URL; | ||
return `${baseUrl}/${params.model}`; | ||
} | ||
case "replicate": { | ||
const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : REPLICATE_API_BASE_URL; | ||
if (params.model.includes(":")) { | ||
return `${baseUrl}/v1/predictions`; | ||
} | ||
return `${baseUrl}/v1/models/${params.model}/predictions`; | ||
} | ||
case "sambanova": { | ||
const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : SAMBANOVA_API_BASE_URL; | ||
if (params.taskHint === "text-generation" && params.chatCompletion) { | ||
return `${baseUrl}/v1/chat/completions`; | ||
} | ||
return baseUrl; | ||
} | ||
case "together": { | ||
const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) : TOGETHER_API_BASE_URL; | ||
if (params.taskHint === "text-to-image") { | ||
return `${baseUrl}/v1/images/generations`; | ||
} | ||
if (params.taskHint === "text-generation") { | ||
if (params.chatCompletion) { | ||
return `${baseUrl}/v1/chat/completions`; | ||
} | ||
return `${baseUrl}/v1/completions`; | ||
} | ||
return baseUrl; | ||
} | ||
default: { | ||
const url = params.forceTask ? `${HF_INFERENCE_API_URL}/pipeline/${params.forceTask}/${params.model}` : `${HF_INFERENCE_API_URL}/models/${params.model}`; | ||
if (params.taskHint === "text-generation" && params.chatCompletion) { | ||
return url + `/v1/chat/completions`; | ||
} | ||
return url; | ||
} | ||
} | ||
} | ||
async function loadDefaultModel(task) { | ||
if (!tasks) { | ||
tasks = await loadTaskInfo(); | ||
} | ||
const taskInfo = tasks[task]; | ||
if ((taskInfo?.models.length ?? 0) <= 0) { | ||
throw new Error(`No default model defined for task ${task}, please define the model explicitly.`); | ||
} | ||
return taskInfo.models[0].id; | ||
} | ||
async function loadTaskInfo() { | ||
const res = await fetch(`${HF_HUB_URL}/api/tasks`); | ||
if (!res.ok) { | ||
throw new Error("Failed to load tasks definitions from Hugging Face Hub."); | ||
} | ||
return await res.json(); | ||
} | ||
@@ -204,12 +330,18 @@ // src/tasks/custom/request.ts | ||
if (!response.ok) { | ||
if (response.headers.get("Content-Type")?.startsWith("application/json")) { | ||
const contentType = response.headers.get("Content-Type"); | ||
if (["application/json", "application/problem+json"].some((ct) => contentType?.startsWith(ct))) { | ||
const output = await response.json(); | ||
if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) { | ||
throw new Error(`Server ${args.model} does not seem to support chat completion. Error: ${output.error}`); | ||
throw new Error( | ||
`Server ${args.model} does not seem to support chat completion. Error: ${JSON.stringify(output.error)}` | ||
); | ||
} | ||
if (output.error) { | ||
throw new Error(JSON.stringify(output.error)); | ||
if (output.error || output.detail) { | ||
throw new Error(JSON.stringify(output.error ?? output.detail)); | ||
} else { | ||
throw new Error(output); | ||
} | ||
} | ||
throw new Error("An error occurred while fetching the blob"); | ||
const message = contentType?.startsWith("text/plain;") ? await response.text() : void 0; | ||
throw new Error(message ?? "An error occurred while fetching the blob"); | ||
} | ||
@@ -337,5 +469,8 @@ if (response.headers.get("Content-Type")?.startsWith("application/json")) { | ||
} | ||
if (output.error) { | ||
if (typeof output.error === "string") { | ||
throw new Error(output.error); | ||
} | ||
if (output.error && "message" in output.error && typeof output.error.message === "string") { | ||
throw new Error(output.error.message); | ||
} | ||
} | ||
@@ -369,4 +504,5 @@ throw new Error(`Server response contains error: ${response.status}`); | ||
const { done, value } = await reader.read(); | ||
if (done) | ||
if (done) { | ||
return; | ||
} | ||
onChunk(value); | ||
@@ -380,3 +516,4 @@ for (const event of events) { | ||
if (typeof data === "object" && data !== null && "error" in data) { | ||
throw new Error(data.error); | ||
const errorStr = typeof data.error === "string" ? data.error : typeof data.error === "object" && data.error && "message" in data.error && typeof data.error.message === "string" ? data.error.message : JSON.stringify(data.error); | ||
throw new Error(`Error forwarded from backend: ` + errorStr); | ||
} | ||
@@ -416,4 +553,25 @@ yield data; | ||
// src/utils/base64FromBytes.ts | ||
function base64FromBytes(arr) { | ||
if (globalThis.Buffer) { | ||
return globalThis.Buffer.from(arr).toString("base64"); | ||
} else { | ||
const bin = []; | ||
arr.forEach((byte) => { | ||
bin.push(String.fromCharCode(byte)); | ||
}); | ||
return globalThis.btoa(bin.join("")); | ||
} | ||
} | ||
// src/tasks/audio/automaticSpeechRecognition.ts | ||
async function automaticSpeechRecognition(args, options) { | ||
if (args.provider === "fal-ai") { | ||
const contentType = args.data instanceof Blob ? args.data.type : "audio/mpeg"; | ||
const base64audio = base64FromBytes( | ||
new Uint8Array(args.data instanceof ArrayBuffer ? args.data : await args.data.arrayBuffer()) | ||
); | ||
args.audio_url = `data:${contentType};base64,${base64audio}`; | ||
delete args.data; | ||
} | ||
const res = await request(args, { | ||
@@ -515,2 +673,10 @@ ...options, | ||
async function textToImage(args, options) { | ||
if (args.provider === "together" || args.provider === "fal-ai") { | ||
args.prompt = args.inputs; | ||
args.inputs = ""; | ||
args.response_format = "base64"; | ||
} else if (args.provider === "replicate") { | ||
args.input = { prompt: args.inputs }; | ||
delete args.inputs; | ||
} | ||
const res = await request(args, { | ||
@@ -520,2 +686,19 @@ ...options, | ||
}); | ||
if (res && typeof res === "object") { | ||
if (args.provider === "fal-ai" && "images" in res && Array.isArray(res.images) && res.images[0].url) { | ||
const image = await fetch(res.images[0].url); | ||
return await image.blob(); | ||
} | ||
if ("data" in res && Array.isArray(res.data) && res.data[0].b64_json) { | ||
const base64Data = res.data[0].b64_json; | ||
const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`); | ||
const blob = await base64Response.blob(); | ||
return blob; | ||
} | ||
if ("output" in res && Array.isArray(res.output)) { | ||
const urlResponse = await fetch(res.output[0]); | ||
const blob = await urlResponse.blob(); | ||
return blob; | ||
} | ||
} | ||
const isValidOutput = res && res instanceof Blob; | ||
@@ -528,15 +711,2 @@ if (!isValidOutput) { | ||
// src/utils/base64FromBytes.ts | ||
function base64FromBytes(arr) { | ||
if (globalThis.Buffer) { | ||
return globalThis.Buffer.from(arr).toString("base64"); | ||
} else { | ||
const bin = []; | ||
arr.forEach((byte) => { | ||
bin.push(String.fromCharCode(byte)); | ||
}); | ||
return globalThis.btoa(bin.join("")); | ||
} | ||
} | ||
// src/tasks/cv/imageToImage.ts | ||
@@ -593,2 +763,32 @@ async function imageToImage(args, options) { | ||
// src/lib/getDefaultTask.ts | ||
var taskCache = /* @__PURE__ */ new Map(); | ||
var CACHE_DURATION = 10 * 60 * 1e3; | ||
var MAX_CACHE_ITEMS = 1e3; | ||
async function getDefaultTask(model, accessToken, options) { | ||
if (isUrl(model)) { | ||
return null; | ||
} | ||
const key = `${model}:${accessToken}`; | ||
let cachedTask = taskCache.get(key); | ||
if (cachedTask && cachedTask.date < new Date(Date.now() - CACHE_DURATION)) { | ||
taskCache.delete(key); | ||
cachedTask = void 0; | ||
} | ||
if (cachedTask === void 0) { | ||
const modelTask = await (options?.fetch ?? fetch)(`${HF_HUB_URL}/api/models/${model}?expand[]=pipeline_tag`, { | ||
headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {} | ||
}).then((resp) => resp.json()).then((json) => json.pipeline_tag).catch(() => null); | ||
if (!modelTask) { | ||
return null; | ||
} | ||
cachedTask = { task: modelTask, date: /* @__PURE__ */ new Date() }; | ||
taskCache.set(key, { task: modelTask, date: /* @__PURE__ */ new Date() }); | ||
if (taskCache.size > MAX_CACHE_ITEMS) { | ||
taskCache.delete(taskCache.keys().next().value); | ||
} | ||
} | ||
return cachedTask.task; | ||
} | ||
// src/tasks/nlp/featureExtraction.ts | ||
@@ -715,13 +915,29 @@ async function featureExtraction(args, options) { | ||
async function textGeneration(args, options) { | ||
const res = toArray( | ||
await request(args, { | ||
if (args.provider === "together") { | ||
args.prompt = args.inputs; | ||
const raw = await request(args, { | ||
...options, | ||
taskHint: "text-generation" | ||
}) | ||
); | ||
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.generated_text === "string"); | ||
if (!isValidOutput) { | ||
throw new InferenceOutputError("Expected Array<{generated_text: string}>"); | ||
}); | ||
const isValidOutput = typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string"; | ||
if (!isValidOutput) { | ||
throw new InferenceOutputError("Expected ChatCompletionOutput"); | ||
} | ||
const completion = raw.choices[0]; | ||
return { | ||
generated_text: completion.text | ||
}; | ||
} else { | ||
const res = toArray( | ||
await request(args, { | ||
...options, | ||
taskHint: "text-generation" | ||
}) | ||
); | ||
const isValidOutput = Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string"); | ||
if (!isValidOutput) { | ||
throw new InferenceOutputError("Expected Array<{generated_text: string}>"); | ||
} | ||
return res?.[0]; | ||
} | ||
return res?.[0]; | ||
} | ||
@@ -793,3 +1009,4 @@ | ||
}); | ||
const isValidOutput = typeof res === "object" && Array.isArray(res?.choices) && typeof res?.created === "number" && typeof res?.id === "string" && typeof res?.model === "string" && typeof res?.system_fingerprint === "string" && typeof res?.usage === "object"; | ||
const isValidOutput = typeof res === "object" && Array.isArray(res?.choices) && typeof res?.created === "number" && typeof res?.id === "string" && typeof res?.model === "string" && /// Together.ai does not output a system_fingerprint | ||
(res.system_fingerprint === void 0 || typeof res.system_fingerprint === "string") && typeof res?.usage === "object"; | ||
if (!isValidOutput) { | ||
@@ -927,6 +1144,14 @@ throw new InferenceOutputError("Expected ChatCompletionOutput"); | ||
}; | ||
// src/types.ts | ||
var INFERENCE_PROVIDERS = ["fal-ai", "replicate", "sambanova", "together", "hf-inference"]; | ||
export { | ||
FAL_AI_SUPPORTED_MODEL_IDS, | ||
HfInference, | ||
HfInferenceEndpoint, | ||
INFERENCE_PROVIDERS, | ||
InferenceOutputError, | ||
REPLICATE_SUPPORTED_MODEL_IDS, | ||
SAMBANOVA_SUPPORTED_MODEL_IDS, | ||
TOGETHER_SUPPORTED_MODEL_IDS, | ||
audioClassification, | ||
@@ -933,0 +1158,0 @@ audioToAudio, |
@@ -0,5 +1,10 @@ | ||
export type { ProviderMapping } from "./providers/types"; | ||
export { HfInference, HfInferenceEndpoint } from "./HfInference"; | ||
export { InferenceOutputError } from "./lib/InferenceOutputError"; | ||
export { FAL_AI_SUPPORTED_MODEL_IDS } from "./providers/fal-ai"; | ||
export { REPLICATE_SUPPORTED_MODEL_IDS } from "./providers/replicate"; | ||
export { SAMBANOVA_SUPPORTED_MODEL_IDS } from "./providers/sambanova"; | ||
export { TOGETHER_SUPPORTED_MODEL_IDS } from "./providers/together"; | ||
export * from "./types"; | ||
export * from "./tasks"; | ||
//# sourceMappingURL=index.d.ts.map |
@@ -1,2 +0,1 @@ | ||
export declare const HF_HUB_URL = "https://huggingface.co"; | ||
export interface DefaultTaskOptions { | ||
@@ -3,0 +2,0 @@ fetch?: typeof fetch; |
import type { InferenceTask, Options, RequestArgs } from "../../types"; | ||
/** | ||
* Primitive to make custom calls to Inference Endpoints | ||
* Primitive to make custom calls to the inference provider | ||
*/ | ||
@@ -5,0 +5,0 @@ export declare function request<T>(args: RequestArgs, options?: Options & { |
@@ -7,2 +7,10 @@ import type { BaseArgs, Options } from "../../types"; | ||
inputs: string; | ||
/** | ||
* Same param but for external providers like Together, Replicate | ||
*/ | ||
prompt?: string; | ||
response_format?: "base64"; | ||
input?: { | ||
prompt: string; | ||
}; | ||
parameters?: { | ||
@@ -9,0 +17,0 @@ /** |
import type { PipelineType } from "@huggingface/tasks"; | ||
import type { ChatCompletionInput } from "@huggingface/tasks"; | ||
/** | ||
* HF model id, like "meta-llama/Llama-3.3-70B-Instruct" | ||
*/ | ||
export type ModelId = string; | ||
export interface Options { | ||
@@ -38,2 +42,4 @@ /** | ||
export type InferenceTask = Exclude<PipelineType, "other">; | ||
export declare const INFERENCE_PROVIDERS: readonly ["fal-ai", "replicate", "sambanova", "together", "hf-inference"]; | ||
export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number]; | ||
export interface BaseArgs { | ||
@@ -44,6 +50,8 @@ /** | ||
* Can be created for free in hf.co/settings/token | ||
* | ||
* You can also pass an external Inference provider's key if you intend to call a compatible provider like Sambanova, Together, Replicate... | ||
*/ | ||
accessToken?: string; | ||
/** | ||
* The model to use. | ||
* The HF model to use. | ||
* | ||
@@ -55,3 +63,3 @@ * If not specified, will call huggingface.co/api/tasks to get the default model for the task. | ||
*/ | ||
model?: string; | ||
model?: ModelId; | ||
/** | ||
@@ -63,2 +71,8 @@ * The URL of the endpoint to use. If not specified, will call huggingface.co/api/tasks to get the default endpoint for the task. | ||
endpointUrl?: string; | ||
/** | ||
* Set an Inference provider to run this model on. | ||
* | ||
* Defaults to the first provider in your user settings that is compatible with this model. | ||
*/ | ||
provider?: InferenceProvider; | ||
} | ||
@@ -65,0 +79,0 @@ export type RequestArgs = BaseArgs & ({ |
{ | ||
"name": "@huggingface/inference", | ||
"version": "2.8.1", | ||
"version": "3.0.0", | ||
"packageManager": "pnpm@8.10.5", | ||
@@ -42,3 +42,3 @@ "license": "MIT", | ||
"dependencies": { | ||
"@huggingface/tasks": "^0.12.9" | ||
"@huggingface/tasks": "^0.13.16" | ||
}, | ||
@@ -45,0 +45,0 @@ "devDependencies": { |
# 🤗 Hugging Face Inference Endpoints | ||
A Typescript powered wrapper for the Hugging Face Inference Endpoints API. Learn more about Inference Endpoints at [Hugging Face](https://huggingface.co/inference-endpoints). | ||
It works with both [Inference API (serverless)](https://huggingface.co/docs/api-inference/index) and [Inference Endpoints (dedicated)](https://huggingface.co/docs/inference-endpoints/index). | ||
A Typescript powered wrapper for the Hugging Face Inference API (serverless), Inference Endpoints (dedicated), and third-party Inference Providers. | ||
It works with [Inference API (serverless)](https://huggingface.co/docs/api-inference/index) and [Inference Endpoints (dedicated)](https://huggingface.co/docs/inference-endpoints/index), and even with supported third-party Inference Providers. | ||
@@ -45,2 +45,30 @@ Check out the [full documentation](https://huggingface.co/docs/huggingface.js/inference/README). | ||
### Requesting third-party inference providers | ||
You can request inference from third-party providers with the inference client. | ||
Currently, we support the following providers: [Fal.ai](https://fal.ai), [Replicate](https://replicate.com), [Together](https://together.xyz) and [Sambanova](https://sambanova.ai). | ||
To make request to a third-party provider, you have to pass the `provider` parameter to the inference function. Make sure your request is authenticated with an access token. | ||
```ts | ||
const accessToken = "hf_..."; // Either a HF access token, or an API key from the 3rd party provider (Replicate in this example) | ||
const client = new HfInference(accessToken); | ||
await client.textToImage({ | ||
provider: "replicate", | ||
model:"black-forest-labs/Flux.1-dev", | ||
inputs: "A black forest cake" | ||
}) | ||
``` | ||
When authenticated with a Hugging Face access token, the request is routed through https://huggingface.co. | ||
When authenticated with a third-party provider key, the request is made directly against that provider's inference API. | ||
Only a subset of models are supported when requesting 3rd party providers. You can check the list of supported models per pipeline tasks here: | ||
- [Fal.ai supported models](./src/providers/fal-ai.ts) | ||
- [Replicate supported models](./src/providers/replicate.ts) | ||
- [Sambanova supported models](./src/providers/sambanova.ts) | ||
- [Together supported models](./src/providers/together.ts) | ||
- [HF Inference API (serverless)](https://huggingface.co/models?inference=warm&sort=trending) | ||
#### Tree-shaking | ||
@@ -95,7 +123,6 @@ | ||
const out = await hf.chatCompletion({ | ||
model: "mistralai/Mistral-7B-Instruct-v0.2", | ||
messages: [{ role: "user", content: "Complete the this sentence with words one plus one is equal " }], | ||
max_tokens: 500, | ||
model: "meta-llama/Llama-3.1-8B-Instruct", | ||
messages: [{ role: "user", content: "Hello, nice to meet you!" }], | ||
max_tokens: 512, | ||
temperature: 0.1, | ||
seed: 0, | ||
}); | ||
@@ -106,9 +133,8 @@ | ||
for await (const chunk of hf.chatCompletionStream({ | ||
model: "mistralai/Mistral-7B-Instruct-v0.2", | ||
model: "meta-llama/Llama-3.1-8B-Instruct", | ||
messages: [ | ||
{ role: "user", content: "Complete the equation 1+1= ,just the answer" }, | ||
{ role: "user", content: "Can you help me solve an equation?" }, | ||
], | ||
max_tokens: 500, | ||
max_tokens: 512, | ||
temperature: 0.1, | ||
seed: 0, | ||
})) { | ||
@@ -402,7 +428,4 @@ if (chunk.choices && chunk.choices.length > 0) { | ||
await hf.textToImage({ | ||
inputs: 'award winning high resolution photo of a giant tortoise/((ladybird)) hybrid, [trending on artstation]', | ||
model: 'stabilityai/stable-diffusion-2', | ||
parameters: { | ||
negative_prompt: 'blurry', | ||
} | ||
model: 'black-forest-labs/FLUX.1-dev', | ||
inputs: 'a picture of a green bird' | ||
}) | ||
@@ -590,3 +613,3 @@ ``` | ||
const ep = hf.endpoint( | ||
"https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2" | ||
"https://api-inference.huggingface.co/models/meta-llama/Llama-3.1-8B-Instruct" | ||
); | ||
@@ -593,0 +616,0 @@ const stream = ep.chatCompletionStream({ |
@@ -0,4 +1,9 @@ | ||
export type { ProviderMapping } from "./providers/types"; | ||
export { HfInference, HfInferenceEndpoint } from "./HfInference"; | ||
export { InferenceOutputError } from "./lib/InferenceOutputError"; | ||
export { FAL_AI_SUPPORTED_MODEL_IDS } from "./providers/fal-ai"; | ||
export { REPLICATE_SUPPORTED_MODEL_IDS } from "./providers/replicate"; | ||
export { SAMBANOVA_SUPPORTED_MODEL_IDS } from "./providers/sambanova"; | ||
export { TOGETHER_SUPPORTED_MODEL_IDS } from "./providers/together"; | ||
export * from "./types"; | ||
export * from "./tasks"; |
@@ -0,1 +1,2 @@ | ||
import { HF_HUB_URL } from "../config"; | ||
import { isUrl } from "./isUrl"; | ||
@@ -11,3 +12,2 @@ | ||
const MAX_CACHE_ITEMS = 1000; | ||
export const HF_HUB_URL = "https://huggingface.co"; | ||
@@ -14,0 +14,0 @@ export interface DefaultTaskOptions { |
@@ -0,10 +1,16 @@ | ||
import type { WidgetType } from "@huggingface/tasks"; | ||
import { HF_HUB_URL, HF_INFERENCE_API_URL } from "../config"; | ||
import { FAL_AI_API_BASE_URL, FAL_AI_SUPPORTED_MODEL_IDS } from "../providers/fal-ai"; | ||
import { REPLICATE_API_BASE_URL, REPLICATE_SUPPORTED_MODEL_IDS } from "../providers/replicate"; | ||
import { SAMBANOVA_API_BASE_URL, SAMBANOVA_SUPPORTED_MODEL_IDS } from "../providers/sambanova"; | ||
import { TOGETHER_API_BASE_URL, TOGETHER_SUPPORTED_MODEL_IDS } from "../providers/together"; | ||
import type { InferenceProvider } from "../types"; | ||
import type { InferenceTask, Options, RequestArgs } from "../types"; | ||
import { omit } from "../utils/omit"; | ||
import { HF_HUB_URL } from "./getDefaultTask"; | ||
import { isUrl } from "./isUrl"; | ||
const HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co"; | ||
const HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_HUB_URL}/api/inference-proxy/{{PROVIDER}}`; | ||
/** | ||
* Loaded from huggingface.co/api/tasks if needed | ||
* Lazy-loaded from huggingface.co/api/tasks when needed | ||
* Used to determine the default model to use when it's not user defined | ||
*/ | ||
@@ -29,36 +35,56 @@ let tasks: Record<string, { models: { id: string }[] }> | null = null; | ||
): Promise<{ url: string; info: RequestInit }> { | ||
const { accessToken, endpointUrl, ...otherArgs } = args; | ||
let { model } = args; | ||
const { | ||
forceTask: task, | ||
includeCredentials, | ||
taskHint, | ||
wait_for_model, | ||
use_cache, | ||
dont_load_model, | ||
chatCompletion, | ||
} = options ?? {}; | ||
const { accessToken, endpointUrl, provider: maybeProvider, model: maybeModel, ...otherArgs } = args; | ||
const provider = maybeProvider ?? "hf-inference"; | ||
const headers: Record<string, string> = {}; | ||
if (accessToken) { | ||
headers["Authorization"] = `Bearer ${accessToken}`; | ||
} | ||
const { forceTask, includeCredentials, taskHint, wait_for_model, use_cache, dont_load_model, chatCompletion } = | ||
options ?? {}; | ||
if (!model && !tasks && taskHint) { | ||
const res = await fetch(`${HF_HUB_URL}/api/tasks`); | ||
if (res.ok) { | ||
tasks = await res.json(); | ||
} | ||
if (endpointUrl && provider !== "hf-inference") { | ||
throw new Error(`Cannot use endpointUrl with a third-party provider.`); | ||
} | ||
if (forceTask && provider !== "hf-inference") { | ||
throw new Error(`Cannot use forceTask with a third-party provider.`); | ||
} | ||
if (maybeModel && isUrl(maybeModel)) { | ||
throw new Error(`Model URLs are no longer supported. Use endpointUrl instead.`); | ||
} | ||
if (!model && tasks && taskHint) { | ||
const taskInfo = tasks[taskHint]; | ||
if (taskInfo) { | ||
model = taskInfo.models[0].id; | ||
let model: string; | ||
if (!maybeModel) { | ||
if (taskHint) { | ||
model = mapModel({ model: await loadDefaultModel(taskHint), provider, taskHint, chatCompletion }); | ||
} else { | ||
throw new Error("No model provided, and no default model found for this task"); | ||
/// TODO : change error message ^ | ||
} | ||
} else { | ||
model = mapModel({ model: maybeModel, provider, taskHint, chatCompletion }); | ||
} | ||
if (!model) { | ||
throw new Error("No model provided, and no default model found for this task"); | ||
/// If accessToken is passed, it should take precedence over includeCredentials | ||
const authMethod = accessToken | ||
? accessToken.startsWith("hf_") | ||
? "hf-token" | ||
: "provider-key" | ||
: includeCredentials === "include" | ||
? "credentials-include" | ||
: "none"; | ||
const url = endpointUrl | ||
? chatCompletion | ||
? endpointUrl + `/v1/chat/completions` | ||
: endpointUrl | ||
: makeUrl({ | ||
authMethod, | ||
chatCompletion: chatCompletion ?? false, | ||
forceTask, | ||
model, | ||
provider: provider ?? "hf-inference", | ||
taskHint, | ||
}); | ||
const headers: Record<string, string> = {}; | ||
if (accessToken) { | ||
headers["Authorization"] = | ||
provider === "fal-ai" && authMethod === "provider-key" ? `Key ${accessToken}` : `Bearer ${accessToken}`; | ||
} | ||
@@ -72,32 +98,16 @@ | ||
if (wait_for_model) { | ||
headers["X-Wait-For-Model"] = "true"; | ||
} | ||
if (use_cache === false) { | ||
headers["X-Use-Cache"] = "false"; | ||
} | ||
if (dont_load_model) { | ||
headers["X-Load-Model"] = "0"; | ||
} | ||
let url = (() => { | ||
if (endpointUrl && isUrl(model)) { | ||
throw new TypeError("Both model and endpointUrl cannot be URLs"); | ||
if (provider === "hf-inference") { | ||
if (wait_for_model) { | ||
headers["X-Wait-For-Model"] = "true"; | ||
} | ||
if (isUrl(model)) { | ||
console.warn("Using a model URL is deprecated, please use the `endpointUrl` parameter instead"); | ||
return model; | ||
if (use_cache === false) { | ||
headers["X-Use-Cache"] = "false"; | ||
} | ||
if (endpointUrl) { | ||
return endpointUrl; | ||
if (dont_load_model) { | ||
headers["X-Load-Model"] = "0"; | ||
} | ||
if (task) { | ||
return `${HF_INFERENCE_API_BASE_URL}/pipeline/${task}/${model}`; | ||
} | ||
} | ||
return `${HF_INFERENCE_API_BASE_URL}/models/${model}`; | ||
})(); | ||
if (chatCompletion && !url.endsWith("/chat/completions")) { | ||
url += "/v1/chat/completions"; | ||
if (provider === "replicate") { | ||
headers["Prefer"] = "wait"; | ||
} | ||
@@ -115,2 +125,10 @@ | ||
/* | ||
* Versioned Replicate models in the format `owner/model:version` expect the version in the body | ||
*/ | ||
if (provider === "replicate" && model.includes(":")) { | ||
const version = model.split(":")[1]; | ||
(otherArgs as typeof otherArgs & { version: string }).version = version; | ||
} | ||
const info: RequestInit = { | ||
@@ -122,5 +140,6 @@ headers, | ||
: JSON.stringify({ | ||
...(otherArgs.model && isUrl(otherArgs.model) ? omit(otherArgs, "model") : otherArgs), | ||
...otherArgs, | ||
...(chatCompletion || provider === "together" ? { model } : undefined), | ||
}), | ||
...(credentials && { credentials }), | ||
...(credentials ? { credentials } : undefined), | ||
signal: options?.signal, | ||
@@ -131,1 +150,122 @@ }; | ||
} | ||
function mapModel(params: { | ||
model: string; | ||
provider: InferenceProvider; | ||
taskHint: InferenceTask | undefined; | ||
chatCompletion: boolean | undefined; | ||
}): string { | ||
if (params.provider === "hf-inference") { | ||
return params.model; | ||
} | ||
if (!params.taskHint) { | ||
throw new Error("taskHint must be specified when using a third-party provider"); | ||
} | ||
const task: WidgetType = | ||
params.taskHint === "text-generation" && params.chatCompletion ? "conversational" : params.taskHint; | ||
const model = (() => { | ||
switch (params.provider) { | ||
case "fal-ai": | ||
return FAL_AI_SUPPORTED_MODEL_IDS[task]?.[params.model]; | ||
case "replicate": | ||
return REPLICATE_SUPPORTED_MODEL_IDS[task]?.[params.model]; | ||
case "sambanova": | ||
return SAMBANOVA_SUPPORTED_MODEL_IDS[task]?.[params.model]; | ||
case "together": | ||
return TOGETHER_SUPPORTED_MODEL_IDS[task]?.[params.model]; | ||
} | ||
})(); | ||
if (!model) { | ||
throw new Error(`Model ${params.model} is not supported for task ${task} and provider ${params.provider}`); | ||
} | ||
return model; | ||
} | ||
function makeUrl(params: { | ||
authMethod: "none" | "hf-token" | "credentials-include" | "provider-key"; | ||
chatCompletion: boolean; | ||
model: string; | ||
provider: InferenceProvider; | ||
taskHint: InferenceTask | undefined; | ||
forceTask?: string | InferenceTask; | ||
}): string { | ||
if (params.authMethod === "none" && params.provider !== "hf-inference") { | ||
throw new Error("Authentication is required when requesting a third-party provider. Please provide accessToken"); | ||
} | ||
const shouldProxy = params.provider !== "hf-inference" && params.authMethod !== "provider-key"; | ||
switch (params.provider) { | ||
case "fal-ai": { | ||
const baseUrl = shouldProxy | ||
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) | ||
: FAL_AI_API_BASE_URL; | ||
return `${baseUrl}/${params.model}`; | ||
} | ||
case "replicate": { | ||
const baseUrl = shouldProxy | ||
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) | ||
: REPLICATE_API_BASE_URL; | ||
if (params.model.includes(":")) { | ||
/// Versioned model | ||
return `${baseUrl}/v1/predictions`; | ||
} | ||
/// Evergreen / Canonical model | ||
return `${baseUrl}/v1/models/${params.model}/predictions`; | ||
} | ||
case "sambanova": { | ||
const baseUrl = shouldProxy | ||
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) | ||
: SAMBANOVA_API_BASE_URL; | ||
/// Sambanova API matches OpenAI-like APIs: model is defined in the request body | ||
if (params.taskHint === "text-generation" && params.chatCompletion) { | ||
return `${baseUrl}/v1/chat/completions`; | ||
} | ||
return baseUrl; | ||
} | ||
case "together": { | ||
const baseUrl = shouldProxy | ||
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) | ||
: TOGETHER_API_BASE_URL; | ||
/// Together API matches OpenAI-like APIs: model is defined in the request body | ||
if (params.taskHint === "text-to-image") { | ||
return `${baseUrl}/v1/images/generations`; | ||
} | ||
if (params.taskHint === "text-generation") { | ||
if (params.chatCompletion) { | ||
return `${baseUrl}/v1/chat/completions`; | ||
} | ||
return `${baseUrl}/v1/completions`; | ||
} | ||
return baseUrl; | ||
} | ||
default: { | ||
const url = params.forceTask | ||
? `${HF_INFERENCE_API_URL}/pipeline/${params.forceTask}/${params.model}` | ||
: `${HF_INFERENCE_API_URL}/models/${params.model}`; | ||
if (params.taskHint === "text-generation" && params.chatCompletion) { | ||
return url + `/v1/chat/completions`; | ||
} | ||
return url; | ||
} | ||
} | ||
} | ||
async function loadDefaultModel(task: InferenceTask): Promise<string> { | ||
if (!tasks) { | ||
tasks = await loadTaskInfo(); | ||
} | ||
const taskInfo = tasks[task]; | ||
if ((taskInfo?.models.length ?? 0) <= 0) { | ||
throw new Error(`No default model defined for task ${task}, please define the model explicitly.`); | ||
} | ||
return taskInfo.models[0].id; | ||
} | ||
async function loadTaskInfo(): Promise<Record<string, { models: { id: string }[] }>> { | ||
const res = await fetch(`${HF_HUB_URL}/api/tasks`); | ||
if (!res.ok) { | ||
throw new Error("Failed to load tasks definitions from Hugging Face Hub."); | ||
} | ||
return await res.json(); | ||
} |
import { InferenceOutputError } from "../../lib/InferenceOutputError"; | ||
import type { BaseArgs, Options } from "../../types"; | ||
import type { BaseArgs, Options, RequestArgs } from "../../types"; | ||
import { base64FromBytes } from "../../utils/base64FromBytes"; | ||
import { request } from "../custom/request"; | ||
@@ -27,2 +28,10 @@ | ||
): Promise<AutomaticSpeechRecognitionOutput> { | ||
if (args.provider === "fal-ai") { | ||
const contentType = args.data instanceof Blob ? args.data.type : "audio/mpeg"; | ||
const base64audio = base64FromBytes( | ||
new Uint8Array(args.data instanceof ArrayBuffer ? args.data : await args.data.arrayBuffer()) | ||
); | ||
(args as RequestArgs & { audio_url: string }).audio_url = `data:${contentType};base64,${base64audio}`; | ||
delete (args as RequestArgs & { data: unknown }).data; | ||
} | ||
const res = await request<AutomaticSpeechRecognitionOutput>(args, { | ||
@@ -29,0 +38,0 @@ ...options, |
@@ -5,3 +5,3 @@ import type { InferenceTask, Options, RequestArgs } from "../../types"; | ||
/** | ||
* Primitive to make custom calls to Inference Endpoints | ||
* Primitive to make custom calls to the inference provider | ||
*/ | ||
@@ -30,12 +30,18 @@ export async function request<T>( | ||
if (!response.ok) { | ||
if (response.headers.get("Content-Type")?.startsWith("application/json")) { | ||
const contentType = response.headers.get("Content-Type"); | ||
if (["application/json", "application/problem+json"].some((ct) => contentType?.startsWith(ct))) { | ||
const output = await response.json(); | ||
if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) { | ||
throw new Error(`Server ${args.model} does not seem to support chat completion. Error: ${output.error}`); | ||
throw new Error( | ||
`Server ${args.model} does not seem to support chat completion. Error: ${JSON.stringify(output.error)}` | ||
); | ||
} | ||
if (output.error) { | ||
throw new Error(JSON.stringify(output.error)); | ||
if (output.error || output.detail) { | ||
throw new Error(JSON.stringify(output.error ?? output.detail)); | ||
} else { | ||
throw new Error(output); | ||
} | ||
} | ||
throw new Error("An error occurred while fetching the blob"); | ||
const message = contentType?.startsWith("text/plain;") ? await response.text() : undefined; | ||
throw new Error(message ?? "An error occurred while fetching the blob"); | ||
} | ||
@@ -42,0 +48,0 @@ |
@@ -35,5 +35,9 @@ import type { InferenceTask, Options, RequestArgs } from "../../types"; | ||
} | ||
if (output.error) { | ||
if (typeof output.error === "string") { | ||
throw new Error(output.error); | ||
} | ||
if (output.error && "message" in output.error && typeof output.error.message === "string") { | ||
/// OpenAI errors | ||
throw new Error(output.error.message); | ||
} | ||
} | ||
@@ -72,3 +76,5 @@ | ||
const { done, value } = await reader.read(); | ||
if (done) return; | ||
if (done) { | ||
return; | ||
} | ||
onChunk(value); | ||
@@ -82,3 +88,12 @@ for (const event of events) { | ||
if (typeof data === "object" && data !== null && "error" in data) { | ||
throw new Error(data.error); | ||
const errorStr = | ||
typeof data.error === "string" | ||
? data.error | ||
: typeof data.error === "object" && | ||
data.error && | ||
"message" in data.error && | ||
typeof data.error.message === "string" | ||
? data.error.message | ||
: JSON.stringify(data.error); | ||
throw new Error(`Error forwarded from backend: ` + errorStr); | ||
} | ||
@@ -85,0 +100,0 @@ yield data as T; |
@@ -11,2 +11,11 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError"; | ||
/** | ||
* Same param but for external providers like Together, Replicate | ||
*/ | ||
prompt?: string; | ||
response_format?: "base64"; | ||
input?: { | ||
prompt: string; | ||
}; | ||
parameters?: { | ||
@@ -38,2 +47,11 @@ /** | ||
interface Base64ImageGeneration { | ||
data: Array<{ | ||
b64_json: string; | ||
}>; | ||
} | ||
interface OutputUrlImageGeneration { | ||
output: string[]; | ||
} | ||
/** | ||
@@ -44,6 +62,31 @@ * This task reads some text input and outputs an image. | ||
export async function textToImage(args: TextToImageArgs, options?: Options): Promise<TextToImageOutput> { | ||
const res = await request<TextToImageOutput>(args, { | ||
if (args.provider === "together" || args.provider === "fal-ai") { | ||
args.prompt = args.inputs; | ||
args.inputs = ""; | ||
args.response_format = "base64"; | ||
} else if (args.provider === "replicate") { | ||
args.input = { prompt: args.inputs }; | ||
delete (args as unknown as { inputs: unknown }).inputs; | ||
} | ||
const res = await request<TextToImageOutput | Base64ImageGeneration | OutputUrlImageGeneration>(args, { | ||
...options, | ||
taskHint: "text-to-image", | ||
}); | ||
if (res && typeof res === "object") { | ||
if (args.provider === "fal-ai" && "images" in res && Array.isArray(res.images) && res.images[0].url) { | ||
const image = await fetch(res.images[0].url); | ||
return await image.blob(); | ||
} | ||
if ("data" in res && Array.isArray(res.data) && res.data[0].b64_json) { | ||
const base64Data = res.data[0].b64_json; | ||
const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`); | ||
const blob = await base64Response.blob(); | ||
return blob; | ||
} | ||
if ("output" in res && Array.isArray(res.output)) { | ||
const urlResponse = await fetch(res.output[0]); | ||
const blob = await urlResponse.blob(); | ||
return blob; | ||
} | ||
} | ||
const isValidOutput = res && res instanceof Blob; | ||
@@ -50,0 +93,0 @@ if (!isValidOutput) { |
@@ -9,3 +9,2 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError"; | ||
*/ | ||
export async function chatCompletion( | ||
@@ -26,3 +25,4 @@ args: BaseArgs & ChatCompletionInput, | ||
typeof res?.model === "string" && | ||
typeof res?.system_fingerprint === "string" && | ||
/// Together.ai does not output a system_fingerprint | ||
(res.system_fingerprint === undefined || typeof res.system_fingerprint === "string") && | ||
typeof res?.usage === "object"; | ||
@@ -29,0 +29,0 @@ |
@@ -1,2 +0,7 @@ | ||
import type { TextGenerationInput, TextGenerationOutput } from "@huggingface/tasks"; | ||
import type { | ||
ChatCompletionOutput, | ||
TextGenerationInput, | ||
TextGenerationOutput, | ||
TextGenerationOutputFinishReason, | ||
} from "@huggingface/tasks"; | ||
import { InferenceOutputError } from "../../lib/InferenceOutputError"; | ||
@@ -9,2 +14,12 @@ import type { BaseArgs, Options } from "../../types"; | ||
interface TogeteherTextCompletionOutput extends Omit<ChatCompletionOutput, "choices"> { | ||
choices: Array<{ | ||
text: string; | ||
finish_reason: TextGenerationOutputFinishReason; | ||
seed: number; | ||
logprobs: unknown; | ||
index: number; | ||
}>; | ||
} | ||
/** | ||
@@ -17,13 +32,32 @@ * Use to continue text from a prompt. This is a very generic task. Recommended model: gpt2 (it’s a simple model, but fun to play with). | ||
): Promise<TextGenerationOutput> { | ||
const res = toArray( | ||
await request<TextGenerationOutput | TextGenerationOutput[]>(args, { | ||
if (args.provider === "together") { | ||
args.prompt = args.inputs; | ||
const raw = await request<TogeteherTextCompletionOutput>(args, { | ||
...options, | ||
taskHint: "text-generation", | ||
}) | ||
); | ||
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.generated_text === "string"); | ||
if (!isValidOutput) { | ||
throw new InferenceOutputError("Expected Array<{generated_text: string}>"); | ||
}); | ||
const isValidOutput = | ||
typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string"; | ||
if (!isValidOutput) { | ||
throw new InferenceOutputError("Expected ChatCompletionOutput"); | ||
} | ||
const completion = raw.choices[0]; | ||
return { | ||
generated_text: completion.text, | ||
}; | ||
} else { | ||
const res = toArray( | ||
await request<TextGenerationOutput | TextGenerationOutput[]>(args, { | ||
...options, | ||
taskHint: "text-generation", | ||
}) | ||
); | ||
const isValidOutput = | ||
Array.isArray(res) && res.every((x) => "generated_text" in x && typeof x?.generated_text === "string"); | ||
if (!isValidOutput) { | ||
throw new InferenceOutputError("Expected Array<{generated_text: string}>"); | ||
} | ||
return (res as TextGenerationOutput[])?.[0]; | ||
} | ||
return res?.[0]; | ||
} |
import type { PipelineType } from "@huggingface/tasks"; | ||
import type { ChatCompletionInput } from "@huggingface/tasks"; | ||
/** | ||
* HF model id, like "meta-llama/Llama-3.3-70B-Instruct" | ||
*/ | ||
export type ModelId = string; | ||
export interface Options { | ||
@@ -43,2 +48,5 @@ /** | ||
export const INFERENCE_PROVIDERS = ["fal-ai", "replicate", "sambanova", "together", "hf-inference"] as const; | ||
export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number]; | ||
export interface BaseArgs { | ||
@@ -49,6 +57,9 @@ /** | ||
* Can be created for free in hf.co/settings/token | ||
* | ||
* You can also pass an external Inference provider's key if you intend to call a compatible provider like Sambanova, Together, Replicate... | ||
*/ | ||
accessToken?: string; | ||
/** | ||
* The model to use. | ||
* The HF model to use. | ||
* | ||
@@ -60,3 +71,3 @@ * If not specified, will call huggingface.co/api/tasks to get the default model for the task. | ||
*/ | ||
model?: string; | ||
model?: ModelId; | ||
@@ -69,2 +80,9 @@ /** | ||
endpointUrl?: string; | ||
/** | ||
* Set an Inference provider to run this model on. | ||
* | ||
* Defaults to the first provider in your user settings that is compatible with this model. | ||
*/ | ||
provider?: InferenceProvider; | ||
} | ||
@@ -71,0 +89,0 @@ |
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
280792
121953
178
6160
657
14
+ Added@huggingface/tasks@0.13.16(transitive)
- Removed@huggingface/tasks@0.12.30(transitive)
Updated@huggingface/tasks@^0.13.16