@google-cloud/vertexai
Advanced tools
Comparing version 0.1.3 to 0.2.0
{ | ||
".": "0.1.3" | ||
".": "0.2.0" | ||
} |
@@ -17,3 +17,3 @@ /** | ||
*/ | ||
import { GoogleAuth } from 'google-auth-library'; | ||
import { GoogleAuth, GoogleAuthOptions } from 'google-auth-library'; | ||
import { Content, CountTokensRequest, CountTokensResponse, GenerateContentRequest, GenerateContentResult, GenerationConfig, ModelParams, Part, SafetySetting, StreamGenerateContentResult, VertexInit } from './types/content'; | ||
@@ -23,12 +23,11 @@ export * from './types'; | ||
* Base class for authenticating to Vertex, creates the preview namespace. | ||
* The base class object takes the following arguments: | ||
* @param project The Google Cloud project to use for the request | ||
* @param location The Google Cloud project location to use for the | ||
* request | ||
* @param apiEndpoint Optional. The base Vertex AI endpoint to use for the | ||
* request. If not provided, the default regionalized endpoint (i.e. | ||
* us-central1-aiplatform.googleapis.com) will be used. | ||
*/ | ||
export declare class VertexAI { | ||
preview: VertexAI_Internal; | ||
preview: VertexAI_Preview; | ||
/** | ||
* @constructor | ||
* @param {VertexInit} init - assign authentication related information, | ||
* including project and location string, to instantiate a Vertex AI | ||
* client. | ||
*/ | ||
constructor(init: VertexInit); | ||
@@ -38,17 +37,30 @@ } | ||
* VertexAI class internal implementation for authentication. | ||
* This class object takes the following arguments: | ||
* @param project The Google Cloud project to use for the request | ||
* @param location The Google Cloud project location to use for the request | ||
* @param apiEndpoint The base Vertex AI endpoint to use for the request. If | ||
* not provided, the default regionalized endpoint | ||
* (i.e. us-central1-aiplatform.googleapis.com) will be used. | ||
*/ | ||
export declare class VertexAI_Internal { | ||
*/ | ||
export declare class VertexAI_Preview { | ||
readonly project: string; | ||
readonly location: string; | ||
readonly apiEndpoint?: string | undefined; | ||
readonly googleAuthOptions?: GoogleAuthOptions<import("google-auth-library/build/src/auth/googleauth").JSONClient> | undefined; | ||
protected googleAuth: GoogleAuth; | ||
private tokenInternalPromise?; | ||
constructor(project: string, location: string, apiEndpoint?: string | undefined); | ||
/** | ||
* @constructor | ||
* @param {string} - project The Google Cloud project to use for the request | ||
* @param {string} - location The Google Cloud project location to use for the request | ||
* @param {string} - [apiEndpoint] The base Vertex AI endpoint to use for the request. If | ||
* not provided, the default regionalized endpoint | ||
* (i.e. us-central1-aiplatform.googleapis.com) will be used. | ||
* @param {GoogleAuthOptions} - [googleAuthOptions] The Authentication options provided by google-auth-library. | ||
* Complete list of authentication options is documented in the GoogleAuthOptions interface: | ||
* https://github.com/googleapis/google-auth-library-nodejs/blob/main/src/auth/googleauth.ts | ||
*/ | ||
constructor(project: string, location: string, apiEndpoint?: string | undefined, googleAuthOptions?: GoogleAuthOptions<import("google-auth-library/build/src/auth/googleauth").JSONClient> | undefined); | ||
/** | ||
* Get access token from GoogleAuth. Throws GoogleAuthError when fails. | ||
* @return {Promise<any>} Promise of token | ||
*/ | ||
get token(): Promise<any>; | ||
/** | ||
* @param {ModelParams} modelParams - {@link ModelParams} Parameters to specify the generative model. | ||
* @return {GenerativeModel} Instance of the GenerativeModel class. {@link GenerativeModel} | ||
*/ | ||
getGenerativeModel(modelParams: ModelParams): GenerativeModel; | ||
@@ -67,9 +79,14 @@ } | ||
* All params passed to initiate multiturn chat via startChat | ||
* @see VertexAI_Preview for details on _vertex_instance parameter | ||
* @see GenerativeModel for details on _model_instance parameter | ||
*/ | ||
export declare interface StartChatSessionRequest extends StartChatParams { | ||
_vertex_instance: VertexAI_Internal; | ||
_vertex_instance: VertexAI_Preview; | ||
_model_instance: GenerativeModel; | ||
} | ||
/** | ||
* Session for a multiturn chat with the model | ||
* Chat session to make multi-turn send message request. | ||
* `sendMessage` method makes async call to get response of a chat message. | ||
* `sendMessageStream` method makes async call to stream response of a chat message. | ||
* @param {StartChatSessionRequest} request - {@link StartChatSessionRequest} | ||
*/ | ||
@@ -79,6 +96,6 @@ export declare class ChatSession { | ||
private location; | ||
private _send_stream_promise; | ||
private historyInternal; | ||
private _vertex_instance; | ||
private _model_instance; | ||
private _send_stream_promise; | ||
generation_config?: GenerationConfig; | ||
@@ -88,4 +105,14 @@ safety_settings?: SafetySetting[]; | ||
constructor(request: StartChatSessionRequest); | ||
/** | ||
* Make an sync call to send message. | ||
* @param {string | Array<string | Part>} request - send message request. {@link Part} | ||
* @return {Promise<GenerateContentResult>} Promise of {@link GenerateContentResult} | ||
*/ | ||
sendMessage(request: string | Array<string | Part>): Promise<GenerateContentResult>; | ||
appendHistory(streamGenerateContentResultPromise: Promise<StreamGenerateContentResult>, newContent: Content): Promise<void>; | ||
/** | ||
* Make an async call to stream send message. Response will be returned in stream. | ||
* @param {string | Array<string | Part>} request - send message request. {@link Part} | ||
* @return {Promise<StreamGenerateContentResult>} Promise of {@link StreamGenerateContentResult} | ||
*/ | ||
sendMessageStream(request: string | Array<string | Part>): Promise<StreamGenerateContentResult>; | ||
@@ -95,3 +122,2 @@ } | ||
* Base class for generative models. | ||
* | ||
* NOTE: this class should not be instantiated directly. Use | ||
@@ -106,5 +132,14 @@ * `vertexai.preview.getGenerativeModel()` instead. | ||
private _use_non_stream; | ||
constructor(vertex_instance: VertexAI_Internal, model: string, generation_config?: GenerationConfig, safety_settings?: SafetySetting[]); | ||
private publisherModelEndpoint; | ||
/** | ||
* Make a generateContent request. | ||
* @constructor | ||
* @param {VertexAI_Preview} vertex_instance - {@link VertexAI_Preview} | ||
* @param {string} model - model name | ||
* @param {GenerationConfig} generation_config - Optional. {@link | ||
* GenerationConfig} | ||
* @param {SafetySetting[]} safety_settings - Optional. {@link SafetySetting} | ||
*/ | ||
constructor(vertex_instance: VertexAI_Preview, model: string, generation_config?: GenerationConfig, safety_settings?: SafetySetting[]); | ||
/** | ||
* Make a async call to generate content. | ||
* @param request A GenerateContentRequest object with the request contents. | ||
@@ -115,9 +150,9 @@ * @return The GenerateContentResponse object with the response candidates. | ||
/** | ||
* Make a generateContentStream request. | ||
* @param request A GenerateContentRequest object with the request contents. | ||
* @return The GenerateContentResponse object with the response candidates. | ||
* Make an async stream request to generate content. The response will be returned in stream. | ||
* @param {GenerateContentRequest} request - {@link GenerateContentRequest} | ||
* @return {Promise<StreamGenerateContentResult>} Promise of {@link StreamGenerateContentResult} | ||
*/ | ||
generateContentStream(request: GenerateContentRequest): Promise<StreamGenerateContentResult>; | ||
/** | ||
* Make a countTokens request. | ||
* Make a async request to count tokens. | ||
* @param request A CountTokensRequest object with the request contents. | ||
@@ -127,3 +162,10 @@ * @return The CountTokensResponse object with the token count. | ||
countTokens(request: CountTokensRequest): Promise<CountTokensResponse>; | ||
startChat(request: StartChatParams): ChatSession; | ||
/** | ||
* Instantiate a ChatSession. | ||
* This method doesn't make any call to remote endpoint. | ||
* Any call to remote endpoint is implemented in ChatSession class @see ChatSession | ||
* @param{StartChatParams} [request] - {@link StartChatParams} | ||
* @return {ChatSession} {@link ChatSession} | ||
*/ | ||
startChat(request?: StartChatParams): ChatSession; | ||
} |
@@ -33,3 +33,3 @@ "use strict"; | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
exports.GenerativeModel = exports.ChatSession = exports.VertexAI_Internal = exports.VertexAI = void 0; | ||
exports.GenerativeModel = exports.ChatSession = exports.VertexAI_Preview = exports.VertexAI = void 0; | ||
/* tslint:disable */ | ||
@@ -43,13 +43,16 @@ const google_auth_library_1 = require("google-auth-library"); | ||
* Base class for authenticating to Vertex, creates the preview namespace. | ||
* The base class object takes the following arguments: | ||
* @param project The Google Cloud project to use for the request | ||
* @param location The Google Cloud project location to use for the | ||
* request | ||
* @param apiEndpoint Optional. The base Vertex AI endpoint to use for the | ||
* request. If not provided, the default regionalized endpoint (i.e. | ||
* us-central1-aiplatform.googleapis.com) will be used. | ||
*/ | ||
class VertexAI { | ||
/** | ||
* @constructor | ||
* @param {VertexInit} init - assign authentication related information, | ||
* including project and location string, to instantiate a Vertex AI | ||
* client. | ||
*/ | ||
constructor(init) { | ||
this.preview = new VertexAI_Internal(init.project, init.location, init.apiEndpoint); | ||
/** | ||
* preview property is used to access any SDK methods available in public | ||
* preview, currently all functionality. | ||
*/ | ||
this.preview = new VertexAI_Preview(init.project, init.location, init.apiEndpoint, init.googleAuthOptions); | ||
} | ||
@@ -60,26 +63,47 @@ } | ||
* VertexAI class internal implementation for authentication. | ||
* This class object takes the following arguments: | ||
* @param project The Google Cloud project to use for the request | ||
* @param location The Google Cloud project location to use for the request | ||
* @param apiEndpoint The base Vertex AI endpoint to use for the request. If | ||
* not provided, the default regionalized endpoint | ||
* (i.e. us-central1-aiplatform.googleapis.com) will be used. | ||
*/ | ||
class VertexAI_Internal { | ||
constructor(project, location, apiEndpoint) { | ||
*/ | ||
class VertexAI_Preview { | ||
/** | ||
* @constructor | ||
* @param {string} - project The Google Cloud project to use for the request | ||
* @param {string} - location The Google Cloud project location to use for the request | ||
* @param {string} - [apiEndpoint] The base Vertex AI endpoint to use for the request. If | ||
* not provided, the default regionalized endpoint | ||
* (i.e. us-central1-aiplatform.googleapis.com) will be used. | ||
* @param {GoogleAuthOptions} - [googleAuthOptions] The Authentication options provided by google-auth-library. | ||
* Complete list of authentication options is documented in the GoogleAuthOptions interface: | ||
* https://github.com/googleapis/google-auth-library-nodejs/blob/main/src/auth/googleauth.ts | ||
*/ | ||
constructor(project, location, apiEndpoint, googleAuthOptions) { | ||
this.project = project; | ||
this.location = location; | ||
this.apiEndpoint = apiEndpoint; | ||
this.googleAuth = new google_auth_library_1.GoogleAuth({ | ||
scopes: 'https://www.googleapis.com/auth/cloud-platform', | ||
}); | ||
this.googleAuthOptions = googleAuthOptions; | ||
let opts; | ||
if (!googleAuthOptions) { | ||
opts = { | ||
scopes: 'https://www.googleapis.com/auth/cloud-platform', | ||
}; | ||
} | ||
else { | ||
if (googleAuthOptions.projectId && | ||
googleAuthOptions.projectId !== project) { | ||
throw new Error(`inconsistent project ID values. argument project got value ${project} but googleAuthOptions.projectId got value ${googleAuthOptions.projectId}`); | ||
} | ||
opts = googleAuthOptions; | ||
if (!opts.scopes) { | ||
opts.scopes = 'https://www.googleapis.com/auth/cloud-platform'; | ||
} | ||
} | ||
this.project = project; | ||
this.location = location; | ||
this.apiEndpoint = apiEndpoint; | ||
this.googleAuth = new google_auth_library_1.GoogleAuth(opts); | ||
} | ||
/** | ||
* Get access token from GoogleAuth. Throws GoogleAuthError when fails. | ||
* @return {Promise<any>} Promise of token | ||
*/ | ||
get token() { | ||
if (this.tokenInternalPromise) { | ||
return this.tokenInternalPromise; | ||
} | ||
const credential_error_message = "\nUnable to authenticate your request\ | ||
const credential_error_message = '\nUnable to authenticate your request\ | ||
\nDepending on your run time environment, you can get authentication by\ | ||
@@ -90,12 +114,15 @@ \n- if in local instance or cloud shell: `!gcloud auth login`\ | ||
\n -`auth.authenticate_user()`\ | ||
\n- if in service account or other: please follow guidance in https://cloud.google.com/docs/authentication"; | ||
\n- if in service account or other: please follow guidance in https://cloud.google.com/docs/authentication'; | ||
const tokenPromise = this.googleAuth.getAccessToken().catch(e => { | ||
throw new errors_1.GoogleAuthError(`${credential_error_message}\n${e}`); | ||
throw new errors_1.GoogleAuthError(credential_error_message, e); | ||
}); | ||
return tokenPromise; | ||
} | ||
/** | ||
* @param {ModelParams} modelParams - {@link ModelParams} Parameters to specify the generative model. | ||
* @return {GenerativeModel} Instance of the GenerativeModel class. {@link GenerativeModel} | ||
*/ | ||
getGenerativeModel(modelParams) { | ||
if (modelParams.generation_config) { | ||
modelParams.generation_config = | ||
validateGenerationConfig(modelParams.generation_config); | ||
modelParams.generation_config = validateGenerationConfig(modelParams.generation_config); | ||
} | ||
@@ -105,5 +132,8 @@ return new GenerativeModel(this, modelParams.model, modelParams.generation_config, modelParams.safety_settings); | ||
} | ||
exports.VertexAI_Internal = VertexAI_Internal; | ||
exports.VertexAI_Preview = VertexAI_Preview; | ||
/** | ||
* Session for a multiturn chat with the model | ||
* Chat session to make multi-turn send message request. | ||
* `sendMessage` method makes async call to get response of a chat message. | ||
* `sendMessageStream` method makes async call to stream response of a chat message. | ||
* @param {StartChatSessionRequest} request - {@link StartChatSessionRequest} | ||
*/ | ||
@@ -123,2 +153,7 @@ class ChatSession { | ||
} | ||
/** | ||
* Make an sync call to send message. | ||
* @param {string | Array<string | Part>} request - send message request. {@link Part} | ||
* @return {Promise<GenerateContentResult>} Promise of {@link GenerateContentResult} | ||
*/ | ||
async sendMessage(request) { | ||
@@ -131,4 +166,8 @@ const newContent = formulateNewContent(request); | ||
}; | ||
const generateContentResult = await this._model_instance.generateContent(generateContentrequest); | ||
const generateContentResponse = await generateContentResult.response; | ||
const generateContentResult = await this._model_instance | ||
.generateContent(generateContentrequest) | ||
.catch(e => { | ||
throw e; | ||
}); | ||
const generateContentResponse = generateContentResult.response; | ||
// Only push the latest message to history if the response returned a result | ||
@@ -166,2 +205,7 @@ if (generateContentResponse.candidates.length !== 0) { | ||
} | ||
/** | ||
* Make an async call to stream send message. Response will be returned in stream. | ||
* @param {string | Array<string | Part>} request - send message request. {@link Part} | ||
* @return {Promise<StreamGenerateContentResult>} Promise of {@link StreamGenerateContentResult} | ||
*/ | ||
async sendMessageStream(request) { | ||
@@ -174,4 +218,10 @@ const newContent = formulateNewContent(request); | ||
}; | ||
const streamGenerateContentResultPromise = this._model_instance.generateContentStream(generateContentrequest); | ||
this._send_stream_promise = this.appendHistory(streamGenerateContentResultPromise, newContent); | ||
const streamGenerateContentResultPromise = this._model_instance | ||
.generateContentStream(generateContentrequest) | ||
.catch(e => { | ||
throw e; | ||
}); | ||
this._send_stream_promise = this.appendHistory(streamGenerateContentResultPromise, newContent).catch(e => { | ||
throw new errors_1.GoogleGenerativeAIError('exception appending chat history', e); | ||
}); | ||
return streamGenerateContentResultPromise; | ||
@@ -183,3 +233,2 @@ } | ||
* Base class for generative models. | ||
* | ||
* NOTE: this class should not be instantiated directly. Use | ||
@@ -189,2 +238,10 @@ * `vertexai.preview.getGenerativeModel()` instead. | ||
class GenerativeModel { | ||
/** | ||
* @constructor | ||
* @param {VertexAI_Preview} vertex_instance - {@link VertexAI_Preview} | ||
* @param {string} model - model name | ||
* @param {GenerationConfig} generation_config - Optional. {@link | ||
* GenerationConfig} | ||
* @param {SafetySetting[]} safety_settings - Optional. {@link SafetySetting} | ||
*/ | ||
constructor(vertex_instance, model, generation_config, safety_settings) { | ||
@@ -196,5 +253,11 @@ this._use_non_stream = false; | ||
this.safety_settings = safety_settings; | ||
if (model.startsWith('models/')) { | ||
this.publisherModelEndpoint = `publishers/google/${this.model}`; | ||
} | ||
else { | ||
this.publisherModelEndpoint = `publishers/google/models/${this.model}`; | ||
} | ||
} | ||
/** | ||
* Make a generateContent request. | ||
* Make a async call to generate content. | ||
* @param request A GenerateContentRequest object with the request contents. | ||
@@ -207,7 +270,8 @@ * @return The GenerateContentResponse object with the response candidates. | ||
if (request.generation_config) { | ||
request.generation_config = | ||
validateGenerationConfig(request.generation_config); | ||
request.generation_config = validateGenerationConfig(request.generation_config); | ||
} | ||
if (!this._use_non_stream) { | ||
const streamGenerateContentResult = await this.generateContentStream(request); | ||
const streamGenerateContentResult = await this.generateContentStream(request).catch(e => { | ||
throw e; | ||
}); | ||
const result = { | ||
@@ -218,3 +282,2 @@ response: await streamGenerateContentResult.response, | ||
} | ||
const publisherModelEndpoint = `publishers/google/models/${this.model}`; | ||
const generateContentRequest = { | ||
@@ -225,23 +288,14 @@ contents: request.contents, | ||
}; | ||
let response; | ||
try { | ||
response = await (0, util_1.postRequest)({ | ||
region: this._vertex_instance.location, | ||
project: this._vertex_instance.project, | ||
resourcePath: publisherModelEndpoint, | ||
resourceMethod: util_1.constants.GENERATE_CONTENT_METHOD, | ||
token: await this._vertex_instance.token, | ||
data: generateContentRequest, | ||
apiEndpoint: this._vertex_instance.apiEndpoint, | ||
}); | ||
if (response === undefined) { | ||
throw new Error('did not get a valid response.'); | ||
} | ||
if (!response.ok) { | ||
throw new Error(`${response.status} ${response.statusText}`); | ||
} | ||
} | ||
catch (e) { | ||
console.log(e); | ||
} | ||
const response = await (0, util_1.postRequest)({ | ||
region: this._vertex_instance.location, | ||
project: this._vertex_instance.project, | ||
resourcePath: this.publisherModelEndpoint, | ||
resourceMethod: util_1.constants.GENERATE_CONTENT_METHOD, | ||
token: await this._vertex_instance.token, | ||
data: generateContentRequest, | ||
apiEndpoint: this._vertex_instance.apiEndpoint, | ||
}).catch(e => { | ||
throw new errors_1.GoogleGenerativeAIError('exception posting request', e); | ||
}); | ||
throwErrorIfNotOK(response); | ||
const result = (0, process_stream_1.processNonStream)(response); | ||
@@ -251,5 +305,5 @@ return Promise.resolve(result); | ||
/** | ||
* Make a generateContentStream request. | ||
* @param request A GenerateContentRequest object with the request contents. | ||
* @return The GenerateContentResponse object with the response candidates. | ||
* Make an async stream request to generate content. The response will be returned in stream. | ||
* @param {GenerateContentRequest} request - {@link GenerateContentRequest} | ||
* @return {Promise<StreamGenerateContentResult>} Promise of {@link StreamGenerateContentResult} | ||
*/ | ||
@@ -260,6 +314,4 @@ async generateContentStream(request) { | ||
if (request.generation_config) { | ||
request.generation_config = | ||
validateGenerationConfig(request.generation_config); | ||
request.generation_config = validateGenerationConfig(request.generation_config); | ||
} | ||
const publisherModelEndpoint = `publishers/google/models/${this.model}`; | ||
const generateContentRequest = { | ||
@@ -270,23 +322,14 @@ contents: request.contents, | ||
}; | ||
let response; | ||
try { | ||
response = await (0, util_1.postRequest)({ | ||
region: this._vertex_instance.location, | ||
project: this._vertex_instance.project, | ||
resourcePath: publisherModelEndpoint, | ||
resourceMethod: util_1.constants.STREAMING_GENERATE_CONTENT_METHOD, | ||
token: await this._vertex_instance.token, | ||
data: generateContentRequest, | ||
apiEndpoint: this._vertex_instance.apiEndpoint, | ||
}); | ||
if (response === undefined) { | ||
throw new Error('did not get a valid response.'); | ||
} | ||
if (!response.ok) { | ||
throw new Error(`${response.status} ${response.statusText}`); | ||
} | ||
} | ||
catch (e) { | ||
console.log(e); | ||
} | ||
const response = await (0, util_1.postRequest)({ | ||
region: this._vertex_instance.location, | ||
project: this._vertex_instance.project, | ||
resourcePath: this.publisherModelEndpoint, | ||
resourceMethod: util_1.constants.STREAMING_GENERATE_CONTENT_METHOD, | ||
token: await this._vertex_instance.token, | ||
data: generateContentRequest, | ||
apiEndpoint: this._vertex_instance.apiEndpoint, | ||
}).catch(e => { | ||
throw new errors_1.GoogleGenerativeAIError('exception posting request', e); | ||
}); | ||
throwErrorIfNotOK(response); | ||
const streamResult = (0, process_stream_1.processStream)(response); | ||
@@ -296,3 +339,3 @@ return Promise.resolve(streamResult); | ||
/** | ||
* Make a countTokens request. | ||
* Make a async request to count tokens. | ||
* @param request A CountTokensRequest object with the request contents. | ||
@@ -302,40 +345,36 @@ * @return The CountTokensResponse object with the token count. | ||
async countTokens(request) { | ||
let response; | ||
try { | ||
response = await (0, util_1.postRequest)({ | ||
region: this._vertex_instance.location, | ||
project: this._vertex_instance.project, | ||
resourcePath: `publishers/google/models/${this.model}`, | ||
resourceMethod: 'countTokens', | ||
token: await this._vertex_instance.token, | ||
data: request, | ||
apiEndpoint: this._vertex_instance.apiEndpoint, | ||
}); | ||
if (response === undefined) { | ||
throw new Error('did not get a valid response.'); | ||
} | ||
if (!response.ok) { | ||
throw new Error(`${response.status} ${response.statusText}`); | ||
} | ||
} | ||
catch (e) { | ||
console.log(e); | ||
} | ||
if (response) { | ||
const responseJson = await response.json(); | ||
return responseJson; | ||
} | ||
else { | ||
throw new Error('did not get a valid response.'); | ||
} | ||
const response = await (0, util_1.postRequest)({ | ||
region: this._vertex_instance.location, | ||
project: this._vertex_instance.project, | ||
resourcePath: this.publisherModelEndpoint, | ||
resourceMethod: 'countTokens', | ||
token: await this._vertex_instance.token, | ||
data: request, | ||
apiEndpoint: this._vertex_instance.apiEndpoint, | ||
}).catch(e => { | ||
throw new errors_1.GoogleGenerativeAIError('exception posting request', e); | ||
}); | ||
throwErrorIfNotOK(response); | ||
return (0, process_stream_1.processCountTokenResponse)(response); | ||
} | ||
/** | ||
* Instantiate a ChatSession. | ||
* This method doesn't make any call to remote endpoint. | ||
* Any call to remote endpoint is implemented in ChatSession class @see ChatSession | ||
* @param{StartChatParams} [request] - {@link StartChatParams} | ||
* @return {ChatSession} {@link ChatSession} | ||
*/ | ||
startChat(request) { | ||
var _a, _b; | ||
const startChatRequest = { | ||
history: request.history, | ||
generation_config: (_a = request.generation_config) !== null && _a !== void 0 ? _a : this.generation_config, | ||
safety_settings: (_b = request.safety_settings) !== null && _b !== void 0 ? _b : this.safety_settings, | ||
_vertex_instance: this._vertex_instance, | ||
_model_instance: this, | ||
}; | ||
if (request) { | ||
startChatRequest.history = request.history; | ||
startChatRequest.generation_config = | ||
(_a = request.generation_config) !== null && _a !== void 0 ? _a : this.generation_config; | ||
startChatRequest.safety_settings = | ||
(_b = request.safety_settings) !== null && _b !== void 0 ? _b : this.safety_settings; | ||
} | ||
return new ChatSession(startChatRequest); | ||
@@ -363,2 +402,16 @@ } | ||
} | ||
function throwErrorIfNotOK(response) { | ||
if (response === undefined) { | ||
throw new errors_1.GoogleGenerativeAIError('response is undefined'); | ||
} | ||
const status = response.status; | ||
const statusText = response.statusText; | ||
const errorMessage = `got status: ${status} ${statusText}`; | ||
if (status >= 400 && status < 500) { | ||
throw new errors_1.ClientError(errorMessage); | ||
} | ||
if (!response.ok) { | ||
throw new errors_1.GoogleGenerativeAIError(errorMessage); | ||
} | ||
} | ||
function validateGcsInput(contents) { | ||
@@ -365,0 +418,0 @@ for (const content of contents) { |
@@ -17,10 +17,29 @@ /** | ||
*/ | ||
import { GenerateContentResult, StreamGenerateContentResult } from './types/content'; | ||
import { CountTokensResponse, GenerateContentResult, StreamGenerateContentResult } from './types/content'; | ||
/** | ||
* Processes model responses from streamGenerateContent | ||
* Process a response.body stream from the backend and return an | ||
* iterator that provides one complete GenerateContentResponse at a time | ||
* and a promise that resolves with a single aggregated | ||
* GenerateContentResponse. | ||
* | ||
* @param response - Response from a fetch call | ||
* @ignore | ||
*/ | ||
export declare function processStream(response: Response | undefined): StreamGenerateContentResult; | ||
/** | ||
* Reads a raw stream from the fetch response and join incomplete | ||
* chunks, returning a new stream that provides a single complete | ||
* GenerateContentResponse in each iteration. | ||
* @ignore | ||
*/ | ||
export declare function getResponseStream<T>(inputStream: ReadableStream<string>): ReadableStream<T>; | ||
/** | ||
* Process model responses from generateContent | ||
* @ignore | ||
*/ | ||
export declare function processNonStream(response: any): GenerateContentResult; | ||
/** | ||
* Process model responses from countTokens | ||
* @ignore | ||
*/ | ||
export declare function processCountTokenResponse(response: any): CountTokensResponse; |
@@ -19,10 +19,8 @@ "use strict"; | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
exports.processNonStream = exports.processStream = void 0; | ||
// eslint-disable-next-line no-useless-escape | ||
const responseLineRE = /^data\: (.*)\r\n/; | ||
// TODO: set a better type for `reader`. Setting it to | ||
// `ReadableStreamDefaultReader` results in an error (diagnostic code 2304) | ||
async function* generateResponseSequence(reader2) { | ||
exports.processCountTokenResponse = exports.processNonStream = exports.getResponseStream = exports.processStream = void 0; | ||
const responseLineRE = /^data: (.*)(?:\n\n|\r\r|\r\n\r\n)/; | ||
async function* generateResponseSequence(stream) { | ||
const reader = stream.getReader(); | ||
while (true) { | ||
const { value, done } = await reader2.read(); | ||
const { value, done } = await reader.read(); | ||
if (done) { | ||
@@ -35,48 +33,77 @@ break; | ||
/** | ||
* Reads a raw stream from the fetch response and joins incomplete | ||
* Process a response.body stream from the backend and return an | ||
* iterator that provides one complete GenerateContentResponse at a time | ||
* and a promise that resolves with a single aggregated | ||
* GenerateContentResponse. | ||
* | ||
* @param response - Response from a fetch call | ||
* @ignore | ||
*/ | ||
function processStream(response) { | ||
if (response === undefined) { | ||
throw new Error('Error processing stream because response === undefined'); | ||
} | ||
if (!response.body) { | ||
throw new Error('Error processing stream because response.body not found'); | ||
} | ||
const inputStream = response.body.pipeThrough(new TextDecoderStream('utf8', { fatal: true })); | ||
const responseStream = getResponseStream(inputStream); | ||
const [stream1, stream2] = responseStream.tee(); | ||
return { | ||
stream: generateResponseSequence(stream1), | ||
response: getResponsePromise(stream2), | ||
}; | ||
} | ||
exports.processStream = processStream; | ||
async function getResponsePromise(stream) { | ||
const allResponses = []; | ||
const reader = stream.getReader(); | ||
// eslint-disable-next-line no-constant-condition | ||
while (true) { | ||
const { done, value } = await reader.read(); | ||
if (done) { | ||
return aggregateResponses(allResponses); | ||
} | ||
allResponses.push(value); | ||
} | ||
} | ||
/** | ||
* Reads a raw stream from the fetch response and join incomplete | ||
* chunks, returning a new stream that provides a single complete | ||
* GenerateContentResponse in each iteration. | ||
* @ignore | ||
*/ | ||
function readFromReader(reader) { | ||
let currentText = ''; | ||
function getResponseStream(inputStream) { | ||
const reader = inputStream.getReader(); | ||
const stream = new ReadableStream({ | ||
start(controller) { | ||
let currentText = ''; | ||
return pump(); | ||
function pump() { | ||
let streamReader; | ||
try { | ||
streamReader = reader.read().then(({ value, done }) => { | ||
if (done) { | ||
controller.close(); | ||
return reader.read().then(({ value, done }) => { | ||
if (done) { | ||
if (currentText.trim()) { | ||
controller.error(new Error('Failed to parse stream')); | ||
return; | ||
} | ||
const chunk = new TextDecoder().decode(value); | ||
currentText += chunk; | ||
const match = currentText.match(responseLineRE); | ||
if (match) { | ||
let parsedResponse; | ||
try { | ||
parsedResponse = JSON.parse(match[1]); | ||
} | ||
catch (e) { | ||
throw new Error(`Error parsing JSON response: "${match[1]}"`); | ||
} | ||
currentText = ''; | ||
if ('candidates' in parsedResponse) { | ||
controller.enqueue(parsedResponse); | ||
} | ||
else { | ||
console.warn(`No candidates in this response: ${parsedResponse}`); | ||
controller.enqueue({ | ||
candidates: [], | ||
}); | ||
} | ||
controller.close(); | ||
return; | ||
} | ||
currentText += value; | ||
let match = currentText.match(responseLineRE); | ||
let parsedResponse; | ||
while (match) { | ||
try { | ||
parsedResponse = JSON.parse(match[1]); | ||
} | ||
return pump(); | ||
}); | ||
} | ||
catch (e) { | ||
throw new Error(`Error reading from stream ${e}.`); | ||
} | ||
return streamReader; | ||
catch (e) { | ||
controller.error(new Error(`Error parsing JSON response: "${match[1]}"`)); | ||
return; | ||
} | ||
controller.enqueue(parsedResponse); | ||
currentText = currentText.substring(match[0].length); | ||
match = currentText.match(responseLineRE); | ||
} | ||
return pump(); | ||
}); | ||
} | ||
@@ -87,5 +114,7 @@ }, | ||
} | ||
exports.getResponseStream = getResponseStream; | ||
/** | ||
* Aggregates an array of `GenerateContentResponse`s into a single | ||
* GenerateContentResponse. | ||
* @ignore | ||
*/ | ||
@@ -114,4 +143,3 @@ function aggregateResponses(responses) { | ||
if (response.candidates[i].citationMetadata) { | ||
if (!((_a = aggregatedResponse.candidates[i] | ||
.citationMetadata) === null || _a === void 0 ? void 0 : _a.citationSources)) { | ||
if (!((_a = aggregatedResponse.candidates[i].citationMetadata) === null || _a === void 0 ? void 0 : _a.citationSources)) { | ||
aggregatedResponse.candidates[i].citationMetadata = { | ||
@@ -121,7 +149,6 @@ citationSources: [], | ||
} | ||
let existingMetadata = (_b = response.candidates[i].citationMetadata) !== null && _b !== void 0 ? _b : {}; | ||
const existingMetadata = (_b = response.candidates[i].citationMetadata) !== null && _b !== void 0 ? _b : {}; | ||
if (aggregatedResponse.candidates[i].citationMetadata) { | ||
aggregatedResponse.candidates[i].citationMetadata.citationSources = | ||
aggregatedResponse.candidates[i] | ||
.citationMetadata.citationSources.concat(existingMetadata); | ||
aggregatedResponse.candidates[i].citationMetadata.citationSources.concat(existingMetadata); | ||
} | ||
@@ -148,40 +175,5 @@ } | ||
} | ||
// TODO: improve error handling throughout stream processing | ||
/** | ||
* Processes model responses from streamGenerateContent | ||
*/ | ||
function processStream(response) { | ||
if (response === undefined) { | ||
throw new Error('Error processing stream because response === undefined'); | ||
} | ||
if (!response.body) { | ||
throw new Error('Error processing stream because response.body not found'); | ||
} | ||
const reader = response.body.getReader(); | ||
const responseStream = readFromReader(reader); | ||
const [stream1, stream2] = responseStream.tee(); | ||
const reader1 = stream1.getReader(); | ||
const reader2 = stream2.getReader(); | ||
const allResponses = []; | ||
const responsePromise = new Promise( | ||
// eslint-disable-next-line no-async-promise-executor | ||
async (resolve) => { | ||
// eslint-disable-next-line no-constant-condition | ||
while (true) { | ||
const { value, done } = await reader1.read(); | ||
if (done) { | ||
resolve(aggregateResponses(allResponses)); | ||
return; | ||
} | ||
allResponses.push(value); | ||
} | ||
}); | ||
return { | ||
response: responsePromise, | ||
stream: generateResponseSequence(reader2), | ||
}; | ||
} | ||
exports.processStream = processStream; | ||
/** | ||
* Process model responses from generateContent | ||
* @ignore | ||
*/ | ||
@@ -201,2 +193,12 @@ function processNonStream(response) { | ||
exports.processNonStream = processNonStream; | ||
/** | ||
* Process model responses from countTokens | ||
* @ignore | ||
*/ | ||
function processCountTokenResponse(response) { | ||
// ts-ignore | ||
const responseJson = response.json(); | ||
return responseJson; | ||
} | ||
exports.processCountTokenResponse = processCountTokenResponse; | ||
//# sourceMappingURL=process_stream.js.map |
@@ -17,4 +17,11 @@ /** | ||
*/ | ||
import { GoogleAuthOptions } from 'google-auth-library'; | ||
/** | ||
* Params used to initialize the Vertex SDK | ||
* @param{string} project - the project name of your Google Cloud project. It is not the numeric project ID. | ||
* @param{string} location - the location of your project. | ||
* @param{string} [apiEndpoint] - If not specified, a default value will be resolved by SDK. | ||
* @param {GoogleAuthOptions} - [googleAuthOptions] The Authentication options provided by google-auth-library. | ||
* Complete list of authentication options is documented in the GoogleAuthOptions interface: | ||
* https://github.com/googleapis/google-auth-library-nodejs/blob/main/src/auth/googleauth.ts | ||
*/ | ||
@@ -25,2 +32,3 @@ export declare interface VertexInit { | ||
apiEndpoint?: string; | ||
googleAuthOptions?: GoogleAuthOptions; | ||
} | ||
@@ -48,2 +56,4 @@ /** | ||
* Configuration for initializing a model, for example via getGenerativeModel | ||
* @param {string} model - model name. | ||
* @example "gemini-pro" | ||
*/ | ||
@@ -78,2 +88,5 @@ export declare interface ModelParams extends BaseModelParams { | ||
} | ||
/** | ||
* Harm categories that would cause prompts or candidates to be blocked. | ||
*/ | ||
export declare enum HarmCategory { | ||
@@ -86,2 +99,5 @@ HARM_CATEGORY_UNSPECIFIED = "HARM_CATEGORY_UNSPECIFIED", | ||
} | ||
/** | ||
* Threshold above which a prompt or candidate will be blocked. | ||
*/ | ||
export declare enum HarmBlockThreshold { | ||
@@ -95,2 +111,12 @@ HARM_BLOCK_THRESHOLD_UNSPECIFIED = "HARM_BLOCK_THRESHOLD_UNSPECIFIED", | ||
/** | ||
* Probability that a prompt or candidate matches a harm category. | ||
*/ | ||
export declare enum HarmProbability { | ||
HARM_PROBABILITY_UNSPECIFIED = "HARM_PROBABILITY_UNSPECIFIED", | ||
NEGLIGIBLE = "NEGLIGIBLE", | ||
LOW = "LOW", | ||
MEDIUM = "MEDIUM", | ||
HIGH = "HIGH" | ||
} | ||
/** | ||
* Safety rating for a piece of content | ||
@@ -100,3 +126,3 @@ */ | ||
category: HarmCategory; | ||
threshold: HarmBlockThreshold; | ||
probability: HarmProbability; | ||
} | ||
@@ -174,2 +200,3 @@ /** | ||
* Wrapper for respones from a generateContent request | ||
* @see GenerateContentResponse | ||
*/ | ||
@@ -181,2 +208,3 @@ export declare interface GenerateContentResult { | ||
* Wrapper for respones from a streamGenerateContent request | ||
* @see GenerateContentResponse | ||
*/ | ||
@@ -221,7 +249,1 @@ export declare interface StreamGenerateContentResult { | ||
} | ||
declare const CLIENT_INFO: { | ||
user_agent: string; | ||
client_library_language: string; | ||
client_library_version: string; | ||
}; | ||
export { CLIENT_INFO }; |
@@ -19,3 +19,6 @@ "use strict"; | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
exports.CLIENT_INFO = exports.FinishReason = exports.BlockedReason = exports.HarmBlockThreshold = exports.HarmCategory = void 0; | ||
exports.FinishReason = exports.BlockedReason = exports.HarmProbability = exports.HarmBlockThreshold = exports.HarmCategory = void 0; | ||
/** | ||
* Harm categories that would cause prompts or candidates to be blocked. | ||
*/ | ||
var HarmCategory; | ||
@@ -29,2 +32,5 @@ (function (HarmCategory) { | ||
})(HarmCategory || (exports.HarmCategory = HarmCategory = {})); | ||
/** | ||
* Threshold above which a prompt or candidate will be blocked. | ||
*/ | ||
var HarmBlockThreshold; | ||
@@ -43,2 +49,18 @@ (function (HarmBlockThreshold) { | ||
})(HarmBlockThreshold || (exports.HarmBlockThreshold = HarmBlockThreshold = {})); | ||
/** | ||
* Probability that a prompt or candidate matches a harm category. | ||
*/ | ||
var HarmProbability; | ||
(function (HarmProbability) { | ||
// Probability is unspecified. | ||
HarmProbability["HARM_PROBABILITY_UNSPECIFIED"] = "HARM_PROBABILITY_UNSPECIFIED"; | ||
// Content has a negligible chance of being unsafe. | ||
HarmProbability["NEGLIGIBLE"] = "NEGLIGIBLE"; | ||
// Content has a low chance of being unsafe. | ||
HarmProbability["LOW"] = "LOW"; | ||
// Content has a medium chance of being unsafe. | ||
HarmProbability["MEDIUM"] = "MEDIUM"; | ||
// Content has a high chance of being unsafe. | ||
HarmProbability["HIGH"] = "HIGH"; | ||
})(HarmProbability || (exports.HarmProbability = HarmProbability = {})); | ||
var BlockedReason; | ||
@@ -68,13 +90,2 @@ (function (BlockedReason) { | ||
})(FinishReason || (exports.FinishReason = FinishReason = {})); | ||
const USER_AGENT_PRODUCT = 'model-builder'; | ||
const CLIENT_LIBRARY_LANGUAGE = 'node-js'; | ||
// TODO: update this version number using release-please | ||
const CLIENT_LIBRARY_VERSION = '0.1.0'; | ||
const USER_AGENT = USER_AGENT_PRODUCT + '/' + CLIENT_LIBRARY_VERSION; | ||
const CLIENT_INFO = { | ||
user_agent: USER_AGENT, | ||
client_library_language: CLIENT_LIBRARY_LANGUAGE, | ||
client_library_version: CLIENT_LIBRARY_VERSION, | ||
}; | ||
exports.CLIENT_INFO = CLIENT_INFO; | ||
//# sourceMappingURL=content.js.map |
@@ -17,5 +17,25 @@ /** | ||
*/ | ||
/** | ||
* GoogleAuthError is thrown when there is authentication issue with the request | ||
*/ | ||
declare class GoogleAuthError extends Error { | ||
constructor(message: string); | ||
readonly stack_trace: any; | ||
constructor(message: string, stack_trace?: any); | ||
} | ||
export { GoogleAuthError }; | ||
/** | ||
* ClientError is thrown when http 4XX status is received. | ||
* For details please refer to https://developer.mozilla.org/en-US/docs/Web/HTTP/Status#client_error_responses | ||
*/ | ||
declare class ClientError extends Error { | ||
readonly stack_trace: any; | ||
constructor(message: string, stack_trace?: any); | ||
} | ||
/** | ||
* GoogleGenerativeAIError is thrown when http response is not ok and status code is not 4XX | ||
* For details please refer to https://developer.mozilla.org/en-US/docs/Web/HTTP/Status | ||
*/ | ||
declare class GoogleGenerativeAIError extends Error { | ||
readonly stack_trace: any; | ||
constructor(message: string, stack_trace?: any); | ||
} | ||
export { ClientError, GoogleAuthError, GoogleGenerativeAIError }; |
@@ -19,10 +19,47 @@ "use strict"; | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
exports.GoogleAuthError = void 0; | ||
exports.GoogleGenerativeAIError = exports.GoogleAuthError = exports.ClientError = void 0; | ||
/** | ||
* GoogleAuthError is thrown when there is authentication issue with the request | ||
*/ | ||
class GoogleAuthError extends Error { | ||
constructor(message) { | ||
constructor(message, stack_trace = undefined) { | ||
super(message); | ||
this.stack_trace = undefined; | ||
this.message = constructErrorMessage('GoogleAuthError', message); | ||
this.name = 'GoogleAuthError'; | ||
this.stack_trace = stack_trace; | ||
} | ||
} | ||
exports.GoogleAuthError = GoogleAuthError; | ||
/** | ||
* ClientError is thrown when http 4XX status is received. | ||
* For details please refer to https://developer.mozilla.org/en-US/docs/Web/HTTP/Status#client_error_responses | ||
*/ | ||
class ClientError extends Error { | ||
constructor(message, stack_trace = undefined) { | ||
super(message); | ||
this.stack_trace = undefined; | ||
this.message = constructErrorMessage('ClientError', message); | ||
this.name = 'ClientError'; | ||
this.stack_trace = stack_trace; | ||
} | ||
} | ||
exports.ClientError = ClientError; | ||
/** | ||
* GoogleGenerativeAIError is thrown when http response is not ok and status code is not 4XX | ||
* For details please refer to https://developer.mozilla.org/en-US/docs/Web/HTTP/Status | ||
*/ | ||
class GoogleGenerativeAIError extends Error { | ||
constructor(message, stack_trace = undefined) { | ||
super(message); | ||
this.stack_trace = undefined; | ||
this.message = constructErrorMessage('GoogleGenerativeAIError', message); | ||
this.name = 'GoogleGenerativeAIError'; | ||
this.stack_trace = stack_trace; | ||
} | ||
} | ||
exports.GoogleGenerativeAIError = GoogleGenerativeAIError; | ||
function constructErrorMessage(exceptionClass, message) { | ||
return `[VertexAI.${exceptionClass}]: ${message}`; | ||
} | ||
//# sourceMappingURL=errors.js.map |
@@ -21,1 +21,2 @@ /** | ||
export declare const MODEL_ROLE = "model"; | ||
export declare const USER_AGENT = "model-builder/0.2.0 grpc-node/0.2.0"; |
"use strict"; | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
exports.USER_AGENT = exports.MODEL_ROLE = exports.USER_ROLE = exports.STREAMING_GENERATE_CONTENT_METHOD = exports.GENERATE_CONTENT_METHOD = void 0; | ||
/** | ||
@@ -18,4 +20,2 @@ * @license | ||
*/ | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
exports.MODEL_ROLE = exports.USER_ROLE = exports.STREAMING_GENERATE_CONTENT_METHOD = exports.GENERATE_CONTENT_METHOD = void 0; | ||
exports.GENERATE_CONTENT_METHOD = 'generateContent'; | ||
@@ -25,2 +25,6 @@ exports.STREAMING_GENERATE_CONTENT_METHOD = 'streamGenerateContent'; | ||
exports.MODEL_ROLE = 'model'; | ||
const USER_AGENT_PRODUCT = 'model-builder'; | ||
const CLIENT_LIBRARY_VERSION = '0.2.0'; | ||
const CLIENT_LIBRARY_LANGUAGE = `grpc-node/${CLIENT_LIBRARY_VERSION}`; | ||
exports.USER_AGENT = `${USER_AGENT_PRODUCT}/${CLIENT_LIBRARY_VERSION} ${CLIENT_LIBRARY_LANGUAGE}`; | ||
//# sourceMappingURL=constants.js.map |
@@ -20,2 +20,3 @@ /** | ||
* Makes a POST request to a Vertex service | ||
* @ignore | ||
*/ | ||
@@ -22,0 +23,0 @@ export declare function postRequest({ region, project, resourcePath, resourceMethod, token, data, apiEndpoint, apiVersion, }: { |
@@ -21,6 +21,6 @@ "use strict"; | ||
const API_BASE_PATH = 'aiplatform.googleapis.com'; | ||
const content_1 = require("../types/content"); | ||
const constants = require("./constants"); | ||
/** | ||
* Makes a POST request to a Vertex service | ||
* @ignore | ||
*/ | ||
@@ -34,10 +34,8 @@ async function postRequest({ region, project, resourcePath, resourceMethod, token, data, apiEndpoint, apiVersion = 'v1', }) { | ||
} | ||
return await fetch(vertexEndpoint, { | ||
return fetch(vertexEndpoint, { | ||
method: 'POST', | ||
headers: { | ||
'Authorization': `Bearer ${token}`, | ||
Authorization: `Bearer ${token}`, | ||
'Content-Type': 'application/json', | ||
'User-Agent': content_1.CLIENT_INFO.user_agent, | ||
'client_library_language': content_1.CLIENT_INFO.client_library_language, | ||
'client_library_version': content_1.CLIENT_INFO.client_library_version, | ||
'User-Agent': constants.USER_AGENT, | ||
}, | ||
@@ -44,0 +42,0 @@ body: JSON.stringify(data), |
@@ -20,4 +20,6 @@ "use strict"; | ||
// @ts-ignore | ||
const vertexai_1 = require("@google-cloud/vertexai"); | ||
const PROJECT = 'cloud-llm-preview1'; // TODO: change this to infer from Kokoro env | ||
const assert = require("assert"); | ||
const src_1 = require("../src"); | ||
// TODO: this env var isn't getting populated correctly | ||
const PROJECT = process.env.GCLOUD_PROJECT; | ||
const LOCATION = 'us-central1'; | ||
@@ -32,38 +34,189 @@ const TEXT_REQUEST = { | ||
file_data: { | ||
file_uri: 'gs://generativeai-downloads/images/scones.jpg', | ||
file_uri: 'gs://nodejs_vertex_system_test_resources/scones.jpg', | ||
mime_type: 'image/jpeg', | ||
}, | ||
}; | ||
const MULTI_PART_REQUEST = { | ||
const BASE_64_IMAGE = 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=='; | ||
const INLINE_DATA_FILE_PART = { | ||
inline_data: { | ||
data: BASE_64_IMAGE, | ||
mime_type: 'image/jpeg', | ||
}, | ||
}; | ||
const MULTI_PART_GCS_REQUEST = { | ||
contents: [{ role: 'user', parts: [TEXT_PART, GCS_FILE_PART] }], | ||
}; | ||
const MULTI_PART_BASE64_REQUEST = { | ||
contents: [{ role: 'user', parts: [TEXT_PART, INLINE_DATA_FILE_PART] }], | ||
}; | ||
// Initialize Vertex with your Cloud project and location | ||
const vertex_ai = new vertexai_1.VertexAI({ project: PROJECT, location: LOCATION }); | ||
const vertex_ai = new src_1.VertexAI({ project: 'long-door-651', location: LOCATION }); | ||
const generativeTextModel = vertex_ai.preview.getGenerativeModel({ | ||
model: 'gemini-pro', | ||
generation_config: { | ||
max_output_tokens: 256, | ||
}, | ||
}); | ||
const generativeTextModelWithPrefix = vertex_ai.preview.getGenerativeModel({ | ||
model: 'models/gemini-pro', | ||
generation_config: { | ||
max_output_tokens: 256, | ||
}, | ||
}); | ||
const textModelNoOutputLimit = vertex_ai.preview.getGenerativeModel({ | ||
model: 'gemini-pro', | ||
}); | ||
const generativeVisionModel = vertex_ai.preview.getGenerativeModel({ | ||
model: 'gemini-vision-pro', | ||
model: 'gemini-pro-vision', | ||
}); | ||
async function testGenerateContentStreamText() { | ||
const streamingResp = await generativeTextModel.generateContentStream(TEXT_REQUEST); | ||
for await (const item of streamingResp.stream) { | ||
console.log('stream chunk:', item); | ||
} | ||
console.log('aggregated response: ', await streamingResp.response); | ||
} | ||
async function testGenerateContentStreamMultiPart() { | ||
const streamingResp = await generativeVisionModel.generateContentStream(MULTI_PART_REQUEST); | ||
for await (const item of streamingResp.stream) { | ||
console.log('stream chunk:', item); | ||
} | ||
console.log('aggregated response: ', await streamingResp.response); | ||
} | ||
async function testCountTokens() { | ||
const countTokensResp = await generativeVisionModel.countTokens(TEXT_REQUEST); | ||
console.log('count tokens response: ', countTokensResp); | ||
} | ||
testGenerateContentStreamText(); | ||
testGenerateContentStreamMultiPart(); | ||
testCountTokens(); | ||
const generativeVisionModelWithPrefix = vertex_ai.preview.getGenerativeModel({ | ||
model: 'models/gemini-pro-vision', | ||
}); | ||
// TODO (b/316599049): update tests to use jasmine expect syntax: | ||
// expect(...).toBeInstanceOf(...) | ||
describe('generateContentStream', () => { | ||
beforeEach(() => { | ||
jasmine.DEFAULT_TIMEOUT_INTERVAL = 20000; | ||
}); | ||
it('should should return a stream and aggregated response when passed text', async () => { | ||
const streamingResp = await generativeTextModel.generateContentStream(TEXT_REQUEST); | ||
for await (const item of streamingResp.stream) { | ||
assert(item.candidates[0], `sys test failure on generateContentStream, for item ${item}`); | ||
} | ||
const aggregatedResp = await streamingResp.response; | ||
assert(aggregatedResp.candidates[0], `sys test failure on generateContentStream for aggregated response: ${aggregatedResp}`); | ||
}); | ||
it('should not return a invalid unicode', async () => { | ||
const streamingResp = await generativeTextModel.generateContentStream({ | ||
contents: [{ role: 'user', parts: [{ text: '创作一首古诗' }] }], | ||
}); | ||
for await (const item of streamingResp.stream) { | ||
assert(item.candidates[0], `sys test failure on generateContentStream, for item ${item}`); | ||
for (const candidate of item.candidates) { | ||
for (const part of candidate.content.parts) { | ||
assert(!part.text.includes('\ufffd'), `sys test failure on generateContentStream, for item ${item}`); | ||
} | ||
} | ||
} | ||
const aggregatedResp = await streamingResp.response; | ||
assert(aggregatedResp.candidates[0], `sys test failure on generateContentStream for aggregated response: ${aggregatedResp}`); | ||
}); | ||
it('should return a stream and aggregated response when passed multipart base64 content', async () => { | ||
const streamingResp = await generativeVisionModel.generateContentStream(MULTI_PART_BASE64_REQUEST); | ||
for await (const item of streamingResp.stream) { | ||
assert(item.candidates[0], `sys test failure on generateContentStream, for item ${item}`); | ||
} | ||
const aggregatedResp = await streamingResp.response; | ||
assert(aggregatedResp.candidates[0], `sys test failure on generateContentStream for aggregated response: ${aggregatedResp}`); | ||
}); | ||
it('should throw ClientError when having invalid input', async () => { | ||
const badRequest = { | ||
contents: [ | ||
{ | ||
role: 'user', | ||
parts: [ | ||
{ text: 'describe this image:' }, | ||
{ inline_data: { mime_type: 'image/png', data: 'invalid data' } }, | ||
], | ||
}, | ||
], | ||
}; | ||
await generativeVisionModel.generateContentStream(badRequest).catch(e => { | ||
assert(e instanceof src_1.ClientError, `sys test failure on generateContentStream when having bad request should throw ClientError but actually thrown ${e}`); | ||
assert(e.message === '[VertexAI.ClientError]: got status: 400 Bad Request', `sys test failure on generateContentStream when having bad request got wrong error message: ${e.message}`); | ||
}); | ||
}); | ||
// TODO: this is returning a 500 on the system test project | ||
// it('should should return a stream and aggregated response when passed | ||
// multipart GCS content', | ||
// async () => { | ||
// const streamingResp = await | ||
// generativeVisionModel.generateContentStream( | ||
// MULTI_PART_GCS_REQUEST); | ||
// for await (const item of streamingResp.stream) { | ||
// assert(item.candidates[0]); | ||
// console.log('stream chunk: ', item); | ||
// } | ||
// const aggregatedResp = await streamingResp.response; | ||
// assert(aggregatedResp.candidates[0]); | ||
// console.log('aggregated response: ', aggregatedResp); | ||
// }); | ||
}); | ||
// TODO (b/316599049): add tests for generateContent and sendMessage | ||
describe('sendMessageStream', () => { | ||
beforeEach(() => { | ||
jasmine.DEFAULT_TIMEOUT_INTERVAL = 30000; | ||
}); | ||
it('should should return a stream and populate history when generation_config is passed to startChat', async () => { | ||
const chat = generativeTextModel.startChat({ | ||
generation_config: { | ||
max_output_tokens: 256, | ||
}, | ||
}); | ||
const chatInput1 = 'How can I learn more about Node.js?'; | ||
const result1 = await chat.sendMessageStream(chatInput1); | ||
for await (const item of result1.stream) { | ||
assert(item.candidates[0], `sys test failure on sendMessageStream, for item ${item}`); | ||
} | ||
const resp = await result1.response; | ||
assert(resp.candidates[0], `sys test failure on sendMessageStream for aggregated response: ${resp}`); | ||
expect(chat.history.length).toBe(2); | ||
}); | ||
it('should should return a stream and populate history when startChat is passed no request obj', async () => { | ||
const chat = generativeTextModel.startChat(); | ||
const chatInput1 = 'How can I learn more about Node.js?'; | ||
const result1 = await chat.sendMessageStream(chatInput1); | ||
for await (const item of result1.stream) { | ||
assert(item.candidates[0], `sys test failure on sendMessageStream, for item ${item}`); | ||
} | ||
const resp = await result1.response; | ||
assert(resp.candidates[0], `sys test failure on sendMessageStream for aggregated response: ${resp}`); | ||
expect(chat.history.length).toBe(2); | ||
}); | ||
it('should return chunks as they come in', async () => { | ||
const chat = textModelNoOutputLimit.startChat({}); | ||
const chatInput1 = 'Tell me a story in 1000 words'; | ||
const result1 = await chat.sendMessageStream(chatInput1); | ||
let firstChunkTimestamp = 0; | ||
let aggregatedResultTimestamp = 0; | ||
// To verify streaming is working correcty, we check that there is >= 2 | ||
// second difference between the first chunk and the aggregated result | ||
const streamThreshold = 2000; | ||
for await (const item of result1.stream) { | ||
if (firstChunkTimestamp === 0) { | ||
firstChunkTimestamp = Date.now(); | ||
} | ||
} | ||
await result1.response; | ||
aggregatedResultTimestamp = Date.now(); | ||
expect(aggregatedResultTimestamp - firstChunkTimestamp).toBeGreaterThan(streamThreshold); | ||
}); | ||
}); | ||
describe('countTokens', () => { | ||
it('should should return a CountTokensResponse', async () => { | ||
const countTokensResp = await generativeTextModel.countTokens(TEXT_REQUEST); | ||
assert(countTokensResp.totalTokens, `sys test failure on countTokens, ${countTokensResp}`); | ||
}); | ||
}); | ||
describe('generateContentStream using models/model-id', () => { | ||
beforeEach(() => { | ||
jasmine.DEFAULT_TIMEOUT_INTERVAL = 20000; | ||
}); | ||
it('should should return a stream and aggregated response when passed text', async () => { | ||
const streamingResp = await generativeTextModelWithPrefix.generateContentStream(TEXT_REQUEST); | ||
for await (const item of streamingResp.stream) { | ||
assert(item.candidates[0], `sys test failure on generateContentStream using models/gemini-pro, for item ${item}`); | ||
} | ||
const aggregatedResp = await streamingResp.response; | ||
assert(aggregatedResp.candidates[0], `sys test failure on generateContentStream using models/gemini-pro for aggregated response: ${aggregatedResp}`); | ||
}); | ||
it('should should return a stream and aggregated response when passed multipart base64 content when using models/gemini-pro-vision', async () => { | ||
const streamingResp = await generativeVisionModelWithPrefix.generateContentStream(MULTI_PART_BASE64_REQUEST); | ||
for await (const item of streamingResp.stream) { | ||
assert(item.candidates[0], `sys test failure on generateContentStream using models/gemini-pro-vision, for item ${item}`); | ||
} | ||
const aggregatedResp = await streamingResp.response; | ||
assert(aggregatedResp.candidates[0], `sys test failure on generateContentStream using models/gemini-pro-visionfor aggregated response: ${aggregatedResp}`); | ||
}); | ||
}); | ||
//# sourceMappingURL=end_to_end_sample_test.js.map |
# Changelog | ||
## [0.2.0](https://github.com/googleapis/nodejs-vertexai/compare/v0.1.3...v0.2.0) (2024-01-03) | ||
### Features | ||
* allow user to pass "models/model-ID" to instantiate model ([e94b285](https://github.com/googleapis/nodejs-vertexai/commit/e94b285dac6aaf0c77c6b9c6220b29b8d4aced52)) | ||
* include all supported authentication options ([257355c](https://github.com/googleapis/nodejs-vertexai/commit/257355ca09ee298623198404a4f889f5cf7788ee)) | ||
### Bug Fixes | ||
* processing of streams, including UTF ([63ce032](https://github.com/googleapis/nodejs-vertexai/commit/63ce032461a32e9e5fdf04d8ce2d4d8628d691b1)) | ||
* remove placeholder cache attribute of access token ([3ec92e7](https://github.com/googleapis/nodejs-vertexai/commit/3ec92e71a9f7ef4a55bf64037f363ec6be6a729d)) | ||
* update safety return types ([449c7a2](https://github.com/googleapis/nodejs-vertexai/commit/449c7a2af2272add956eb44d8e617878468af344)) | ||
* throw ClientError or GoogleGenerativeAIError according to response status so that users can catch them and handle them according to class name. ([ea0dcb7](https://github.com/googleapis/nodejs-vertexai/commit/ea0dcb717be8d22d98916252ccee352e9af4a09f)) | ||
## [0.1.3](https://github.com/googleapis/nodejs-vertexai/compare/v0.1.2...v0.1.3) (2023-12-13) | ||
@@ -4,0 +20,0 @@ |
{ | ||
"name": "@google-cloud/vertexai", | ||
"description": "Vertex Generative AI client for Node.js", | ||
"version": "0.1.3", | ||
"version": "0.2.0", | ||
"license": "Apache-2.0", | ||
"author": "Google LLC", | ||
"engines": { | ||
"node": ">=14.0.0" | ||
"node": ">=18.0.0" | ||
}, | ||
@@ -16,3 +16,3 @@ "homepage": "https://github.com/googleapis/nodejs-vertexai", | ||
"clean": "gts clean", | ||
"compile": "tsc", | ||
"compile": "tsc -p .", | ||
"docs": "jsdoc -c .jsdoc.js", | ||
@@ -23,3 +23,4 @@ "predocs-test": "npm run docs", | ||
"fix": "gts fix", | ||
"test": "TODO", | ||
"test": "jasmine build/test/*.js", | ||
"system-test": "jasmine build/system_test/*.js", | ||
"lint": "gts lint", | ||
@@ -42,4 +43,9 @@ "clean-js-files": "find . -type f -name \"*.js\" -exec rm -f {} +", | ||
"gts": "^5.2.0", | ||
"jasmine": "^5.1.0", | ||
"jsdoc": "^4.0.0", | ||
"jsdoc-fresh": "^3.0.0", | ||
"jsdoc-region-tag": "^3.0.0", | ||
"linkinator": "^4.0.0", | ||
"typescript": "~5.2.0" | ||
} | ||
} |
@@ -71,3 +71,3 @@ # Vertex AI Node.js SDK | ||
async function streamChat() { | ||
const chat = generativeModel.startChat({}); | ||
const chat = generativeModel.startChat(); | ||
const chatInput1 = "How can I learn more about Node.js?"; | ||
@@ -84,3 +84,3 @@ const result1 = await chat.sendMessageStream(chatInput1); | ||
## Multi-part content generation: text and image | ||
## Multi-part content generation | ||
@@ -109,4 +109,5 @@ ### Providing a Google Cloud Storage image URI | ||
async function multiPartContentImageString() { | ||
const b64imageStr = "yourbase64imagestr"; | ||
const filePart = {inline_data: {data: b64imageStr, mime_type: "image/jpeg"}}; | ||
// Replace this with your own base64 image string | ||
const base64Image = 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=='; | ||
const filePart = {inline_data: {data: base64Image, mime_type: 'image/jpeg'}}; | ||
const textPart = {text: 'What is this a picture of?'}; | ||
@@ -118,3 +119,3 @@ const request = { | ||
const contentResponse = await resp.response; | ||
console.log(contentResponse.candidates[0].content); | ||
console.log(contentResponse.candidates[0].content.parts[0].text); | ||
} | ||
@@ -125,2 +126,21 @@ | ||
### Multi-part content with text and video | ||
```typescript | ||
async function multiPartContentVideo() { | ||
const filePart = {file_data: {file_uri: 'gs://cloud-samples-data/video/animals.mp4', mime_type: 'video/mp4'}}; | ||
const textPart = {text: 'What is in the video?'}; | ||
const request = { | ||
contents: [{role: 'user', parts: [textPart, filePart]}], | ||
}; | ||
const streamingResp = await generativeVisionModel.generateContentStream(request); | ||
for await (const item of streamingResp.stream) { | ||
console.log('stream chunk: ', JSON.stringify(item)); | ||
} | ||
const aggregatedResponse = await streamingResp.response; | ||
console.log(aggregatedResponse.candidates[0].content); | ||
} | ||
multiPartContentVideo(); | ||
``` | ||
## Content generation: non-streaming | ||
@@ -127,0 +147,0 @@ |
389
src/index.ts
@@ -19,6 +19,10 @@ /** | ||
/* tslint:disable */ | ||
import {GoogleAuth} from 'google-auth-library'; | ||
import {GoogleAuth, GoogleAuthOptions} from 'google-auth-library'; | ||
import {processNonStream, processStream} from './process_stream'; | ||
import { | ||
processCountTokenResponse, | ||
processNonStream, | ||
processStream, | ||
} from './process_stream'; | ||
import { | ||
Content, | ||
@@ -36,3 +40,7 @@ CountTokensRequest, | ||
} from './types/content'; | ||
import {GoogleAuthError} from './types/errors'; | ||
import { | ||
ClientError, | ||
GoogleAuthError, | ||
GoogleGenerativeAIError, | ||
} from './types/errors'; | ||
import {constants, postRequest} from './util'; | ||
@@ -43,18 +51,22 @@ export * from './types'; | ||
* Base class for authenticating to Vertex, creates the preview namespace. | ||
* The base class object takes the following arguments: | ||
* @param project The Google Cloud project to use for the request | ||
* @param location The Google Cloud project location to use for the | ||
* request | ||
* @param apiEndpoint Optional. The base Vertex AI endpoint to use for the | ||
* request. If not provided, the default regionalized endpoint (i.e. | ||
* us-central1-aiplatform.googleapis.com) will be used. | ||
*/ | ||
export class VertexAI { | ||
public preview: VertexAI_Internal; | ||
public preview: VertexAI_Preview; | ||
/** | ||
* @constructor | ||
* @param {VertexInit} init - assign authentication related information, | ||
* including project and location string, to instantiate a Vertex AI | ||
* client. | ||
*/ | ||
constructor(init: VertexInit) { | ||
this.preview = new VertexAI_Internal( | ||
/** | ||
* preview property is used to access any SDK methods available in public | ||
* preview, currently all functionality. | ||
*/ | ||
this.preview = new VertexAI_Preview( | ||
init.project, | ||
init.location, | ||
init.apiEndpoint | ||
init.apiEndpoint, | ||
init.googleAuthOptions | ||
); | ||
@@ -66,30 +78,55 @@ } | ||
* VertexAI class internal implementation for authentication. | ||
* This class object takes the following arguments: | ||
* @param project The Google Cloud project to use for the request | ||
* @param location The Google Cloud project location to use for the request | ||
* @param apiEndpoint The base Vertex AI endpoint to use for the request. If | ||
* not provided, the default regionalized endpoint | ||
* (i.e. us-central1-aiplatform.googleapis.com) will be used. | ||
*/ | ||
export class VertexAI_Internal { | ||
protected googleAuth: GoogleAuth = new GoogleAuth({ | ||
scopes: 'https://www.googleapis.com/auth/cloud-platform', | ||
}); | ||
private tokenInternalPromise?: Promise<any>; | ||
*/ | ||
export class VertexAI_Preview { | ||
protected googleAuth: GoogleAuth; | ||
/** | ||
* @constructor | ||
* @param {string} - project The Google Cloud project to use for the request | ||
* @param {string} - location The Google Cloud project location to use for the request | ||
* @param {string} - [apiEndpoint] The base Vertex AI endpoint to use for the request. If | ||
* not provided, the default regionalized endpoint | ||
* (i.e. us-central1-aiplatform.googleapis.com) will be used. | ||
* @param {GoogleAuthOptions} - [googleAuthOptions] The Authentication options provided by google-auth-library. | ||
* Complete list of authentication options is documented in the GoogleAuthOptions interface: | ||
* https://github.com/googleapis/google-auth-library-nodejs/blob/main/src/auth/googleauth.ts | ||
*/ | ||
constructor( | ||
readonly project: string, | ||
readonly location: string, | ||
readonly apiEndpoint?: string | ||
readonly apiEndpoint?: string, | ||
readonly googleAuthOptions?: GoogleAuthOptions | ||
) { | ||
let opts: GoogleAuthOptions; | ||
if (!googleAuthOptions) { | ||
opts = { | ||
scopes: 'https://www.googleapis.com/auth/cloud-platform', | ||
}; | ||
} else { | ||
if ( | ||
googleAuthOptions.projectId && | ||
googleAuthOptions.projectId !== project | ||
) { | ||
throw new Error( | ||
`inconsistent project ID values. argument project got value ${project} but googleAuthOptions.projectId got value ${googleAuthOptions.projectId}` | ||
); | ||
} | ||
opts = googleAuthOptions; | ||
if (!opts.scopes) { | ||
opts.scopes = 'https://www.googleapis.com/auth/cloud-platform'; | ||
} | ||
} | ||
this.project = project; | ||
this.location = location; | ||
this.apiEndpoint = apiEndpoint; | ||
this.googleAuth = new GoogleAuth(opts); | ||
} | ||
/** | ||
* Get access token from GoogleAuth. Throws GoogleAuthError when fails. | ||
* @return {Promise<any>} Promise of token | ||
*/ | ||
get token(): Promise<any> { | ||
if (this.tokenInternalPromise) { | ||
return this.tokenInternalPromise; | ||
} | ||
const credential_error_message = "\nUnable to authenticate your request\ | ||
const credential_error_message = | ||
'\nUnable to authenticate your request\ | ||
\nDepending on your run time environment, you can get authentication by\ | ||
@@ -100,15 +137,18 @@ \n- if in local instance or cloud shell: `!gcloud auth login`\ | ||
\n -`auth.authenticate_user()`\ | ||
\n- if in service account or other: please follow guidance in https://cloud.google.com/docs/authentication"; | ||
const tokenPromise = this.googleAuth.getAccessToken().catch( | ||
e => { | ||
throw new GoogleAuthError(`${credential_error_message}\n${e}`); | ||
} | ||
); | ||
\n- if in service account or other: please follow guidance in https://cloud.google.com/docs/authentication'; | ||
const tokenPromise = this.googleAuth.getAccessToken().catch(e => { | ||
throw new GoogleAuthError(credential_error_message, e); | ||
}); | ||
return tokenPromise; | ||
} | ||
/** | ||
* @param {ModelParams} modelParams - {@link ModelParams} Parameters to specify the generative model. | ||
* @return {GenerativeModel} Instance of the GenerativeModel class. {@link GenerativeModel} | ||
*/ | ||
getGenerativeModel(modelParams: ModelParams): GenerativeModel { | ||
if (modelParams.generation_config) { | ||
modelParams.generation_config = | ||
validateGenerationConfig(modelParams.generation_config); | ||
modelParams.generation_config = validateGenerationConfig( | ||
modelParams.generation_config | ||
); | ||
} | ||
@@ -137,9 +177,11 @@ | ||
// src/types to avoid a circular dependency issue due the dep on | ||
// VertexAI_Internal | ||
// VertexAI_Preview | ||
/** | ||
* All params passed to initiate multiturn chat via startChat | ||
* @see VertexAI_Preview for details on _vertex_instance parameter | ||
* @see GenerativeModel for details on _model_instance parameter | ||
*/ | ||
export declare interface StartChatSessionRequest extends StartChatParams { | ||
_vertex_instance: VertexAI_Internal; | ||
_vertex_instance: VertexAI_Preview; | ||
_model_instance: GenerativeModel; | ||
@@ -149,3 +191,6 @@ } | ||
/** | ||
* Session for a multiturn chat with the model | ||
* Chat session to make multi-turn send message request. | ||
* `sendMessage` method makes async call to get response of a chat message. | ||
* `sendMessageStream` method makes async call to stream response of a chat message. | ||
* @param {StartChatSessionRequest} request - {@link StartChatSessionRequest} | ||
*/ | ||
@@ -155,8 +200,7 @@ export class ChatSession { | ||
private location: string; | ||
private _send_stream_promise: Promise<void> = Promise.resolve(); | ||
private historyInternal: Content[]; | ||
private _vertex_instance: VertexAI_Internal; | ||
private _vertex_instance: VertexAI_Preview; | ||
private _model_instance: GenerativeModel; | ||
private _send_stream_promise: Promise<void> = Promise.resolve(); | ||
generation_config?: GenerationConfig; | ||
@@ -177,2 +221,7 @@ safety_settings?: SafetySetting[]; | ||
/** | ||
* Make an sync call to send message. | ||
* @param {string | Array<string | Part>} request - send message request. {@link Part} | ||
* @return {Promise<GenerateContentResult>} Promise of {@link GenerateContentResult} | ||
*/ | ||
async sendMessage( | ||
@@ -188,6 +237,9 @@ request: string | Array<string | Part> | ||
const generateContentResult = await this._model_instance.generateContent( | ||
generateContentrequest | ||
); | ||
const generateContentResponse = await generateContentResult.response; | ||
const generateContentResult: GenerateContentResult = | ||
await this._model_instance | ||
.generateContent(generateContentrequest) | ||
.catch(e => { | ||
throw e; | ||
}); | ||
const generateContentResponse = generateContentResult.response; | ||
// Only push the latest message to history if the response returned a result | ||
@@ -209,10 +261,11 @@ if (generateContentResponse.candidates.length !== 0) { | ||
} | ||
async appendHistory( | ||
streamGenerateContentResultPromise: Promise<StreamGenerateContentResult>, | ||
newContent: Content, | ||
): Promise<void> { | ||
const streamGenerateContentResult = await streamGenerateContentResultPromise; | ||
streamGenerateContentResultPromise: Promise<StreamGenerateContentResult>, | ||
newContent: Content | ||
): Promise<void> { | ||
const streamGenerateContentResult = | ||
await streamGenerateContentResultPromise; | ||
const streamGenerateContentResponse = | ||
await streamGenerateContentResult.response; | ||
await streamGenerateContentResult.response; | ||
// Only push the latest message to history if the response returned a result | ||
@@ -233,4 +286,10 @@ if (streamGenerateContentResponse.candidates.length !== 0) { | ||
async sendMessageStream(request: string|Array<string|Part>): | ||
Promise<StreamGenerateContentResult> { | ||
/** | ||
* Make an async call to stream send message. Response will be returned in stream. | ||
* @param {string | Array<string | Part>} request - send message request. {@link Part} | ||
* @return {Promise<StreamGenerateContentResult>} Promise of {@link StreamGenerateContentResult} | ||
*/ | ||
async sendMessageStream( | ||
request: string | Array<string | Part> | ||
): Promise<StreamGenerateContentResult> { | ||
const newContent: Content = formulateNewContent(request); | ||
@@ -243,7 +302,14 @@ const generateContentrequest: GenerateContentRequest = { | ||
const streamGenerateContentResultPromise = | ||
this._model_instance.generateContentStream( | ||
generateContentrequest); | ||
const streamGenerateContentResultPromise = this._model_instance | ||
.generateContentStream(generateContentrequest) | ||
.catch(e => { | ||
throw e; | ||
}); | ||
this._send_stream_promise = this.appendHistory(streamGenerateContentResultPromise, newContent); | ||
this._send_stream_promise = this.appendHistory( | ||
streamGenerateContentResultPromise, | ||
newContent | ||
).catch(e => { | ||
throw new GoogleGenerativeAIError('exception appending chat history', e); | ||
}); | ||
return streamGenerateContentResultPromise; | ||
@@ -255,3 +321,2 @@ } | ||
* Base class for generative models. | ||
* | ||
* NOTE: this class should not be instantiated directly. Use | ||
@@ -264,7 +329,16 @@ * `vertexai.preview.getGenerativeModel()` instead. | ||
safety_settings?: SafetySetting[]; | ||
private _vertex_instance: VertexAI_Internal; | ||
private _vertex_instance: VertexAI_Preview; | ||
private _use_non_stream = false; | ||
private publisherModelEndpoint: string; | ||
/** | ||
* @constructor | ||
* @param {VertexAI_Preview} vertex_instance - {@link VertexAI_Preview} | ||
* @param {string} model - model name | ||
* @param {GenerationConfig} generation_config - Optional. {@link | ||
* GenerationConfig} | ||
* @param {SafetySetting[]} safety_settings - Optional. {@link SafetySetting} | ||
*/ | ||
constructor( | ||
vertex_instance: VertexAI_Internal, | ||
vertex_instance: VertexAI_Preview, | ||
model: string, | ||
@@ -278,6 +352,11 @@ generation_config?: GenerationConfig, | ||
this.safety_settings = safety_settings; | ||
if (model.startsWith('models/')) { | ||
this.publisherModelEndpoint = `publishers/google/${this.model}`; | ||
} else { | ||
this.publisherModelEndpoint = `publishers/google/models/${this.model}`; | ||
} | ||
} | ||
/** | ||
* Make a generateContent request. | ||
* Make a async call to generate content. | ||
* @param request A GenerateContentRequest object with the request contents. | ||
@@ -292,4 +371,5 @@ * @return The GenerateContentResponse object with the response candidates. | ||
if (request.generation_config) { | ||
request.generation_config = | ||
validateGenerationConfig(request.generation_config); | ||
request.generation_config = validateGenerationConfig( | ||
request.generation_config | ||
); | ||
} | ||
@@ -299,3 +379,5 @@ | ||
const streamGenerateContentResult: StreamGenerateContentResult = | ||
await this.generateContentStream(request); | ||
await this.generateContentStream(request).catch(e => { | ||
throw e; | ||
}); | ||
const result: GenerateContentResult = { | ||
@@ -307,4 +389,2 @@ response: await streamGenerateContentResult.response, | ||
const publisherModelEndpoint = `publishers/google/models/${this.model}`; | ||
const generateContentRequest: GenerateContentRequest = { | ||
@@ -316,23 +396,14 @@ contents: request.contents, | ||
let response; | ||
try { | ||
response = await postRequest({ | ||
region: this._vertex_instance.location, | ||
project: this._vertex_instance.project, | ||
resourcePath: publisherModelEndpoint, | ||
resourceMethod: constants.GENERATE_CONTENT_METHOD, | ||
token: await this._vertex_instance.token, | ||
data: generateContentRequest, | ||
apiEndpoint: this._vertex_instance.apiEndpoint, | ||
}); | ||
if (response === undefined) { | ||
throw new Error('did not get a valid response.'); | ||
} | ||
if (!response.ok) { | ||
throw new Error(`${response.status} ${response.statusText}`); | ||
} | ||
} catch (e) { | ||
console.log(e); | ||
} | ||
const response: Response | undefined = await postRequest({ | ||
region: this._vertex_instance.location, | ||
project: this._vertex_instance.project, | ||
resourcePath: this.publisherModelEndpoint, | ||
resourceMethod: constants.GENERATE_CONTENT_METHOD, | ||
token: await this._vertex_instance.token, | ||
data: generateContentRequest, | ||
apiEndpoint: this._vertex_instance.apiEndpoint, | ||
}).catch(e => { | ||
throw new GoogleGenerativeAIError('exception posting request', e); | ||
}); | ||
throwErrorIfNotOK(response); | ||
const result: GenerateContentResult = processNonStream(response); | ||
@@ -343,17 +414,17 @@ return Promise.resolve(result); | ||
/** | ||
* Make a generateContentStream request. | ||
* @param request A GenerateContentRequest object with the request contents. | ||
* @return The GenerateContentResponse object with the response candidates. | ||
* Make an async stream request to generate content. The response will be returned in stream. | ||
* @param {GenerateContentRequest} request - {@link GenerateContentRequest} | ||
* @return {Promise<StreamGenerateContentResult>} Promise of {@link StreamGenerateContentResult} | ||
*/ | ||
async generateContentStream(request: GenerateContentRequest): | ||
Promise<StreamGenerateContentResult> { | ||
async generateContentStream( | ||
request: GenerateContentRequest | ||
): Promise<StreamGenerateContentResult> { | ||
validateGcsInput(request.contents); | ||
if (request.generation_config) { | ||
request.generation_config = | ||
validateGenerationConfig(request.generation_config); | ||
request.generation_config = validateGenerationConfig( | ||
request.generation_config | ||
); | ||
} | ||
const publisherModelEndpoint = `publishers/google/models/${this.model}`; | ||
const generateContentRequest: GenerateContentRequest = { | ||
@@ -364,24 +435,14 @@ contents: request.contents, | ||
}; | ||
let response; | ||
try { | ||
response = await postRequest({ | ||
region: this._vertex_instance.location, | ||
project: this._vertex_instance.project, | ||
resourcePath: publisherModelEndpoint, | ||
resourceMethod: constants.STREAMING_GENERATE_CONTENT_METHOD, | ||
token: await this._vertex_instance.token, | ||
data: generateContentRequest, | ||
apiEndpoint: this._vertex_instance.apiEndpoint, | ||
}); | ||
if (response === undefined) { | ||
throw new Error('did not get a valid response.'); | ||
} | ||
if (!response.ok) { | ||
throw new Error(`${response.status} ${response.statusText}`); | ||
} | ||
} catch (e) { | ||
console.log(e); | ||
} | ||
const response = await postRequest({ | ||
region: this._vertex_instance.location, | ||
project: this._vertex_instance.project, | ||
resourcePath: this.publisherModelEndpoint, | ||
resourceMethod: constants.STREAMING_GENERATE_CONTENT_METHOD, | ||
token: await this._vertex_instance.token, | ||
data: generateContentRequest, | ||
apiEndpoint: this._vertex_instance.apiEndpoint, | ||
}).catch(e => { | ||
throw new GoogleGenerativeAIError('exception posting request', e); | ||
}); | ||
throwErrorIfNotOK(response); | ||
const streamResult = processStream(response); | ||
@@ -392,3 +453,3 @@ return Promise.resolve(streamResult); | ||
/** | ||
* Make a countTokens request. | ||
* Make a async request to count tokens. | ||
* @param request A CountTokensRequest object with the request contents. | ||
@@ -398,35 +459,26 @@ * @return The CountTokensResponse object with the token count. | ||
async countTokens(request: CountTokensRequest): Promise<CountTokensResponse> { | ||
let response; | ||
try { | ||
response = await postRequest({ | ||
region: this._vertex_instance.location, | ||
project: this._vertex_instance.project, | ||
resourcePath: `publishers/google/models/${this.model}`, | ||
resourceMethod: 'countTokens', | ||
token: await this._vertex_instance.token, | ||
data: request, | ||
apiEndpoint: this._vertex_instance.apiEndpoint, | ||
}); | ||
if (response === undefined) { | ||
throw new Error('did not get a valid response.'); | ||
} | ||
if (!response.ok) { | ||
throw new Error(`${response.status} ${response.statusText}`); | ||
} | ||
} catch (e) { | ||
console.log(e); | ||
} | ||
if (response) { | ||
const responseJson = await response.json(); | ||
return responseJson as CountTokensResponse; | ||
} else { | ||
throw new Error('did not get a valid response.'); | ||
} | ||
const response = await postRequest({ | ||
region: this._vertex_instance.location, | ||
project: this._vertex_instance.project, | ||
resourcePath: this.publisherModelEndpoint, | ||
resourceMethod: 'countTokens', | ||
token: await this._vertex_instance.token, | ||
data: request, | ||
apiEndpoint: this._vertex_instance.apiEndpoint, | ||
}).catch(e => { | ||
throw new GoogleGenerativeAIError('exception posting request', e); | ||
}); | ||
throwErrorIfNotOK(response); | ||
return processCountTokenResponse(response); | ||
} | ||
startChat(request: StartChatParams): ChatSession { | ||
const startChatRequest = { | ||
history: request.history, | ||
generation_config: request.generation_config ?? this.generation_config, | ||
safety_settings: request.safety_settings ?? this.safety_settings, | ||
/** | ||
* Instantiate a ChatSession. | ||
* This method doesn't make any call to remote endpoint. | ||
* Any call to remote endpoint is implemented in ChatSession class @see ChatSession | ||
* @param{StartChatParams} [request] - {@link StartChatParams} | ||
* @return {ChatSession} {@link ChatSession} | ||
*/ | ||
startChat(request?: StartChatParams): ChatSession { | ||
const startChatRequest: StartChatSessionRequest = { | ||
_vertex_instance: this._vertex_instance, | ||
@@ -436,2 +488,9 @@ _model_instance: this, | ||
if (request) { | ||
startChatRequest.history = request.history; | ||
startChatRequest.generation_config = | ||
request.generation_config ?? this.generation_config; | ||
startChatRequest.safety_settings = | ||
request.safety_settings ?? this.safety_settings; | ||
} | ||
return new ChatSession(startChatRequest); | ||
@@ -460,2 +519,17 @@ } | ||
function throwErrorIfNotOK(response: Response | undefined) { | ||
if (response === undefined) { | ||
throw new GoogleGenerativeAIError('response is undefined'); | ||
} | ||
const status: number = response.status; | ||
const statusText: string = response.statusText; | ||
const errorMessage = `got status: ${status} ${statusText}`; | ||
if (status >= 400 && status < 500) { | ||
throw new ClientError(errorMessage); | ||
} | ||
if (!response.ok) { | ||
throw new GoogleGenerativeAIError(errorMessage); | ||
} | ||
} | ||
function validateGcsInput(contents: Content[]) { | ||
@@ -467,3 +541,5 @@ for (const content of contents) { | ||
if (!uri.startsWith('gs://')) { | ||
throw new URIError(`Found invalid Google Cloud Storage URI ${uri}, Google Cloud Storage URIs must start with gs://`); | ||
throw new URIError( | ||
`Found invalid Google Cloud Storage URI ${uri}, Google Cloud Storage URIs must start with gs://` | ||
); | ||
} | ||
@@ -475,4 +551,5 @@ } | ||
function validateGenerationConfig(generation_config: GenerationConfig): | ||
GenerationConfig { | ||
function validateGenerationConfig( | ||
generation_config: GenerationConfig | ||
): GenerationConfig { | ||
if ('top_k' in generation_config) { | ||
@@ -479,0 +556,0 @@ if (!(generation_config.top_k! > 0) || !(generation_config.top_k! <= 40)) { |
@@ -18,14 +18,19 @@ /** | ||
import {CitationSource, GenerateContentCandidate, GenerateContentResponse, GenerateContentResult, StreamGenerateContentResult,} from './types/content'; | ||
import { | ||
CitationSource, | ||
CountTokensResponse, | ||
GenerateContentCandidate, | ||
GenerateContentResponse, | ||
GenerateContentResult, | ||
StreamGenerateContentResult, | ||
} from './types/content'; | ||
// eslint-disable-next-line no-useless-escape | ||
const responseLineRE = /^data\: (.*)\r\n/; | ||
const responseLineRE = /^data: (.*)(?:\n\n|\r\r|\r\n\r\n)/; | ||
// TODO: set a better type for `reader`. Setting it to | ||
// `ReadableStreamDefaultReader` results in an error (diagnostic code 2304) | ||
async function* generateResponseSequence( | ||
reader2: any | ||
stream: ReadableStream<GenerateContentResponse> | ||
): AsyncGenerator<GenerateContentResponse> { | ||
const reader = stream.getReader(); | ||
while (true) { | ||
const {value, done} = await reader2.read(); | ||
const {value, done} = await reader.read(); | ||
if (done) { | ||
@@ -39,51 +44,89 @@ break; | ||
/** | ||
* Reads a raw stream from the fetch response and joins incomplete | ||
* Process a response.body stream from the backend and return an | ||
* iterator that provides one complete GenerateContentResponse at a time | ||
* and a promise that resolves with a single aggregated | ||
* GenerateContentResponse. | ||
* | ||
* @param response - Response from a fetch call | ||
* @ignore | ||
*/ | ||
export function processStream( | ||
response: Response | undefined | ||
): StreamGenerateContentResult { | ||
if (response === undefined) { | ||
throw new Error('Error processing stream because response === undefined'); | ||
} | ||
if (!response.body) { | ||
throw new Error('Error processing stream because response.body not found'); | ||
} | ||
const inputStream = response.body!.pipeThrough( | ||
new TextDecoderStream('utf8', {fatal: true}) | ||
); | ||
const responseStream = | ||
getResponseStream<GenerateContentResponse>(inputStream); | ||
const [stream1, stream2] = responseStream.tee(); | ||
return { | ||
stream: generateResponseSequence(stream1), | ||
response: getResponsePromise(stream2), | ||
}; | ||
} | ||
async function getResponsePromise( | ||
stream: ReadableStream<GenerateContentResponse> | ||
): Promise<GenerateContentResponse> { | ||
const allResponses: GenerateContentResponse[] = []; | ||
const reader = stream.getReader(); | ||
// eslint-disable-next-line no-constant-condition | ||
while (true) { | ||
const {done, value} = await reader.read(); | ||
if (done) { | ||
return aggregateResponses(allResponses); | ||
} | ||
allResponses.push(value); | ||
} | ||
} | ||
/** | ||
* Reads a raw stream from the fetch response and join incomplete | ||
* chunks, returning a new stream that provides a single complete | ||
* GenerateContentResponse in each iteration. | ||
* @ignore | ||
*/ | ||
function readFromReader( | ||
reader: ReadableStreamDefaultReader | ||
): ReadableStream<GenerateContentResponse> { | ||
let currentText = ''; | ||
const stream = new ReadableStream<GenerateContentResponse>({ | ||
export function getResponseStream<T>( | ||
inputStream: ReadableStream<string> | ||
): ReadableStream<T> { | ||
const reader = inputStream.getReader(); | ||
const stream = new ReadableStream<T>({ | ||
start(controller) { | ||
let currentText = ''; | ||
return pump(); | ||
function pump(): Promise<(() => Promise<void>) | undefined> { | ||
let streamReader; | ||
try { | ||
streamReader = reader.read().then(({value, done}) => { | ||
if (done) { | ||
controller.close(); | ||
return reader.read().then(({value, done}) => { | ||
if (done) { | ||
if (currentText.trim()) { | ||
controller.error(new Error('Failed to parse stream')); | ||
return; | ||
} | ||
const chunk = new TextDecoder().decode(value); | ||
currentText += chunk; | ||
const match = currentText.match(responseLineRE); | ||
if (match) { | ||
let parsedResponse: GenerateContentResponse; | ||
try { | ||
parsedResponse = JSON.parse( | ||
match[1] | ||
) as GenerateContentResponse; | ||
} catch (e) { | ||
throw new Error(`Error parsing JSON response: "${match[1]}"`); | ||
} | ||
currentText = ''; | ||
if ('candidates' in parsedResponse) { | ||
controller.enqueue(parsedResponse); | ||
} else { | ||
console.warn( | ||
`No candidates in this response: ${parsedResponse}` | ||
); | ||
controller.enqueue({ | ||
candidates: [], | ||
}); | ||
} | ||
controller.close(); | ||
return; | ||
} | ||
currentText += value; | ||
let match = currentText.match(responseLineRE); | ||
let parsedResponse: T; | ||
while (match) { | ||
try { | ||
parsedResponse = JSON.parse(match[1]) as T; | ||
} catch (e) { | ||
controller.error( | ||
new Error(`Error parsing JSON response: "${match[1]}"`) | ||
); | ||
return; | ||
} | ||
return pump(); | ||
}); | ||
} catch (e) { | ||
throw new Error(`Error reading from stream ${e}.`); | ||
} | ||
return streamReader; | ||
controller.enqueue(parsedResponse); | ||
currentText = currentText.substring(match[0].length); | ||
match = currentText.match(responseLineRE); | ||
} | ||
return pump(); | ||
}); | ||
} | ||
@@ -98,2 +141,3 @@ }, | ||
* GenerateContentResponse. | ||
* @ignore | ||
*/ | ||
@@ -127,4 +171,5 @@ function aggregateResponses( | ||
if (response.candidates[i].citationMetadata) { | ||
if (!aggregatedResponse.candidates[i] | ||
.citationMetadata?.citationSources) { | ||
if ( | ||
!aggregatedResponse.candidates[i].citationMetadata?.citationSources | ||
) { | ||
aggregatedResponse.candidates[i].citationMetadata = { | ||
@@ -135,9 +180,9 @@ citationSources: [] as CitationSource[], | ||
const existingMetadata = response.candidates[i].citationMetadata ?? {}; | ||
let existingMetadata = response.candidates[i].citationMetadata ?? {}; | ||
if (aggregatedResponse.candidates[i].citationMetadata) { | ||
aggregatedResponse.candidates[i].citationMetadata!.citationSources = | ||
aggregatedResponse.candidates[i] | ||
.citationMetadata!.citationSources.concat(existingMetadata); | ||
aggregatedResponse.candidates[ | ||
i | ||
].citationMetadata!.citationSources.concat(existingMetadata); | ||
} | ||
@@ -165,43 +210,5 @@ } | ||
// TODO: improve error handling throughout stream processing | ||
/** | ||
* Processes model responses from streamGenerateContent | ||
*/ | ||
export function processStream( | ||
response: Response | undefined | ||
): StreamGenerateContentResult { | ||
if (response === undefined) { | ||
throw new Error('Error processing stream because response === undefined'); | ||
} | ||
if (!response.body) { | ||
throw new Error('Error processing stream because response.body not found'); | ||
} | ||
const reader = response.body.getReader(); | ||
const responseStream = readFromReader(reader); | ||
const [stream1, stream2] = responseStream.tee(); | ||
const reader1 = stream1.getReader(); | ||
const reader2 = stream2.getReader(); | ||
const allResponses: GenerateContentResponse[] = []; | ||
const responsePromise = new Promise<GenerateContentResponse>( | ||
// eslint-disable-next-line no-async-promise-executor | ||
async resolve => { | ||
// eslint-disable-next-line no-constant-condition | ||
while (true) { | ||
const {value, done} = await reader1.read(); | ||
if (done) { | ||
resolve(aggregateResponses(allResponses)); | ||
return; | ||
} | ||
allResponses.push(value); | ||
} | ||
} | ||
); | ||
return { | ||
response: responsePromise, | ||
stream: generateResponseSequence(reader2), | ||
}; | ||
} | ||
/** | ||
* Process model responses from generateContent | ||
* @ignore | ||
*/ | ||
@@ -221,1 +228,11 @@ export function processNonStream(response: any): GenerateContentResult { | ||
} | ||
/** | ||
* Process model responses from countTokens | ||
* @ignore | ||
*/ | ||
export function processCountTokenResponse(response: any): CountTokensResponse { | ||
// ts-ignore | ||
const responseJson = response.json(); | ||
return responseJson as CountTokensResponse; | ||
} |
@@ -18,4 +18,13 @@ /** | ||
// @ts-nocheck | ||
import {GoogleAuthOptions} from 'google-auth-library'; | ||
/** | ||
* Params used to initialize the Vertex SDK | ||
* @param{string} project - the project name of your Google Cloud project. It is not the numeric project ID. | ||
* @param{string} location - the location of your project. | ||
* @param{string} [apiEndpoint] - If not specified, a default value will be resolved by SDK. | ||
* @param {GoogleAuthOptions} - [googleAuthOptions] The Authentication options provided by google-auth-library. | ||
* Complete list of authentication options is documented in the GoogleAuthOptions interface: | ||
* https://github.com/googleapis/google-auth-library-nodejs/blob/main/src/auth/googleauth.ts | ||
*/ | ||
@@ -26,2 +35,3 @@ export declare interface VertexInit { | ||
apiEndpoint?: string; | ||
googleAuthOptions?: GoogleAuthOptions; | ||
} | ||
@@ -53,2 +63,4 @@ | ||
* Configuration for initializing a model, for example via getGenerativeModel | ||
* @param {string} model - model name. | ||
* @example "gemini-pro" | ||
*/ | ||
@@ -86,3 +98,5 @@ export declare interface ModelParams extends BaseModelParams { | ||
} | ||
/** | ||
* Harm categories that would cause prompts or candidates to be blocked. | ||
*/ | ||
export enum HarmCategory { | ||
@@ -96,2 +110,5 @@ HARM_CATEGORY_UNSPECIFIED = 'HARM_CATEGORY_UNSPECIFIED', | ||
/** | ||
* Threshold above which a prompt or candidate will be blocked. | ||
*/ | ||
export enum HarmBlockThreshold { | ||
@@ -111,2 +128,18 @@ // Unspecified harm block threshold. | ||
/** | ||
* Probability that a prompt or candidate matches a harm category. | ||
*/ | ||
export enum HarmProbability { | ||
// Probability is unspecified. | ||
HARM_PROBABILITY_UNSPECIFIED = 'HARM_PROBABILITY_UNSPECIFIED', | ||
// Content has a negligible chance of being unsafe. | ||
NEGLIGIBLE = 'NEGLIGIBLE', | ||
// Content has a low chance of being unsafe. | ||
LOW = 'LOW', | ||
// Content has a medium chance of being unsafe. | ||
MEDIUM = 'MEDIUM', | ||
// Content has a high chance of being unsafe. | ||
HIGH = 'HIGH', | ||
} | ||
/** | ||
* Safety rating for a piece of content | ||
@@ -116,3 +149,3 @@ */ | ||
category: HarmCategory; | ||
threshold: HarmBlockThreshold; | ||
probability: HarmProbability; | ||
} | ||
@@ -216,2 +249,3 @@ | ||
* Wrapper for respones from a generateContent request | ||
* @see GenerateContentResponse | ||
*/ | ||
@@ -225,2 +259,3 @@ export declare interface GenerateContentResult { | ||
* Wrapper for respones from a streamGenerateContent request | ||
* @see GenerateContentResponse | ||
*/ | ||
@@ -271,18 +306,1 @@ export declare interface StreamGenerateContentResult { | ||
} | ||
const USER_AGENT_PRODUCT = 'model-builder'; | ||
const CLIENT_LIBRARY_LANGUAGE = 'node-js'; | ||
// TODO: update this version number using release-please | ||
const CLIENT_LIBRARY_VERSION = '0.1.0'; | ||
const USER_AGENT = USER_AGENT_PRODUCT + '/' + CLIENT_LIBRARY_VERSION; | ||
const CLIENT_INFO = { | ||
user_agent: USER_AGENT, | ||
client_library_language: CLIENT_LIBRARY_LANGUAGE, | ||
client_library_version: CLIENT_LIBRARY_VERSION, | ||
}; | ||
export {CLIENT_INFO}; |
@@ -18,9 +18,50 @@ /** | ||
/** | ||
* GoogleAuthError is thrown when there is authentication issue with the request | ||
*/ | ||
class GoogleAuthError extends Error { | ||
constructor(message: string) { | ||
super(message); | ||
this.name = 'GoogleAuthError'; | ||
} | ||
public readonly stack_trace: any = undefined; | ||
constructor(message: string, stack_trace: any = undefined) { | ||
super(message); | ||
this.message = constructErrorMessage('GoogleAuthError', message); | ||
this.name = 'GoogleAuthError'; | ||
this.stack_trace = stack_trace; | ||
} | ||
} | ||
export {GoogleAuthError}; | ||
/** | ||
* ClientError is thrown when http 4XX status is received. | ||
* For details please refer to https://developer.mozilla.org/en-US/docs/Web/HTTP/Status#client_error_responses | ||
*/ | ||
class ClientError extends Error { | ||
public readonly stack_trace: any = undefined; | ||
constructor(message: string, stack_trace: any = undefined) { | ||
super(message); | ||
this.message = constructErrorMessage('ClientError', message); | ||
this.name = 'ClientError'; | ||
this.stack_trace = stack_trace; | ||
} | ||
} | ||
/** | ||
* GoogleGenerativeAIError is thrown when http response is not ok and status code is not 4XX | ||
* For details please refer to https://developer.mozilla.org/en-US/docs/Web/HTTP/Status | ||
*/ | ||
class GoogleGenerativeAIError extends Error { | ||
public readonly stack_trace: any = undefined; | ||
constructor(message: string, stack_trace: any = undefined) { | ||
super(message); | ||
this.message = constructErrorMessage('GoogleGenerativeAIError', message); | ||
this.name = 'GoogleGenerativeAIError'; | ||
this.stack_trace = stack_trace; | ||
} | ||
} | ||
function constructErrorMessage( | ||
exceptionClass: string, | ||
message: string | ||
): string { | ||
return `[VertexAI.${exceptionClass}]: ${message}`; | ||
} | ||
export {ClientError, GoogleAuthError, GoogleGenerativeAIError}; |
@@ -17,3 +17,2 @@ /** | ||
*/ | ||
export const GENERATE_CONTENT_METHOD = 'generateContent'; | ||
@@ -23,1 +22,5 @@ export const STREAMING_GENERATE_CONTENT_METHOD = 'streamGenerateContent'; | ||
export const MODEL_ROLE = 'model'; | ||
const USER_AGENT_PRODUCT = 'model-builder'; | ||
const CLIENT_LIBRARY_VERSION = '0.2.0'; | ||
const CLIENT_LIBRARY_LANGUAGE = `grpc-node/${CLIENT_LIBRARY_VERSION}`; | ||
export const USER_AGENT = `${USER_AGENT_PRODUCT}/${CLIENT_LIBRARY_VERSION} ${CLIENT_LIBRARY_LANGUAGE}`; |
@@ -20,7 +20,3 @@ /** | ||
import { | ||
GenerateContentRequest, | ||
CLIENT_INFO, | ||
CountTokensRequest, | ||
} from '../types/content'; | ||
import {GenerateContentRequest, CountTokensRequest} from '../types/content'; | ||
import * as constants from './constants'; | ||
@@ -30,2 +26,3 @@ | ||
* Makes a POST request to a Vertex service | ||
* @ignore | ||
*/ | ||
@@ -42,3 +39,6 @@ export async function postRequest({ | ||
}: { | ||
region: string; project: string; resourcePath: string; resourceMethod: string; | ||
region: string; | ||
project: string; | ||
resourcePath: string; | ||
resourceMethod: string; | ||
token: string; | ||
@@ -48,3 +48,3 @@ data: GenerateContentRequest | CountTokensRequest; | ||
apiVersion?: string; | ||
}): Promise<Response|undefined> { | ||
}): Promise<Response | undefined> { | ||
const vertexBaseEndpoint = apiEndpoint ?? `${region}-${API_BASE_PATH}`; | ||
@@ -59,10 +59,8 @@ | ||
return await fetch(vertexEndpoint, { | ||
return fetch(vertexEndpoint, { | ||
method: 'POST', | ||
headers: { | ||
'Authorization': `Bearer ${token}`, | ||
Authorization: `Bearer ${token}`, | ||
'Content-Type': 'application/json', | ||
'User-Agent': CLIENT_INFO.user_agent, | ||
'client_library_language': CLIENT_INFO.client_library_language, | ||
'client_library_version': CLIENT_INFO.client_library_version, | ||
'User-Agent': constants.USER_AGENT, | ||
}, | ||
@@ -69,0 +67,0 @@ body: JSON.stringify(data), |
@@ -19,5 +19,8 @@ /** | ||
// @ts-ignore | ||
import {VertexAI} from '@google-cloud/vertexai'; | ||
import * as assert from 'assert'; | ||
const PROJECT = 'cloud-llm-preview1'; // TODO: change this to infer from Kokoro env | ||
import {ClientError, VertexAI, TextPart} from '../src'; | ||
// TODO: this env var isn't getting populated correctly | ||
const PROJECT = process.env.GCLOUD_PROJECT; | ||
const LOCATION = 'us-central1'; | ||
@@ -31,52 +34,281 @@ const TEXT_REQUEST = { | ||
}; | ||
const GCS_FILE_PART = { | ||
file_data: { | ||
file_uri: 'gs://generativeai-downloads/images/scones.jpg', | ||
file_uri: 'gs://nodejs_vertex_system_test_resources/scones.jpg', | ||
mime_type: 'image/jpeg', | ||
}, | ||
}; | ||
const MULTI_PART_REQUEST = { | ||
const BASE_64_IMAGE = | ||
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=='; | ||
const INLINE_DATA_FILE_PART = { | ||
inline_data: { | ||
data: BASE_64_IMAGE, | ||
mime_type: 'image/jpeg', | ||
}, | ||
}; | ||
const MULTI_PART_GCS_REQUEST = { | ||
contents: [{role: 'user', parts: [TEXT_PART, GCS_FILE_PART]}], | ||
}; | ||
const MULTI_PART_BASE64_REQUEST = { | ||
contents: [{role: 'user', parts: [TEXT_PART, INLINE_DATA_FILE_PART]}], | ||
}; | ||
// Initialize Vertex with your Cloud project and location | ||
const vertex_ai = new VertexAI({project: PROJECT, location: LOCATION}); | ||
const vertex_ai = new VertexAI({project: 'long-door-651', location: LOCATION}); | ||
const generativeTextModel = vertex_ai.preview.getGenerativeModel({ | ||
model: 'gemini-pro', | ||
generation_config: { | ||
max_output_tokens: 256, | ||
}, | ||
}); | ||
const generativeTextModelWithPrefix = vertex_ai.preview.getGenerativeModel({ | ||
model: 'models/gemini-pro', | ||
generation_config: { | ||
max_output_tokens: 256, | ||
}, | ||
}); | ||
const textModelNoOutputLimit = vertex_ai.preview.getGenerativeModel({ | ||
model: 'gemini-pro', | ||
}); | ||
const generativeVisionModel = vertex_ai.preview.getGenerativeModel({ | ||
model: 'gemini-vision-pro', | ||
model: 'gemini-pro-vision', | ||
}); | ||
const generativeVisionModelWithPrefix = vertex_ai.preview.getGenerativeModel({ | ||
model: 'models/gemini-pro-vision', | ||
}); | ||
async function testGenerateContentStreamText() { | ||
const streamingResp = | ||
// TODO (b/316599049): update tests to use jasmine expect syntax: | ||
// expect(...).toBeInstanceOf(...) | ||
describe('generateContentStream', () => { | ||
beforeEach(() => { | ||
jasmine.DEFAULT_TIMEOUT_INTERVAL = 20000; | ||
}); | ||
it('should should return a stream and aggregated response when passed text', async () => { | ||
const streamingResp = | ||
await generativeTextModel.generateContentStream(TEXT_REQUEST); | ||
for await (const item of streamingResp.stream) { | ||
console.log('stream chunk:', item); | ||
} | ||
for await (const item of streamingResp.stream) { | ||
assert( | ||
item.candidates[0], | ||
`sys test failure on generateContentStream, for item ${item}` | ||
); | ||
} | ||
console.log('aggregated response: ', await streamingResp.response); | ||
} | ||
const aggregatedResp = await streamingResp.response; | ||
assert( | ||
aggregatedResp.candidates[0], | ||
`sys test failure on generateContentStream for aggregated response: ${aggregatedResp}` | ||
); | ||
}); | ||
it('should not return a invalid unicode', async () => { | ||
const streamingResp = await generativeTextModel.generateContentStream({ | ||
contents: [{role: 'user', parts: [{text: '创作一首古诗'}]}], | ||
}); | ||
async function testGenerateContentStreamMultiPart() { | ||
const streamingResp = | ||
await generativeVisionModel.generateContentStream(MULTI_PART_REQUEST); | ||
for await (const item of streamingResp.stream) { | ||
assert( | ||
item.candidates[0], | ||
`sys test failure on generateContentStream, for item ${item}` | ||
); | ||
for (const candidate of item.candidates) { | ||
for (const part of candidate.content.parts as TextPart[]) { | ||
assert( | ||
!part.text.includes('\ufffd'), | ||
`sys test failure on generateContentStream, for item ${item}` | ||
); | ||
} | ||
} | ||
} | ||
for await (const item of streamingResp.stream) { | ||
console.log('stream chunk:', item); | ||
} | ||
const aggregatedResp = await streamingResp.response; | ||
assert( | ||
aggregatedResp.candidates[0], | ||
`sys test failure on generateContentStream for aggregated response: ${aggregatedResp}` | ||
); | ||
}); | ||
it('should return a stream and aggregated response when passed multipart base64 content', async () => { | ||
const streamingResp = await generativeVisionModel.generateContentStream( | ||
MULTI_PART_BASE64_REQUEST | ||
); | ||
console.log('aggregated response: ', await streamingResp.response); | ||
} | ||
for await (const item of streamingResp.stream) { | ||
assert( | ||
item.candidates[0], | ||
`sys test failure on generateContentStream, for item ${item}` | ||
); | ||
} | ||
async function testCountTokens() { | ||
const countTokensResp = await generativeVisionModel.countTokens(TEXT_REQUEST); | ||
console.log('count tokens response: ', countTokensResp); | ||
} | ||
const aggregatedResp = await streamingResp.response; | ||
assert( | ||
aggregatedResp.candidates[0], | ||
`sys test failure on generateContentStream for aggregated response: ${aggregatedResp}` | ||
); | ||
}); | ||
it('should throw ClientError when having invalid input', async () => { | ||
const badRequest = { | ||
contents: [ | ||
{ | ||
role: 'user', | ||
parts: [ | ||
{text: 'describe this image:'}, | ||
{inline_data: {mime_type: 'image/png', data: 'invalid data'}}, | ||
], | ||
}, | ||
], | ||
}; | ||
await generativeVisionModel.generateContentStream(badRequest).catch(e => { | ||
assert( | ||
e instanceof ClientError, | ||
`sys test failure on generateContentStream when having bad request should throw ClientError but actually thrown ${e}` | ||
); | ||
assert( | ||
e.message === '[VertexAI.ClientError]: got status: 400 Bad Request', | ||
`sys test failure on generateContentStream when having bad request got wrong error message: ${e.message}` | ||
); | ||
}); | ||
}); | ||
// TODO: this is returning a 500 on the system test project | ||
// it('should should return a stream and aggregated response when passed | ||
// multipart GCS content', | ||
// async () => { | ||
// const streamingResp = await | ||
// generativeVisionModel.generateContentStream( | ||
// MULTI_PART_GCS_REQUEST); | ||
testGenerateContentStreamText(); | ||
testGenerateContentStreamMultiPart(); | ||
testCountTokens(); | ||
// for await (const item of streamingResp.stream) { | ||
// assert(item.candidates[0]); | ||
// console.log('stream chunk: ', item); | ||
// } | ||
// const aggregatedResp = await streamingResp.response; | ||
// assert(aggregatedResp.candidates[0]); | ||
// console.log('aggregated response: ', aggregatedResp); | ||
// }); | ||
}); | ||
// TODO (b/316599049): add tests for generateContent and sendMessage | ||
describe('sendMessageStream', () => { | ||
beforeEach(() => { | ||
jasmine.DEFAULT_TIMEOUT_INTERVAL = 30000; | ||
}); | ||
it('should should return a stream and populate history when generation_config is passed to startChat', async () => { | ||
const chat = generativeTextModel.startChat({ | ||
generation_config: { | ||
max_output_tokens: 256, | ||
}, | ||
}); | ||
const chatInput1 = 'How can I learn more about Node.js?'; | ||
const result1 = await chat.sendMessageStream(chatInput1); | ||
for await (const item of result1.stream) { | ||
assert( | ||
item.candidates[0], | ||
`sys test failure on sendMessageStream, for item ${item}` | ||
); | ||
} | ||
const resp = await result1.response; | ||
assert( | ||
resp.candidates[0], | ||
`sys test failure on sendMessageStream for aggregated response: ${resp}` | ||
); | ||
expect(chat.history.length).toBe(2); | ||
}); | ||
it('should should return a stream and populate history when startChat is passed no request obj', async () => { | ||
const chat = generativeTextModel.startChat(); | ||
const chatInput1 = 'How can I learn more about Node.js?'; | ||
const result1 = await chat.sendMessageStream(chatInput1); | ||
for await (const item of result1.stream) { | ||
assert( | ||
item.candidates[0], | ||
`sys test failure on sendMessageStream, for item ${item}` | ||
); | ||
} | ||
const resp = await result1.response; | ||
assert( | ||
resp.candidates[0], | ||
`sys test failure on sendMessageStream for aggregated response: ${resp}` | ||
); | ||
expect(chat.history.length).toBe(2); | ||
}); | ||
it('should return chunks as they come in', async () => { | ||
const chat = textModelNoOutputLimit.startChat({}); | ||
const chatInput1 = 'Tell me a story in 1000 words'; | ||
const result1 = await chat.sendMessageStream(chatInput1); | ||
let firstChunkTimestamp = 0; | ||
let aggregatedResultTimestamp = 0; | ||
// To verify streaming is working correcty, we check that there is >= 2 | ||
// second difference between the first chunk and the aggregated result | ||
const streamThreshold = 2000; | ||
for await (const item of result1.stream) { | ||
if (firstChunkTimestamp === 0) { | ||
firstChunkTimestamp = Date.now(); | ||
} | ||
} | ||
await result1.response; | ||
aggregatedResultTimestamp = Date.now(); | ||
expect(aggregatedResultTimestamp - firstChunkTimestamp).toBeGreaterThan( | ||
streamThreshold | ||
); | ||
}); | ||
}); | ||
describe('countTokens', () => { | ||
it('should should return a CountTokensResponse', async () => { | ||
const countTokensResp = await generativeTextModel.countTokens(TEXT_REQUEST); | ||
assert( | ||
countTokensResp.totalTokens, | ||
`sys test failure on countTokens, ${countTokensResp}` | ||
); | ||
}); | ||
}); | ||
describe('generateContentStream using models/model-id', () => { | ||
beforeEach(() => { | ||
jasmine.DEFAULT_TIMEOUT_INTERVAL = 20000; | ||
}); | ||
it('should should return a stream and aggregated response when passed text', async () => { | ||
const streamingResp = | ||
await generativeTextModelWithPrefix.generateContentStream(TEXT_REQUEST); | ||
for await (const item of streamingResp.stream) { | ||
assert( | ||
item.candidates[0], | ||
`sys test failure on generateContentStream using models/gemini-pro, for item ${item}` | ||
); | ||
} | ||
const aggregatedResp = await streamingResp.response; | ||
assert( | ||
aggregatedResp.candidates[0], | ||
`sys test failure on generateContentStream using models/gemini-pro for aggregated response: ${aggregatedResp}` | ||
); | ||
}); | ||
it('should should return a stream and aggregated response when passed multipart base64 content when using models/gemini-pro-vision', async () => { | ||
const streamingResp = | ||
await generativeVisionModelWithPrefix.generateContentStream( | ||
MULTI_PART_BASE64_REQUEST | ||
); | ||
for await (const item of streamingResp.stream) { | ||
assert( | ||
item.candidates[0], | ||
`sys test failure on generateContentStream using models/gemini-pro-vision, for item ${item}` | ||
); | ||
} | ||
const aggregatedResp = await streamingResp.response; | ||
assert( | ||
aggregatedResp.candidates[0], | ||
`sys test failure on generateContentStream using models/gemini-pro-visionfor aggregated response: ${aggregatedResp}` | ||
); | ||
}); | ||
}); |
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
License Policy Violation
LicenseThis package is not allowed per your license policy. Review the package's license to ensure compliance.
Found 1 instance in 1 package
Major refactor
Supply chain riskPackage has recently undergone a major refactor. It may be unstable or indicate significant internal changes. Use caution when updating to versions that include significant changes.
Found 1 instance in 1 package
Environment variable access
Supply chain riskPackage accesses environment variables, which may be a sign of credential stuffing or data theft.
Found 1 instance in 1 package
License Policy Violation
LicenseThis package is not allowed per your license policy. Review the package's license to ensure compliance.
Found 1 instance in 1 package
293277
76
4785
173
9
2
2