@huggingface/inference
Advanced tools
Comparing version
@@ -119,3 +119,3 @@ "use strict"; | ||
const { url, info } = makeRequestOptions(args, options); | ||
const response = await fetch(url, info); | ||
const response = await (options?.custom_fetch ?? fetch)(url, info); | ||
if (options?.retry_on_error !== false && response.status === 503 && !options?.wait_for_model) { | ||
@@ -244,3 +244,3 @@ return request(args, { | ||
const { url, info } = makeRequestOptions({ ...args, stream: true }, options); | ||
const response = await fetch(url, info); | ||
const response = await (options?.custom_fetch ?? fetch)(url, info); | ||
if (options?.retry_on_error !== false && response.status === 503 && !options?.wait_for_model) { | ||
@@ -438,3 +438,3 @@ return streamingRequest(args, { | ||
const res = await request(args, options); | ||
const isValidOutput = typeof res?.answer === "string" && typeof res.end === "number" && typeof res.score === "number" && typeof res.start === "number"; | ||
const isValidOutput = typeof res === "object" && !!res && typeof res.answer === "string" && typeof res.end === "number" && typeof res.score === "number" && typeof res.start === "number"; | ||
if (!isValidOutput) { | ||
@@ -572,6 +572,8 @@ throw new InferenceOutputError("Expected {answer: string, end: number, score: number, start: number}"); | ||
}; | ||
const res = (await request(reqArgs, options))?.[0]; | ||
const isValidOutput = typeof res?.answer === "string" && typeof res.end === "number" && typeof res.score === "number" && typeof res.start === "number"; | ||
const res = toArray( | ||
await request(reqArgs, options) | ||
)?.[0]; | ||
const isValidOutput = typeof res?.answer === "string" && (typeof res.end === "number" || typeof res.end === "undefined") && (typeof res.score === "number" || typeof res.score === "undefined") && (typeof res.start === "number" || typeof res.start === "undefined"); | ||
if (!isValidOutput) { | ||
throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>"); | ||
throw new InferenceOutputError("Expected Array<{answer: string, end?: number, score?: number, start?: number}>"); | ||
} | ||
@@ -578,0 +580,0 @@ return res; |
{ | ||
"name": "@huggingface/inference", | ||
"version": "2.1.0", | ||
"version": "2.1.1", | ||
"license": "MIT", | ||
@@ -5,0 +5,0 @@ "author": "Tim Mikeladze <tim.mikeladze@gmail.com>", |
@@ -15,3 +15,3 @@ import type { Options, RequestArgs } from "../../types"; | ||
const { url, info } = makeRequestOptions(args, options); | ||
const response = await fetch(url, info); | ||
const response = await (options?.custom_fetch ?? fetch)(url, info); | ||
@@ -18,0 +18,0 @@ if (options?.retry_on_error !== false && response.status === 503 && !options?.wait_for_model) { |
@@ -17,3 +17,3 @@ import type { Options, RequestArgs } from "../../types"; | ||
const { url, info } = makeRequestOptions({ ...args, stream: true }, options); | ||
const response = await fetch(url, info); | ||
const response = await (options?.custom_fetch ?? fetch)(url, info); | ||
@@ -20,0 +20,0 @@ if (options?.retry_on_error !== false && response.status === 503 && !options?.wait_for_model) { |
@@ -6,2 +6,3 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError"; | ||
import { base64FromBytes } from "../../../../shared/src/base64FromBytes"; | ||
import { toArray } from "../../utils/toArray"; | ||
@@ -28,11 +29,11 @@ export type DocumentQuestionAnsweringArgs = BaseArgs & { | ||
*/ | ||
end: number; | ||
end?: number; | ||
/** | ||
* A float that represents how likely that the answer is correct | ||
*/ | ||
score: number; | ||
score?: number; | ||
/** | ||
* ? | ||
*/ | ||
start: number; | ||
start?: number; | ||
} | ||
@@ -55,12 +56,14 @@ | ||
} as RequestArgs; | ||
const res = (await request<[DocumentQuestionAnsweringOutput]>(reqArgs, options))?.[0]; | ||
const res = toArray( | ||
await request<[DocumentQuestionAnsweringOutput] | DocumentQuestionAnsweringOutput>(reqArgs, options) | ||
)?.[0]; | ||
const isValidOutput = | ||
typeof res?.answer === "string" && | ||
typeof res.end === "number" && | ||
typeof res.score === "number" && | ||
typeof res.start === "number"; | ||
(typeof res.end === "number" || typeof res.end === "undefined") && | ||
(typeof res.score === "number" || typeof res.score === "undefined") && | ||
(typeof res.start === "number" || typeof res.start === "undefined"); | ||
if (!isValidOutput) { | ||
throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>"); | ||
throw new InferenceOutputError("Expected Array<{answer: string, end?: number, score?: number, start?: number}>"); | ||
} | ||
return res; | ||
} |
@@ -40,3 +40,5 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError"; | ||
const isValidOutput = | ||
typeof res?.answer === "string" && | ||
typeof res === "object" && | ||
!!res && | ||
typeof res.answer === "string" && | ||
typeof res.end === "number" && | ||
@@ -43,0 +45,0 @@ typeof res.score === "number" && |
@@ -23,2 +23,6 @@ export interface Options { | ||
wait_for_model?: boolean; | ||
/** | ||
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. | ||
*/ | ||
custom_fetch?: typeof fetch; | ||
} | ||
@@ -25,0 +29,0 @@ |
Sorry, the diff of this file is not supported yet
122466
0.73%3148
0.41%