You're Invited: Meet the Socket team at BSidesSF and RSAC - April 27 - May 1.RSVP

@huggingface/inference

Advanced tools

Socket logo

Install Socket

Detect and block malicious and high-risk dependencies

Install

@huggingface/inference - npm Package Compare versions

Comparing version

to
2.6.0

@@ -100,7 +100,40 @@ /// <reference path="./index.d.ts" />

// src/lib/getDefaultTask.ts
var taskCache = /* @__PURE__ */ new Map();
var CACHE_DURATION = 10 * 60 * 1e3;
var MAX_CACHE_ITEMS = 1e3;
var HF_HUB_URL = "https://huggingface.co";
async function getDefaultTask(model, accessToken) {
if (isUrl(model)) {
return null;
}
const key = `${model}:${accessToken}`;
let cachedTask = taskCache.get(key);
if (cachedTask && cachedTask.date < new Date(Date.now() - CACHE_DURATION)) {
taskCache.delete(key);
cachedTask = void 0;
}
if (cachedTask === void 0) {
const modelTask = await fetch(`${HF_HUB_URL}/api/models/${model}?expand[]=pipeline_tag`, {
headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {}
}).then((resp) => resp.json()).then((json) => json.pipeline_tag).catch(() => null);
if (!modelTask) {
return null;
}
cachedTask = { task: modelTask, date: /* @__PURE__ */ new Date() };
taskCache.set(key, { task: modelTask, date: /* @__PURE__ */ new Date() });
if (taskCache.size > MAX_CACHE_ITEMS) {
taskCache.delete(taskCache.keys().next().value);
}
}
return cachedTask.task;
}
// src/lib/makeRequestOptions.ts
var HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co";
function makeRequestOptions(args, options) {
const { model, accessToken, ...otherArgs } = args;
const { task, includeCredentials, ...otherOptions } = options ?? {};
var tasks = null;
async function makeRequestOptions(args, options) {
const { accessToken, model: _model, ...otherArgs } = args;
let { model } = args;
const { forceTask: task, includeCredentials, taskHint, ...otherOptions } = options ?? {};
const headers = {};

@@ -110,2 +143,17 @@ if (accessToken) {

}
if (!model && !tasks && taskHint) {
const res = await fetch(`${HF_HUB_URL}/api/tasks`);
if (res.ok) {
tasks = await res.json();
}
}
if (!model && tasks && taskHint) {
const taskInfo = tasks[taskHint];
if (taskInfo) {
model = taskInfo.models[0].id;
}
}
if (!model) {
throw new Error("No model provided, and no default model found for this task");
}
const binary = "data" in args && !!args.data;

@@ -148,3 +196,3 @@ if (!binary) {

async function request(args, options) {
const { url, info } = makeRequestOptions(args, options);
const { url, info } = await makeRequestOptions(args, options);
const response = await (options?.fetch ?? fetch)(url, info);

@@ -273,3 +321,3 @@ if (options?.retry_on_error !== false && response.status === 503 && !options?.wait_for_model) {

async function* streamingRequest(args, options) {
const { url, info } = makeRequestOptions({ ...args, stream: true }, options);
const { url, info } = await makeRequestOptions({ ...args, stream: true }, options);
const response = await (options?.fetch ?? fetch)(url, info);

@@ -347,3 +395,6 @@ if (options?.retry_on_error !== false && response.status === 503 && !options?.wait_for_model) {

async function audioClassification(args, options) {
const res = await request(args, options);
const res = await request(args, {
...options,
taskHint: "audio-classification"
});
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");

@@ -358,3 +409,6 @@ if (!isValidOutput) {

async function automaticSpeechRecognition(args, options) {
const res = await request(args, options);
const res = await request(args, {
...options,
taskHint: "automatic-speech-recognition"
});
const isValidOutput = typeof res?.text === "string";

@@ -369,3 +423,6 @@ if (!isValidOutput) {

async function textToSpeech(args, options) {
const res = await request(args, options);
const res = await request(args, {
...options,
taskHint: "text-to-speech"
});
const isValidOutput = res && res instanceof Blob;

@@ -380,3 +437,6 @@ if (!isValidOutput) {

async function audioToAudio(args, options) {
const res = await request(args, options);
const res = await request(args, {
...options,
taskHint: "audio-to-audio"
});
const isValidOutput = Array.isArray(res) && res.every(

@@ -393,3 +453,6 @@ (x) => typeof x.label === "string" && typeof x.blob === "string" && typeof x["content-type"] === "string"

async function imageClassification(args, options) {
const res = await request(args, options);
const res = await request(args, {
...options,
taskHint: "image-classification"
});
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");

@@ -404,3 +467,6 @@ if (!isValidOutput) {

async function imageSegmentation(args, options) {
const res = await request(args, options);
const res = await request(args, {
...options,
taskHint: "image-segmentation"
});
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number");

@@ -415,3 +481,6 @@ if (!isValidOutput) {

async function imageToText(args, options) {
const res = (await request(args, options))?.[0];
const res = (await request(args, {
...options,
taskHint: "image-to-text"
}))?.[0];
if (typeof res?.generated_text !== "string") {

@@ -425,3 +494,6 @@ throw new InferenceOutputError("Expected {generated_text: string}");

async function objectDetection(args, options) {
const res = await request(args, options);
const res = await request(args, {
...options,
taskHint: "object-detection"
});
const isValidOutput = Array.isArray(res) && res.every(

@@ -440,3 +512,6 @@ (x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number"

async function textToImage(args, options) {
const res = await request(args, options);
const res = await request(args, {
...options,
taskHint: "text-to-image"
});
const isValidOutput = res && res instanceof Blob;

@@ -483,3 +558,6 @@ if (!isValidOutput) {

}
const res = await request(reqArgs, options);
const res = await request(reqArgs, {
...options,
taskHint: "image-to-image"
});
const isValidOutput = res && res instanceof Blob;

@@ -504,3 +582,6 @@ if (!isValidOutput) {

};
const res = await request(reqArgs, options);
const res = await request(reqArgs, {
...options,
taskHint: "zero-shot-image-classification"
});
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");

@@ -515,3 +596,3 @@ if (!isValidOutput) {

async function conversational(args, options) {
const res = await request(args, options);
const res = await request(args, { ...options, taskHint: "conversational" });
const isValidOutput = Array.isArray(res.conversation.generated_responses) && res.conversation.generated_responses.every((x) => typeof x === "string") && Array.isArray(res.conversation.past_user_inputs) && res.conversation.past_user_inputs.every((x) => typeof x === "string") && typeof res.generated_text === "string" && Array.isArray(res.warnings) && res.warnings.every((x) => typeof x === "string");

@@ -526,43 +607,10 @@ if (!isValidOutput) {

// src/lib/getDefaultTask.ts
var taskCache = /* @__PURE__ */ new Map();
var CACHE_DURATION = 10 * 60 * 1e3;
var MAX_CACHE_ITEMS = 1e3;
var HF_HUB_URL = "https://huggingface.co";
async function getDefaultTask(model, accessToken) {
if (isUrl(model)) {
return null;
}
const key = `${model}:${accessToken}`;
let cachedTask = taskCache.get(key);
if (cachedTask && cachedTask.date < new Date(Date.now() - CACHE_DURATION)) {
taskCache.delete(key);
cachedTask = void 0;
}
if (cachedTask === void 0) {
const modelTask = await fetch(`${HF_HUB_URL}/api/models/${model}?expand[]=pipeline_tag`, {
headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {}
}).then((resp) => resp.json()).then((json) => json.pipeline_tag).catch(() => null);
if (!modelTask) {
return null;
}
cachedTask = { task: modelTask, date: /* @__PURE__ */ new Date() };
taskCache.set(key, { task: modelTask, date: /* @__PURE__ */ new Date() });
if (taskCache.size > MAX_CACHE_ITEMS) {
taskCache.delete(taskCache.keys().next().value);
}
}
return cachedTask.task;
}
// src/tasks/nlp/featureExtraction.ts
async function featureExtraction(args, options) {
const defaultTask = await getDefaultTask(args.model, args.accessToken);
const res = await request(
args,
defaultTask === "sentence-similarity" ? {
...options,
task: "feature-extraction"
} : options
);
const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken) : void 0;
const res = await request(args, {
...options,
taskHint: "feature-extraction",
...defaultTask === "sentence-similarity" && { forceTask: "feature-extraction" }
});
let isValidOutput = true;

@@ -587,3 +635,6 @@ const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {

async function fillMask(args, options) {
const res = await request(args, options);
const res = await request(args, {
...options,
taskHint: "fill-mask"
});
const isValidOutput = Array.isArray(res) && res.every(

@@ -602,3 +653,6 @@ (x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string"

async function questionAnswering(args, options) {
const res = await request(args, options);
const res = await request(args, {
...options,
taskHint: "question-answering"
});
const isValidOutput = typeof res === "object" && !!res && typeof res.answer === "string" && typeof res.end === "number" && typeof res.score === "number" && typeof res.start === "number";

@@ -613,10 +667,8 @@ if (!isValidOutput) {

async function sentenceSimilarity(args, options) {
const defaultTask = await getDefaultTask(args.model, args.accessToken);
const res = await request(
args,
defaultTask === "feature-extraction" ? {
...options,
task: "sentence-similarity"
} : options
);
const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken) : void 0;
const res = await request(args, {
...options,
taskHint: "sentence-similarity",
...defaultTask === "feature-extraction" && { forceTask: "sentence-similarity" }
});
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");

@@ -631,3 +683,6 @@ if (!isValidOutput) {

async function summarization(args, options) {
const res = await request(args, options);
const res = await request(args, {
...options,
taskHint: "summarization"
});
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.summary_text === "string");

@@ -642,3 +697,6 @@ if (!isValidOutput) {

async function tableQuestionAnswering(args, options) {
const res = await request(args, options);
const res = await request(args, {
...options,
taskHint: "table-question-answering"
});
const isValidOutput = typeof res?.aggregator === "string" && typeof res.answer === "string" && Array.isArray(res.cells) && res.cells.every((x) => typeof x === "string") && Array.isArray(res.coordinates) && res.coordinates.every((coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number"));

@@ -655,3 +713,6 @@ if (!isValidOutput) {

async function textClassification(args, options) {
const res = (await request(args, options))?.[0];
const res = (await request(args, {
...options,
taskHint: "text-classification"
}))?.[0];
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.label === "string" && typeof x.score === "number");

@@ -666,3 +727,6 @@ if (!isValidOutput) {

async function textGeneration(args, options) {
const res = await request(args, options);
const res = await request(args, {
...options,
taskHint: "text-generation"
});
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.generated_text === "string");

@@ -677,3 +741,6 @@ if (!isValidOutput) {

async function* textGenerationStream(args, options) {
yield* streamingRequest(args, options);
yield* streamingRequest(args, {
...options,
taskHint: "text-generation"
});
}

@@ -691,3 +758,8 @@

async function tokenClassification(args, options) {
const res = toArray(await request(args, options));
const res = toArray(
await request(args, {
...options,
taskHint: "token-classification"
})
);
const isValidOutput = Array.isArray(res) && res.every(

@@ -706,3 +778,6 @@ (x) => typeof x.end === "number" && typeof x.entity_group === "string" && typeof x.score === "number" && typeof x.start === "number" && typeof x.word === "string"

async function translation(args, options) {
const res = await request(args, options);
const res = await request(args, {
...options,
taskHint: "translation"
});
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.translation_text === "string");

@@ -718,3 +793,6 @@ if (!isValidOutput) {

const res = toArray(
await request(args, options)
await request(args, {
...options,
taskHint: "zero-shot-classification"
})
);

@@ -745,3 +823,6 @@ const isValidOutput = Array.isArray(res) && res.every(

const res = toArray(
await request(reqArgs, options)
await request(reqArgs, {
...options,
taskHint: "document-question-answering"
})
)?.[0];

@@ -769,3 +850,6 @@ const isValidOutput = typeof res?.answer === "string" && (typeof res.end === "number" || typeof res.end === "undefined") && (typeof res.score === "number" || typeof res.score === "undefined") && (typeof res.start === "number" || typeof res.start === "undefined");

};
const res = (await request(reqArgs, options))?.[0];
const res = (await request(reqArgs, {
...options,
taskHint: "visual-question-answering"
}))?.[0];
const isValidOutput = typeof res?.answer === "string" && typeof res.score === "number";

@@ -780,3 +864,6 @@ if (!isValidOutput) {

async function tabularRegression(args, options) {
const res = await request(args, options);
const res = await request(args, {
...options,
taskHint: "tabular-regression"
});
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");

@@ -791,3 +878,6 @@ if (!isValidOutput) {

async function tabularClassification(args, options) {
const res = await request(args, options);
const res = await request(args, {
...options,
taskHint: "tabular-classification"
});
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");

@@ -794,0 +884,0 @@ if (!isValidOutput) {

@@ -29,3 +29,35 @@ export interface Options {

export type InferenceTask = "text-classification" | "feature-extraction" | "sentence-similarity";
export type InferenceTask =
| "audio-classification"
| "audio-to-audio"
| "automatic-speech-recognition"
| "conversational"
| "depth-estimation"
| "document-question-answering"
| "feature-extraction"
| "fill-mask"
| "image-classification"
| "image-segmentation"
| "image-to-image"
| "image-to-text"
| "object-detection"
| "video-classification"
| "question-answering"
| "reinforcement-learning"
| "sentence-similarity"
| "summarization"
| "table-question-answering"
| "tabular-classification"
| "tabular-regression"
| "text-classification"
| "text-generation"
| "text-to-image"
| "text-to-speech"
| "text-to-video"
| "token-classification"
| "translation"
| "unconditional-image-generation"
| "visual-question-answering"
| "zero-shot-classification"
| "zero-shot-image-classification";

@@ -41,4 +73,6 @@ export interface BaseArgs {

* The model to use. Can be a full URL for HF inference endpoints.
*
* If not specified, will call huggingface.co/api/tasks to get the default model for the task.
*/
model: string;
model?: string;
}

@@ -149,2 +183,4 @@

task?: string | InferenceTask;
/** To load default model if needed */
taskHint?: InferenceTask;
}

@@ -162,2 +198,4 @@ ): Promise<T>;

task?: string | InferenceTask;
/** To load default model if needed */
taskHint?: InferenceTask;
}

@@ -1027,2 +1065,4 @@ ): AsyncGenerator<T>;

task?: string | InferenceTask;
/** To load default model if needed */
taskHint?: InferenceTask;
}

@@ -1040,2 +1080,4 @@ ): Promise<T>;

task?: string | InferenceTask;
/** To load default model if needed */
taskHint?: InferenceTask;
}

@@ -1234,2 +1276,4 @@ ): AsyncGenerator<T>;

task?: string | InferenceTask;
/** To load default model if needed */
taskHint?: InferenceTask;
}

@@ -1247,2 +1291,4 @@ ): Promise<T>;

task?: string | InferenceTask;
/** To load default model if needed */
taskHint?: InferenceTask;
}

@@ -1249,0 +1295,0 @@ ): AsyncGenerator<T>;

@@ -100,7 +100,40 @@ /// <reference path="./index.d.ts" />

// src/lib/getDefaultTask.ts
var taskCache = /* @__PURE__ */ new Map();
var CACHE_DURATION = 10 * 60 * 1e3;
var MAX_CACHE_ITEMS = 1e3;
var HF_HUB_URL = "https://huggingface.co";
async function getDefaultTask(model, accessToken) {
if (isUrl(model)) {
return null;
}
const key = `${model}:${accessToken}`;
let cachedTask = taskCache.get(key);
if (cachedTask && cachedTask.date < new Date(Date.now() - CACHE_DURATION)) {
taskCache.delete(key);
cachedTask = void 0;
}
if (cachedTask === void 0) {
const modelTask = await fetch(`${HF_HUB_URL}/api/models/${model}?expand[]=pipeline_tag`, {
headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {}
}).then((resp) => resp.json()).then((json) => json.pipeline_tag).catch(() => null);
if (!modelTask) {
return null;
}
cachedTask = { task: modelTask, date: /* @__PURE__ */ new Date() };
taskCache.set(key, { task: modelTask, date: /* @__PURE__ */ new Date() });
if (taskCache.size > MAX_CACHE_ITEMS) {
taskCache.delete(taskCache.keys().next().value);
}
}
return cachedTask.task;
}
// src/lib/makeRequestOptions.ts
var HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co";
function makeRequestOptions(args, options) {
const { model, accessToken, ...otherArgs } = args;
const { task, includeCredentials, ...otherOptions } = options ?? {};
var tasks = null;
async function makeRequestOptions(args, options) {
const { accessToken, model: _model, ...otherArgs } = args;
let { model } = args;
const { forceTask: task, includeCredentials, taskHint, ...otherOptions } = options ?? {};
const headers = {};

@@ -110,2 +143,17 @@ if (accessToken) {

}
if (!model && !tasks && taskHint) {
const res = await fetch(`${HF_HUB_URL}/api/tasks`);
if (res.ok) {
tasks = await res.json();
}
}
if (!model && tasks && taskHint) {
const taskInfo = tasks[taskHint];
if (taskInfo) {
model = taskInfo.models[0].id;
}
}
if (!model) {
throw new Error("No model provided, and no default model found for this task");
}
const binary = "data" in args && !!args.data;

@@ -148,3 +196,3 @@ if (!binary) {

async function request(args, options) {
const { url, info } = makeRequestOptions(args, options);
const { url, info } = await makeRequestOptions(args, options);
const response = await (options?.fetch ?? fetch)(url, info);

@@ -273,3 +321,3 @@ if (options?.retry_on_error !== false && response.status === 503 && !options?.wait_for_model) {

async function* streamingRequest(args, options) {
const { url, info } = makeRequestOptions({ ...args, stream: true }, options);
const { url, info } = await makeRequestOptions({ ...args, stream: true }, options);
const response = await (options?.fetch ?? fetch)(url, info);

@@ -347,3 +395,6 @@ if (options?.retry_on_error !== false && response.status === 503 && !options?.wait_for_model) {

async function audioClassification(args, options) {
const res = await request(args, options);
const res = await request(args, {
...options,
taskHint: "audio-classification"
});
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");

@@ -358,3 +409,6 @@ if (!isValidOutput) {

async function automaticSpeechRecognition(args, options) {
const res = await request(args, options);
const res = await request(args, {
...options,
taskHint: "automatic-speech-recognition"
});
const isValidOutput = typeof res?.text === "string";

@@ -369,3 +423,6 @@ if (!isValidOutput) {

async function textToSpeech(args, options) {
const res = await request(args, options);
const res = await request(args, {
...options,
taskHint: "text-to-speech"
});
const isValidOutput = res && res instanceof Blob;

@@ -380,3 +437,6 @@ if (!isValidOutput) {

async function audioToAudio(args, options) {
const res = await request(args, options);
const res = await request(args, {
...options,
taskHint: "audio-to-audio"
});
const isValidOutput = Array.isArray(res) && res.every(

@@ -393,3 +453,6 @@ (x) => typeof x.label === "string" && typeof x.blob === "string" && typeof x["content-type"] === "string"

async function imageClassification(args, options) {
const res = await request(args, options);
const res = await request(args, {
...options,
taskHint: "image-classification"
});
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");

@@ -404,3 +467,6 @@ if (!isValidOutput) {

async function imageSegmentation(args, options) {
const res = await request(args, options);
const res = await request(args, {
...options,
taskHint: "image-segmentation"
});
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number");

@@ -415,3 +481,6 @@ if (!isValidOutput) {

async function imageToText(args, options) {
const res = (await request(args, options))?.[0];
const res = (await request(args, {
...options,
taskHint: "image-to-text"
}))?.[0];
if (typeof res?.generated_text !== "string") {

@@ -425,3 +494,6 @@ throw new InferenceOutputError("Expected {generated_text: string}");

async function objectDetection(args, options) {
const res = await request(args, options);
const res = await request(args, {
...options,
taskHint: "object-detection"
});
const isValidOutput = Array.isArray(res) && res.every(

@@ -440,3 +512,6 @@ (x) => typeof x.label === "string" && typeof x.score === "number" && typeof x.box.xmin === "number" && typeof x.box.ymin === "number" && typeof x.box.xmax === "number" && typeof x.box.ymax === "number"

async function textToImage(args, options) {
const res = await request(args, options);
const res = await request(args, {
...options,
taskHint: "text-to-image"
});
const isValidOutput = res && res instanceof Blob;

@@ -483,3 +558,6 @@ if (!isValidOutput) {

}
const res = await request(reqArgs, options);
const res = await request(reqArgs, {
...options,
taskHint: "image-to-image"
});
const isValidOutput = res && res instanceof Blob;

@@ -504,3 +582,6 @@ if (!isValidOutput) {

};
const res = await request(reqArgs, options);
const res = await request(reqArgs, {
...options,
taskHint: "zero-shot-image-classification"
});
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");

@@ -515,3 +596,3 @@ if (!isValidOutput) {

async function conversational(args, options) {
const res = await request(args, options);
const res = await request(args, { ...options, taskHint: "conversational" });
const isValidOutput = Array.isArray(res.conversation.generated_responses) && res.conversation.generated_responses.every((x) => typeof x === "string") && Array.isArray(res.conversation.past_user_inputs) && res.conversation.past_user_inputs.every((x) => typeof x === "string") && typeof res.generated_text === "string" && Array.isArray(res.warnings) && res.warnings.every((x) => typeof x === "string");

@@ -526,43 +607,10 @@ if (!isValidOutput) {

// src/lib/getDefaultTask.ts
var taskCache = /* @__PURE__ */ new Map();
var CACHE_DURATION = 10 * 60 * 1e3;
var MAX_CACHE_ITEMS = 1e3;
var HF_HUB_URL = "https://huggingface.co";
async function getDefaultTask(model, accessToken) {
if (isUrl(model)) {
return null;
}
const key = `${model}:${accessToken}`;
let cachedTask = taskCache.get(key);
if (cachedTask && cachedTask.date < new Date(Date.now() - CACHE_DURATION)) {
taskCache.delete(key);
cachedTask = void 0;
}
if (cachedTask === void 0) {
const modelTask = await fetch(`${HF_HUB_URL}/api/models/${model}?expand[]=pipeline_tag`, {
headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {}
}).then((resp) => resp.json()).then((json) => json.pipeline_tag).catch(() => null);
if (!modelTask) {
return null;
}
cachedTask = { task: modelTask, date: /* @__PURE__ */ new Date() };
taskCache.set(key, { task: modelTask, date: /* @__PURE__ */ new Date() });
if (taskCache.size > MAX_CACHE_ITEMS) {
taskCache.delete(taskCache.keys().next().value);
}
}
return cachedTask.task;
}
// src/tasks/nlp/featureExtraction.ts
async function featureExtraction(args, options) {
const defaultTask = await getDefaultTask(args.model, args.accessToken);
const res = await request(
args,
defaultTask === "sentence-similarity" ? {
...options,
task: "feature-extraction"
} : options
);
const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken) : void 0;
const res = await request(args, {
...options,
taskHint: "feature-extraction",
...defaultTask === "sentence-similarity" && { forceTask: "feature-extraction" }
});
let isValidOutput = true;

@@ -587,3 +635,6 @@ const isNumArrayRec = (arr, maxDepth, curDepth = 0) => {

async function fillMask(args, options) {
const res = await request(args, options);
const res = await request(args, {
...options,
taskHint: "fill-mask"
});
const isValidOutput = Array.isArray(res) && res.every(

@@ -602,3 +653,6 @@ (x) => typeof x.score === "number" && typeof x.sequence === "string" && typeof x.token === "number" && typeof x.token_str === "string"

async function questionAnswering(args, options) {
const res = await request(args, options);
const res = await request(args, {
...options,
taskHint: "question-answering"
});
const isValidOutput = typeof res === "object" && !!res && typeof res.answer === "string" && typeof res.end === "number" && typeof res.score === "number" && typeof res.start === "number";

@@ -613,10 +667,8 @@ if (!isValidOutput) {

async function sentenceSimilarity(args, options) {
const defaultTask = await getDefaultTask(args.model, args.accessToken);
const res = await request(
args,
defaultTask === "feature-extraction" ? {
...options,
task: "sentence-similarity"
} : options
);
const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken) : void 0;
const res = await request(args, {
...options,
taskHint: "sentence-similarity",
...defaultTask === "feature-extraction" && { forceTask: "sentence-similarity" }
});
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");

@@ -631,3 +683,6 @@ if (!isValidOutput) {

async function summarization(args, options) {
const res = await request(args, options);
const res = await request(args, {
...options,
taskHint: "summarization"
});
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.summary_text === "string");

@@ -642,3 +697,6 @@ if (!isValidOutput) {

async function tableQuestionAnswering(args, options) {
const res = await request(args, options);
const res = await request(args, {
...options,
taskHint: "table-question-answering"
});
const isValidOutput = typeof res?.aggregator === "string" && typeof res.answer === "string" && Array.isArray(res.cells) && res.cells.every((x) => typeof x === "string") && Array.isArray(res.coordinates) && res.coordinates.every((coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number"));

@@ -655,3 +713,6 @@ if (!isValidOutput) {

async function textClassification(args, options) {
const res = (await request(args, options))?.[0];
const res = (await request(args, {
...options,
taskHint: "text-classification"
}))?.[0];
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.label === "string" && typeof x.score === "number");

@@ -666,3 +727,6 @@ if (!isValidOutput) {

async function textGeneration(args, options) {
const res = await request(args, options);
const res = await request(args, {
...options,
taskHint: "text-generation"
});
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.generated_text === "string");

@@ -677,3 +741,6 @@ if (!isValidOutput) {

async function* textGenerationStream(args, options) {
yield* streamingRequest(args, options);
yield* streamingRequest(args, {
...options,
taskHint: "text-generation"
});
}

@@ -691,3 +758,8 @@

async function tokenClassification(args, options) {
const res = toArray(await request(args, options));
const res = toArray(
await request(args, {
...options,
taskHint: "token-classification"
})
);
const isValidOutput = Array.isArray(res) && res.every(

@@ -706,3 +778,6 @@ (x) => typeof x.end === "number" && typeof x.entity_group === "string" && typeof x.score === "number" && typeof x.start === "number" && typeof x.word === "string"

async function translation(args, options) {
const res = await request(args, options);
const res = await request(args, {
...options,
taskHint: "translation"
});
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.translation_text === "string");

@@ -718,3 +793,6 @@ if (!isValidOutput) {

const res = toArray(
await request(args, options)
await request(args, {
...options,
taskHint: "zero-shot-classification"
})
);

@@ -745,3 +823,6 @@ const isValidOutput = Array.isArray(res) && res.every(

const res = toArray(
await request(reqArgs, options)
await request(reqArgs, {
...options,
taskHint: "document-question-answering"
})
)?.[0];

@@ -769,3 +850,6 @@ const isValidOutput = typeof res?.answer === "string" && (typeof res.end === "number" || typeof res.end === "undefined") && (typeof res.score === "number" || typeof res.score === "undefined") && (typeof res.start === "number" || typeof res.start === "undefined");

};
const res = (await request(reqArgs, options))?.[0];
const res = (await request(reqArgs, {
...options,
taskHint: "visual-question-answering"
}))?.[0];
const isValidOutput = typeof res?.answer === "string" && typeof res.score === "number";

@@ -780,3 +864,6 @@ if (!isValidOutput) {

async function tabularRegression(args, options) {
const res = await request(args, options);
const res = await request(args, {
...options,
taskHint: "tabular-regression"
});
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");

@@ -791,3 +878,6 @@ if (!isValidOutput) {

async function tabularClassification(args, options) {
const res = await request(args, options);
const res = await request(args, {
...options,
taskHint: "tabular-classification"
});
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");

@@ -794,0 +884,0 @@ if (!isValidOutput) {

{
"name": "@huggingface/inference",
"version": "2.5.2",
"version": "2.6.0",
"packageManager": "pnpm@8.3.1",

@@ -5,0 +5,0 @@ "license": "MIT",

@@ -11,3 +11,3 @@ import { isUrl } from "./isUrl";

const MAX_CACHE_ITEMS = 1000;
const HF_HUB_URL = "https://huggingface.co";
export const HF_HUB_URL = "https://huggingface.co";

@@ -14,0 +14,0 @@ /**

import type { InferenceTask, Options, RequestArgs } from "../types";
import { HF_HUB_URL } from "./getDefaultTask";
import { isUrl } from "./isUrl";

@@ -7,5 +8,10 @@

/**
* Loaded from huggingface.co/api/tasks if needed
*/
let tasks: Record<string, { models: { id: string }[] }> | null = null;
/**
* Helper that prepares request arguments
*/
export function makeRequestOptions(
export async function makeRequestOptions(
args: RequestArgs & {

@@ -19,7 +25,11 @@ data?: Blob | ArrayBuffer;

/** When a model can be used for multiple tasks, and we want to run a non-default task */
task?: string | InferenceTask;
forceTask?: string | InferenceTask;
/** To load default model if needed */
taskHint?: InferenceTask;
}
): { url: string; info: RequestInit } {
const { model, accessToken, ...otherArgs } = args;
const { task, includeCredentials, ...otherOptions } = options ?? {};
): Promise<{ url: string; info: RequestInit }> {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const { accessToken, model: _model, ...otherArgs } = args;
let { model } = args;
const { forceTask: task, includeCredentials, taskHint, ...otherOptions } = options ?? {};

@@ -31,2 +41,21 @@ const headers: Record<string, string> = {};

if (!model && !tasks && taskHint) {
const res = await fetch(`${HF_HUB_URL}/api/tasks`);
if (res.ok) {
tasks = await res.json();
}
}
if (!model && tasks && taskHint) {
const taskInfo = tasks[taskHint];
if (taskInfo) {
model = taskInfo.models[0].id;
}
}
if (!model) {
throw new Error("No model provided, and no default model found for this task");
}
const binary = "data" in args && !!args.data;

@@ -33,0 +62,0 @@

@@ -34,3 +34,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";

): Promise<AudioClassificationReturn> {
const res = await request<AudioClassificationReturn>(args, options);
const res = await request<AudioClassificationReturn>(args, {
...options,
taskHint: "audio-classification",
});
const isValidOutput =

@@ -37,0 +40,0 @@ Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");

@@ -36,3 +36,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";

export async function audioToAudio(args: AudioToAudioArgs, options?: Options): Promise<AudioToAudioReturn> {
const res = await request<AudioToAudioReturn>(args, options);
const res = await request<AudioToAudioReturn>(args, {
...options,
taskHint: "audio-to-audio",
});
const isValidOutput =

@@ -39,0 +42,0 @@ Array.isArray(res) &&

@@ -27,3 +27,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";

): Promise<AutomaticSpeechRecognitionOutput> {
const res = await request<AutomaticSpeechRecognitionOutput>(args, options);
const res = await request<AutomaticSpeechRecognitionOutput>(args, {
...options,
taskHint: "automatic-speech-recognition",
});
const isValidOutput = typeof res?.text === "string";

@@ -30,0 +33,0 @@ if (!isValidOutput) {

@@ -19,3 +19,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";

export async function textToSpeech(args: TextToSpeechArgs, options?: Options): Promise<TextToSpeechOutput> {
const res = await request<TextToSpeechOutput>(args, options);
const res = await request<TextToSpeechOutput>(args, {
...options,
taskHint: "text-to-speech",
});
const isValidOutput = res && res instanceof Blob;

@@ -22,0 +25,0 @@ if (!isValidOutput) {

@@ -14,5 +14,7 @@ import type { InferenceTask, Options, RequestArgs } from "../../types";

task?: string | InferenceTask;
/** To load default model if needed */
taskHint?: InferenceTask;
}
): Promise<T> {
const { url, info } = makeRequestOptions(args, options);
const { url, info } = await makeRequestOptions(args, options);
const response = await (options?.fetch ?? fetch)(url, info);

@@ -19,0 +21,0 @@

@@ -16,5 +16,7 @@ import type { InferenceTask, Options, RequestArgs } from "../../types";

task?: string | InferenceTask;
/** To load default model if needed */
taskHint?: InferenceTask;
}
): AsyncGenerator<T> {
const { url, info } = makeRequestOptions({ ...args, stream: true }, options);
const { url, info } = await makeRequestOptions({ ...args, stream: true }, options);
const response = await (options?.fetch ?? fetch)(url, info);

@@ -21,0 +23,0 @@

@@ -33,3 +33,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";

): Promise<ImageClassificationOutput> {
const res = await request<ImageClassificationOutput>(args, options);
const res = await request<ImageClassificationOutput>(args, {
...options,
taskHint: "image-classification",
});
const isValidOutput =

@@ -36,0 +39,0 @@ Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");

@@ -37,3 +37,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";

): Promise<ImageSegmentationOutput> {
const res = await request<ImageSegmentationOutput>(args, options);
const res = await request<ImageSegmentationOutput>(args, {
...options,
taskHint: "image-segmentation",
});
const isValidOutput =

@@ -40,0 +43,0 @@ Array.isArray(res) &&

@@ -77,3 +77,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";

}
const res = await request<ImageToImageOutput>(reqArgs, options);
const res = await request<ImageToImageOutput>(reqArgs, {
...options,
taskHint: "image-to-image",
});
const isValidOutput = res && res instanceof Blob;

@@ -80,0 +83,0 @@ if (!isValidOutput) {

@@ -23,3 +23,8 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";

export async function imageToText(args: ImageToTextArgs, options?: Options): Promise<ImageToTextOutput> {
const res = (await request<[ImageToTextOutput]>(args, options))?.[0];
const res = (
await request<[ImageToTextOutput]>(args, {
...options,
taskHint: "image-to-text",
})
)?.[0];

@@ -26,0 +31,0 @@ if (typeof res?.generated_text !== "string") {

@@ -40,3 +40,6 @@ import { request } from "../custom/request";

export async function objectDetection(args: ObjectDetectionArgs, options?: Options): Promise<ObjectDetectionOutput> {
const res = await request<ObjectDetectionOutput>(args, options);
const res = await request<ObjectDetectionOutput>(args, {
...options,
taskHint: "object-detection",
});
const isValidOutput =

@@ -43,0 +46,0 @@ Array.isArray(res) &&

@@ -42,3 +42,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";

export async function textToImage(args: TextToImageArgs, options?: Options): Promise<TextToImageOutput> {
const res = await request<TextToImageOutput>(args, options);
const res = await request<TextToImageOutput>(args, {
...options,
taskHint: "text-to-image",
});
const isValidOutput = res && res instanceof Blob;

@@ -45,0 +48,0 @@ if (!isValidOutput) {

@@ -48,3 +48,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";

const res = await request<ZeroShotImageClassificationOutput>(reqArgs, options);
const res = await request<ZeroShotImageClassificationOutput>(reqArgs, {
...options,
taskHint: "zero-shot-image-classification",
});
const isValidOutput =

@@ -51,0 +54,0 @@ Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");

@@ -59,3 +59,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";

const res = toArray(
await request<[DocumentQuestionAnsweringOutput] | DocumentQuestionAnsweringOutput>(reqArgs, options)
await request<[DocumentQuestionAnsweringOutput] | DocumentQuestionAnsweringOutput>(reqArgs, {
...options,
taskHint: "document-question-answering",
})
)?.[0];

@@ -62,0 +65,0 @@ const isValidOutput =

@@ -48,3 +48,8 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";

} as RequestArgs;
const res = (await request<[VisualQuestionAnsweringOutput]>(reqArgs, options))?.[0];
const res = (
await request<[VisualQuestionAnsweringOutput]>(reqArgs, {
...options,
taskHint: "visual-question-answering",
})
)?.[0];
const isValidOutput = typeof res?.answer === "string" && typeof res.score === "number";

@@ -51,0 +56,0 @@ if (!isValidOutput) {

@@ -66,3 +66,3 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";

export async function conversational(args: ConversationalArgs, options?: Options): Promise<ConversationalOutput> {
const res = await request<ConversationalOutput>(args, options);
const res = await request<ConversationalOutput>(args, { ...options, taskHint: "conversational" });
const isValidOutput =

@@ -69,0 +69,0 @@ Array.isArray(res.conversation.generated_responses) &&

@@ -28,12 +28,9 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";

): Promise<FeatureExtractionOutput> {
const defaultTask = await getDefaultTask(args.model, args.accessToken);
const res = await request<FeatureExtractionOutput>(
args,
defaultTask === "sentence-similarity"
? {
...options,
task: "feature-extraction",
}
: options
);
const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken) : undefined;
const res = await request<FeatureExtractionOutput>(args, {
...options,
taskHint: "feature-extraction",
...(defaultTask === "sentence-similarity" && { forceTask: "feature-extraction" }),
});
let isValidOutput = true;

@@ -40,0 +37,0 @@

@@ -32,3 +32,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";

export async function fillMask(args: FillMaskArgs, options?: Options): Promise<FillMaskOutput> {
const res = await request<FillMaskOutput>(args, options);
const res = await request<FillMaskOutput>(args, {
...options,
taskHint: "fill-mask",
});
const isValidOutput =

@@ -35,0 +38,0 @@ Array.isArray(res) &&

@@ -38,3 +38,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";

): Promise<QuestionAnsweringOutput> {
const res = await request<QuestionAnsweringOutput>(args, options);
const res = await request<QuestionAnsweringOutput>(args, {
...options,
taskHint: "question-answering",
});
const isValidOutput =

@@ -41,0 +44,0 @@ typeof res === "object" &&

@@ -28,12 +28,8 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";

): Promise<SentenceSimilarityOutput> {
const defaultTask = await getDefaultTask(args.model, args.accessToken);
const res = await request<SentenceSimilarityOutput>(
args,
defaultTask === "feature-extraction"
? {
...options,
task: "sentence-similarity",
}
: options
);
const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken) : undefined;
const res = await request<SentenceSimilarityOutput>(args, {
...options,
taskHint: "sentence-similarity",
...(defaultTask === "feature-extraction" && { forceTask: "sentence-similarity" }),
});

@@ -40,0 +36,0 @@ const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");

@@ -53,3 +53,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";

export async function summarization(args: SummarizationArgs, options?: Options): Promise<SummarizationOutput> {
const res = await request<SummarizationOutput[]>(args, options);
const res = await request<SummarizationOutput[]>(args, {
...options,
taskHint: "summarization",
});
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.summary_text === "string");

@@ -56,0 +59,0 @@ if (!isValidOutput) {

@@ -44,3 +44,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";

): Promise<TableQuestionAnsweringOutput> {
const res = await request<TableQuestionAnsweringOutput>(args, options);
const res = await request<TableQuestionAnsweringOutput>(args, {
...options,
taskHint: "table-question-answering",
});
const isValidOutput =

@@ -47,0 +50,0 @@ typeof res?.aggregator === "string" &&

@@ -30,3 +30,8 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";

): Promise<TextClassificationOutput> {
const res = (await request<TextClassificationOutput[]>(args, options))?.[0];
const res = (
await request<TextClassificationOutput[]>(args, {
...options,
taskHint: "text-classification",
})
)?.[0];
const isValidOutput =

@@ -33,0 +38,0 @@ Array.isArray(res) && res.every((x) => typeof x?.label === "string" && typeof x.score === "number");

@@ -65,3 +65,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";

export async function textGeneration(args: TextGenerationArgs, options?: Options): Promise<TextGenerationOutput> {
const res = await request<TextGenerationOutput[]>(args, options);
const res = await request<TextGenerationOutput[]>(args, {
...options,
taskHint: "text-generation",
});
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.generated_text === "string");

@@ -68,0 +71,0 @@ if (!isValidOutput) {

@@ -91,3 +91,6 @@ import type { Options } from "../../types";

): AsyncGenerator<TextGenerationStreamOutput> {
yield* streamingRequest<TextGenerationStreamOutput>(args, options);
yield* streamingRequest<TextGenerationStreamOutput>(args, {
...options,
taskHint: "text-generation",
});
}

@@ -61,3 +61,8 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";

): Promise<TokenClassificationOutput> {
const res = toArray(await request<TokenClassificationOutput[number] | TokenClassificationOutput>(args, options));
const res = toArray(
await request<TokenClassificationOutput[number] | TokenClassificationOutput>(args, {
...options,
taskHint: "token-classification",
})
);
const isValidOutput =

@@ -64,0 +69,0 @@ Array.isArray(res) &&

@@ -23,3 +23,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";

export async function translation(args: TranslationArgs, options?: Options): Promise<TranslationOutput> {
const res = await request<TranslationOutput[]>(args, options);
const res = await request<TranslationOutput[]>(args, {
...options,
taskHint: "translation",
});
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.translation_text === "string");

@@ -26,0 +29,0 @@ if (!isValidOutput) {

@@ -39,3 +39,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";

const res = toArray(
await request<ZeroShotClassificationOutput[number] | ZeroShotClassificationOutput>(args, options)
await request<ZeroShotClassificationOutput[number] | ZeroShotClassificationOutput>(args, {
...options,
taskHint: "zero-shot-classification",
})
);

@@ -42,0 +45,0 @@ const isValidOutput =

@@ -28,3 +28,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";

): Promise<TabularClassificationOutput> {
const res = await request<TabularClassificationOutput>(args, options);
const res = await request<TabularClassificationOutput>(args, {
...options,
taskHint: "tabular-classification",
});
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");

@@ -31,0 +34,0 @@ if (!isValidOutput) {

@@ -28,3 +28,6 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";

): Promise<TabularRegressionOutput> {
const res = await request<TabularRegressionOutput>(args, options);
const res = await request<TabularRegressionOutput>(args, {
...options,
taskHint: "tabular-regression",
});
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");

@@ -31,0 +34,0 @@ if (!isValidOutput) {

@@ -29,3 +29,35 @@ export interface Options {

export type InferenceTask = "text-classification" | "feature-extraction" | "sentence-similarity";
export type InferenceTask =
| "audio-classification"
| "audio-to-audio"
| "automatic-speech-recognition"
| "conversational"
| "depth-estimation"
| "document-question-answering"
| "feature-extraction"
| "fill-mask"
| "image-classification"
| "image-segmentation"
| "image-to-image"
| "image-to-text"
| "object-detection"
| "video-classification"
| "question-answering"
| "reinforcement-learning"
| "sentence-similarity"
| "summarization"
| "table-question-answering"
| "tabular-classification"
| "tabular-regression"
| "text-classification"
| "text-generation"
| "text-to-image"
| "text-to-speech"
| "text-to-video"
| "token-classification"
| "translation"
| "unconditional-image-generation"
| "visual-question-answering"
| "zero-shot-classification"
| "zero-shot-image-classification";

@@ -41,4 +73,6 @@ export interface BaseArgs {

* The model to use. Can be a full URL for HF inference endpoints.
*
* If not specified, will call huggingface.co/api/tasks to get the default model for the task.
*/
model: string;
model?: string;
}

@@ -45,0 +79,0 @@

Sorry, the diff of this file is not supported yet