@huggingface/inference
Advanced tools
@@ -100,7 +100,40 @@ /// <reference path="./index.d.ts" /> | ||
// src/lib/getDefaultTask.ts | ||
var taskCache = /* @__PURE__ */ new Map(); | ||
var CACHE_DURATION = 10 * 60 * 1e3; | ||
var MAX_CACHE_ITEMS = 1e3; | ||
var HF_HUB_URL = "https://huggingface.co"; | ||
async function getDefaultTask(model, accessToken) { | ||
if (isUrl(model)) { | ||
return null; | ||
} | ||
const key = `${model}:${accessToken}`; | ||
let cachedTask = taskCache.get(key); | ||
if (cachedTask && cachedTask.date < new Date(Date.now() - CACHE_DURATION)) { | ||
taskCache.delete(key); | ||
cachedTask = void 0; | ||
} | ||
if (cachedTask === void 0) { | ||
const modelTask = await fetch(`${HF_HUB_URL}/api/models/${model}?expand[]=pipeline_tag`, { | ||
headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {} | ||
}).then((resp) => resp.json()).then((json) => json.pipeline_tag).catch(() => null); | ||
if (!modelTask) { | ||
return null; | ||
} | ||
cachedTask = { task: modelTask, date: /* @__PURE__ */ new Date() }; | ||
taskCache.set(key, { task: modelTask, date: /* @__PURE__ */ new Date() }); | ||
if (taskCache.size > MAX_CACHE_ITEMS) { | ||
taskCache.delete(taskCache.keys().next().value); | ||
} | ||
} | ||
return cachedTask.task; | ||
} | ||
// src/lib/makeRequestOptions.ts | ||
var HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co"; | ||
function makeRequestOptions(args, options) { | ||
const { model, accessToken, ...otherArgs } = args; | ||
const { task, includeCredentials, ...otherOptions } = options ?? {}; | ||
var tasks = null; | ||
async function makeRequestOptions(args, options) { | ||
const { accessToken, model: _model, ...otherArgs } = args; | ||
let { model } = args; | ||
const { forceTask: task, includeCredentials, taskHint, ...otherOptions } = options ?? {}; | ||
const headers = {}; | ||
@@ -110,2 +143,17 @@ if (accessToken) { | ||
} | ||
if (!model && !tasks && taskHint) { | ||
const res = await fetch(`${HF_HUB_URL}/api/tasks`); | ||
if (res.ok) { | ||
tasks = await res.json(); | ||
} | ||
} | ||
if (!model && tasks && taskHint) { | ||
const taskInfo = tasks[taskHint]; | ||
if (taskInfo) { | ||
model = taskInfo.models[0].id; | ||
} | ||
} | ||
if (!model) { | ||
throw new Error("No model provided, and no default model found for this task"); | ||
} | ||
const binary = "data" in args && !!args.data; | ||
@@ -148,3 +196,3 @@ if (!binary) { | ||
async function request(args, options) { | ||
const { url, info } = makeRequestOptions(args, options); | ||
const { url, info } = await makeRequestOptions(args, options); | ||
const response = await (options?.fetch ?? fetch)(url, info); | ||
@@ -273,3 +321,3 @@ if (options?.retry_on_error !== false && response.status === 503 && !options?.wait_for_model) { | ||
async function* streamingRequest(args, options) { | ||
const { url, info } = makeRequestOptions({ ...args, stream: true }, options); | ||
const { url, info } = await makeRequestOptions({ ...args, stream: true }, options); | ||
const response = await (options?.fetch ?? fetch)(url, info); | ||
@@ -347,3 +395,6 @@ if (options?.retry_on_error !== false && response.status === 503 && !options?.wait_for_model) { | ||
async function audioClassification(args, options) { | ||
const res = await request(args, options); | ||
const res = await request(args, { | ||
...options, | ||
taskHint: "audio-classification" | ||
}); | ||
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number"); | ||
@@ -358,3 +409,6 @@ if (!isValidOutput) { | ||
async function automaticSpeechRecognition(args, options) { | ||
const res = await request(args, options); | ||
const res = await request(args, { | ||
...options, | ||
taskHint: "automatic-speech-recognition" | ||
}); | ||
const isValidOutput = typeof res?.text === "string"; | ||
@@ -369,3 +423,6 @@ if (!isValidOutput) { | ||
async function textToSpeech(args, options) { | ||
const res = await request(args, options); | ||
const res = await request(args, { | ||
...options, | ||
taskHint: "text-to-speech" | ||
}); | ||
const isValidOutput = res && res instanceof Blob; | ||
@@ -380,3 +437,6 @@ if (!isValidOutput) { | ||
async function audioToAudio(args, options) { | ||
const res = await request(args, options); | ||
const res = await request(args, { | ||
...options, | ||
taskHint: "audio-to-audio" | ||
}); | ||
const isValidOutput = Array.isArray(res) && res.every( | ||
@@ -393,3 +453,6 @@ (x) => typeof x.label === "string" && typeof x.blob === "string" && typeof x["content-type"] === "string" | ||
async function imageClassification(args, options) { | ||
const res = await request(args, options); | ||
const res = await request(args, { | ||
...options, | ||
taskHint: "image-classification" | ||
}); | ||
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number"); | ||
@@ -404,3 +467,6 @@ if (!isValidOutput) { | ||
async function imageSegmentation(args, options) { | ||
const res = await request(args, options); | ||
const res = await request(args, { | ||
...options, | ||
taskHint: "image-segmentation" | ||
}); | ||
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number"); | ||
@@ -415,3 +481,6 @@ if (!isValidOutput) { | ||
async function imageToText(args, options) { | ||
const res = (await request(args, options))?.[0]; | ||
const res = (await request(args, { | ||
...options, | ||
taskHint: "image-to-text" | ||
}))?.[0]; | ||
if (typeof res?.generated_text !== "string") { | ||
@@ -425,3 +494,6 @@ throw new InferenceOutputError("Expected {generated_text: string}"); | ||
async function objectDetection(args, options) { | ||
const res = await request(args, options); | ||
const res = await request(args, { | ||
...options, | ||
taskHint: "object-detection" | ||
}); | ||
const isValidOutput = Array.isArray(res) && res.every( | ||
@@ -440,3 +512,6 @@ (x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number" | ||
async function textToImage(args, options) { | ||
const res = await request(args, options); | ||
const res = await request(args, { | ||
...options, | ||
taskHint: "text-to-image" | ||
}); | ||
const isValidOutput = res && res instanceof Blob; | ||
@@ -483,3 +558,6 @@ if (!isValidOutput) { | ||
} | ||
const res = await request(reqArgs, options); | ||
const res = await request(reqArgs, { | ||
...options, | ||
taskHint: "image-to-image" | ||
}); | ||
const isValidOutput = res && res instanceof Blob; | ||
@@ -504,3 +582,6 @@ if (!isValidOutput) { | ||
}; | ||
const res = await request(reqArgs, options); | ||
const res = await request(reqArgs, { | ||
...options, | ||
taskHint: "zero-shot-image-classification" | ||
}); | ||
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number"); | ||
@@ -515,3 +596,3 @@ if (!isValidOutput) { | ||
async function conversational(args, options) { | ||
const res = await request(args, options); | ||
const res = await request(args, { ...options, taskHint: "conversational" }); | ||
const isValidOutput = Array.isArray(res.conversation.generated_responses) && res.conversation.generated_responses.every((x) => typeof x === "string") && Array.isArray(res.conversation.past_user_inputs) && res.conversation.past_user_inputs.every((x) => typeof x === "string") && typeof res.generated_text === "string" && Array.isArray(res.warnings) && res.warnings.every((x) => typeof x === "string"); | ||
@@ -526,43 +607,10 @@ if (!isValidOutput) { | ||
// src/lib/getDefaultTask.ts | ||
var taskCache = /* @__PURE__ */ new Map(); | ||
var CACHE_DURATION = 10 * 60 * 1e3; | ||
var MAX_CACHE_ITEMS = 1e3; | ||
var HF_HUB_URL = "https://huggingface.co"; | ||
async function getDefaultTask(model, accessToken) { | ||
if (isUrl(model)) { | ||
return null; | ||
} | ||
const key = `${model}:${accessToken}`; | ||
let cachedTask = taskCache.get(key); | ||
if (cachedTask && cachedTask.date < new Date(Date.now() - CACHE_DURATION)) { | ||
taskCache.delete(key); | ||
cachedTask = void 0; | ||
} | ||
if (cachedTask === void 0) { | ||
const modelTask = await fetch(`${HF_HUB_URL}/api/models/${model}?expand[]=pipeline_tag`, { | ||
headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {} | ||
}).then((resp) => resp.json()).then((json) => json.pipeline_tag).catch(() => null); | ||
if (!modelTask) { | ||
return null; | ||
} | ||
cachedTask = { task: modelTask, date: /* @__PURE__ */ new Date() }; | ||
taskCache.set(key, { task: modelTask, date: /* @__PURE__ */ new Date() }); | ||
if (taskCache.size > MAX_CACHE_ITEMS) { | ||
taskCache.delete(taskCache.keys().next().value); | ||
} | ||
} | ||
return cachedTask.task; | ||
} | ||
// src/tasks/nlp/featureExtraction.ts | ||
async function featureExtraction(args, options) { | ||
const defaultTask = await getDefaultTask(args.model, args.accessToken); | ||
const res = await request( | ||
args, | ||
defaultTask === "sentence-similarity" ? { | ||
...options, | ||
task: "feature-extraction" | ||
} : options | ||
); | ||
const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken) : void 0; | ||
const res = await request(args, { | ||
...options, | ||
taskHint: "feature-extraction", | ||
...defaultTask === "sentence-similarity" && { forceTask: "feature-extraction" } | ||
}); | ||
let isValidOutput = true; | ||
@@ -587,3 +635,6 @@ const isNumArrayRec = (arr, maxDepth, curDepth = 0) => { | ||
async function fillMask(args, options) { | ||
const res = await request(args, options); | ||
const res = await request(args, { | ||
...options, | ||
taskHint: "fill-mask" | ||
}); | ||
const isValidOutput = Array.isArray(res) && res.every( | ||
@@ -602,3 +653,6 @@ (x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string" | ||
async function questionAnswering(args, options) { | ||
const res = await request(args, options); | ||
const res = await request(args, { | ||
...options, | ||
taskHint: "question-answering" | ||
}); | ||
const isValidOutput = typeof res === "object" && !!res && typeof res.answer === "string" && typeof res.end === "number" && typeof res.score === "number" && typeof res.start === "number"; | ||
@@ -613,10 +667,8 @@ if (!isValidOutput) { | ||
async function sentenceSimilarity(args, options) { | ||
const defaultTask = await getDefaultTask(args.model, args.accessToken); | ||
const res = await request( | ||
args, | ||
defaultTask === "feature-extraction" ? { | ||
...options, | ||
task: "sentence-similarity" | ||
} : options | ||
); | ||
const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken) : void 0; | ||
const res = await request(args, { | ||
...options, | ||
taskHint: "sentence-similarity", | ||
...defaultTask === "feature-extraction" && { forceTask: "sentence-similarity" } | ||
}); | ||
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number"); | ||
@@ -631,3 +683,6 @@ if (!isValidOutput) { | ||
async function summarization(args, options) { | ||
const res = await request(args, options); | ||
const res = await request(args, { | ||
...options, | ||
taskHint: "summarization" | ||
}); | ||
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.summary_text === "string"); | ||
@@ -642,3 +697,6 @@ if (!isValidOutput) { | ||
async function tableQuestionAnswering(args, options) { | ||
const res = await request(args, options); | ||
const res = await request(args, { | ||
...options, | ||
taskHint: "table-question-answering" | ||
}); | ||
const isValidOutput = typeof res?.aggregator === "string" && typeof res.answer === "string" && Array.isArray(res.cells) && res.cells.every((x) => typeof x === "string") && Array.isArray(res.coordinates) && res.coordinates.every((coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number")); | ||
@@ -655,3 +713,6 @@ if (!isValidOutput) { | ||
async function textClassification(args, options) { | ||
const res = (await request(args, options))?.[0]; | ||
const res = (await request(args, { | ||
...options, | ||
taskHint: "text-classification" | ||
}))?.[0]; | ||
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.label === "string" && typeof x.score === "number"); | ||
@@ -666,3 +727,6 @@ if (!isValidOutput) { | ||
async function textGeneration(args, options) { | ||
const res = await request(args, options); | ||
const res = await request(args, { | ||
...options, | ||
taskHint: "text-generation" | ||
}); | ||
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.generated_text === "string"); | ||
@@ -677,3 +741,6 @@ if (!isValidOutput) { | ||
async function* textGenerationStream(args, options) { | ||
yield* streamingRequest(args, options); | ||
yield* streamingRequest(args, { | ||
...options, | ||
taskHint: "text-generation" | ||
}); | ||
} | ||
@@ -691,3 +758,8 @@ | ||
async function tokenClassification(args, options) { | ||
const res = toArray(await request(args, options)); | ||
const res = toArray( | ||
await request(args, { | ||
...options, | ||
taskHint: "token-classification" | ||
}) | ||
); | ||
const isValidOutput = Array.isArray(res) && res.every( | ||
@@ -706,3 +778,6 @@ (x) => typeof x.end === "number" && typeof x.entity_group === "string" && typeof x.score === "number" && typeof x.start === "number" && typeof x.word === "string" | ||
async function translation(args, options) { | ||
const res = await request(args, options); | ||
const res = await request(args, { | ||
...options, | ||
taskHint: "translation" | ||
}); | ||
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.translation_text === "string"); | ||
@@ -718,3 +793,6 @@ if (!isValidOutput) { | ||
const res = toArray( | ||
await request(args, options) | ||
await request(args, { | ||
...options, | ||
taskHint: "zero-shot-classification" | ||
}) | ||
); | ||
@@ -745,3 +823,6 @@ const isValidOutput = Array.isArray(res) && res.every( | ||
const res = toArray( | ||
await request(reqArgs, options) | ||
await request(reqArgs, { | ||
...options, | ||
taskHint: "document-question-answering" | ||
}) | ||
)?.[0]; | ||
@@ -769,3 +850,6 @@ const isValidOutput = typeof res?.answer === "string" && (typeof res.end === "number" || typeof res.end === "undefined") && (typeof res.score === "number" || typeof res.score === "undefined") && (typeof res.start === "number" || typeof res.start === "undefined"); | ||
}; | ||
const res = (await request(reqArgs, options))?.[0]; | ||
const res = (await request(reqArgs, { | ||
...options, | ||
taskHint: "visual-question-answering" | ||
}))?.[0]; | ||
const isValidOutput = typeof res?.answer === "string" && typeof res.score === "number"; | ||
@@ -780,3 +864,6 @@ if (!isValidOutput) { | ||
async function tabularRegression(args, options) { | ||
const res = await request(args, options); | ||
const res = await request(args, { | ||
...options, | ||
taskHint: "tabular-regression" | ||
}); | ||
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number"); | ||
@@ -791,3 +878,6 @@ if (!isValidOutput) { | ||
async function tabularClassification(args, options) { | ||
const res = await request(args, options); | ||
const res = await request(args, { | ||
...options, | ||
taskHint: "tabular-classification" | ||
}); | ||
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number"); | ||
@@ -794,0 +884,0 @@ if (!isValidOutput) { |
@@ -29,3 +29,35 @@ export interface Options { | ||
export type InferenceTask = "text-classification" | "feature-extraction" | "sentence-similarity"; | ||
export type InferenceTask = | ||
| "audio-classification" | ||
| "audio-to-audio" | ||
| "automatic-speech-recognition" | ||
| "conversational" | ||
| "depth-estimation" | ||
| "document-question-answering" | ||
| "feature-extraction" | ||
| "fill-mask" | ||
| "image-classification" | ||
| "image-segmentation" | ||
| "image-to-image" | ||
| "image-to-text" | ||
| "object-detection" | ||
| "video-classification" | ||
| "question-answering" | ||
| "reinforcement-learning" | ||
| "sentence-similarity" | ||
| "summarization" | ||
| "table-question-answering" | ||
| "tabular-classification" | ||
| "tabular-regression" | ||
| "text-classification" | ||
| "text-generation" | ||
| "text-to-image" | ||
| "text-to-speech" | ||
| "text-to-video" | ||
| "token-classification" | ||
| "translation" | ||
| "unconditional-image-generation" | ||
| "visual-question-answering" | ||
| "zero-shot-classification" | ||
| "zero-shot-image-classification"; | ||
@@ -41,4 +73,6 @@ export interface BaseArgs { | ||
* The model to use. Can be a full URL for HF inference endpoints. | ||
* | ||
* If not specified, will call huggingface.co/api/tasks to get the default model for the task. | ||
*/ | ||
model: string; | ||
model?: string; | ||
} | ||
@@ -149,2 +183,4 @@ | ||
task?: string | InferenceTask; | ||
/** To load default model if needed */ | ||
taskHint?: InferenceTask; | ||
} | ||
@@ -162,2 +198,4 @@ ): Promise<T>; | ||
task?: string | InferenceTask; | ||
/** To load default model if needed */ | ||
taskHint?: InferenceTask; | ||
} | ||
@@ -1027,2 +1065,4 @@ ): AsyncGenerator<T>; | ||
task?: string | InferenceTask; | ||
/** To load default model if needed */ | ||
taskHint?: InferenceTask; | ||
} | ||
@@ -1040,2 +1080,4 @@ ): Promise<T>; | ||
task?: string | InferenceTask; | ||
/** To load default model if needed */ | ||
taskHint?: InferenceTask; | ||
} | ||
@@ -1234,2 +1276,4 @@ ): AsyncGenerator<T>; | ||
task?: string | InferenceTask; | ||
/** To load default model if needed */ | ||
taskHint?: InferenceTask; | ||
} | ||
@@ -1247,2 +1291,4 @@ ): Promise<T>; | ||
task?: string | InferenceTask; | ||
/** To load default model if needed */ | ||
taskHint?: InferenceTask; | ||
} | ||
@@ -1249,0 +1295,0 @@ ): AsyncGenerator<T>; |
@@ -100,7 +100,40 @@ /// <reference path="./index.d.ts" /> | ||
// src/lib/getDefaultTask.ts | ||
var taskCache = /* @__PURE__ */ new Map(); | ||
var CACHE_DURATION = 10 * 60 * 1e3; | ||
var MAX_CACHE_ITEMS = 1e3; | ||
var HF_HUB_URL = "https://huggingface.co"; | ||
async function getDefaultTask(model, accessToken) { | ||
if (isUrl(model)) { | ||
return null; | ||
} | ||
const key = `${model}:${accessToken}`; | ||
let cachedTask = taskCache.get(key); | ||
if (cachedTask && cachedTask.date < new Date(Date.now() - CACHE_DURATION)) { | ||
taskCache.delete(key); | ||
cachedTask = void 0; | ||
} | ||
if (cachedTask === void 0) { | ||
const modelTask = await fetch(`${HF_HUB_URL}/api/models/${model}?expand[]=pipeline_tag`, { | ||
headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {} | ||
}).then((resp) => resp.json()).then((json) => json.pipeline_tag).catch(() => null); | ||
if (!modelTask) { | ||
return null; | ||
} | ||
cachedTask = { task: modelTask, date: /* @__PURE__ */ new Date() }; | ||
taskCache.set(key, { task: modelTask, date: /* @__PURE__ */ new Date() }); | ||
if (taskCache.size > MAX_CACHE_ITEMS) { | ||
taskCache.delete(taskCache.keys().next().value); | ||
} | ||
} | ||
return cachedTask.task; | ||
} | ||
// src/lib/makeRequestOptions.ts | ||
var HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co"; | ||
function makeRequestOptions(args, options) { | ||
const { model, accessToken, ...otherArgs } = args; | ||
const { task, includeCredentials, ...otherOptions } = options ?? {}; | ||
var tasks = null; | ||
async function makeRequestOptions(args, options) { | ||
const { accessToken, model: _model, ...otherArgs } = args; | ||
let { model } = args; | ||
const { forceTask: task, includeCredentials, taskHint, ...otherOptions } = options ?? {}; | ||
const headers = {}; | ||
@@ -110,2 +143,17 @@ if (accessToken) { | ||
} | ||
if (!model && !tasks && taskHint) { | ||
const res = await fetch(`${HF_HUB_URL}/api/tasks`); | ||
if (res.ok) { | ||
tasks = await res.json(); | ||
} | ||
} | ||
if (!model && tasks && taskHint) { | ||
const taskInfo = tasks[taskHint]; | ||
if (taskInfo) { | ||
model = taskInfo.models[0].id; | ||
} | ||
} | ||
if (!model) { | ||
throw new Error("No model provided, and no default model found for this task"); | ||
} | ||
const binary = "data" in args && !!args.data; | ||
@@ -148,3 +196,3 @@ if (!binary) { | ||
async function request(args, options) { | ||
const { url, info } = makeRequestOptions(args, options); | ||
const { url, info } = await makeRequestOptions(args, options); | ||
const response = await (options?.fetch ?? fetch)(url, info); | ||
@@ -273,3 +321,3 @@ if (options?.retry_on_error !== false && response.status === 503 && !options?.wait_for_model) { | ||
async function* streamingRequest(args, options) { | ||
const { url, info } = makeRequestOptions({ ...args, stream: true }, options); | ||
const { url, info } = await makeRequestOptions({ ...args, stream: true }, options); | ||
const response = await (options?.fetch ?? fetch)(url, info); | ||
@@ -347,3 +395,6 @@ if (options?.retry_on_error !== false && response.status === 503 && !options?.wait_for_model) { | ||
async function audioClassification(args, options) { | ||
const res = await request(args, options); | ||
const res = await request(args, { | ||
...options, | ||
taskHint: "audio-classification" | ||
}); | ||
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number"); | ||
@@ -358,3 +409,6 @@ if (!isValidOutput) { | ||
async function automaticSpeechRecognition(args, options) { | ||
const res = await request(args, options); | ||
const res = await request(args, { | ||
...options, | ||
taskHint: "automatic-speech-recognition" | ||
}); | ||
const isValidOutput = typeof res?.text === "string"; | ||
@@ -369,3 +423,6 @@ if (!isValidOutput) { | ||
async function textToSpeech(args, options) { | ||
const res = await request(args, options); | ||
const res = await request(args, { | ||
...options, | ||
taskHint: "text-to-speech" | ||
}); | ||
const isValidOutput = res && res instanceof Blob; | ||
@@ -380,3 +437,6 @@ if (!isValidOutput) { | ||
async function audioToAudio(args, options) { | ||
const res = await request(args, options); | ||
const res = await request(args, { | ||
...options, | ||
taskHint: "audio-to-audio" | ||
}); | ||
const isValidOutput = Array.isArray(res) && res.every( | ||
@@ -393,3 +453,6 @@ (x) => typeof x.label === "string" && typeof x.blob === "string" && typeof x["content-type"] === "string" | ||
async function imageClassification(args, options) { | ||
const res = await request(args, options); | ||
const res = await request(args, { | ||
...options, | ||
taskHint: "image-classification" | ||
}); | ||
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number"); | ||
@@ -404,3 +467,6 @@ if (!isValidOutput) { | ||
async function imageSegmentation(args, options) { | ||
const res = await request(args, options); | ||
const res = await request(args, { | ||
...options, | ||
taskHint: "image-segmentation" | ||
}); | ||
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number"); | ||
@@ -415,3 +481,6 @@ if (!isValidOutput) { | ||
async function imageToText(args, options) { | ||
const res = (await request(args, options))?.[0]; | ||
const res = (await request(args, { | ||
...options, | ||
taskHint: "image-to-text" | ||
}))?.[0]; | ||
if (typeof res?.generated_text !== "string") { | ||
@@ -425,3 +494,6 @@ throw new InferenceOutputError("Expected {generated_text: string}"); | ||
async function objectDetection(args, options) { | ||
const res = await request(args, options); | ||
const res = await request(args, { | ||
...options, | ||
taskHint: "object-detection" | ||
}); | ||
const isValidOutput = Array.isArray(res) && res.every( | ||
@@ -440,3 +512,6 @@ (x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number" | ||
async function textToImage(args, options) { | ||
const res = await request(args, options); | ||
const res = await request(args, { | ||
...options, | ||
taskHint: "text-to-image" | ||
}); | ||
const isValidOutput = res && res instanceof Blob; | ||
@@ -483,3 +558,6 @@ if (!isValidOutput) { | ||
} | ||
const res = await request(reqArgs, options); | ||
const res = await request(reqArgs, { | ||
...options, | ||
taskHint: "image-to-image" | ||
}); | ||
const isValidOutput = res && res instanceof Blob; | ||
@@ -504,3 +582,6 @@ if (!isValidOutput) { | ||
}; | ||
const res = await request(reqArgs, options); | ||
const res = await request(reqArgs, { | ||
...options, | ||
taskHint: "zero-shot-image-classification" | ||
}); | ||
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number"); | ||
@@ -515,3 +596,3 @@ if (!isValidOutput) { | ||
async function conversational(args, options) { | ||
const res = await request(args, options); | ||
const res = await request(args, { ...options, taskHint: "conversational" }); | ||
const isValidOutput = Array.isArray(res.conversation.generated_responses) && res.conversation.generated_responses.every((x) => typeof x === "string") && Array.isArray(res.conversation.past_user_inputs) && res.conversation.past_user_inputs.every((x) => typeof x === "string") && typeof res.generated_text === "string" && Array.isArray(res.warnings) && res.warnings.every((x) => typeof x === "string"); | ||
@@ -526,43 +607,10 @@ if (!isValidOutput) { | ||
// src/lib/getDefaultTask.ts | ||
var taskCache = /* @__PURE__ */ new Map(); | ||
var CACHE_DURATION = 10 * 60 * 1e3; | ||
var MAX_CACHE_ITEMS = 1e3; | ||
var HF_HUB_URL = "https://huggingface.co"; | ||
async function getDefaultTask(model, accessToken) { | ||
if (isUrl(model)) { | ||
return null; | ||
} | ||
const key = `${model}:${accessToken}`; | ||
let cachedTask = taskCache.get(key); | ||
if (cachedTask && cachedTask.date < new Date(Date.now() - CACHE_DURATION)) { | ||
taskCache.delete(key); | ||
cachedTask = void 0; | ||
} | ||
if (cachedTask === void 0) { | ||
const modelTask = await fetch(`${HF_HUB_URL}/api/models/${model}?expand[]=pipeline_tag`, { | ||
headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {} | ||
}).then((resp) => resp.json()).then((json) => json.pipeline_tag).catch(() => null); | ||
if (!modelTask) { | ||
return null; | ||
} | ||
cachedTask = { task: modelTask, date: /* @__PURE__ */ new Date() }; | ||
taskCache.set(key, { task: modelTask, date: /* @__PURE__ */ new Date() }); | ||
if (taskCache.size > MAX_CACHE_ITEMS) { | ||
taskCache.delete(taskCache.keys().next().value); | ||
} | ||
} | ||
return cachedTask.task; | ||
} | ||
// src/tasks/nlp/featureExtraction.ts | ||
async function featureExtraction(args, options) { | ||
const defaultTask = await getDefaultTask(args.model, args.accessToken); | ||
const res = await request( | ||
args, | ||
defaultTask === "sentence-similarity" ? { | ||
...options, | ||
task: "feature-extraction" | ||
} : options | ||
); | ||
const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken) : void 0; | ||
const res = await request(args, { | ||
...options, | ||
taskHint: "feature-extraction", | ||
...defaultTask === "sentence-similarity" && { forceTask: "feature-extraction" } | ||
}); | ||
let isValidOutput = true; | ||
@@ -587,3 +635,6 @@ const isNumArrayRec = (arr, maxDepth, curDepth = 0) => { | ||
async function fillMask(args, options) { | ||
const res = await request(args, options); | ||
const res = await request(args, { | ||
...options, | ||
taskHint: "fill-mask" | ||
}); | ||
const isValidOutput = Array.isArray(res) && res.every( | ||
@@ -602,3 +653,6 @@ (x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string" | ||
async function questionAnswering(args, options) { | ||
const res = await request(args, options); | ||
const res = await request(args, { | ||
...options, | ||
taskHint: "question-answering" | ||
}); | ||
const isValidOutput = typeof res === "object" && !!res && typeof res.answer === "string" && typeof res.end === "number" && typeof res.score === "number" && typeof res.start === "number"; | ||
@@ -613,10 +667,8 @@ if (!isValidOutput) { | ||
async function sentenceSimilarity(args, options) { | ||
const defaultTask = await getDefaultTask(args.model, args.accessToken); | ||
const res = await request( | ||
args, | ||
defaultTask === "feature-extraction" ? { | ||
...options, | ||
task: "sentence-similarity" | ||
} : options | ||
); | ||
const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken) : void 0; | ||
const res = await request(args, { | ||
...options, | ||
taskHint: "sentence-similarity", | ||
...defaultTask === "feature-extraction" && { forceTask: "sentence-similarity" } | ||
}); | ||
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number"); | ||
@@ -631,3 +683,6 @@ if (!isValidOutput) { | ||
async function summarization(args, options) { | ||
const res = await request(args, options); | ||
const res = await request(args, { | ||
...options, | ||
taskHint: "summarization" | ||
}); | ||
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.summary_text === "string"); | ||
@@ -642,3 +697,6 @@ if (!isValidOutput) { | ||
async function tableQuestionAnswering(args, options) { | ||
const res = await request(args, options); | ||
const res = await request(args, { | ||
...options, | ||
taskHint: "table-question-answering" | ||
}); | ||
const isValidOutput = typeof res?.aggregator === "string" && typeof res.answer === "string" && Array.isArray(res.cells) && res.cells.every((x) => typeof x === "string") && Array.isArray(res.coordinates) && res.coordinates.every((coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number")); | ||
@@ -655,3 +713,6 @@ if (!isValidOutput) { | ||
async function textClassification(args, options) { | ||
const res = (await request(args, options))?.[0]; | ||
const res = (await request(args, { | ||
...options, | ||
taskHint: "text-classification" | ||
}))?.[0]; | ||
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.label === "string" && typeof x.score === "number"); | ||
@@ -666,3 +727,6 @@ if (!isValidOutput) { | ||
async function textGeneration(args, options) { | ||
const res = await request(args, options); | ||
const res = await request(args, { | ||
...options, | ||
taskHint: "text-generation" | ||
}); | ||
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.generated_text === "string"); | ||
@@ -677,3 +741,6 @@ if (!isValidOutput) { | ||
async function* textGenerationStream(args, options) { | ||
yield* streamingRequest(args, options); | ||
yield* streamingRequest(args, { | ||
...options, | ||
taskHint: "text-generation" | ||
}); | ||
} | ||
@@ -691,3 +758,8 @@ | ||
async function tokenClassification(args, options) { | ||
const res = toArray(await request(args, options)); | ||
const res = toArray( | ||
await request(args, { | ||
...options, | ||
taskHint: "token-classification" | ||
}) | ||
); | ||
const isValidOutput = Array.isArray(res) && res.every( | ||
@@ -706,3 +778,6 @@ (x) => typeof x.end === "number" && typeof x.entity_group === "string" && typeof x.score === "number" && typeof x.start === "number" && typeof x.word === "string" | ||
async function translation(args, options) { | ||
const res = await request(args, options); | ||
const res = await request(args, { | ||
...options, | ||
taskHint: "translation" | ||
}); | ||
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.translation_text === "string"); | ||
@@ -718,3 +793,6 @@ if (!isValidOutput) { | ||
const res = toArray( | ||
await request(args, options) | ||
await request(args, { | ||
...options, | ||
taskHint: "zero-shot-classification" | ||
}) | ||
); | ||
@@ -745,3 +823,6 @@ const isValidOutput = Array.isArray(res) && res.every( | ||
const res = toArray( | ||
await request(reqArgs, options) | ||
await request(reqArgs, { | ||
...options, | ||
taskHint: "document-question-answering" | ||
}) | ||
)?.[0]; | ||
@@ -769,3 +850,6 @@ const isValidOutput = typeof res?.answer === "string" && (typeof res.end === "number" || typeof res.end === "undefined") && (typeof res.score === "number" || typeof res.score === "undefined") && (typeof res.start === "number" || typeof res.start === "undefined"); | ||
}; | ||
const res = (await request(reqArgs, options))?.[0]; | ||
const res = (await request(reqArgs, { | ||
...options, | ||
taskHint: "visual-question-answering" | ||
}))?.[0]; | ||
const isValidOutput = typeof res?.answer === "string" && typeof res.score === "number"; | ||
@@ -780,3 +864,6 @@ if (!isValidOutput) { | ||
async function tabularRegression(args, options) { | ||
const res = await request(args, options); | ||
const res = await request(args, { | ||
...options, | ||
taskHint: "tabular-regression" | ||
}); | ||
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number"); | ||
@@ -791,3 +878,6 @@ if (!isValidOutput) { | ||
async function tabularClassification(args, options) { | ||
const res = await request(args, options); | ||
const res = await request(args, { | ||
...options, | ||
taskHint: "tabular-classification" | ||
}); | ||
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number"); | ||
@@ -794,0 +884,0 @@ if (!isValidOutput) { |
{ | ||
"name": "@huggingface/inference", | ||
"version": "2.5.2", | ||
"version": "2.6.0", | ||
"packageManager": "pnpm@8.3.1", | ||
@@ -5,0 +5,0 @@ "license": "MIT", |
@@ -11,3 +11,3 @@ import { isUrl } from "./isUrl"; | ||
const MAX_CACHE_ITEMS = 1000; | ||
const HF_HUB_URL = "https://huggingface.co"; | ||
export const HF_HUB_URL = "https://huggingface.co"; | ||
@@ -14,0 +14,0 @@ /** |
import type { InferenceTask, Options, RequestArgs } from "../types"; | ||
import { HF_HUB_URL } from "./getDefaultTask"; | ||
import { isUrl } from "./isUrl"; | ||
@@ -7,5 +8,10 @@ | ||
/** | ||
* Loaded from huggingface.co/api/tasks if needed | ||
*/ | ||
let tasks: Record<string, { models: { id: string }[] }> | null = null; | ||
/** | ||
* Helper that prepares request arguments | ||
*/ | ||
export function makeRequestOptions( | ||
export async function makeRequestOptions( | ||
args: RequestArgs & { | ||
@@ -19,7 +25,11 @@ data?: Blob | ArrayBuffer; | ||
/** When a model can be used for multiple tasks, and we want to run a non-default task */ | ||
task?: string | InferenceTask; | ||
forceTask?: string | InferenceTask; | ||
/** To load default model if needed */ | ||
taskHint?: InferenceTask; | ||
} | ||
): { url: string; info: RequestInit } { | ||
const { model, accessToken, ...otherArgs } = args; | ||
const { task, includeCredentials, ...otherOptions } = options ?? {}; | ||
): Promise<{ url: string; info: RequestInit }> { | ||
// eslint-disable-next-line @typescript-eslint/no-unused-vars | ||
const { accessToken, model: _model, ...otherArgs } = args; | ||
let { model } = args; | ||
const { forceTask: task, includeCredentials, taskHint, ...otherOptions } = options ?? {}; | ||
@@ -31,2 +41,21 @@ const headers: Record<string, string> = {}; | ||
if (!model && !tasks && taskHint) { | ||
const res = await fetch(`${HF_HUB_URL}/api/tasks`); | ||
if (res.ok) { | ||
tasks = await res.json(); | ||
} | ||
} | ||
if (!model && tasks && taskHint) { | ||
const taskInfo = tasks[taskHint]; | ||
if (taskInfo) { | ||
model = taskInfo.models[0].id; | ||
} | ||
} | ||
if (!model) { | ||
throw new Error("No model provided, and no default model found for this task"); | ||
} | ||
const binary = "data" in args && !!args.data; | ||
@@ -33,0 +62,0 @@ |
@@ -34,3 +34,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError"; | ||
): Promise<AudioClassificationReturn> { | ||
const res = await request<AudioClassificationReturn>(args, options); | ||
const res = await request<AudioClassificationReturn>(args, { | ||
...options, | ||
taskHint: "audio-classification", | ||
}); | ||
const isValidOutput = | ||
@@ -37,0 +40,0 @@ Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number"); |
@@ -36,3 +36,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError"; | ||
export async function audioToAudio(args: AudioToAudioArgs, options?: Options): Promise<AudioToAudioReturn> { | ||
const res = await request<AudioToAudioReturn>(args, options); | ||
const res = await request<AudioToAudioReturn>(args, { | ||
...options, | ||
taskHint: "audio-to-audio", | ||
}); | ||
const isValidOutput = | ||
@@ -39,0 +42,0 @@ Array.isArray(res) && |
@@ -27,3 +27,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError"; | ||
): Promise<AutomaticSpeechRecognitionOutput> { | ||
const res = await request<AutomaticSpeechRecognitionOutput>(args, options); | ||
const res = await request<AutomaticSpeechRecognitionOutput>(args, { | ||
...options, | ||
taskHint: "automatic-speech-recognition", | ||
}); | ||
const isValidOutput = typeof res?.text === "string"; | ||
@@ -30,0 +33,0 @@ if (!isValidOutput) { |
@@ -19,3 +19,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError"; | ||
export async function textToSpeech(args: TextToSpeechArgs, options?: Options): Promise<TextToSpeechOutput> { | ||
const res = await request<TextToSpeechOutput>(args, options); | ||
const res = await request<TextToSpeechOutput>(args, { | ||
...options, | ||
taskHint: "text-to-speech", | ||
}); | ||
const isValidOutput = res && res instanceof Blob; | ||
@@ -22,0 +25,0 @@ if (!isValidOutput) { |
@@ -14,5 +14,7 @@ import type { InferenceTask, Options, RequestArgs } from "../../types"; | ||
task?: string | InferenceTask; | ||
/** To load default model if needed */ | ||
taskHint?: InferenceTask; | ||
} | ||
): Promise<T> { | ||
const { url, info } = makeRequestOptions(args, options); | ||
const { url, info } = await makeRequestOptions(args, options); | ||
const response = await (options?.fetch ?? fetch)(url, info); | ||
@@ -19,0 +21,0 @@ |
@@ -16,5 +16,7 @@ import type { InferenceTask, Options, RequestArgs } from "../../types"; | ||
task?: string | InferenceTask; | ||
/** To load default model if needed */ | ||
taskHint?: InferenceTask; | ||
} | ||
): AsyncGenerator<T> { | ||
const { url, info } = makeRequestOptions({ ...args, stream: true }, options); | ||
const { url, info } = await makeRequestOptions({ ...args, stream: true }, options); | ||
const response = await (options?.fetch ?? fetch)(url, info); | ||
@@ -21,0 +23,0 @@ |
@@ -33,3 +33,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError"; | ||
): Promise<ImageClassificationOutput> { | ||
const res = await request<ImageClassificationOutput>(args, options); | ||
const res = await request<ImageClassificationOutput>(args, { | ||
...options, | ||
taskHint: "image-classification", | ||
}); | ||
const isValidOutput = | ||
@@ -36,0 +39,0 @@ Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number"); |
@@ -37,3 +37,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError"; | ||
): Promise<ImageSegmentationOutput> { | ||
const res = await request<ImageSegmentationOutput>(args, options); | ||
const res = await request<ImageSegmentationOutput>(args, { | ||
...options, | ||
taskHint: "image-segmentation", | ||
}); | ||
const isValidOutput = | ||
@@ -40,0 +43,0 @@ Array.isArray(res) && |
@@ -77,3 +77,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError"; | ||
} | ||
const res = await request<ImageToImageOutput>(reqArgs, options); | ||
const res = await request<ImageToImageOutput>(reqArgs, { | ||
...options, | ||
taskHint: "image-to-image", | ||
}); | ||
const isValidOutput = res && res instanceof Blob; | ||
@@ -80,0 +83,0 @@ if (!isValidOutput) { |
@@ -23,3 +23,8 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError"; | ||
export async function imageToText(args: ImageToTextArgs, options?: Options): Promise<ImageToTextOutput> { | ||
const res = (await request<[ImageToTextOutput]>(args, options))?.[0]; | ||
const res = ( | ||
await request<[ImageToTextOutput]>(args, { | ||
...options, | ||
taskHint: "image-to-text", | ||
}) | ||
)?.[0]; | ||
@@ -26,0 +31,0 @@ if (typeof res?.generated_text !== "string") { |
@@ -40,3 +40,6 @@ import { request } from "../custom/request"; | ||
export async function objectDetection(args: ObjectDetectionArgs, options?: Options): Promise<ObjectDetectionOutput> { | ||
const res = await request<ObjectDetectionOutput>(args, options); | ||
const res = await request<ObjectDetectionOutput>(args, { | ||
...options, | ||
taskHint: "object-detection", | ||
}); | ||
const isValidOutput = | ||
@@ -43,0 +46,0 @@ Array.isArray(res) && |
@@ -42,3 +42,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError"; | ||
export async function textToImage(args: TextToImageArgs, options?: Options): Promise<TextToImageOutput> { | ||
const res = await request<TextToImageOutput>(args, options); | ||
const res = await request<TextToImageOutput>(args, { | ||
...options, | ||
taskHint: "text-to-image", | ||
}); | ||
const isValidOutput = res && res instanceof Blob; | ||
@@ -45,0 +48,0 @@ if (!isValidOutput) { |
@@ -48,3 +48,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError"; | ||
const res = await request<ZeroShotImageClassificationOutput>(reqArgs, options); | ||
const res = await request<ZeroShotImageClassificationOutput>(reqArgs, { | ||
...options, | ||
taskHint: "zero-shot-image-classification", | ||
}); | ||
const isValidOutput = | ||
@@ -51,0 +54,0 @@ Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number"); |
@@ -59,3 +59,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError"; | ||
const res = toArray( | ||
await request<[DocumentQuestionAnsweringOutput] | DocumentQuestionAnsweringOutput>(reqArgs, options) | ||
await request<[DocumentQuestionAnsweringOutput] | DocumentQuestionAnsweringOutput>(reqArgs, { | ||
...options, | ||
taskHint: "document-question-answering", | ||
}) | ||
)?.[0]; | ||
@@ -62,0 +65,0 @@ const isValidOutput = |
@@ -48,3 +48,8 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError"; | ||
} as RequestArgs; | ||
const res = (await request<[VisualQuestionAnsweringOutput]>(reqArgs, options))?.[0]; | ||
const res = ( | ||
await request<[VisualQuestionAnsweringOutput]>(reqArgs, { | ||
...options, | ||
taskHint: "visual-question-answering", | ||
}) | ||
)?.[0]; | ||
const isValidOutput = typeof res?.answer === "string" && typeof res.score === "number"; | ||
@@ -51,0 +56,0 @@ if (!isValidOutput) { |
@@ -66,3 +66,3 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError"; | ||
export async function conversational(args: ConversationalArgs, options?: Options): Promise<ConversationalOutput> { | ||
const res = await request<ConversationalOutput>(args, options); | ||
const res = await request<ConversationalOutput>(args, { ...options, taskHint: "conversational" }); | ||
const isValidOutput = | ||
@@ -69,0 +69,0 @@ Array.isArray(res.conversation.generated_responses) && |
@@ -28,12 +28,9 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError"; | ||
): Promise<FeatureExtractionOutput> { | ||
const defaultTask = await getDefaultTask(args.model, args.accessToken); | ||
const res = await request<FeatureExtractionOutput>( | ||
args, | ||
defaultTask === "sentence-similarity" | ||
? { | ||
...options, | ||
task: "feature-extraction", | ||
} | ||
: options | ||
); | ||
const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken) : undefined; | ||
const res = await request<FeatureExtractionOutput>(args, { | ||
...options, | ||
taskHint: "feature-extraction", | ||
...(defaultTask === "sentence-similarity" && { forceTask: "feature-extraction" }), | ||
}); | ||
let isValidOutput = true; | ||
@@ -40,0 +37,0 @@ |
@@ -32,3 +32,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError"; | ||
export async function fillMask(args: FillMaskArgs, options?: Options): Promise<FillMaskOutput> { | ||
const res = await request<FillMaskOutput>(args, options); | ||
const res = await request<FillMaskOutput>(args, { | ||
...options, | ||
taskHint: "fill-mask", | ||
}); | ||
const isValidOutput = | ||
@@ -35,0 +38,0 @@ Array.isArray(res) && |
@@ -38,3 +38,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError"; | ||
): Promise<QuestionAnsweringOutput> { | ||
const res = await request<QuestionAnsweringOutput>(args, options); | ||
const res = await request<QuestionAnsweringOutput>(args, { | ||
...options, | ||
taskHint: "question-answering", | ||
}); | ||
const isValidOutput = | ||
@@ -41,0 +44,0 @@ typeof res === "object" && |
@@ -28,12 +28,8 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError"; | ||
): Promise<SentenceSimilarityOutput> { | ||
const defaultTask = await getDefaultTask(args.model, args.accessToken); | ||
const res = await request<SentenceSimilarityOutput>( | ||
args, | ||
defaultTask === "feature-extraction" | ||
? { | ||
...options, | ||
task: "sentence-similarity", | ||
} | ||
: options | ||
); | ||
const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken) : undefined; | ||
const res = await request<SentenceSimilarityOutput>(args, { | ||
...options, | ||
taskHint: "sentence-similarity", | ||
...(defaultTask === "feature-extraction" && { forceTask: "sentence-similarity" }), | ||
}); | ||
@@ -40,0 +36,0 @@ const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number"); |
@@ -53,3 +53,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError"; | ||
export async function summarization(args: SummarizationArgs, options?: Options): Promise<SummarizationOutput> { | ||
const res = await request<SummarizationOutput[]>(args, options); | ||
const res = await request<SummarizationOutput[]>(args, { | ||
...options, | ||
taskHint: "summarization", | ||
}); | ||
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.summary_text === "string"); | ||
@@ -56,0 +59,0 @@ if (!isValidOutput) { |
@@ -44,3 +44,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError"; | ||
): Promise<TableQuestionAnsweringOutput> { | ||
const res = await request<TableQuestionAnsweringOutput>(args, options); | ||
const res = await request<TableQuestionAnsweringOutput>(args, { | ||
...options, | ||
taskHint: "table-question-answering", | ||
}); | ||
const isValidOutput = | ||
@@ -47,0 +50,0 @@ typeof res?.aggregator === "string" && |
@@ -30,3 +30,8 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError"; | ||
): Promise<TextClassificationOutput> { | ||
const res = (await request<TextClassificationOutput[]>(args, options))?.[0]; | ||
const res = ( | ||
await request<TextClassificationOutput[]>(args, { | ||
...options, | ||
taskHint: "text-classification", | ||
}) | ||
)?.[0]; | ||
const isValidOutput = | ||
@@ -33,0 +38,0 @@ Array.isArray(res) && res.every((x) => typeof x?.label === "string" && typeof x.score === "number"); |
@@ -65,3 +65,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError"; | ||
export async function textGeneration(args: TextGenerationArgs, options?: Options): Promise<TextGenerationOutput> { | ||
const res = await request<TextGenerationOutput[]>(args, options); | ||
const res = await request<TextGenerationOutput[]>(args, { | ||
...options, | ||
taskHint: "text-generation", | ||
}); | ||
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.generated_text === "string"); | ||
@@ -68,0 +71,0 @@ if (!isValidOutput) { |
@@ -91,3 +91,6 @@ import type { Options } from "../../types"; | ||
): AsyncGenerator<TextGenerationStreamOutput> { | ||
yield* streamingRequest<TextGenerationStreamOutput>(args, options); | ||
yield* streamingRequest<TextGenerationStreamOutput>(args, { | ||
...options, | ||
taskHint: "text-generation", | ||
}); | ||
} |
@@ -61,3 +61,8 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError"; | ||
): Promise<TokenClassificationOutput> { | ||
const res = toArray(await request<TokenClassificationOutput[number] | TokenClassificationOutput>(args, options)); | ||
const res = toArray( | ||
await request<TokenClassificationOutput[number] | TokenClassificationOutput>(args, { | ||
...options, | ||
taskHint: "token-classification", | ||
}) | ||
); | ||
const isValidOutput = | ||
@@ -64,0 +69,0 @@ Array.isArray(res) && |
@@ -23,3 +23,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError"; | ||
export async function translation(args: TranslationArgs, options?: Options): Promise<TranslationOutput> { | ||
const res = await request<TranslationOutput[]>(args, options); | ||
const res = await request<TranslationOutput[]>(args, { | ||
...options, | ||
taskHint: "translation", | ||
}); | ||
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.translation_text === "string"); | ||
@@ -26,0 +29,0 @@ if (!isValidOutput) { |
@@ -39,3 +39,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError"; | ||
const res = toArray( | ||
await request<ZeroShotClassificationOutput[number] | ZeroShotClassificationOutput>(args, options) | ||
await request<ZeroShotClassificationOutput[number] | ZeroShotClassificationOutput>(args, { | ||
...options, | ||
taskHint: "zero-shot-classification", | ||
}) | ||
); | ||
@@ -42,0 +45,0 @@ const isValidOutput = |
@@ -28,3 +28,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError"; | ||
): Promise<TabularClassificationOutput> { | ||
const res = await request<TabularClassificationOutput>(args, options); | ||
const res = await request<TabularClassificationOutput>(args, { | ||
...options, | ||
taskHint: "tabular-classification", | ||
}); | ||
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number"); | ||
@@ -31,0 +34,0 @@ if (!isValidOutput) { |
@@ -28,3 +28,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError"; | ||
): Promise<TabularRegressionOutput> { | ||
const res = await request<TabularRegressionOutput>(args, options); | ||
const res = await request<TabularRegressionOutput>(args, { | ||
...options, | ||
taskHint: "tabular-regression", | ||
}); | ||
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number"); | ||
@@ -31,0 +34,0 @@ if (!isValidOutput) { |
@@ -29,3 +29,35 @@ export interface Options { | ||
export type InferenceTask = "text-classification" | "feature-extraction" | "sentence-similarity"; | ||
export type InferenceTask = | ||
| "audio-classification" | ||
| "audio-to-audio" | ||
| "automatic-speech-recognition" | ||
| "conversational" | ||
| "depth-estimation" | ||
| "document-question-answering" | ||
| "feature-extraction" | ||
| "fill-mask" | ||
| "image-classification" | ||
| "image-segmentation" | ||
| "image-to-image" | ||
| "image-to-text" | ||
| "object-detection" | ||
| "video-classification" | ||
| "question-answering" | ||
| "reinforcement-learning" | ||
| "sentence-similarity" | ||
| "summarization" | ||
| "table-question-answering" | ||
| "tabular-classification" | ||
| "tabular-regression" | ||
| "text-classification" | ||
| "text-generation" | ||
| "text-to-image" | ||
| "text-to-speech" | ||
| "text-to-video" | ||
| "token-classification" | ||
| "translation" | ||
| "unconditional-image-generation" | ||
| "visual-question-answering" | ||
| "zero-shot-classification" | ||
| "zero-shot-image-classification"; | ||
@@ -41,4 +73,6 @@ export interface BaseArgs { | ||
* The model to use. Can be a full URL for HF inference endpoints. | ||
* | ||
* If not specified, will call huggingface.co/api/tasks to get the default model for the task. | ||
*/ | ||
model: string; | ||
model?: string; | ||
} | ||
@@ -45,0 +79,0 @@ |
Sorry, the diff of this file is not supported yet
236848
4.36%6468
7.53%12
33.33%