@fuyun/generative-ai
Advanced tools
Comparing version
@@ -44,8 +44,8 @@ /** | ||
model: string; | ||
baseURL: string; | ||
params?: StartChatParams; | ||
requestOptions?: RequestOptions; | ||
private _apiKey; | ||
private _history; | ||
private _sendPromise; | ||
constructor(apiKey: string, model: string, baseURL: string, params?: StartChatParams); | ||
constructor(apiKey: string, model: string, params?: StartChatParams, requestOptions?: RequestOptions); | ||
/** | ||
@@ -267,7 +267,7 @@ * Gets the chat history so far. Blocked prompts are not added to history. | ||
apiKey: string; | ||
baseURL?: string; | ||
model: string; | ||
generationConfig: GenerationConfig; | ||
safetySettings: SafetySetting[]; | ||
constructor(apiKey: string, modelParams: ModelParams, baseURL?: string); | ||
requestOptions: RequestOptions; | ||
constructor(apiKey: string, modelParams: ModelParams, requestOptions?: RequestOptions); | ||
/** | ||
@@ -310,12 +310,11 @@ * Makes a single non-streaming call to the model | ||
apiKey: string; | ||
baseURL?: string; | ||
constructor(apiKey: string, baseURL?: string); | ||
constructor(apiKey: string); | ||
/** | ||
* Gets a {@link GenerativeModel} instance for the provided model name. | ||
*/ | ||
getGenerativeModel(modelParams: ModelParams): GenerativeModel; | ||
getGenerativeModel(modelParams: ModelParams, requestOptions?: RequestOptions): GenerativeModel; | ||
} | ||
/** | ||
* Threshhold above which a prompt or candidate will be blocked. | ||
* Threshold above which a prompt or candidate will be blocked. | ||
* @public | ||
@@ -415,2 +414,11 @@ */ | ||
/** | ||
* Params passed to {@link GoogleGenerativeAI.getGenerativeModel}. | ||
* @public | ||
*/ | ||
export declare interface RequestOptions { | ||
timeout?: number; | ||
baseURL?: string; | ||
} | ||
/** | ||
* A safety rating associated with a {@link GenerateContentCandidate} | ||
@@ -417,0 +425,0 @@ * @public |
@@ -32,3 +32,3 @@ 'use strict'; | ||
/** | ||
* Threshhold above which a prompt or candidate will be blocked. | ||
* Threshold above which a prompt or candidate will be blocked. | ||
* @public | ||
@@ -161,3 +161,3 @@ */ | ||
*/ | ||
const PACKAGE_VERSION = "0.1.3"; | ||
const PACKAGE_VERSION = "0.2.0"; | ||
const PACKAGE_LOG_HEADER = "genai-js"; | ||
@@ -173,3 +173,3 @@ var Task; | ||
class RequestUrl { | ||
constructor(model, task, apiKey, stream, baseURL) { | ||
constructor(model, task, apiKey, stream, requestOptions) { | ||
this.model = model; | ||
@@ -179,6 +179,7 @@ this.task = task; | ||
this.stream = stream; | ||
this.baseURL = baseURL; | ||
this.requestOptions = requestOptions; | ||
} | ||
toString() { | ||
let url = `${this.baseURL}/${API_VERSION}/models/${this.model}:${this.task}`; | ||
const baseURL = this.requestOptions.baseURL || 'https://generativelanguage.googleapis.com'; | ||
let url = `${baseURL}/${API_VERSION}/models/${this.model}:${this.task}`; | ||
if (this.stream) { | ||
@@ -196,14 +197,10 @@ url += "?alt=sse"; | ||
} | ||
async function makeRequest(url, body) { | ||
async function makeRequest(url, body, requestOptions) { | ||
let response; | ||
try { | ||
response = await fetch(url.toString(), { | ||
method: "POST", | ||
headers: { | ||
response = await fetch(url.toString(), Object.assign(Object.assign({}, buildFetchOptions(requestOptions)), { method: "POST", headers: { | ||
"Content-Type": "application/json", | ||
"x-goog-api-client": getClientHeaders(), | ||
"x-goog-api-key": url.apiKey, | ||
}, | ||
body, | ||
}); | ||
}, body })); | ||
if (!response.ok) { | ||
@@ -231,2 +228,17 @@ let message = ""; | ||
} | ||
/** | ||
* Generates the request options to be passed to the fetch API. | ||
* @param requestOptions - The user-defined request options. | ||
* @returns The generated request options. | ||
*/ | ||
function buildFetchOptions(requestOptions) { | ||
const fetchOptions = {}; | ||
if ((requestOptions === null || requestOptions === void 0 ? void 0 : requestOptions.timeout) >= 0) { | ||
const abortController = new AbortController(); | ||
const signal = abortController.signal; | ||
setTimeout(() => abortController.abort(), requestOptions.timeout); | ||
fetchOptions.signal = signal; | ||
} | ||
return fetchOptions; | ||
} | ||
@@ -522,12 +534,12 @@ /** | ||
*/ | ||
async function generateContentStream(apiKey, model, params, baseURL) { | ||
async function generateContentStream(apiKey, model, params, requestOptions) { | ||
const url = new RequestUrl(model, Task.STREAM_GENERATE_CONTENT, apiKey, | ||
/* stream */ true, baseURL); | ||
const response = await makeRequest(url, JSON.stringify(params)); | ||
/* stream */ true, requestOptions); | ||
const response = await makeRequest(url, JSON.stringify(params), requestOptions); | ||
return processStream(response); | ||
} | ||
async function generateContent(apiKey, model, params, baseURL) { | ||
async function generateContent(apiKey, model, params, requestOptions) { | ||
const url = new RequestUrl(model, Task.GENERATE_CONTENT, apiKey, | ||
/* stream */ false, baseURL); | ||
const response = await makeRequest(url, JSON.stringify(params)); | ||
/* stream */ false, requestOptions); | ||
const response = await makeRequest(url, JSON.stringify(params), requestOptions); | ||
const responseJson = await response.json(); | ||
@@ -617,6 +629,6 @@ const enhancedResponse = addHelpers(responseJson); | ||
class ChatSession { | ||
constructor(apiKey, model, baseURL, params) { | ||
constructor(apiKey, model, params, requestOptions) { | ||
this.model = model; | ||
this.baseURL = baseURL; | ||
this.params = params; | ||
this.requestOptions = requestOptions; | ||
this._history = []; | ||
@@ -659,3 +671,3 @@ this._sendPromise = Promise.resolve(); | ||
this._sendPromise = this._sendPromise | ||
.then(() => generateContent(this._apiKey, this.model, generateContentRequest, this.baseURL)) | ||
.then(() => generateContent(this._apiKey, this.model, generateContentRequest, this.requestOptions)) | ||
.then((result) => { | ||
@@ -696,3 +708,3 @@ var _a; | ||
}; | ||
const streamPromise = generateContentStream(this._apiKey, this.model, generateContentRequest, this.baseURL); | ||
const streamPromise = generateContentStream(this._apiKey, this.model, generateContentRequest, this.requestOptions); | ||
// Add onto the chain. | ||
@@ -754,5 +766,5 @@ this._sendPromise = this._sendPromise | ||
*/ | ||
async function countTokens(apiKey, model, baseURL, params) { | ||
const url = new RequestUrl(model, Task.COUNT_TOKENS, apiKey, false, baseURL); | ||
const response = await makeRequest(url, JSON.stringify(Object.assign(Object.assign({}, params), { model }))); | ||
async function countTokens(apiKey, model, params, requestOptions) { | ||
const url = new RequestUrl(model, Task.COUNT_TOKENS, apiKey, false, requestOptions); | ||
const response = await makeRequest(url, JSON.stringify(Object.assign(Object.assign({}, params), { model })), requestOptions); | ||
return response.json(); | ||
@@ -777,13 +789,13 @@ } | ||
*/ | ||
async function embedContent(apiKey, model, baseURL, params) { | ||
const url = new RequestUrl(model, Task.EMBED_CONTENT, apiKey, false, baseURL); | ||
const response = await makeRequest(url, JSON.stringify(params)); | ||
async function embedContent(apiKey, model, params, requestOptions) { | ||
const url = new RequestUrl(model, Task.EMBED_CONTENT, apiKey, false, requestOptions); | ||
const response = await makeRequest(url, JSON.stringify(params), requestOptions); | ||
return response.json(); | ||
} | ||
async function batchEmbedContents(apiKey, model, baseURL, params) { | ||
const url = new RequestUrl(model, Task.BATCH_EMBED_CONTENTS, apiKey, false, baseURL); | ||
async function batchEmbedContents(apiKey, model, params, requestOptions) { | ||
const url = new RequestUrl(model, Task.BATCH_EMBED_CONTENTS, apiKey, false, requestOptions); | ||
const requestsWithModel = params.requests.map((request) => { | ||
return Object.assign(Object.assign({}, request), { model: `models/${model}` }); | ||
}); | ||
const response = await makeRequest(url, JSON.stringify({ requests: requestsWithModel })); | ||
const response = await makeRequest(url, JSON.stringify({ requests: requestsWithModel }), requestOptions); | ||
return response.json(); | ||
@@ -813,6 +825,5 @@ } | ||
class GenerativeModel { | ||
constructor(apiKey, modelParams, baseURL) { | ||
constructor(apiKey, modelParams, requestOptions) { | ||
var _a; | ||
this.apiKey = apiKey; | ||
this.baseURL = baseURL; | ||
if (modelParams.model.startsWith("models/")) { | ||
@@ -826,3 +837,3 @@ this.model = (_a = modelParams.model.split("models/")) === null || _a === void 0 ? void 0 : _a[1]; | ||
this.safetySettings = modelParams.safetySettings || []; | ||
this.baseURL = baseURL || 'https://generativelanguage.googleapis.com'; | ||
this.requestOptions = requestOptions || {}; | ||
} | ||
@@ -835,3 +846,3 @@ /** | ||
const formattedParams = formatGenerateContentInput(request); | ||
return generateContent(this.apiKey, this.model, Object.assign({ generationConfig: this.generationConfig, safetySettings: this.safetySettings }, formattedParams), this.baseURL); | ||
return generateContent(this.apiKey, this.model, Object.assign({ generationConfig: this.generationConfig, safetySettings: this.safetySettings }, formattedParams), this.requestOptions); | ||
} | ||
@@ -846,3 +857,3 @@ /** | ||
const formattedParams = formatGenerateContentInput(request); | ||
return generateContentStream(this.apiKey, this.model, Object.assign({ generationConfig: this.generationConfig, safetySettings: this.safetySettings }, formattedParams), this.baseURL); | ||
return generateContentStream(this.apiKey, this.model, Object.assign({ generationConfig: this.generationConfig, safetySettings: this.safetySettings }, formattedParams), this.requestOptions); | ||
} | ||
@@ -854,3 +865,3 @@ /** | ||
startChat(startChatParams) { | ||
return new ChatSession(this.apiKey, this.model, this.baseURL, startChatParams); | ||
return new ChatSession(this.apiKey, this.model, startChatParams, this.requestOptions); | ||
} | ||
@@ -862,3 +873,3 @@ /** | ||
const formattedParams = formatGenerateContentInput(request); | ||
return countTokens(this.apiKey, this.model, this.baseURL, formattedParams); | ||
return countTokens(this.apiKey, this.model, formattedParams, this.requestOptions); | ||
} | ||
@@ -870,3 +881,3 @@ /** | ||
const formattedParams = formatEmbedContentInput(request); | ||
return embedContent(this.apiKey, this.model, this.baseURL, formattedParams); | ||
return embedContent(this.apiKey, this.model, formattedParams, this.requestOptions); | ||
} | ||
@@ -877,3 +888,3 @@ /** | ||
async batchEmbedContents(batchEmbedContentRequest) { | ||
return batchEmbedContents(this.apiKey, this.model, this.baseURL, batchEmbedContentRequest); | ||
return batchEmbedContents(this.apiKey, this.model, batchEmbedContentRequest, this.requestOptions); | ||
} | ||
@@ -903,5 +914,4 @@ } | ||
class GoogleGenerativeAI { | ||
constructor(apiKey, baseURL) { | ||
constructor(apiKey) { | ||
this.apiKey = apiKey; | ||
this.baseURL = baseURL; | ||
} | ||
@@ -911,3 +921,3 @@ /** | ||
*/ | ||
getGenerativeModel(modelParams) { | ||
getGenerativeModel(modelParams, requestOptions) { | ||
if (!modelParams.model) { | ||
@@ -917,3 +927,3 @@ throw new GoogleGenerativeAIError(`Must provide a model name. ` + | ||
} | ||
return new GenerativeModel(this.apiKey, modelParams, this.baseURL); | ||
return new GenerativeModel(this.apiKey, modelParams, requestOptions); | ||
} | ||
@@ -920,0 +930,0 @@ } |
@@ -17,3 +17,3 @@ /** | ||
*/ | ||
import { ModelParams } from "../types"; | ||
import { ModelParams, RequestOptions } from "../types"; | ||
import { GenerativeModel } from "./models/generative-model"; | ||
@@ -28,8 +28,7 @@ export { ChatSession } from "./methods/chat-session"; | ||
apiKey: string; | ||
baseURL?: string; | ||
constructor(apiKey: string, baseURL?: string); | ||
constructor(apiKey: string); | ||
/** | ||
* Gets a {@link GenerativeModel} instance for the provided model name. | ||
*/ | ||
getGenerativeModel(modelParams: ModelParams): GenerativeModel; | ||
getGenerativeModel(modelParams: ModelParams, requestOptions?: RequestOptions): GenerativeModel; | ||
} |
@@ -17,3 +17,3 @@ /** | ||
*/ | ||
import { Content, GenerateContentResult, GenerateContentStreamResult, Part, StartChatParams } from "../../types"; | ||
import { Content, GenerateContentResult, GenerateContentStreamResult, Part, RequestOptions, StartChatParams } from "../../types"; | ||
/** | ||
@@ -27,8 +27,8 @@ * ChatSession class that enables sending chat messages and stores | ||
model: string; | ||
baseURL: string; | ||
params?: StartChatParams; | ||
requestOptions?: RequestOptions; | ||
private _apiKey; | ||
private _history; | ||
private _sendPromise; | ||
constructor(apiKey: string, model: string, baseURL: string, params?: StartChatParams); | ||
constructor(apiKey: string, model: string, params?: StartChatParams, requestOptions?: RequestOptions); | ||
/** | ||
@@ -35,0 +35,0 @@ * Gets the chat history so far. Blocked prompts are not added to history. |
@@ -17,3 +17,3 @@ /** | ||
*/ | ||
import { CountTokensRequest, CountTokensResponse } from "../../types"; | ||
export declare function countTokens(apiKey: string, model: string, baseURL: string, params: CountTokensRequest): Promise<CountTokensResponse>; | ||
import { CountTokensRequest, CountTokensResponse, RequestOptions } from "../../types"; | ||
export declare function countTokens(apiKey: string, model: string, params: CountTokensRequest, requestOptions: RequestOptions): Promise<CountTokensResponse>; |
@@ -17,4 +17,4 @@ /** | ||
*/ | ||
import { BatchEmbedContentsRequest, BatchEmbedContentsResponse, EmbedContentRequest, EmbedContentResponse } from "../../types"; | ||
export declare function embedContent(apiKey: string, model: string, baseURL: string, params: EmbedContentRequest): Promise<EmbedContentResponse>; | ||
export declare function batchEmbedContents(apiKey: string, model: string, baseURL: string, params: BatchEmbedContentsRequest): Promise<BatchEmbedContentsResponse>; | ||
import { BatchEmbedContentsRequest, BatchEmbedContentsResponse, EmbedContentRequest, EmbedContentResponse, RequestOptions } from "../../types"; | ||
export declare function embedContent(apiKey: string, model: string, params: EmbedContentRequest, requestOptions?: RequestOptions): Promise<EmbedContentResponse>; | ||
export declare function batchEmbedContents(apiKey: string, model: string, params: BatchEmbedContentsRequest, requestOptions?: RequestOptions): Promise<BatchEmbedContentsResponse>; |
@@ -17,4 +17,4 @@ /** | ||
*/ | ||
import { GenerateContentRequest, GenerateContentResult, GenerateContentStreamResult } from "../../types"; | ||
export declare function generateContentStream(apiKey: string, model: string, params: GenerateContentRequest, baseURL: string): Promise<GenerateContentStreamResult>; | ||
export declare function generateContent(apiKey: string, model: string, params: GenerateContentRequest, baseURL: string): Promise<GenerateContentResult>; | ||
import { GenerateContentRequest, GenerateContentResult, GenerateContentStreamResult, RequestOptions } from "../../types"; | ||
export declare function generateContentStream(apiKey: string, model: string, params: GenerateContentRequest, requestOptions: RequestOptions): Promise<GenerateContentStreamResult>; | ||
export declare function generateContent(apiKey: string, model: string, params: GenerateContentRequest, requestOptions?: RequestOptions): Promise<GenerateContentResult>; |
@@ -17,3 +17,3 @@ /** | ||
*/ | ||
import { BatchEmbedContentsRequest, BatchEmbedContentsResponse, CountTokensRequest, CountTokensResponse, EmbedContentRequest, EmbedContentResponse, GenerateContentRequest, GenerateContentResult, GenerateContentStreamResult, GenerationConfig, ModelParams, Part, SafetySetting, StartChatParams } from "../../types"; | ||
import { BatchEmbedContentsRequest, BatchEmbedContentsResponse, CountTokensRequest, CountTokensResponse, EmbedContentRequest, EmbedContentResponse, GenerateContentRequest, GenerateContentResult, GenerateContentStreamResult, GenerationConfig, ModelParams, Part, RequestOptions, SafetySetting, StartChatParams } from "../../types"; | ||
import { ChatSession } from "../methods/chat-session"; | ||
@@ -26,7 +26,7 @@ /** | ||
apiKey: string; | ||
baseURL?: string; | ||
model: string; | ||
generationConfig: GenerationConfig; | ||
safetySettings: SafetySetting[]; | ||
constructor(apiKey: string, modelParams: ModelParams, baseURL?: string); | ||
requestOptions: RequestOptions; | ||
constructor(apiKey: string, modelParams: ModelParams, requestOptions?: RequestOptions); | ||
/** | ||
@@ -33,0 +33,0 @@ * Makes a single non-streaming call to the model |
@@ -17,2 +17,3 @@ /** | ||
*/ | ||
import { RequestOptions } from "../../types"; | ||
export declare enum Task { | ||
@@ -30,6 +31,6 @@ GENERATE_CONTENT = "generateContent", | ||
stream: boolean; | ||
baseURL: string; | ||
constructor(model: string, task: Task, apiKey: string, stream: boolean, baseURL: string); | ||
requestOptions: RequestOptions; | ||
constructor(model: string, task: Task, apiKey: string, stream: boolean, requestOptions: RequestOptions); | ||
toString(): string; | ||
} | ||
export declare function makeRequest(url: RequestUrl, body: string): Promise<Response>; | ||
export declare function makeRequest(url: RequestUrl, body: string, requestOptions?: RequestOptions): Promise<Response>; |
@@ -29,3 +29,3 @@ /** | ||
/** | ||
* Threshhold above which a prompt or candidate will be blocked. | ||
* Threshold above which a prompt or candidate will be blocked. | ||
* @public | ||
@@ -32,0 +32,0 @@ */ |
@@ -91,1 +91,9 @@ /** | ||
} | ||
/** | ||
* Params passed to {@link GoogleGenerativeAI.getGenerativeModel}. | ||
* @public | ||
*/ | ||
export interface RequestOptions { | ||
timeout?: number; | ||
baseURL?: string; | ||
} |
{ | ||
"name": "@fuyun/generative-ai", | ||
"version": "0.1.3", | ||
"version": "0.2.0", | ||
"description": "Google AI JavaScript SDK", | ||
@@ -5,0 +5,0 @@ "main": "dist/index.js", |
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
118445
2.22%2976
1.19%